ryaalbr commited on
Commit
2e7f326
·
1 Parent(s): 6eca14e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -21,7 +21,7 @@ orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=
21
  # Load the Unsplash dataset
22
  dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
23
 
24
- height = 256 # height for resizing images
25
 
26
  def predict(image, labels):
27
  inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
@@ -32,7 +32,7 @@ def predict(image, labels):
32
 
33
 
34
  def predict2(image, labels):
35
- image = orig_clip_processor(img).unsqueeze(0).to(device)
36
  text = clip.tokenize(labels).to(device)
37
  with torch.no_grad():
38
  image_features = orig_clip_model.encode_image(image)
@@ -157,6 +157,6 @@ with gr.Blocks() as demo:
157
  desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
158
  search_btn = gr.Button("Find Images").style(full_width=False)
159
  gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
160
- search_btn.click(search,inputs=desc, outputs=gallery)
161
 
162
  demo.launch()
 
21
  # Load the Unsplash dataset
22
  dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
23
 
24
+ height = 512 # height for resizing images
25
 
26
  def predict(image, labels):
27
  inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
 
32
 
33
 
34
  def predict2(image, labels):
35
+ image = orig_clip_processor(image).unsqueeze(0).to(device)
36
  text = clip.tokenize(labels).to(device)
37
  with torch.no_grad():
38
  image_features = orig_clip_model.encode_image(image)
 
157
  desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
158
  search_btn = gr.Button("Find Images").style(full_width=False)
159
  gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
160
+ search_btn.click(search,inputs=desc, outputs=gallery, postprocess=False)
161
 
162
  demo.launch()