|
import base64 |
|
import io |
|
import json |
|
import os |
|
import os.path |
|
import re |
|
import time |
|
from abc import ABC |
|
from typing import Any |
|
from uuid import uuid4 |
|
|
|
import gradio as gr |
|
import requests |
|
from PIL import Image |
|
from langchain.agents import initialize_agent |
|
from langchain.chat_models import AzureChatOpenAI |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from langchain.tools import BaseTool |
|
|
|
SAVE_FOLDER = "./img" |
|
SDXL_API_KEY = "XXX" |
|
SDXL_API_SECRET = "XXXX" |
|
AZURE_END_POINT = "https://aimodelgpt.openai.azure.com" |
|
AZURE_OPEN_KEY = "XXXX" |
|
|
|
|
|
class SdxlImage(BaseTool, ABC): |
|
name = "AI SDXL Image Generator" |
|
|
|
description = 'use this tool when you need to generate images by using SDXL model, To use the tool, you must ' \ |
|
'provide prompt parameters prompt, prompt is the description and number of the image, for example, ' \ |
|
'if you want to generate two images about a cute cat, set prompt = a cute cat[SEP]2' |
|
|
|
NEGATIVE_PROMPT = "worst quality, low quality, normal quality, lowres, watermark, monochrome, grayscale, ugly, " \ |
|
"blurry, Tan skin, dark skin, black skin, skin spots, skin blemishes, age spot, glans, " \ |
|
"disabled, distorted, bad anatomy, morbid, malformation, amputation, bad proportions, twins, " \ |
|
"missing body, fused body, extra head, poorly drawn face, bad eyes, deformed eye, unclear eyes, " \ |
|
"cross-eyed, long neck, malformed limbs, extra limbs, extra arms, missing arms, bad tongue, " \ |
|
"strange fingers, mutated hands, missing hands, poorly drawn hands, extra hands, fused hands, " \ |
|
"connected hand, bad hands, wrong fingers, missing fingers, extra fingers, 4 fingers, " \ |
|
"3 fingers, deformed hands, extra legs, bad legs, many legs, more than two legs, bad feet, " \ |
|
"wrong feet, extra feets," |
|
|
|
api_key: str |
|
api_secret: str |
|
|
|
|
|
|
|
|
|
|
|
def _run( |
|
self, |
|
prompt, |
|
**kwargs: Any, |
|
) -> Any: |
|
print(f"execute SDXL Image Tool {prompt}") |
|
split_items = prompt.split("[SEP]") |
|
number = 1 |
|
if len(split_items) > 1: |
|
prompt, number = split_items |
|
return self.generate_image(query=prompt, number=int(number)) |
|
|
|
def get_access_token(self): |
|
""" |
|
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key |
|
""" |
|
url = f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={self.api_key}&client_secret={self.api_secret}' |
|
|
|
payload = json.dumps("") |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json' |
|
} |
|
|
|
response = requests.request("POST", url, headers=headers, data=payload) |
|
return response.json().get("access_token") |
|
|
|
def save_image(self, base64_string): |
|
file_path = _id = str(uuid4()) + ".png" |
|
image_data = base64.b64decode(base64_string) |
|
image = Image.open(io.BytesIO(image_data)) |
|
if not os.path.exists(SAVE_FOLDER): |
|
os.mkdir(SAVE_FOLDER) |
|
image.save(os.path.join(SAVE_FOLDER, file_path)) |
|
return file_path |
|
|
|
def generate_image(self, query: str, number: int = 1): |
|
token = self.get_access_token() |
|
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/text2image/sd_xl?access_token=" + token |
|
|
|
payload = json.dumps({ |
|
"prompt": query, |
|
"negative_prompt": self.NEGATIVE_PROMPT, |
|
"size": "768x1024", |
|
"steps": 25, |
|
"n": number, |
|
"sampler_index": "DPM++ SDE Karras" |
|
}) |
|
headers = { |
|
'Content-Type': 'application/json' |
|
} |
|
response = requests.request("POST", url, headers=headers, data=payload) |
|
|
|
try: |
|
if response and response.text: |
|
data = json.loads(response.text)['data'] |
|
if data: |
|
filenames = ",".join([self.save_image(sub_data['b64_image']) for sub_data in data]) |
|
return f"generate total {number} of the {query}, output is all the files {filenames}" |
|
except Exception as err: |
|
print(err) |
|
|
|
return "failed to call tool, got error message" |
|
|
|
|
|
class AgentBot: |
|
def __init__(self): |
|
chat_llm = AzureChatOpenAI( |
|
azure_endpoint=AZURE_END_POINT, |
|
openai_api_key=AZURE_OPEN_KEY, |
|
deployment_name="gpt-35-turbo", |
|
openai_api_version="2023-10-01-preview", |
|
temperature=0.0 |
|
) |
|
|
|
conversational_memory = ConversationBufferWindowMemory( |
|
memory_key='chat_history', |
|
k=5, |
|
return_messages=True |
|
) |
|
|
|
tools = [SdxlImage(api_key=SDXL_API_KEY, api_secret=SDXL_API_SECRET)] |
|
|
|
|
|
self.agent = initialize_agent( |
|
agent='chat-conversational-react-description', |
|
tools=tools, |
|
llm=chat_llm, |
|
verbose=True, |
|
max_iterations=3, |
|
early_stopping_method='generate', |
|
memory=conversational_memory |
|
) |
|
|
|
def run(self, txt) -> str: |
|
result = self.agent(txt) |
|
return result["output"] |
|
|
|
def clear(self): |
|
self.agent.memory.clear() |
|
|
|
|
|
bot = AgentBot() |
|
|
|
block_css = """#col_container {width: 1000px; margin-left: auto; margin-right: auto;} |
|
#chatbot {height: 520px; overflow: auto;}""" |
|
|
|
with gr.Blocks(css=block_css) as demo: |
|
gr.Markdown("<h3><center>ChatGPT LangChain</center></h3>") |
|
gr.Markdown( |
|
""" |
|
This LangChain GPT can generate SD-XL Image |
|
""" |
|
) |
|
|
|
with gr.Row() as input_raw: |
|
with gr.Column(elem_id="col_container"): |
|
chatbot = gr.Chatbot([], |
|
elem_id="chatbot", |
|
label="ChatBot LangChain for AIGC", |
|
bubble_full_width=False, |
|
avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))), |
|
) |
|
|
|
msg = gr.Textbox() |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.10, min_width=0): |
|
run = gr.Button("🏃♂️Run") |
|
with gr.Column(scale=0.10, min_width=0): |
|
clear = gr.Button("🔄Clear️") |
|
|
|
|
|
def respond(message, chat_history): |
|
|
|
bot_message = bot.run(message) |
|
regx = r'\b[\w-]+\.png' |
|
match_image = re.findall(regx, bot_message) |
|
chat_history.append((message, bot_message)) |
|
if match_image: |
|
for image in match_image: |
|
image_path = os.path.join(SAVE_FOLDER, image) |
|
chat_history.append( |
|
(None, (image_path,)), |
|
) |
|
time.sleep(2) |
|
return "", chat_history |
|
|
|
|
|
def clearMessage(): |
|
|
|
bot.clear() |
|
|
|
|
|
msg.submit(respond, [msg, chatbot], [msg, chatbot]) |
|
run.click(respond, [msg, chatbot], [msg, chatbot]) |
|
clear.click(clearMessage) |
|
clear.click(lambda: [], None, chatbot) |
|
|
|
gr.Examples( |
|
examples=["generate a image about a boy reading books using SDXL", |
|
"generate two images about a gril in the classroom using SDXL", |
|
], |
|
inputs=msg |
|
) |
|
|
|
demo.queue(concurrency_count=10).launch() |
|
|