Johannes commited on
Commit
ea23bc4
·
1 Parent(s): 9c281ac

updated version

Browse files
Files changed (1) hide show
  1. app.py +104 -15
app.py CHANGED
@@ -6,7 +6,34 @@ from diffusers import StableDiffusionPipeline
6
  import urllib, urllib.request
7
  import os
8
  from xml.etree import ElementTree
9
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  context = autocast if device == "cuda" else nullcontext
@@ -26,11 +53,9 @@ if disable_safety:
26
  pipe.safety_checker = null_safety
27
 
28
 
29
- def infer(prompt, n_samples, steps, scale):
30
- paper_title = get_paper_name(prompt)
31
-
32
  with context("cuda"):
33
- images = pipe(n_samples*[paper_title], guidance_scale=scale, num_inference_steps=steps).images
34
 
35
  return images
36
 
@@ -38,12 +63,13 @@ def get_paper_name(url: str):
38
  paper_id = os.path.basename(url)
39
  paper_id = paper_id.split(".pdf")[0]
40
  query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}"
41
- hdr = { 'Content-Type' : 'application/atom+xml' }
42
  req = urllib.request.Request(query_url, headers=hdr)
43
  response = urllib.request.urlopen(req)
44
- tree = ElementTree.fromstring(response.read().decode('utf-8'))
45
- paper_title = tree.find('{http://www.w3.org/2005/Atom}entry').find('{http://www.w3.org/2005/Atom}title').text
46
-
 
47
  return paper_title
48
 
49
 
@@ -52,10 +78,25 @@ block = gr.Blocks()
52
 
53
  examples = [
54
  [
55
- 'https://arxiv.org/abs/1706.03762',
 
 
 
 
 
56
  2,
57
  7.5,
58
  ],
 
 
 
 
 
 
 
 
 
 
59
  ]
60
 
61
  with block:
@@ -77,10 +118,10 @@ with block:
77
  with gr.Box():
78
  with gr.Row().style(mobile_collapse=False, equal_height=True):
79
  text = gr.Textbox(
80
- label="Enter your prompt",
81
  show_label=False,
82
  max_lines=1,
83
- placeholder="Enter your prompt",
84
  ).style(
85
  border=(True, False, True, True),
86
  rounded=(True, False, False, True),
@@ -90,6 +131,17 @@ with block:
90
  margin=False,
91
  rounded=(False, True, True, False),
92
  )
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  gallery = gr.Gallery(
95
  label="Generated images", show_label=False, elem_id="gallery"
@@ -107,9 +159,46 @@ with block:
107
  ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False)
108
  ex.dataset.headers = [""]
109
 
