winfred2027 commited on
Commit
39570c3
·
verified ·
1 Parent(s): e0eb046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -77
app.py CHANGED
@@ -7,7 +7,7 @@ import openshape
7
  import transformers
8
  from PIL import Image
9
  from huggingface_hub import HfFolder, snapshot_download
10
- from demo_support import retrieval, utils, lvis
11
  from collections import OrderedDict
12
 
13
  @st.cache_resource
@@ -23,6 +23,7 @@ def load_openclip():
23
  clip_model.cuda()
24
  return clip_model, clip_prep
25
 
 
26
  @st.cache_resource
27
  def load_openshape(name, to_cpu=False):
28
  pce = openshape.load_pc_encoder(name)
@@ -30,6 +31,7 @@ def load_openshape(name, to_cpu=False):
30
  pce = pce.cpu()
31
  return pce
32
 
 
33
  def retrieval_filter_expand():
34
  sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth')
35
  tag = ""
@@ -47,6 +49,7 @@ def retrieval_filter_expand():
47
  )
48
  return sim_th, filter_fn
49
 
 
50
  def retrieval_results(results):
51
  st.caption("Click the link to view the 3D shape")
52
  for i in range(len(results) // 4):
@@ -65,76 +68,6 @@ def retrieval_results(results):
65
  st.markdown(f"[{quote_name}]({ext_link})")
66
 
67
 
68
-
69
- def demo_captioning():
70
- with st.form("capform"):
71
- cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')
72
-
73
- def demo_pc2img():
74
- with st.form("sdform"):
75
- prompt = st.text_input("Prompt (Optional)", key='sdtprompt')
76
-
77
- def demo_retrieval():
78
- with tab_pc:
79
- with st.form("rpcform"):
80
- k = st.slider("Number of items to retrieve", 1, 100, 16, key='rpc')
81
- load_data = utils.input_3d_shape('rpcinput')
82
- sim_th, filter_fn = retrieval_filter_expand('pc')
83
- if st.form_submit_button("Retrieve with Point Cloud"):
84
- prog.progress(0.49, "Computing Embeddings")
85
- pc = load_data(prog)
86
- col2 = utils.render_pc(pc)
87
- ref_dev = next(model_g14.parameters()).device
88
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
89
-
90
- sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
91
- argsort = torch.argsort(sim, descending=True)
92
- pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
93
- with col2:
94
- for i, (cat, sim) in zip(range(5), pred.items()):
95
- st.text(cat)
96
- st.caption("Similarity %.4f" % sim)
97
-
98
- prog.progress(0.7, "Running Retrieval")
99
- retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
100
-
101
- prog.progress(1.0, "Idle")
102
-
103
- with tab_img:
104
- with st.form("rimgform"):
105
- k = st.slider("Number of items to retrieve", 1, 100, 16, key='rimage')
106
- pic = st.file_uploader("Upload an Image", key='rimageinput')
107
- sim_th, filter_fn = retrieval_filter_expand('image')
108
- if st.form_submit_button("Retrieve with Image"):
109
- prog.progress(0.49, "Computing Embeddings")
110
- img = Image.open(pic)
111
- st.image(img)
112
- device = clip_model.device
113
- tn = clip_prep(images=[img], return_tensors="pt").to(device)
114
- enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
115
-
116
- prog.progress(0.7, "Running Retrieval")
117
- retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
118
-
119
- prog.progress(1.0, "Idle")
120
-
121
- with tab_text:
122
- with st.form("rtextform"):
123
- k = st.slider("Number of items to retrieve", 1, 100, 16, key='rtext')
124
- text = st.text_input("Input Text", key='rtextinput')
125
- sim_th, filter_fn = retrieval_filter_expand('text')
126
- if st.form_submit_button("Retrieve with Text"):
127
- prog.progress(0.49, "Computing Embeddings")
128
- device = clip_model.device
129
- tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
130
- enc = clip_model.get_text_features(**tn).float().cpu()
131
-
132
- prog.progress(0.7, "Running Retrieval")
133
- retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
134
-
135
- prog.progress(1.0, "Idle")
136
-
137
-
138
  def classification_lvis(load_data):
139
  pc = load_data(prog)
140
  col2 = utils.render_pc(pc)
@@ -151,6 +84,7 @@ def classification_lvis(load_data):
151
  st.caption("Similarity %.4f" % sim)
152
  prog.progress(1.0, "Idle")
153
 
 
154
  def classification_custom(load_data, cats):
155
  pc = load_data(prog)
156
  col2 = utils.render_pc(pc)
@@ -174,7 +108,7 @@ def classification_custom(load_data, cats):
174
 
175
  def retrieval_pc(load_data, k, sim_th, filter_fn):
176
  pc = load_data(prog)
177
- prog.progress(0.49, "Computing Embeddings")
178
  col2 = utils.render_pc(pc)
179
  ref_dev = next(model_g14.parameters()).device
180
  enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
@@ -192,9 +126,10 @@ def retrieval_pc(load_data, k, sim_th, filter_fn):
192
 
193
  prog.progress(1.0, "Idle")
194
 
 
195
  def retrieval_img(pic, k, sim_th, filter_fn):
196
  img = Image.open(pic)
197
- prog.progress(0.49, "Computing Embeddings")
198
  st.image(img)
199
  device = clip_model.device
200
  tn = clip_prep(images=[img], return_tensors="pt").to(device)
@@ -205,8 +140,9 @@ def retrieval_img(pic, k, sim_th, filter_fn):
205
 
206
  prog.progress(1.0, "Idle")
207
 
 
208
  def retrieval_text(text, k, sim_th, filter_fn):
209
- prog.progress(0.49, "Computing Embeddings")
210
  device = clip_model.device
211
  tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
212
  enc = clip_model.get_text_features(**tn).float().cpu()
@@ -216,9 +152,10 @@ def retrieval_text(text, k, sim_th, filter_fn):
216
 
217
  prog.progress(1.0, "Idle")
218
 
 
219
  def generation_img(load_data, prompt, noise_scale, cfg_scale, steps):
220
  pc = load_data(prog)
221
- prog.progress(0.49, "Running Generation")
222
  col2 = utils.render_pc(pc)
223
  if torch.cuda.is_available():
224
  with sys.clip_move_lock:
@@ -226,6 +163,27 @@ def generation_img(load_data, prompt, noise_scale, cfg_scale, steps):
226
 
227
  width = 640
228
  height = 640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  try:
231
  f32 = numpy.float32
@@ -294,7 +252,7 @@ try:
294
  if generation_mode == "PointCloud-to-Image":
295
  st.title("Image Generation")
296
  prog = st.progress(0.0, "Idle")
297
- prompt = st.sidebar.text_input("Prompt (Optional)", key='sdtprompt')
298
  noise_scale = st.sidebar.slider('Variation Level', 0, 5, 1)
299
  cfg_scale = st.sidebar.slider('Guidance Scale', 0.0, 30.0, 10.0)
300
  steps = st.sidebar.slider('Diffusion Steps', 8, 50, 25)
@@ -303,8 +261,9 @@ try:
303
  elif generation_mode == "PointCloud-to-Text":
304
  st.title("Text Generation")
305
  prog = st.progress(0.0, "Idle")
 
306
  if st.sidebar.button("submit"):
307
- pc = st.text_input("Input pc", key='rtextinput')
308
 
309
  except Exception:
310
  import traceback
 
7
  import transformers
8
  from PIL import Image
9
  from huggingface_hub import HfFolder, snapshot_download
10
+ from demo_support import retrieval, generation, utils, lvis
11
  from collections import OrderedDict
12
 
13
  @st.cache_resource
 
23
  clip_model.cuda()
24
  return clip_model, clip_prep
25
 
26
+
27
  @st.cache_resource
28
  def load_openshape(name, to_cpu=False):
29
  pce = openshape.load_pc_encoder(name)
 
31
  pce = pce.cpu()
32
  return pce
33
 
34
+
35
  def retrieval_filter_expand():
36
  sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth')
37
  tag = ""
 
49
  )
