File size: 4,784 Bytes
a103d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from openai import OpenAI
import base64
import requests
import re

from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
import argparse


# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')


# Function to retrieve openai api key
def get_openai_key(key_path):
	with open(key_path) as f:
		key = f.read().strip()

	print("Reading OpenAI API key from: ", key_path)
	return key


# Function to obtain GPT4V response
def vision_gpt(prompt, image_url, api_key):
    client = OpenAI(api_key=api_key)
    response = client.chat.completions.create(
        model="gpt-4-vision-preview",
        messages=[
            {
              "role": "user",
              "content": [
                  {"type": "text",
                   "text": prompt},
                  {
                      "type": "image_url",
                      "image_url": {
                          "url": f"data:image/jpeg;base64,{image_url}", },
                  },
              ],
            }
        ],
        max_tokens=600,
    )
    return response.choices[0].message.content



if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="extract differentiating attributes of the gt object class from the mistaken object class, generate synthatic images of the gt class highlighting such attributes")
    parser.add_argument('-i', "--input_path", type=str, metavar='', required=True, help="path to input image")
    parser.add_argument('-o', "--output_path", type=str, metavar='', required=True, help="path to output folder")
    parser.add_argument('-k', "--api_key_path", type=str, metavar='', required=True, help="path to file containing openai api key")
    parser.add_argument('-m', "--mistaken_class", type=str, metavar='', required=True, help="model wrongly predicted this class")
    parser.add_argument('-g', "--ground_truth_class", type=str, metavar='', required=True, help="the ground truth class of the image")
    parser.add_argument('-n', "--num_generations", type=int, metavar='', required=False, default=5, help="number of generations")
    args = parser.parse_args()
    
    
    gt, ms = args.ground_truth_class, args.mistaken_class
    oai_key = get_openai_key(args.api_key_path)
    
    if os.path.exists(args.output_path):
        pass
    else:
        os.mkdir(args.output_path)
    
    
    base64_image = encode_image(args.input_path)
    
    prompt = """
    List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and 
    concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start 
    with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
    non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
    """.format(gt, ms, gt, gt, ms, ms)
    
    # prompt = """
    # List features of the {} in this image that make it distinct from a {}? Then, write a very short and 
    # concise non-artistic visual diffusion prompt of a {} that includes the above features of {} (starting 
    # with 'photo,') and put it inside square brackets []. Do no mention {} in 
    # your prompt, ignore unrelated background scenes, non-essential sub-components, objects, and people.
    # """.format(gt, ms, gt, gt, ms, ms)
    
    
    print("--------------gpt prompt--------------: \n", prompt, "\n\n")
    response = vision_gpt(prompt, base64_image, oai_key)
    print("--------------GPT response--------------: \n", response, "\n\n")
    stable_diffusion_prompt =  re.search(r'\[(.*?)\]', response).group(1)
    print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
    
    
    SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
    SD_pipe.to("cuda")

    RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
    RF_pipe.to("cuda")
    
    
    for i in range(args.num_generations):
        generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
        refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
        # refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
        # refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
        refined_image.save(args.output_path + "{}.png".format(i), 'PNG')