bryanzhou008's picture
Update src/v2_for_hf.py
3522874 verified
from openai import OpenAI
import base64
import requests
import re
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
from huggingface_hub import login
login(token=os.environ.get("HF_token"))
# Modfiy this to change the number of generations
NUM_GEN = 3
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def vision_gpt(prompt, image_url, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text",
"text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_url}", },
},
],
}
],
max_tokens=600,
)
return response.choices[0].message.content
def generate_images(oai_key, input_path, mistaken_class, ground_truth_class):
output_path = "out/"
num_generations = 2
print("--------------input_path--------------: \n", input_path, "\n\n")
base64_image = encode_image(input_path)
prompt = """
List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
""".format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class)
print("--------------gpt prompt--------------: \n", prompt, "\n\n")
response = vision_gpt(prompt, base64_image, oai_key)
print("--------------GPT response--------------: \n", response, "\n\n")
stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
SD_pipe.to("cuda")
RF_pipe.to("cuda")
out_images = []
for i in range(NUM_GEN):
generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
# refined_image.save(output_path + "{}.png".format(i), 'PNG')
out_images.append(refined_image)
return tuple(out_images)