bryanzhou008 commited on
Commit
a103d54
1 Parent(s): 1dab74f

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +21 -0
  2. environment.yml +220 -0
  3. src/v1.py +87 -0
  4. src/v2.py +110 -0
  5. src/v2_for_hf.py +90 -0
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from src.v2_for_hf import generate_images
4
+ from src.v2_for_hf import NUM_GEN
5
+
6
+ iface = gr.Interface(
7
+ fn=generate_images,
8
+ inputs=[
9
+ gr.Textbox(label="OpenAI API Key"),
10
+ gr.Image(label="Input Image", type="filepath"),
11
+ gr.Textbox(label="Mistaken Class"),
12
+ gr.Textbox(label="Ground Truth Class")
13
+ ],
14
+ outputs=[
15
+ gr.Image(label="Output Image") for i in range(NUM_GEN)
16
+ ],
17
+ title="visual-data-aug",
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ iface.launch(share=True)
environment.yml ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: torch_env
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ - conda-forge
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - brotli-python=1.0.9=py38h6a678d5_7
11
+ - bzip2=1.0.8=h7b6447c_0
12
+ - ca-certificates=2023.12.12=h06a4308_0
13
+ - cryptography=41.0.7=py38hdda0065_0
14
+ - cudatoolkit=10.2.89=h713d32c_10
15
+ - ffmpeg=4.3=hf484d3e_0
16
+ - freetype=2.12.1=h4a9f257_0
17
+ - giflib=5.2.1=h5eee18b_3
18
+ - gmp=6.2.1=h295c915_3
19
+ - gmpy2=2.1.2=py38heeb90bb_0
20
+ - gnutls=3.6.15=he1e5248_0
21
+ - idna=3.4=py38h06a4308_0
22
+ - intel-openmp=2023.1.0=hdb19cb5_46306
23
+ - jinja2=3.1.2=py38h06a4308_0
24
+ - jpeg=9e=h5eee18b_1
25
+ - lame=3.100=h7b6447c_0
26
+ - lcms2=2.12=h3be6417_0
27
+ - ld_impl_linux-64=2.38=h1181459_1
28
+ - lerc=3.0=h295c915_0
29
+ - libdeflate=1.17=h5eee18b_1
30
+ - libffi=3.4.4=h6a678d5_0
31
+ - libgcc-ng=11.2.0=h1234567_1
32
+ - libgomp=11.2.0=h1234567_1
33
+ - libiconv=1.16=h7f8727e_2
34
+ - libidn2=2.3.4=h5eee18b_0
35
+ - libjpeg-turbo=2.0.0=h9bf148f_0
36
+ - libpng=1.6.39=h5eee18b_0
37
+ - libstdcxx-ng=11.2.0=h1234567_1
38
+ - libtasn1=4.19.0=h5eee18b_0
39
+ - libtiff=4.5.1=h6a678d5_0
40
+ - libunistring=0.9.10=h27cfd23_0
41
+ - libwebp=1.3.2=h11a3e52_0
42
+ - libwebp-base=1.3.2=h5eee18b_0
43
+ - llvm-openmp=14.0.6=h9e868ea_0
44
+ - lz4-c=1.9.4=h6a678d5_0
45
+ - markupsafe=2.1.3=py38h5eee18b_0
46
+ - mkl=2023.1.0=h213fc3f_46344
47
+ - mkl-service=2.4.0=py38h5eee18b_1
48
+ - mkl_fft=1.3.8=py38h5eee18b_0
49
+ - mkl_random=1.2.4=py38hdb19cb5_0
50
+ - mpc=1.1.0=h10f8cd9_1
51
+ - mpfr=4.0.2=hb69a4c5_1
52
+ - mpmath=1.3.0=py38h06a4308_0
53
+ - ncurses=6.4=h6a678d5_0
54
+ - nettle=3.7.3=hbbd107a_1
55
+ - networkx=3.1=py38h06a4308_0
56
+ - numpy=1.24.3=py38hf6e8229_1
57
+ - numpy-base=1.24.3=py38h060ed82_1
58
+ - openh264=2.1.1=h4ff587b_0
59
+ - openjpeg=2.4.0=h3ad879b_0
60
+ - openssl=3.0.12=h7f8727e_0
61
+ - pip=23.3.1=py38h06a4308_0
62
+ - pycparser=2.21=pyhd3eb1b0_0
63
+ - pyopenssl=23.2.0=py38h06a4308_0
64
+ - pysocks=1.7.1=py38h06a4308_0
65
+ - python=3.8.18=h955ad1f_0
66
+ - pytorch-mutex=1.0=cpu
67
+ - readline=8.2=h5eee18b_0
68
+ - requests=2.31.0=py38h06a4308_0
69
+ - sqlite=3.41.2=h5eee18b_0
70
+ - sympy=1.12=py38h06a4308_0
71
+ - tbb=2021.8.0=hdb19cb5_0
72
+ - tk=8.6.12=h1ccaba5_0
73
+ - torchaudio=2.1.2=py38_cpu
74
+ - torchvision=0.16.2=py38_cpu
75
+ - typing_extensions=4.9.0=py38h06a4308_0
76
+ - wheel=0.41.2=py38h06a4308_0
77
+ - xz=5.4.5=h5eee18b_0
78
+ - yaml=0.2.5=h7b6447c_0
79
+ - zlib=1.2.13=h5eee18b_0
80
+ - zstd=1.5.5=hc292b87_0
81
+ - pip:
82
+ - accelerate==0.26.1
83
+ - aiofiles==23.2.1
84
+ - aiohttp==3.8.4
85
+ - aiosignal==1.3.1
86
+ - altair==5.2.0
87
+ - annotated-types==0.6.0
88
+ - anyio==4.2.0
89
+ - argon2-cffi==21.3.0
90
+ - argon2-cffi-bindings==21.2.0
91
+ - argparse==1.4.0
92
+ - asttokens==2.4.1
93
+ - async-timeout==4.0.3
94
+ - attrs==23.1.0
95
+ - backcall==0.2.0
96
+ - beautifulsoup4==4.12.2
97
+ - bleach==6.0.0
98
+ - certifi==2023.5.7
99
+ - cffi==1.15.1
100
+ - charset-normalizer==3.1.0
101
+ - click==8.1.3
102
+ - cmake==3.28.1
103
+ - colorama==0.4.6
104
+ - comm==0.2.1
105
+ - contourpy==1.1.1
106
+ - cycler==0.12.1
107
+ - datasets==2.13.1
108
+ - debugpy==1.8.0
109
+ - decorator==5.1.1
110
+ - diffusers==0.24.0
111
+ - dill==0.3.6
112
+ - distro==1.9.0
113
+ - exceptiongroup==1.2.0
114
+ - executing==2.0.1
115
+ - fastapi==0.109.0
116
+ - fastjsonschema==2.17.1
117
+ - ffmpy==0.3.1
118
+ - filelock==3.12.2
119
+ - fonttools==4.47.2
120
+ - frozenlist==1.4.1
121
+ - fsspec==2023.12.2
122
+ - gradio==4.14.0
123
+ - gradio-client==0.8.0
124
+ - h11==0.14.0
125
+ - httpcore==1.0.2
126
+ - httpx==0.26.0
127
+ - huggingface-hub==0.20.1
128
+ - importlib-metadata==6.7.0
129
+ - importlib-resources==6.1.1
130
+ - ipykernel==6.24.0
131
+ - ipython==8.12.2
132
+ - jedi==0.18.2
133
+ - joblib==1.3.1
134
+ - jsonschema==4.17.3
135
+ - jupyter-client==8.6.0
136
+ - jupyter-core==5.7.1
137
+ - kiwisolver==1.4.5
138
+ - lit==17.0.6
139
+ - markdown-it-py==3.0.0
140
+ - matplotlib==3.7.1
141
+ - matplotlib-inline==0.1.6
142
+ - mdurl==0.1.2
143
+ - multidict==6.0.4
144
+ - multiprocess==0.70.14
145
+ - nest-asyncio==1.5.8
146
+ - nltk==3.8.1
147
+ - nvidia-cublas-cu11==11.10.3.66
148
+ - nvidia-cuda-cupti-cu11==11.7.101
149
+ - nvidia-cuda-nvrtc-cu11==11.7.99
150
+ - nvidia-cuda-runtime-cu11==11.7.99
151
+ - nvidia-cudnn-cu11==8.5.0.96
152
+ - nvidia-cufft-cu11==10.9.0.58
153
+ - nvidia-curand-cu11==10.2.10.91
154
+ - nvidia-cusolver-cu11==11.4.0.1
155
+ - nvidia-cusparse-cu11==11.7.4.91
156
+ - nvidia-nccl-cu11==2.14.3
157
+ - nvidia-nvtx-cu11==11.7.91
158
+ - openai==1.6.1
159
+ - orjson==3.9.10
160
+ - packaging==23.2
161
+ - pandas==2.0.3
162
+ - parso==0.8.3
163
+ - pexpect==4.9.0
164
+ - pickleshare==0.7.5
165
+ - pillow==10.0.0
166
+ - pkgutil-resolve-name==1.3.10
167
+ - platformdirs==4.1.0
168
+ - prompt-toolkit==3.0.38
169
+ - psutil==5.9.5
170
+ - ptyprocess==0.7.0
171
+ - pure-eval==0.2.2
172
+ - pyarrow==14.0.2
173
+ - pydantic==2.5.3
174
+ - pydantic-core==2.14.6
175
+ - pydub==0.25.1
176
+ - pygments==2.15.1
177
+ - pyparsing==3.1.0
178
+ - pyrsistent==0.20.0
179
+ - python-dateutil==2.8.2
180
+ - python-multipart==0.0.6
181
+ - pytz==2023.3
182
+ - pyyaml==6.0
183
+ - pyzmq==25.1.2
184
+ - regex==2023.12.25
185
+ - rich==13.7.0
186
+ - safetensors==0.4.1
187
+ - scikit-learn==1.3.0
188
+ - scipy==1.10.1
189
+ - semantic-version==2.10.0
190
+ - sentence-transformers==2.2.2
191
+ - sentencepiece==0.1.99
192
+ - setuptools==67.8.0
193
+ - shellingham==1.5.4
194
+ - six==1.16.0
195
+ - sniffio==1.3.0
196
+ - soupsieve==2.5
197
+ - stack-data==0.6.3
198
+ - starlette==0.35.1
199
+ - threadpoolctl==3.2.0
200
+ - tokenizers==0.13.3
201
+ - tomlkit==0.12.0
202
+ - toolz==0.12.0
203
+ - torch==2.0.1
204
+ - tornado==6.4
205
+ - tqdm==4.65.0
206
+ - traitlets==5.14.1
207
+ - transformers==4.30.2
208
+ - triton==2.0.0
209
+ - typer==0.9.0
210
+ - tzdata==2023.4
211
+ - urllib3==2.0.3
212
+ - uvicorn==0.25.0
213
+ - wcwidth==0.2.13
214
+ - webencodings==0.5.1
215
+ - websockets==11.0.3
216
+ - xxhash==3.4.1
217
+ - yarl==1.9.4
218
+ - zipp==3.17.0
219
+ - gradio
220
+ prefix: /home/bingxuan/anaconda3/envs/torch_env
src/v1.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import base64
3
+ import requests
4
+ import re
5
+
6
+ from diffusers import DiffusionPipeline
7
+ import torch
8
+ from PIL import Image
9
+ import os
10
+ import argparse
11
+
12
+
13
+ SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
14
+ SD_pipe.to("cuda")
15
+
16
+ RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
17
+ RF_pipe.to("cuda")
18
+
19
+ # Function to encode the image
20
+ def encode_image(image_path):
21
+ with open(image_path, "rb") as image_file:
22
+ return base64.b64encode(image_file.read()).decode('utf-8')
23
+
24
+
25
+ def vision_gpt(prompt, image_url, api_key):
26
+ client = OpenAI(api_key=api_key)
27
+ response = client.chat.completions.create(
28
+ model="gpt-4-vision-preview",
29
+ messages=[
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {"type": "text",
34
+ "text": prompt},
35
+ {
36
+ "type": "image_url",
37
+ "image_url": {
38
+ "url": f"data:image/jpeg;base64,{image_url}", },
39
+ },
40
+ ],
41
+ }
42
+ ],
43
+ max_tokens=600,
44
+ )
45
+ return response.choices[0].message.content
46
+
47
+
48
+
49
+ if __name__ == "__main__":
50
+
51
+ parser = argparse.ArgumentParser(description="extract differentiating attributes of the gt object class from the mistaken object class, generate synthatic images of the gt class highlighting such attributes")
52
+ parser.add_argument('-i', "--input_path", type=str, metavar='', required=True, help="path to input image")
53
+ parser.add_argument('-o', "--output_path", type=str, metavar='', required=True, help="path to output folder")
54
+ parser.add_argument('-k', "--api_key", type=str, metavar='', required=True, help="valid openai api key")
55
+ parser.add_argument('-m', "--mistaken_class", type=str, metavar='', required=True, help="model wrongly predicted this class")
56
+ parser.add_argument('-g', "--ground_truth_class", type=str, metavar='', required=True, help="the ground truth class of the image")
57
+ parser.add_argument('-n', "--num_generations", type=int, metavar='', required=False, default=5, help="number of generations")
58
+ args = parser.parse_args()
59
+
60
+ gt, ms = args.ground_truth_class, args.mistaken_class
61
+
62
+
63
+ if os.path.exists(args.output_path):
64
+ pass
65
+ else:
66
+ os.mkdir(args.output_path)
67
+
68
+
69
+ base64_image = encode_image(args.input_path)
70
+
71
+ prompt = """List features of the {} in this image that make it distinct from a {}? Then, write a short and
72
+ concise non-artistic visual diffusion prompt of a {} that includes the above features of {} (starting
73
+ with 'photorealistic candid portrait of') and put it inside square brackets []. Do no mention {} in
74
+ your prompt and ignore unrelated background scenes.""".format(gt, ms, gt, gt, ms, ms)
75
+
76
+
77
+ print("--------------gpt prompt--------------: \n", prompt, "\n\n")
78
+ response = vision_gpt(prompt, base64_image, args.api_key)
79
+ print("--------------GPT response--------------: \n", response, "\n\n")
80
+ stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
81
+ print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
82
+
83
+
84
+ for i in range(args.num_generations):
85
+ generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
86
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
87
+ refined_image.save(args.output_path + "{}.png".format(i), 'PNG')
src/v2.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import base64
3
+ import requests
4
+ import re
5
+
6
+ from diffusers import DiffusionPipeline
7
+ import torch
8
+ from PIL import Image
9
+ import os
10
+ import argparse
11
+
12
+
13
+ # Function to encode the image
14
+ def encode_image(image_path):
15
+ with open(image_path, "rb") as image_file:
16
+ return base64.b64encode(image_file.read()).decode('utf-8')
17
+
18
+
19
+ # Function to retrieve openai api key
20
+ def get_openai_key(key_path):
21
+ with open(key_path) as f:
22
+ key = f.read().strip()
23
+
24
+ print("Reading OpenAI API key from: ", key_path)
25
+ return key
26
+
27
+
28
+ # Function to obtain GPT4V response
29
+ def vision_gpt(prompt, image_url, api_key):
30
+ client = OpenAI(api_key=api_key)
31
+ response = client.chat.completions.create(
32
+ model="gpt-4-vision-preview",
33
+ messages=[
34
+ {
35
+ "role": "user",
36
+ "content": [
37
+ {"type": "text",
38
+ "text": prompt},
39
+ {
40
+ "type": "image_url",
41
+ "image_url": {
42
+ "url": f"data:image/jpeg;base64,{image_url}", },
43
+ },
44
+ ],
45
+ }
46
+ ],
47
+ max_tokens=600,
48
+ )
49
+ return response.choices[0].message.content
50
+
51
+
52
+
53
+ if __name__ == "__main__":
54
+
55
+ parser = argparse.ArgumentParser(description="extract differentiating attributes of the gt object class from the mistaken object class, generate synthatic images of the gt class highlighting such attributes")
56
+ parser.add_argument('-i', "--input_path", type=str, metavar='', required=True, help="path to input image")
57
+ parser.add_argument('-o', "--output_path", type=str, metavar='', required=True, help="path to output folder")
58
+ parser.add_argument('-k', "--api_key_path", type=str, metavar='', required=True, help="path to file containing openai api key")
59
+ parser.add_argument('-m', "--mistaken_class", type=str, metavar='', required=True, help="model wrongly predicted this class")
60
+ parser.add_argument('-g', "--ground_truth_class", type=str, metavar='', required=True, help="the ground truth class of the image")
61
+ parser.add_argument('-n', "--num_generations", type=int, metavar='', required=False, default=5, help="number of generations")
62
+ args = parser.parse_args()
63
+
64
+
65
+ gt, ms = args.ground_truth_class, args.mistaken_class
66
+ oai_key = get_openai_key(args.api_key_path)
67
+
68
+ if os.path.exists(args.output_path):
69
+ pass
70
+ else:
71
+ os.mkdir(args.output_path)
72
+
73
+
74
+ base64_image = encode_image(args.input_path)
75
+
76
+ prompt = """
77
+ List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
78
+ concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
79
+ with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
80
+ non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
81
+ """.format(gt, ms, gt, gt, ms, ms)
82
+
83
+ # prompt = """
84
+ # List features of the {} in this image that make it distinct from a {}? Then, write a very short and
85
+ # concise non-artistic visual diffusion prompt of a {} that includes the above features of {} (starting
86
+ # with 'photo,') and put it inside square brackets []. Do no mention {} in
87
+ # your prompt, ignore unrelated background scenes, non-essential sub-components, objects, and people.
88
+ # """.format(gt, ms, gt, gt, ms, ms)
89
+
90
+
91
+ print("--------------gpt prompt--------------: \n", prompt, "\n\n")
92
+ response = vision_gpt(prompt, base64_image, oai_key)
93
+ print("--------------GPT response--------------: \n", response, "\n\n")
94
+ stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
95
+ print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
96
+
97
+
98
+ SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
99
+ SD_pipe.to("cuda")
100
+
101
+ RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
102
+ RF_pipe.to("cuda")
103
+
104
+
105
+ for i in range(args.num_generations):
106
+ generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
107
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
108
+ # refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
109
+ # refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
110
+ refined_image.save(args.output_path + "{}.png".format(i), 'PNG')
src/v2_for_hf.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import base64
3
+ import requests
4
+ import re
5
+
6
+ from diffusers import DiffusionPipeline
7
+ import torch
8
+ from PIL import Image
9
+ import os
10
+
11
+ from huggingface_hub import login
12
+ with open("key.txt", "r") as f:
13
+ login(token=f.read().strip())
14
+
15
+ # Modfiy this to change the number of generations
16
+ NUM_GEN = 2
17
+
18
+ def encode_image(image_path):
19
+ with open(image_path, "rb") as image_file:
20
+ return base64.b64encode(image_file.read()).decode('utf-8')
21
+
22
+ def vision_gpt(prompt, image_url, api_key):
23
+ client = OpenAI(api_key=api_key)
24
+ response = client.chat.completions.create(
25
+ model="gpt-4-vision-preview",
26
+ messages=[
27
+ {
28
+ "role": "user",
29
+ "content": [
30
+ {"type": "text",
31
+ "text": prompt},
32
+ {
33
+ "type": "image_url",
34
+ "image_url": {
35
+ "url": f"data:image/jpeg;base64,{image_url}", },
36
+ },
37
+ ],
38
+ }
39
+ ],
40
+ max_tokens=600,
41
+ )
42
+ return response.choices[0].message.content
43
+
44
+
45
+ def generate_images(oai_key, input_path, mistaken_class, ground_truth_class):
46
+
47
+ output_path = "out/"
48
+ num_generations = 2
49
+ print("--------------input_path--------------: \n", input_path, "\n\n")
50
+ base64_image = encode_image(input_path)
51
+
52
+ prompt = """
53
+ List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
54
+ concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
55
+ with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
56
+ non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
57
+ """.format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class)
58
+
59
+
60
+ print("--------------gpt prompt--------------: \n", prompt, "\n\n")
61
+ response = vision_gpt(prompt, base64_image, oai_key)
62
+ print("--------------GPT response--------------: \n", response, "\n\n")
63
+ stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
64
+ print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
65
+
66
+
67
+ SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
68
+ RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
69
+
70
+ SD_pipe.to("cuda")
71
+ RF_pipe.to("cuda")
72
+
73
+ out_images = []
74
+ for i in range(NUM_GEN):
75
+ generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
76
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
77
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
78
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
79
+ # refined_image.save(output_path + "{}.png".format(i), 'PNG')
80
+ out_images.append(refined_image)
81
+
82
+ return tuple(out_images)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ oai_key = "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"
87
+ input_path = "out/0.png"
88
+ mistaken_class = "dog"
89
+ ground_truth_class = "cat"
90
+ generate_images(oai_key, input_path, mistaken_class, ground_truth_class)