50
  return sim_th, filter_fn
51
 
52
+
53
  def retrieval_results(results):
54
  st.caption("Click the link to view the 3D shape")
55
  for i in range(len(results) // 4):
 
68
  st.markdown(f"[{quote_name}]({ext_link})")
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def classification_lvis(load_data):
72
  pc = load_data(prog)
73
  col2 = utils.render_pc(pc)
 
84
  st.caption("Similarity %.4f" % sim)
85
  prog.progress(1.0, "Idle")
86
 
87
+
88
  def classification_custom(load_data, cats):
89
  pc = load_data(prog)
90
  col2 = utils.render_pc(pc)
 
108
 
109
  def retrieval_pc(load_data, k, sim_th, filter_fn):
110
  pc = load_data(prog)
111
+ prog.progress(0.5, "Computing Embeddings")
112
  col2 = utils.render_pc(pc)
113
  ref_dev = next(model_g14.parameters()).device
114
  enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
 
126
 
127
  prog.progress(1.0, "Idle")
128
 
129
+
130
  def retrieval_img(pic, k, sim_th, filter_fn):
131
  img = Image.open(pic)
132
+ prog.progress(0.5, "Computing Embeddings")
133
  st.image(img)
134
  device = clip_model.device
135
  tn = clip_prep(images=[img], return_tensors="pt").to(device)
 
140
 
141
  prog.progress(1.0, "Idle")
142
 
143
+
144
  def retrieval_text(text, k, sim_th, filter_fn):
145
+ prog.progress(0.5, "Computing Embeddings")
146
  device = clip_model.device
147
  tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
148
  enc = clip_model.get_text_features(**tn).float().cpu()
 
152
 
153
  prog.progress(1.0, "Idle")
154
 
155
+
156
  def generation_img(load_data, prompt, noise_scale, cfg_scale, steps):
157
  pc = load_data(prog)
158
+ prog.progress(0.5, "Running Generation")
159
  col2 = utils.render_pc(pc)
160
  if torch.cuda.is_available():
161
  with sys.clip_move_lock:
 
163
 
164
  width = 640
165
  height = 640
166
+ img = generation.pc_to_image(
167
+ model_g14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
168
+ lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
169
+ )
170
+ if torch.cuda.is_available():
171
+ with sys.clip_move_lock:
172
+ clip_model.cuda()
173
+ with col2:
174
+ st.image(img)
175
+ prog.progress(1.0, "Idle")
176
+
177
+
178
+ def generation_text(load_data, cond_scale):
179
+ pc = load_data(prog)
180
+ prog.progress(0.5, "Running Generation")
181
+ col2 = utils.render_pc(pc)
182
+
183
+ cap = generation.pc_to_text(model_g14, pc, cond_scale)
184
+ st.text(cap)
185
+ prog.progress(1.0, "Idle")
186
+
187
 
188
  try:
189
  f32 = numpy.float32
 
252
  if generation_mode == "PointCloud-to-Image":
253
  st.title("Image Generation")
254
  prog = st.progress(0.0, "Idle")
255
+ prompt = st.sidebar.text_input("Prompt (Optional)", key='gprompt')
256
  noise_scale = st.sidebar.slider('Variation Level', 0, 5, 1)
257
  cfg_scale = st.sidebar.slider('Guidance Scale', 0.0, 30.0, 10.0)
258
  steps = st.sidebar.slider('Diffusion Steps', 8, 50, 25)
 
261
  elif generation_mode == "PointCloud-to-Text":
262
  st.title("Text Generation")
263
  prog = st.progress(0.0, "Idle")
264
+ cond_scale = st.sidebar.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='gcond')
265
  if st.sidebar.button("submit"):
266
+ generation_text(load_data, cond_scale)
267
 
268
  except Exception:
269
  import traceback