visual_data_aug / app.py
bryanzhou008's picture
Update app.py
05065c1 verified
from openai import OpenAI
import base64
import requests
import re
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
import argparse
import gradio as gr
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def vision_gpt(prompt, image_url, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text",
"text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_url}", },
},
],
}
],
max_tokens=600,
)
return response.choices[0].message.content
def generate_images(oai_key, input_path, mistaken_class, ground_truth_class, num_generations):
output_path = "out/"
base64_image = encode_image(input_path)
prompt = """
List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
""".format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class)
print("--------------gpt prompt--------------: \n", prompt, "\n\n")
response = vision_gpt(prompt, base64_image, oai_key)
print("--------------GPT response--------------: \n", response, "\n\n")
stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
SD_pipe.to("cuda")
RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
RF_pipe.to("cuda")
for i in range(num_generations):
generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
refined_image.save(output_path + "{}.png".format(i), 'PNG')
return [output_path + "{}.png".format(i) for i in range(num_generations)]
iface = gr.Interface(
fn=generate_images,
inputs=[
gr.Textbox(label="OpenAI API Key"),
gr.Image(label="Input Image"),
gr.Textbox(label="Mistaken Class"),
gr.Textbox(label="Ground Truth Class"),
gr.Number(label="Number of Generations")
],
outputs=[
gr.Image(label="Output Image")
],
title="Image Generation and Refinement",
description="Generates and refines images based on input classes and parameters."
)
if __name__ == "__main__":
iface.launch()