ashish-001 commited on
Commit
5e4b1fa
·
verified ·
1 Parent(s): cf4dabd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -145
app.py CHANGED
@@ -1,145 +1,145 @@
1
- import chromadb
2
- from chromadb.config import Settings
3
- import torchvision.models as models
4
- import torch
5
- from torchvision import transforms
6
- from PIL import Image
7
- import logging
8
- import streamlit as st
9
- import requests
10
- import json
11
- import uuid
12
- import os
13
-
14
- try:
15
-
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- @st.cache_resource
20
- def load_mobilenet_model():
21
- device = 'cpu'
22
- model = models.mobilenet_v3_small(pretrained=False)
23
- model.classifier[3] = torch.nn.Linear(1024, 768)
24
- model.load_state_dict(torch.load(
25
- 'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device))
26
- model.eval().to(device)
27
- return model
28
-
29
- @st.cache_resource
30
- def load_chromadb():
31
- chroma_client = chromadb.PersistentClient(
32
- path='data', settings=Settings(anonymized_telemetry=False))
33
- collection = chroma_client.get_collection(name='images')
34
- return collection
35
-
36
- model = load_mobilenet_model()
37
- logger.info("MobileNet loaded")
38
- collection = load_chromadb()
39
- logger.info("ChromaDB loaded")
40
- logger.info(
41
- f"Connected to ChromaDB collection images with {collection.count()} items")
42
-
43
- preprocess = transforms.Compose([
44
- transforms.Resize((224, 224)),
45
- transforms.ToTensor(),
46
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
47
- 0.229, 0.224, 0.225])
48
- ])
49
-
50
- def get_image_embedding(image):
51
- if isinstance(image, str):
52
- img = Image.open(image).convert('RGB')
53
- else:
54
- img = Image.open(image).convert('RGB')
55
- input_tensor = preprocess(img).unsqueeze(0).to('cpu')
56
- with torch.no_grad():
57
- student_embedding = model(input_tensor)
58
-
59
- return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist()
60
-
61
- def save_image(image_file):
62
- unique_filename = f"{image_file.name}"
63
- save_path = os.path.join('images', unique_filename)
64
- with open(save_path, "wb") as f:
65
- f.write(image_file.getbuffer())
66
- return save_path
67
-
68
- def resize_image(image_path, size=(224, 224)):
69
- if isinstance(image_path, str):
70
- img = Image.open(image_path).convert("RGB")
71
- else:
72
- # Handle uploaded file
73
- img = Image.open(image_path).convert("RGB")
74
- img_resized = img.resize(size, Image.LANCZOS) # High-quality resizing
75
- return img_resized
76
-
77
- st.sidebar.header("Upload Images")
78
- image_files = st.sidebar.file_uploader(
79
- "Upload images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
80
- num_images = st.sidebar.slider(
81
- "Number of results to return", min_value=1, max_value=10, value=3)
82
-
83
- if image_files:
84
- st.sidebar.subheader(
85
- "Add Images to collection")
86
- if st.sidebar.button("Add uploaded images"):
87
- for idx, image_file in enumerate(image_files):
88
- image_embedding = get_image_embedding(image_file)
89
- saved_path = save_image(image_file)
90
- unique_id = str(uuid.uuid4())
91
- metadata = {
92
- 'path': f'images/{image_file.name}', "type": "photo"
93
- }
94
- collection.add(
95
- embeddings=[image_embedding],
96
- ids=[unique_id],
97
- metadatas=[metadata]
98
- )
99
- st.sidebar.success(
100
- f"Image {image_file.name} added to the collection")
101
-
102
- st.title('Image Search Using Text')
103
- st.write(
104
- "The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).")
105
- st.write('Enter the text to search for images with matching description')
106
- text_input = st.text_input("Description", "Playground")
107
- if st.button("Search"):
108
- if text_input.strip():
109
- params = {'text': text_input}
110
- response = requests.get(
111
- 'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
112
- if response.status_code == 200:
113
- logger.info("Embedding returned by API successfully")
114
- data = json.loads(response.content)
115
- embedding = data['embedding']
116
- results = collection.query(
117
- query_embeddings=[embedding],
118
- n_results=num_images
119
- )
120
- images = [results['metadatas'][0][i]['path']
121
- for i in range(len(results['metadatas'][0]))]
122
- distances = [results['distances'][0][i]
123
- for i in range(len(results['metadatas'][0]))]
124
- if images:
125
- cols_per_row = 3
126
- rows = (len(images)+cols_per_row-1)//cols_per_row
127
- for row in range(rows):
128
- cols = st.columns(cols_per_row)
129
- for col_idx, col in enumerate(cols):
130
- img_idx = row*cols_per_row+col_idx
131
- if img_idx < len(images):
132
- resized_img = resize_image(
133
- images[img_idx], size=(224, 224))
134
- col.image(resized_img,
135
- caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True)
136
- else:
137
- st.write("No image found")
138
- else:
139
- st.write("Please try again later")
140
- logger.info(f"status code {response.status_code} returned")
141
- else:
142
- st.write("Please enter text in the text area")
143
-
144
- except Exception as e:
145
- logger.info(f"Exception occured: {e}")
 
1
+ import chromadb
2
+ from chromadb.config import Settings
3
+ import torchvision.models as models
4
+ import torch
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import logging
8
+ import streamlit as st
9
+ import requests
10
+ import json
11
+ import uuid
12
+ import os
13
+
14
+ try:
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ @st.cache_resource
20
+ def load_mobilenet_model():
21
+ device = 'cpu'
22
+ model = models.mobilenet_v3_small(pretrained=False)
23
+ model.classifier[3] = torch.nn.Linear(1024, 768)
24
+ model.load_state_dict(torch.load(
25
+ 'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device))
26
+ model.eval().to(device)
27
+ return model
28
+
29
+ @st.cache_resource
30
+ def load_chromadb():
31
+ chroma_client = chromadb.PersistentClient(
32
+ path='data', settings=Settings(anonymized_telemetry=False))
33
+ collection = chroma_client.get_collection(name='images')
34
+ return collection
35
+
36
+ model = load_mobilenet_model()
37
+ logger.info("MobileNet loaded")
38
+ collection = load_chromadb()
39
+ logger.info("ChromaDB loaded")
40
+ logger.info(
41
+ f"Connected to ChromaDB collection images with {collection.count()} items")
42
+
43
+ preprocess = transforms.Compose([
44
+ transforms.Resize((224, 224)),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
47
+ 0.229, 0.224, 0.225])
48
+ ])
49
+
50
+ def get_image_embedding(image):
51
+ if isinstance(image, str):
52
+ img = Image.open(image).convert('RGB')
53
+ else:
54
+ img = Image.open(image).convert('RGB')
55
+ input_tensor = preprocess(img).unsqueeze(0).to('cpu')
56
+ with torch.no_grad():
57
+ student_embedding = model(input_tensor)
58
+
59
+ return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist()
60
+
61
+ def save_image(image_file):
62
+ unique_filename = f"{image_file.name}"
63
+ save_path = os.path.join('images', unique_filename)
64
+ with open(save_path, "wb") as f:
65
+ f.write(image_file.getbuffer())
66
+ return save_path
67
+
68
+ def resize_image(image_path, size=(224, 224)):
69
+ if isinstance(image_path, str):
70
+ img = Image.open(image_path).convert("RGB")
71
+ else:
72
+ # Handle uploaded file
73
+ img = Image.open(image_path).convert("RGB")
74
+ img_resized = img.resize(size, Image.LANCZOS) # High-quality resizing
75
+ return img_resized
76
+
77
+ st.sidebar.header("Upload Images")
78
+ image_files = st.sidebar.file_uploader(
79
+ "Upload images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
80
+ num_images = st.sidebar.slider(
81
+ "Number of results to return", min_value=1, max_value=10, value=3)
82
+
83
+ if image_files:
84
+ st.sidebar.subheader(
85
+ "Add Images to collection")
86
+ if st.sidebar.button("Add uploaded images"):
87
+ for idx, image_file in enumerate(image_files):
88
+ image_embedding = get_image_embedding(image_file)
89
+ saved_path = save_image(image_file)
90
+ unique_id = str(uuid.uuid4())
91
+ metadata = {
92
+ 'path': f'images/{image_file.name}', "type": "photo"
93
+ }
94
+ collection.add(
95
+ embeddings=[image_embedding],
96
+ ids=[unique_id],
97
+ metadatas=[metadata]
98
+ )
99
+ st.sidebar.success(
100
+ f"Image {image_file.name} added to the collection")
101
+
102
+ st.title('Image Search Using Text')
103
+ st.write(
104
+ "The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).")
105
+ st.write('Enter the text to search for images with matching description')
106
+ text_input = st.text_input("Description", "Road")
107
+ if st.button("Search"):
108
+ if text_input.strip():
109
+ params = {'text': text_input}
110
+ response = requests.get(
111
+ 'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
112
+ if response.status_code == 200:
113
+ logger.info("Embedding returned by API successfully")
114
+ data = json.loads(response.content)
115
+ embedding = data['embedding']
116
+ results = collection.query(
117
+ query_embeddings=[embedding],
118
+ n_results=num_images
119
+ )
120
+ images = [results['metadatas'][0][i]['path']
121
+ for i in range(len(results['metadatas'][0]))]
122
+ distances = [results['distances'][0][i]
123
+ for i in range(len(results['metadatas'][0]))]
124
+ if images:
125
+ cols_per_row = 3
126
+ rows = (len(images)+cols_per_row-1)//cols_per_row
127
+ for row in range(rows):
128
+ cols = st.columns(cols_per_row)
129
+ for col_idx, col in enumerate(cols):
130
+ img_idx = row*cols_per_row+col_idx
131
+ if img_idx < len(images):
132
+ resized_img = resize_image(
133
+ images[img_idx], size=(224, 224))
134
+ col.image(resized_img,
135
+ caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True)
136
+ else:
137
+ st.write("No image found")
138
+ else:
139
+ st.write("Please try again later")
140
+ logger.info(f"status code {response.status_code} returned")
141
+ else:
142
+ st.write("Please enter text in the text area")
143
+
144
+ except Exception as e:
145
+ logger.info(f"Exception occured: {e}")