bryanzhou008's picture
Upload 5 files
a103d54 verified
raw
history blame
3.71 kB
from openai import OpenAI
import base64
import requests
import re
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
import argparse
SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
SD_pipe.to("cuda")
RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
RF_pipe.to("cuda")
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def vision_gpt(prompt, image_url, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text",
"text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_url}", },
},
],
}
],
max_tokens=600,
)
return response.choices[0].message.content
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="extract differentiating attributes of the gt object class from the mistaken object class, generate synthatic images of the gt class highlighting such attributes")
parser.add_argument('-i', "--input_path", type=str, metavar='', required=True, help="path to input image")
parser.add_argument('-o', "--output_path", type=str, metavar='', required=True, help="path to output folder")
parser.add_argument('-k', "--api_key", type=str, metavar='', required=True, help="valid openai api key")
parser.add_argument('-m', "--mistaken_class", type=str, metavar='', required=True, help="model wrongly predicted this class")
parser.add_argument('-g', "--ground_truth_class", type=str, metavar='', required=True, help="the ground truth class of the image")
parser.add_argument('-n', "--num_generations", type=int, metavar='', required=False, default=5, help="number of generations")
args = parser.parse_args()
gt, ms = args.ground_truth_class, args.mistaken_class
if os.path.exists(args.output_path):
pass
else:
os.mkdir(args.output_path)
base64_image = encode_image(args.input_path)
prompt = """List features of the {} in this image that make it distinct from a {}? Then, write a short and
concise non-artistic visual diffusion prompt of a {} that includes the above features of {} (starting
with 'photorealistic candid portrait of') and put it inside square brackets []. Do no mention {} in
your prompt and ignore unrelated background scenes.""".format(gt, ms, gt, gt, ms, ms)
print("--------------gpt prompt--------------: \n", prompt, "\n\n")
response = vision_gpt(prompt, base64_image, args.api_key)
print("--------------GPT response--------------: \n", response, "\n\n")
stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
for i in range(args.num_generations):
generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
refined_image.save(args.output_path + "{}.png".format(i), 'PNG')