samnji commited on
Commit
eb2cd48
Β·
1 Parent(s): 997704d
Files changed (3) hide show
  1. app.py +1 -18
  2. appv2.py +92 -0
  3. requirements.txt +4 -0
app.py CHANGED
@@ -63,23 +63,6 @@ def initial_query(image, text):
63
  top_3_products = [deepfashion_database[i] for i in top_3_indices]
64
  return top_3_products
65
 
66
- def send_message(txt, btn):
67
- if btn is not None:
68
- image = Image.open(btn)
69
- image = preprocess_image(image)
70
- else:
71
- image = None
72
- top_3_products = initial_query(image, txt)
73
- output_html = generate_output_html(top_3_products)
74
- chatbot.append_message("You", txt)
75
- chatbot.append_message("AI", output_html)
76
-
77
-
78
- chatbot = gr.Chatbot([]).style(height=750)
79
- txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False)
80
- btn = gr.UploadButton("πŸ“", file_types=["image", "video", "audio"])
81
-
82
- gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()
83
  @app.post("/initial_query/")
84
  async def api_initial_query(text: str, image: UploadFile = File(None)):
85
  if image is not None:
@@ -89,4 +72,4 @@ async def api_initial_query(text: str, image: UploadFile = File(None)):
89
  else:
90
  image = None
91
  top_3_products = initial_query(image, text)
92
- return {"top_3_products": top_3_products}
 
63
  top_3_products = [deepfashion_database[i] for i in top_3_indices]
64
  return top_3_products
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  @app.post("/initial_query/")
67
  async def api_initial_query(text: str, image: UploadFile = File(None)):
68
  if image is not None:
 
72
  else:
73
  image = None
74
  top_3_products = initial_query(image, text)
75
+ return {"top_3_products": top_3_products}
appv2.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from datasets import load_dataset
6
+ import torch
7
+ import io
8
+
9
+ app = FastAPI()
10
+
11
+ # Load the pre-trained CLIP model and its tokenizer
12
+ model_name = "openai/clip-vit-base-patch32"
13
+ processor = CLIPProcessor.from_pretrained(model_name)
14
+ model = CLIPModel.from_pretrained(model_name)
15
+
16
+ # Load the fashion product images dataset from Hugging Face
17
+ dataset = load_dataset("ashraq/fashion-product-images-small")
18
+ deepfashion_database = dataset["train"]
19
+
20
+ def preprocess_image(image):
21
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
22
+ preprocess = transforms.Compose([
23
+ transforms.Resize(224),
24
+ transforms.CenterCrop(224),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
+ ])
28
+ return preprocess(pil_image).unsqueeze(0)
29
+
30
+ def encode_text(text):
31
+ inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
32
+ return inputs
33
+
34
+ def encode_image(image):
35
+ inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
36
+ return inputs
37
+
38
+ def calculate_similarities(query_image, query_text):
39
+ query_image_features = model.get_image_features(query_image)
40
+ query_text_features = model.get_text_features(query_text)
41
+
42
+ similarities = []
43
+ for product in deepfashion_database:
44
+ product_image_features = torch.Tensor(product["image_features"])
45
+ product_text_features = torch.Tensor(product["text_features"])
46
+
47
+ image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
48
+ text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
49
+
50
+ similarity_score = image_similarity * text_similarity
51
+ similarities.append(similarity_score)
52
+
53
+ return similarities
54
+
55
+ def initial_query(image, text):
56
+ query_image = encode_image(image)
57
+ query_text = encode_text(text)
58
+
59
+ similarities = calculate_similarities(query_image, query_text)
60
+ sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
61
+ top_3_indices = sorted_indices[:3]
62
+
63
+ top_3_products = [deepfashion_database[i] for i in top_3_indices]
64
+ return top_3_products
65
+
66
+ def send_message(txt, btn):
67
+ if btn is not None:
68
+ image = Image.open(btn)
69
+ image = preprocess_image(image)
70
+ else:
71
+ image = None
72
+ top_3_products = initial_query(image, txt)
73
+ output_html = generate_output_html(top_3_products)
74
+ chatbot.append_message("You", txt)
75
+ chatbot.append_message("AI", output_html)
76
+
77
+
78
+ chatbot = gr.Chatbot([]).style(height=750)
79
+ txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False)
80
+ btn = gr.UploadButton("πŸ“", file_types=["image", "video", "audio"])
81
+
82
+ gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch()
83
+ @app.post("/initial_query/")
84
+ async def api_initial_query(text: str, image: UploadFile = File(None)):
85
+ if image is not None:
86
+ image_content = await image.read()
87
+ image = Image.open(io.BytesIO(image_content))
88
+ image = preprocess_image(image)
89
+ else:
90
+ image = None
91
+ top_3_products = initial_query(image, text)
92
+ return {"top_3_products": top_3_products}
requirements.txt CHANGED
@@ -3,3 +3,7 @@ torchvision
3
  Pillow
4
  gradio
5
  transformers
 
 
 
 
 
3
  Pillow
4
  gradio
5
  transformers
6
+ fastapi
7
+ uvicorn
8
+ fastapi
9
+ datasets