Chain-of-Image / app.py
fxmeng's picture
Update app.py
f4f8a92 verified
import gradio as gr
import time
import base64
from openai import OpenAI
def wait_on_run(run, client, thread):
while run.status == "queued" or run.status == "in_progress":
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
def GenerateImageByCode(client, message, code_prompt):
assistant = client.beta.assistants.create(
name = "Chain of Image",
instructions=code_prompt,
model="gpt-4-1106-preview",
tools=[{"type": "code_interpreter"}]
)
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=message,
)
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
)
run = wait_on_run(run, client, thread)
run_steps = client.beta.threads.runs.steps.list(thread_id=thread.id, run_id=run.id, order="asc")
image_id = None
for data in run_steps.model_dump()['data']:
if "tool_calls" in data['step_details']:
code = data['step_details']['tool_calls'][0]['code_interpreter']['input']
if 'image' in data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0].keys():
image_id = data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0]['image']['file_id']
assert image_id is not None
image_bytes = client.files.with_raw_response.content(image_id).content
with open(f'{image_id}.png', 'wb') as f:
f.write(image_bytes)
base64_image = base64.b64encode(image_bytes).decode('utf-8')
return f"{image_id}.png", base64_image
def visual_question_answer(client, base64_image, question, vqa_prompt, max_tokens=256):
response = client.chat.completions.create(model="gpt-4-vision-preview",
messages=[
{"role": "system", "content": vqa_prompt},
{"role": "user", "content": [
{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{base64_image}",},},
{"type": "text", "text": f"Question:\n{question}\nAnswer:\n"},],},
], max_tokens=max_tokens,)
return response.choices[0].message.content
def chain_of_images(message, history, code_prompt, vqa_prompt, api_token, max_tokens):
client = OpenAI(api_key=api_token)
if len(history):
return visual_question_answer(client, history[0][1][1], message, vqa_prompt, max_tokens=max_tokens)
else:
return GenerateImageByCode(client, message, code_prompt)
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
demo = gr.ChatInterface(chain_of_images,
additional_inputs=[
gr.Textbox("You are a research drawing assistant. Your primary role is to help visualize questions posed by users. Instead of directly answering questions, you will use code to invoke the most suitable toolkit, transforming these questions into images. This helps users quickly understand the question and find answers through visualization. You should prioritize clarity and effectiveness in your visual representations, ensuring that complex scientific or technical concepts are made accessible and comprehensible through your drawings.", label="Code Interpreter Prompt"),
gr.Textbox("You are a visual thinking expert. Your primary role is to answer questions about an image posed by users.", label="VQA Prompt"),
gr.Textbox(label="API Key"),
gr.Slider(32, 128),
],
).queue()
if __name__ == "__main__":
demo.launch()