前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >API调用ComfyUI模板高效文生图

API调用ComfyUI模板高效文生图

原创
作者头像
平常心
修改2024-08-10 15:27:11
3300
修改2024-08-10 15:27:11
举报
文章被收录于专栏:个人总结系列

一、基础环境

环境搭建参考ComfyUI搭建文生图,并开启ComfyUI的Dev Mode。

ComfyUI API

二、本地化运行脚本编写

代码语言:javascript
复制
# -- utf-8 ---
# https://www.bilibili.com/read/cv33202530/
# https://www.wehelpwin.com/article/5317
import json
import websocket
import uuid
import urllib.request
import urllib.parse
import random


# 显示图片
def show_gif(fname):
    import base64
    from IPython import display
    with open(fname, 'rb') as fd:
        b64 = base64.b64encode(fd.read()).decode('ascii')
    return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')

# 向服务器队列发送提示词
def queue_prompt(textPrompt):
    p = {"prompt": textPrompt, "client_id": client_id}
    data = json.dumps(p).encode('utf-8')
    req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
    return json.loads(urllib.request.urlopen(req).read())  
    
# 获取生成图片
def get_image(fileName, subFolder, folder_type):
    data = {"filename": fileName, "subfolder": subFolder, "type": folder_type}
    url_values = urllib.parse.urlencode(data)
    with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
        return response.read()

# 获取历史记录
def get_history(prompt_id):
    with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
        return json.loads(response.read())

# 获取图片,监听WebSocket消息
def get_images(ws, prompt):
    prompt_id = queue_prompt(prompt)['prompt_id']
    print('prompt: {}'.format(prompt))
    print('prompt_id:{}'.format(prompt_id))
    output_images = {}
    while True:
        out = ws.recv()
        if isinstance(out, str):
            message = json.loads(out)
            if message['type'] == 'executing': 
                data = message['data']
                if data['node'] is None and data['prompt_id'] == prompt_id:
                    print('执行完成')
                    break
        else:
            continue
    history = get_history(prompt_id)[prompt_id]
    print(history)   
    for o in history['outputs']:
        for node_id in history['outputs']:
            node_output = history['outputs'][node_id]
            # 图片分支
            if 'images' in node_output:
                images_output = []
                for image in node_output['images']:
                    image_data = get_image(image['filename'], image['subfolder'], image['type'])
                    images_output.append(image_data)
                    output_images[node_id] = images_output
            # 视频分支
            if 'videos' in node_output:
                videos_output = []
                for video in node_output['videos']:
                    video_data = get_image(video['filename'], video['subfolder'], video['type'])
                    videos_output.append(video_data)
                    output_images[node_id] = videos_output
    print('获取图片完成:{}'.format(output_images))
    return output_images

# 解析comfyUI 工作流并获取图片
def parse_worflow(ws, prompt, seed, workflowfile):
    workflowfile = workflowfile
    print('workflowfile:{}'.format(workflowfile))
    with open(workflowfile, 'r', encoding="utf-8") as workflow_api_txt2gif_file:
        prompt_data = json.load(workflow_api_txt2gif_file)
        # 设置文本提示
        prompt_data["6"]["inputs"]["text"] = prompt
        return get_images(ws, prompt_data)

# 生成图像并显示
def generate_clip(prompt, seed, workflowfile, idx):
    print('seed:'+str(seed))
    ws = websocket.WebSocket()
    ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
    images = parse_worflow(ws, prompt, seed, workflowfile)
    for node_id in images:
        for image_data in images[node_id]:
            from datetime import datetime
            # 获取当前时间,并格式化为 YYYYMMDDHHMMSS 的格式
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            # 使用格式化的时间戳在文件名中
            GIF_LOCATION = "{}/{}_{}_{}.png".format('/mnt/d/aigc_result', idx, seed, timestamp)
            print('GIF_LOCATION:'+GIF_LOCATION)
            with open(GIF_LOCATION, "wb") as binary_file:
                # 写入二进制文件
                binary_file.write(image_data)
                show_gif(GIF_LOCATION)
                print("{} DONE!!!".format(GIF_LOCATION))
                

if __name__ == "__main__":
    # 设置工作目录和项目相关的路径
    WORKING_DIR = 'output'
    SageMaker_ComfyUI = WORKING_DIR
    workflowfile = '/mnt/d/code/aigc/workflow_api.json'
    COMFYUI_ENDPOINT = 'localhost:8188'
    server_address = COMFYUI_ENDPOINT
    client_id = str(uuid.uuid4())
    seed = 15465856
    prompt = 'Leopards hunt on the grassland'
    generate_clip(prompt, seed, workflowfile, 1)

三、生产部署

