Spaces:
Runtime error
Runtime error
app
Browse files- __pycache__/app.cpython-38.pyc +0 -0
- app.py +44 -52
__pycache__/app.cpython-38.pyc
ADDED
Binary file (2.93 kB). View file
|
|
app.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
-
|
2 |
from PIL import Image
|
3 |
-
from torchvision import
|
4 |
from transformers import CLIPProcessor, CLIPModel
|
5 |
from datasets import load_dataset
|
6 |
import torch
|
7 |
-
import io
|
8 |
-
|
9 |
-
app = FastAPI()
|
10 |
|
11 |
# Load the pre-trained CLIP model and its tokenizer
|
12 |
model_name = "openai/clip-vit-base-patch32"
|
@@ -17,59 +14,54 @@ model = CLIPModel.from_pretrained(model_name)
|
|
17 |
dataset = load_dataset("ashraq/fashion-product-images-small")
|
18 |
deepfashion_database = dataset["train"]
|
19 |
|
|
|
20 |
def preprocess_image(image):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
27 |
])
|
28 |
-
return preprocess(
|
29 |
-
|
30 |
-
def encode_text(text):
|
31 |
-
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
|
32 |
-
return inputs
|
33 |
-
|
34 |
-
def encode_image(image):
|
35 |
-
inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
|
36 |
-
return inputs
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
for product in deepfashion_database:
|
44 |
-
product_image_features = torch.Tensor(product["image_features"])
|
45 |
-
product_text_features = torch.Tensor(product["text_features"])
|
46 |
|
47 |
-
|
48 |
-
text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
|
49 |
-
|
50 |
-
similarity_score = image_similarity * text_similarity
|
51 |
-
similarities.append(similarity_score)
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
def
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
async def api_initial_query(text: str, image: UploadFile = File(None)):
|
68 |
-
if image is not None:
|
69 |
-
image_content = await image.read()
|
70 |
-
image = Image.open(io.BytesIO(image_content))
|
71 |
-
image = preprocess_image(image)
|
72 |
-
else:
|
73 |
-
image = None
|
74 |
-
top_3_products = initial_query(image, text)
|
75 |
-
return {"top_3_products": top_3_products}
|
|
|
1 |
+
import gradio as gr
|
2 |
from PIL import Image
|
3 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
4 |
from transformers import CLIPProcessor, CLIPModel
|
5 |
from datasets import load_dataset
|
6 |
import torch
|
|
|
|
|
|
|
7 |
|
8 |
# Load the pre-trained CLIP model and its tokenizer
|
9 |
model_name = "openai/clip-vit-base-patch32"
|
|
|
14 |
dataset = load_dataset("ashraq/fashion-product-images-small")
|
15 |
deepfashion_database = dataset["train"]
|
16 |
|
17 |
+
# Define the preprocessing function for images
|
18 |
def preprocess_image(image):
|
19 |
+
preprocess = Compose([
|
20 |
+
Resize(256, interpolation=Image.BICUBIC),
|
21 |
+
CenterCrop(224),
|
22 |
+
ToTensor(),
|
23 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
|
|
24 |
])
|
25 |
+
return preprocess(image).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
# Modify initial_query function to return the same input the user gives 3 times
|
28 |
+
def initial_query(image, text):
|
29 |
+
input_product = {"description": text, "image_path": None}
|
30 |
+
top_3_products = [(input_product, 1), (input_product, 1), (input_product, 1)]
|
31 |
+
return top_3_products
|
|
|
|
|
|
|
32 |
|
33 |
+
# Keep the rest of the code unchanged
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
def generate_output_html(products):
|
36 |
+
html_output = "<ol>"
|
37 |
+
for product in products:
|
38 |
+
html_output += f'<li>{product[0]["description"]}</li>'
|
39 |
+
html_output += "</ol>"
|
40 |
+
return html_output
|
41 |
|
42 |
+
def initial_query_wrapper(image, text):
|
43 |
+
top_3_products = initial_query(image, text)
|
44 |
+
return generate_output_html(top_3_products),
|
45 |
|
46 |
+
def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None):
|
47 |
+
if image is not None or text is not None:
|
48 |
+
top_3_products = initial_query(image, text)
|
49 |
+
return generate_output_html(top_3_products),
|
50 |
+
else:
|
51 |
+
return "",
|
52 |
|
53 |
+
iface = gr.Interface(
|
54 |
+
fn=product_search_wrapper,
|
55 |
+
inputs=[
|
56 |
+
gr.inputs.Image(optional=True),
|
57 |
+
gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
|
58 |
+
],
|
59 |
+
outputs=[
|
60 |
+
gr.outputs.HTML(label="Results")
|
61 |
+
],
|
62 |
+
title="Product Search",
|
63 |
+
description="Find the best matching products using images and text queries.",
|
64 |
+
layout="vertical"
|
65 |
+
)
|
66 |
|
67 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|