110
-
111
- text.submit(infer, inputs=[text, samples, steps, scale], outputs=gallery)
112
- btn.click(infer, inputs=[text, samples, steps, scale], outputs=gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  gr.HTML(
114
  """
115
  <div class="footer" style="text-align: center; max-width: 650px; margin: 50px auto;">
 
6
  import urllib, urllib.request
7
  import os
8
  from xml.etree import ElementTree
9
+ import random
10
+ import re
11
+
12
+
13
+ pokemon_types = ["Normal",
14
+ "Water",
15
+ "Fire",
16
+ "Ice",
17
+ "Psychic",
18
+ "Rock",
19
+ "Dark",
20
+ "Electric",
21
+ "Grass",
22
+ "Fighting",
23
+ "Poison",
24
+ "Ground",
25
+ "Flying",
26
+ "Bug",
27
+ "Ghost",
28
+ "Dragon",
29
+ "Steel",
30
+ "Fairy"
31
+ ]
32
+
33
+ type_choices=["None", "Random"]
34
+ type_choices.extend(pokemon_types)
35
+
36
+ paper_name = None
37
 
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
  context = autocast if device == "cuda" else nullcontext
 
53
  pipe.safety_checker = null_safety
54
 
55
 
56
+ def infer(prompt, n_samples, steps, scale):
 
 
57
  with context("cuda"):
58
+ images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images
59
 
60
  return images
61
 
 
63
  paper_id = os.path.basename(url)
64
  paper_id = paper_id.split(".pdf")[0]
65
  query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}"
66
+ hdr = { "Content-Type" : "application/atom+xml" }
67
  req = urllib.request.Request(query_url, headers=hdr)
68
  response = urllib.request.urlopen(req)
69
+ tree = ElementTree.fromstring(response.read().decode("utf-8"))
70
+ paper_title = tree.find("{http://www.w3.org/2005/Atom}entry").find("{http://www.w3.org/2005/Atom}title").text
71
+ paper_title = paper_title.replace("\n", "")
72
+ paper_title = re.sub(' +', ' ', paper_title)
73
  return paper_title
74
 
75
 
 
78
 
79
  examples = [
80
  [
81
+ "https://arxiv.org/abs/1706.03762",
82
+ 2,
83
+ 7.5,
84
+ ],
85
+ [
86
+ "https://arxiv.org/abs/1404.5997v2",
87
  2,
88
  7.5,
89
  ],
90
+ [
91
+ "https://arxiv.org/abs/2010.11929",
92
+ 2,
93
+ 7.5,
94
+ ],
95
+ [
96
+ "https://arxiv.org/abs/1810.04805v2",
97
+ 2,
98
+ 7.5,
99
+ ]
100
  ]
101
 
102
  with block:
 
118
  with gr.Box():
119
  with gr.Row().style(mobile_collapse=False, equal_height=True):
120
  text = gr.Textbox(
121
+ label="Link or ID for paper",
122
  show_label=False,
123
  max_lines=1,
124
+ placeholder="Give arXiv the link or ID for the paper",
125
  ).style(
126
  border=(True, False, True, True),
127
  rounded=(True, False, False, True),
 
131
  margin=False,
132
  rounded=(False, True, True, False),
133
  )
134
+ poke_type = gr.Radio(choices=type_choices, value="None", label="Pokemon Type")
135
+
136
+ prompt_ideas = gr.CheckboxGroup(choices=["as a bird",
137
+ "with four legs",
138
+ "with wings",
139
+ "as a koala",
140
+ "with a beak",
141
+ "looking like a llama"],
142
+ label="Additional prompt ideas")
143
+
144
+ prompt_box = gr.Textbox(placeholder="Your prompt appears here", interactive=True, label="Prompt")
145
 
146
  gallery = gr.Gallery(
147
  label="Generated images", show_label=False, elem_id="gallery"
 
159
  ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False)
160
  ex.dataset.headers = [""]
161
 
162
+ def resolve_poke_type(pok_type: str):
163
+ if pok_type == "None":
164
+ return ""
165
+ elif pok_type == "Random":
166
+ idx = random.randint(0,len(pokemon_types)-1)
167
+ return pokemon_types[idx]
168
+ else:
169
+ return pok_type
170
+
171
+ def update_prompt_link(new_link: str, pok_type: str, prompt_ideas: list[str]):
172
+ global paper_name
173
+ paper_name = get_paper_name(new_link)
174
+ pok_type = resolve_poke_type(pok_type)
175
+
176
+ prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
177
+
178
+ return build_prompt_text(paper_name, pok_type, prompt_ideas)
179
+
180
+ def update_prompt_type(paper_link: str, pok_type: str, prompt_ideas: list[str]):
181
+ global paper_name
182
+ if paper_name is None:
183
+ paper_name = get_paper_name(paper_link)
184
+
185
+ pok_type = resolve_poke_type(pok_type)
186
+
187
+ return build_prompt_text(paper_name, pok_type, prompt_ideas)
188
+
189
+ def build_prompt_text(paper_name, pok_type, add_ideas):
190
+ prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
191
+ prompt_text = f"""{prompt_text} {" ".join(add_ideas)}"""
192
+ return prompt_text
193
+
194
+ text.change(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
195
+ text.submit(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
196
+
197
+ poke_type.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
198
+ prompt_ideas.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
199
+
200
+
201
+ btn.click(infer, inputs=[prompt_box, samples, steps, scale], outputs=gallery)
202
  gr.HTML(
203
  """
204
  <div class="footer" style="text-align: center; max-width: 650px; margin: 50px auto;">