Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=
|
|
21 |
# Load the Unsplash dataset
|
22 |
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
|
23 |
|
24 |
-
height =
|
25 |
|
26 |
def predict(image, labels):
|
27 |
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
|
@@ -32,7 +32,7 @@ def predict(image, labels):
|
|
32 |
|
33 |
|
34 |
def predict2(image, labels):
|
35 |
-
image = orig_clip_processor(
|
36 |
text = clip.tokenize(labels).to(device)
|
37 |
with torch.no_grad():
|
38 |
image_features = orig_clip_model.encode_image(image)
|
@@ -157,6 +157,6 @@ with gr.Blocks() as demo:
|
|
157 |
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
|
158 |
search_btn = gr.Button("Find Images").style(full_width=False)
|
159 |
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
|
160 |
-
search_btn.click(search,inputs=desc, outputs=gallery)
|
161 |
|
162 |
demo.launch()
|
|
|
21 |
# Load the Unsplash dataset
|
22 |
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
|
23 |
|
24 |
+
height = 512 # height for resizing images
|
25 |
|
26 |
def predict(image, labels):
|
27 |
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
|
|
|
32 |
|
33 |
|
34 |
def predict2(image, labels):
|
35 |
+
image = orig_clip_processor(image).unsqueeze(0).to(device)
|
36 |
text = clip.tokenize(labels).to(device)
|
37 |
with torch.no_grad():
|
38 |
image_features = orig_clip_model.encode_image(image)
|
|
|
157 |
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
|
158 |
search_btn = gr.Button("Find Images").style(full_width=False)
|
159 |
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
|
160 |
+
search_btn.click(search,inputs=desc, outputs=gallery, postprocess=False)
|
161 |
|
162 |
demo.launch()
|