from openai import OpenAI import base64 import requests import re from diffusers import DiffusionPipeline import torch from PIL import Image import os from huggingface_hub import login login(token=os.environ.get("HF_token")) # Modfiy this to change the number of generations NUM_GEN = 3 def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') def vision_gpt(prompt, image_url, api_key): client = OpenAI(api_key=api_key) response = client.chat.completions.create( model="gpt-4-vision-preview", messages=[ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_url}", }, }, ], } ], max_tokens=600, ) return response.choices[0].message.content def generate_images(oai_key, input_path, mistaken_class, ground_truth_class): output_path = "out/" num_generations = 2 print("--------------input_path--------------: \n", input_path, "\n\n") base64_image = encode_image(input_path) prompt = """ List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt. """.format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class) print("--------------gpt prompt--------------: \n", prompt, "\n\n") response = vision_gpt(prompt, base64_image, oai_key) print("--------------GPT response--------------: \n", response, "\n\n") stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1) print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n") SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") SD_pipe.to("cuda") RF_pipe.to("cuda") out_images = [] for i in range(NUM_GEN): generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0] refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0] refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0] # refined_image.save(output_path + "{}.png".format(i), 'PNG') out_images.append(refined_image) return tuple(out_images)