bryanzhou008's picture
Update src/v2_for_hf.py
3522874 verified
raw
history blame
3.28 kB
from openai import OpenAI
import base64
import requests
import re
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
from huggingface_hub import login
login(token=os.environ.get("HF_token"))
# Modfiy this to change the number of generations
NUM_GEN = 3
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def vision_gpt(prompt, image_url, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text",
"text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_url}", },
},
],
}
],
max_tokens=600,
)
return response.choices[0].message.content
def generate_images(oai_key, input_path, mistaken_class, ground_truth_class):
output_path = "out/"
num_generations = 2
print("--------------input_path--------------: \n", input_path, "\n\n")
base64_image = encode_image(input_path)
prompt = """
List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
""".format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class)
print("--------------gpt prompt--------------: \n", prompt, "\n\n")
response = vision_gpt(prompt, base64_image, oai_key)
print("--------------GPT response--------------: \n", response, "\n\n")
stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
SD_pipe.to("cuda")
RF_pipe.to("cuda")
out_images = []
for i in range(NUM_GEN):
generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
# refined_image.save(output_path + "{}.png".format(i), 'PNG')
out_images.append(refined_image)
return tuple(out_images)