Spaces:
Runtime error
Runtime error
import gradio as gr | |
import time | |
import base64 | |
from openai import OpenAI | |
def wait_on_run(run, client, thread): | |
while run.status == "queued" or run.status == "in_progress": | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread.id, | |
run_id=run.id, | |
) | |
time.sleep(0.5) | |
return run | |
def GenerateImageByCode(client, message, code_prompt): | |
assistant = client.beta.assistants.create( | |
name = "Chain of Image", | |
instructions=code_prompt, | |
model="gpt-4-1106-preview", | |
tools=[{"type": "code_interpreter"}] | |
) | |
thread = client.beta.threads.create() | |
client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=message, | |
) | |
run = client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=assistant.id, | |
) | |
run = wait_on_run(run, client, thread) | |
run_steps = client.beta.threads.runs.steps.list(thread_id=thread.id, run_id=run.id, order="asc") | |
image_id = None | |
for data in run_steps.model_dump()['data']: | |
if "tool_calls" in data['step_details']: | |
code = data['step_details']['tool_calls'][0]['code_interpreter']['input'] | |
if 'image' in data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0].keys(): | |
image_id = data['step_details']['tool_calls'][0]['code_interpreter']['outputs'][0]['image']['file_id'] | |
assert image_id is not None | |
image_bytes = client.files.with_raw_response.content(image_id).content | |
with open(f'{image_id}.png', 'wb') as f: | |
f.write(image_bytes) | |
base64_image = base64.b64encode(image_bytes).decode('utf-8') | |
return f"{image_id}.png", base64_image | |
def visual_question_answer(client, base64_image, question, vqa_prompt, max_tokens=256): | |
response = client.chat.completions.create(model="gpt-4-vision-preview", | |
messages=[ | |
{"role": "system", "content": vqa_prompt}, | |
{"role": "user", "content": [ | |
{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{base64_image}",},}, | |
{"type": "text", "text": f"Question:\n{question}\nAnswer:\n"},],}, | |
], max_tokens=max_tokens,) | |
return response.choices[0].message.content | |
def chain_of_images(message, history, code_prompt, vqa_prompt, api_token, max_tokens): | |
client = OpenAI(api_key=api_token) | |
if len(history): | |
return visual_question_answer(client, history[0][1][1], message, vqa_prompt, max_tokens=max_tokens) | |
else: | |
return GenerateImageByCode(client, message, code_prompt) | |
def vote(data: gr.LikeData): | |
if data.liked: | |
print("You upvoted this response: " + data.value) | |
else: | |
print("You downvoted this response: " + data.value) | |
demo = gr.ChatInterface(chain_of_images, | |
additional_inputs=[ | |
gr.Textbox("You are a research drawing assistant. Your primary role is to help visualize questions posed by users. Instead of directly answering questions, you will use code to invoke the most suitable toolkit, transforming these questions into images. This helps users quickly understand the question and find answers through visualization. You should prioritize clarity and effectiveness in your visual representations, ensuring that complex scientific or technical concepts are made accessible and comprehensible through your drawings.", label="Code Interpreter Prompt"), | |
gr.Textbox("You are a visual thinking expert. Your primary role is to answer questions about an image posed by users.", label="VQA Prompt"), | |
gr.Textbox(label="API Key"), | |
gr.Slider(32, 128), | |
], | |
).queue() | |
if __name__ == "__main__": | |
demo.launch() |