samnji commited on
Commit
c37abe8
·
1 Parent(s): eb2cd48
Files changed (2) hide show
  1. __pycache__/app.cpython-38.pyc +0 -0
  2. app.py +44 -52
__pycache__/app.cpython-38.pyc ADDED
Binary file (2.93 kB). View file
 
app.py CHANGED
@@ -1,12 +1,9 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  from PIL import Image
3
- from torchvision import transforms
4
  from transformers import CLIPProcessor, CLIPModel
5
  from datasets import load_dataset
6
  import torch
7
- import io
8
-
9
- app = FastAPI()
10
 
11
  # Load the pre-trained CLIP model and its tokenizer
12
  model_name = "openai/clip-vit-base-patch32"
@@ -17,59 +14,54 @@ model = CLIPModel.from_pretrained(model_name)
17
  dataset = load_dataset("ashraq/fashion-product-images-small")
18
  deepfashion_database = dataset["train"]
19
 
 
20
  def preprocess_image(image):
21
- pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
22
- preprocess = transforms.Compose([
23
- transforms.Resize(224),
24
- transforms.CenterCrop(224),
25
- transforms.ToTensor(),
26
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
  ])
28
- return preprocess(pil_image).unsqueeze(0)
29
-
30
- def encode_text(text):
31
- inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
32
- return inputs
33
-
34
- def encode_image(image):
35
- inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
36
- return inputs
37
 
38
- def calculate_similarities(query_image, query_text):
39
- query_image_features = model.get_image_features(query_image)
40
- query_text_features = model.get_text_features(query_text)
41
-
42
- similarities = []
43
- for product in deepfashion_database:
44
- product_image_features = torch.Tensor(product["image_features"])
45
- product_text_features = torch.Tensor(product["text_features"])
46
 
47
- image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
48
- text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
49
-
50
- similarity_score = image_similarity * text_similarity
51
- similarities.append(similarity_score)
52
 
53
- return similarities
 
 
 
 
 
54
 
55
- def initial_query(image, text):
56
- query_image = encode_image(image)
57
- query_text = encode_text(text)
58
 
59
- similarities = calculate_similarities(query_image, query_text)
60
- sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
61
- top_3_indices = sorted_indices[:3]
 
 
 
62
 
63
- top_3_products = [deepfashion_database[i] for i in top_3_indices]
64
- return top_3_products
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- @app.post("/initial_query/")
67
- async def api_initial_query(text: str, image: UploadFile = File(None)):
68
- if image is not None:
69
- image_content = await image.read()
70
- image = Image.open(io.BytesIO(image_content))
71
- image = preprocess_image(image)
72
- else:
73
- image = None
74
- top_3_products = initial_query(image, text)
75
- return {"top_3_products": top_3_products}
 
1
+ import gradio as gr
2
  from PIL import Image
3
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
4
  from transformers import CLIPProcessor, CLIPModel
5
  from datasets import load_dataset
6
  import torch
 
 
 
7
 
8
  # Load the pre-trained CLIP model and its tokenizer
9
  model_name = "openai/clip-vit-base-patch32"
 
14
  dataset = load_dataset("ashraq/fashion-product-images-small")
15
  deepfashion_database = dataset["train"]
16
 
17
+ # Define the preprocessing function for images
18
  def preprocess_image(image):
19
+ preprocess = Compose([
20
+ Resize(256, interpolation=Image.BICUBIC),
21
+ CenterCrop(224),
22
+ ToTensor(),
23
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
 
24
  ])
25
+ return preprocess(image).unsqueeze(0)
 
 
 
 
 
 
 
 
26
 
27
+ # Modify initial_query function to return the same input the user gives 3 times
28
+ def initial_query(image, text):
29
+ input_product = {"description": text, "image_path": None}
30
+ top_3_products = [(input_product, 1), (input_product, 1), (input_product, 1)]
31
+ return top_3_products
 
 
 
32
 
33
+ # Keep the rest of the code unchanged
 
 
 
 
34
 
35
+ def generate_output_html(products):
36
+ html_output = "<ol>"
37
+ for product in products:
38
+ html_output += f'<li>{product[0]["description"]}</li>'
39
+ html_output += "</ol>"
40
+ return html_output
41
 
42
+ def initial_query_wrapper(image, text):
43
+ top_3_products = initial_query(image, text)
44
+ return generate_output_html(top_3_products),
45
 
46
+ def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None):
47
+ if image is not None or text is not None:
48
+ top_3_products = initial_query(image, text)
49
+ return generate_output_html(top_3_products),
50
+ else:
51
+ return "",
52
 
53
+ iface = gr.Interface(
54
+ fn=product_search_wrapper,
55
+ inputs=[
56
+ gr.inputs.Image(optional=True),
57
+ gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
58
+ ],
59
+ outputs=[
60
+ gr.outputs.HTML(label="Results")
61
+ ],
62
+ title="Product Search",
63
+ description="Find the best matching products using images and text queries.",
64
+ layout="vertical"
65
+ )
66
 
67
+ iface.launch()