代码语言:python
代码运行次数:0
复制
1、comfyui源代码不变
2、新创建一个类似 mian_v2.py 采用flask 或 fast api方式变现代码(参考server.py内容 ),并引用comfyui的模块的方法,如:

# -*- coding: utf-8 -*-
"""
To start this server, run

$ waitress-serve --port=1230 --call agc_server:create_app

or specify the init sd model:

$ COMMANDLINE_ARGS="--ckpt models/Stable-diffusion/novelaifinal-pruned.ckpt" waitress-serve --port=1230 --call agc_server:create_app
"""
import datetime
import logging
import os
from flask import Flask, request,jsonify
from flask_cors import CORS
import json
# from sdutils import parse_image_data, encode_pil_to_base64
import cv2 as cv
from PIL import Image
import shutil 
import numpy as np
from process_base import *


app = Flask(__name__)
CORS(app, supports_credentials=True)

# print(f"aisticker_server: {datetime.datetime.now()}")
# logging.basicConfig(
#     filename=None,
#     level=logging.INFO,
#     format='%(asctime)s.%(msecs)03d:%(levelname)s:%(message)s',
#     datefmt = '%m/%d/%Y %H:%M:%S')
# logging.info('aisticker_server level: info')
# logging.debug('aisticker_server level: debug')


def update_params_lora(params, loras):
    if len(loras) < 1:
        return params

    start_index = 50000
    start_node, _ = get_prompt_item(params["prompt"], "CheckpointLoaderSimple")

    for i in range(len(loras)):
        lora = loras[i]
        lora["model"] = [start_node, 0]
        lora["clip"] = [start_node, 1]
        item = {"inputs": lora, "class_type": "LoraLoader"}

        params["prompt"][str(start_index)] = item

        start_node = str(start_index)
        start_index += 1
    
    next_node, _ = get_prompt_items(params["prompt"], "CLIPTextEncode")
    for key in next_node:
        params["prompt"][key]["inputs"]["clip"][0] = start_node

    next_node, _ = get_prompt_items(params["prompt"], "KSampler")
    for key in next_node:
        params["prompt"][key]["inputs"]["model"][0] = start_node

    return params


class AISticker:
    def __init__(self):
        self.server, self.q = startup()
        self.e = execution.PromptExecutor(self.server)

    def forward(self, params):
        params = check_prompt_seed(params)
        demo_prompt_process(self.server, params)
        result = demo_prompt_worker(self.q, self.server, self.e)

        _, result_image = get_prompt_item_with_title(params["prompt"], "output_images")
        result_key = result_image["inputs"]["images"][0]
        _, result_rm_bg = get_prompt_item_with_title(params["prompt"], "remove_bg_images")
        remove_bg_key = result_rm_bg["inputs"]["images"][0]

        return result[result_key][0][0], result[remove_bg_key][0][0]


# def base64_to_str(img_list):
#     for i in range(len(img_list)):
#         img_list[i] = str(img_list[i], "utf-8")
#     return img_list


sticker = AISticker()


@app.route('/sdapi/v1/aisticker', methods=['POST'])
def algo_aisticker():
    params = request.json
    print(params)
    out_params = {"errno" : 0, "outputs": ""}
    
    # images = params.get('images', "")
    json_params = params.get('params', [])
    
    prompt_params = {"prompt": json_params[0]}
    lora_params = json_params[1]
    logging.info(params)
    params = update_params_lora(prompt_params, lora_params["sd_loras"])
    logging.info(params)
    outputs, remove_bg_outputs = sticker.forward(params)
    logging.info(outputs,remove_bg_outputs)
    # save_image(outputs)
    reselt = save_image(remove_bg_outputs)
    out_params["outputs"] = reselt
    return jsonify(out_params), 200 
    # to_base64(output)


def _parse_command_line():
    from argparse import ArgumentParser, RawDescriptionHelpFormatter

    parser = ArgumentParser(epilog="""
测试novelai
==============
""",
                            formatter_class=RawDescriptionHelpFormatter)
    parser.add_argument("-p", "--port", default=1230, type=int, help="Specify the port")
    parser.add_argument("--no-half-vae", dest="no_half_vae", action="store_true", default=False)
    parser.add_argument("--ckpt", type=str)
    parser.add_argument('--xformers', dest='xformers', action='store_true', default=False)
    parser.add_argument('--lora-dir', type=str)
    parser.add_argument('--package-version', type=int, default=0, help="{0, 1, 2} 0 for official server, 1 for debug server, 2 for webui")
    parser.add_argument("--output-dir", type=str, default=None)
    parser.add_argument("--raise-all", dest="raise_all", action='store_true', default=False)
    return parser.parse_args()


def create_app():
    return app


if __name__ == '__main__':
    args = _parse_command_line()
    print(args)
    app.run(port=args.port)

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、基础环境
  • 三、生产部署
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档