bryanzhou008 commited on
Commit
7f9334d
·
verified ·
1 Parent(s): fa80443

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import base64
3
+ import requests
4
+ import re
5
+
6
+ from diffusers import DiffusionPipeline
7
+ import torch
8
+ from PIL import Image
9
+ import os
10
+ import argparse
11
+
12
+ import gradio as gr
13
+
14
+ from huggingface_hub import HfFolder
15
+ from transformers import AutoModel
16
+
17
+ HfFolder.save_token('your_hf_api_token_here')
18
+
19
+ def encode_image(image_path):
20
+ with open(image_path, "rb") as image_file:
21
+ return base64.b64encode(image_file.read()).decode('utf-8')
22
+
23
+ def vision_gpt(prompt, image_url, api_key):
24
+ client = OpenAI(api_key=api_key)
25
+ response = client.chat.completions.create(
26
+ model="gpt-4-vision-preview",
27
+ messages=[
28
+ {
29
+ "role": "user",
30
+ "content": [
31
+ {"type": "text",
32
+ "text": prompt},
33
+ {
34
+ "type": "image_url",
35
+ "image_url": {
36
+ "url": f"data:image/jpeg;base64,{image_url}", },
37
+ },
38
+ ],
39
+ }
40
+ ],
41
+ max_tokens=600,
42
+ )
43
+ return response.choices[0].message.content
44
+
45
+
46
+ def generate_images(oai_key, input_path, mistaken_class, ground_truth_class, num_generations):
47
+
48
+ output_path = "out/"
49
+ base64_image = encode_image(input_path)
50
+
51
+ prompt = """
52
+ List key features of the {} itself in this image that make it distinct from a {}? Then, write a very short and
53
+ concise visual midjourney prompt of the {} that includes the above features of {} (prompt should start
54
+ with '4K SLR photo,') and put it inside square brackets []. Do no mention {} in your prompt, also do not mention
55
+ non-essential background scenes like "calm waters, mountains" and sub-components like "paddle of canoe" in the prompt.
56
+ """.format(ground_truth_class, mistaken_class, ground_truth_class, ground_truth_class, mistaken_class, mistaken_class)
57
+
58
+
59
+ print("--------------gpt prompt--------------: \n", prompt, "\n\n")
60
+ response = vision_gpt(prompt, base64_image, oai_key)
61
+ print("--------------GPT response--------------: \n", response, "\n\n")
62
+ stable_diffusion_prompt = re.search(r'\[(.*?)\]', response).group(1)
63
+ print("--------------stable_diffusion_prompt-------------- \n", stable_diffusion_prompt, "\n\n")
64
+
65
+
66
+ SD_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
67
+ SD_pipe.to("cuda")
68
+
69
+ RF_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
70
+ RF_pipe.to("cuda")
71
+
72
+ for i in range(num_generations):
73
+ generated_images = SD_pipe(prompt=stable_diffusion_prompt, num_inference_steps=75).images
74
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=generated_images).images[0]
75
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
76
+ refined_image = RF_pipe(prompt=stable_diffusion_prompt, image=refined_image).images[0]
77
+ refined_image.save(output_path + "{}.png".format(i), 'PNG')
78
+
79
+ return [output_path + "{}.png".format(i) for i in range(num_generations)]
80
+
81
+ iface = gr.Interface(
82
+ fn=generate_images,
83
+ inputs=[
84
+ gr.Textbox(label="OpenAI API Key"),
85
+ gr.Image(label="Input Image"),
86
+ gr.Textbox(label="Mistaken Class"),
87
+ gr.Textbox(label="Ground Truth Class"),
88
+ gr.Number(label="Number of Generations")
89
+ ],
90
+ outputs=[
91
+ gr.Image(label="Output Image")
92
+ ],
93
+ title="Image Generation and Refinement",
94
+ description="Generates and refines images based on input classes and parameters."
95
+ )
96
+
97
+ if __name__ == "__main__":
98
+ iface.launch()