Spaces:
Runtime error
Runtime error
File size: 3,621 Bytes
7f9334d 05065c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from openai import OpenAI
import base64
import requests
import re
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
import argparse
import gradio as gr
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def vision_gpt(prompt, image_url, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text",
"text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_url}", },
},
],
}
],
max_tokens=600,
)
return response.choices[0].message.content
def generate_images(oai_key, input_path, mistaken_class, ground_truth_class, num_generations):
output_path = "out/"
base64_image = encode_image(input_path)
prompt = """
List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
""".format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class)
print("--------------gpt prompt--------------: \n", prompt, "\n\n")
response = vision_gpt(prompt, base64_image, oai_key)
print("--------------GPT response--------------: \n", response, "\n\n")
stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
SD_pipe.to("cuda")
RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
RF_pipe.to("cuda")
for i in range(num_generations):
generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
refined_image.save(output_path + "{}.png".format(i), 'PNG')
return [output_path + "{}.png".format(i) for i in range(num_generations)]
iface = gr.Interface(
fn=generate_images,
inputs=[
gr.Textbox(label="OpenAI API Key"),
gr.Image(label="Input Image"),
gr.Textbox(label="Mistaken Class"),
gr.Textbox(label="Ground Truth Class"),
gr.Number(label="Number of Generations")
],
outputs=[
gr.Image(label="Output Image")
],
title="Image Generation and Refinement",
description="Generates and refines images based on input classes and parameters."
)
if __name__ == "__main__":
iface.launch() |