|
import chromadb |
|
from chromadb.config import Settings |
|
import torchvision.models as models |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import logging |
|
import streamlit as st |
|
import requests |
|
import json |
|
import uuid |
|
import os |
|
|
|
try: |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@st.cache_resource |
|
def load_mobilenet_model(): |
|
device = 'cpu' |
|
model = models.mobilenet_v3_small(pretrained=False) |
|
model.classifier[3] = torch.nn.Linear(1024, 768) |
|
model.load_state_dict(torch.load( |
|
'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device)) |
|
model.eval().to(device) |
|
return model |
|
|
|
@st.cache_resource |
|
def load_chromadb(): |
|
chroma_client = chromadb.PersistentClient( |
|
path='data', settings=Settings(anonymized_telemetry=False)) |
|
collection = chroma_client.get_collection(name='images') |
|
return collection |
|
|
|
model = load_mobilenet_model() |
|
logger.info("MobileNet loaded") |
|
collection = load_chromadb() |
|
logger.info("ChromaDB loaded") |
|
logger.info( |
|
f"Connected to ChromaDB collection images with {collection.count()} items") |
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ |
|
0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def get_image_embedding(image): |
|
if isinstance(image, str): |
|
img = Image.open(image).convert('RGB') |
|
else: |
|
img = Image.open(image).convert('RGB') |
|
input_tensor = preprocess(img).unsqueeze(0).to('cpu') |
|
with torch.no_grad(): |
|
student_embedding = model(input_tensor) |
|
|
|
return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist() |
|
|
|
def save_image(image_file): |
|
unique_filename = f"{image_file.name}" |
|
save_path = os.path.join('images', unique_filename) |
|
with open(save_path, "wb") as f: |
|
f.write(image_file.getbuffer()) |
|
return save_path |
|
|
|
def resize_image(image_path, size=(224, 224)): |
|
if isinstance(image_path, str): |
|
img = Image.open(image_path).convert("RGB") |
|
else: |
|
|
|
img = Image.open(image_path).convert("RGB") |
|
img_resized = img.resize(size, Image.LANCZOS) |
|
return img_resized |
|
|
|
st.sidebar.header("Upload Images") |
|
image_files = st.sidebar.file_uploader( |
|
"Upload images", type=["png", "jpg", "jpeg"], accept_multiple_files=True) |
|
num_images = st.sidebar.slider( |
|
"Number of results to return", min_value=1, max_value=10, value=3) |
|
|
|
if image_files: |
|
st.sidebar.subheader( |
|
"Add Images to collection") |
|
if st.sidebar.button("Add uploaded images"): |
|
for idx, image_file in enumerate(image_files): |
|
image_embedding = get_image_embedding(image_file) |
|
saved_path = save_image(image_file) |
|
unique_id = str(uuid.uuid4()) |
|
metadata = { |
|
'path': f'images/{image_file.name}', "type": "photo" |
|
} |
|
collection.add( |
|
embeddings=[image_embedding], |
|
ids=[unique_id], |
|
metadatas=[metadata] |
|
) |
|
st.sidebar.success( |
|
f"Image {image_file.name} added to the collection") |
|
|
|
st.title('Image Search Using Text') |
|
st.write( |
|
"The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).") |
|
st.write('Enter the text to search for images with matching description') |
|
text_input = st.text_input("Description", "Road") |
|
if st.button("Search"): |
|
if text_input.strip(): |
|
params = {'text': text_input} |
|
response = requests.get( |
|
'https://ashish-001-text-embedding-api.hf.space/embedding', params=params) |
|
if response.status_code == 200: |
|
logger.info("Embedding returned by API successfully") |
|
data = json.loads(response.content) |
|
embedding = data['embedding'] |
|
results = collection.query( |
|
query_embeddings=[embedding], |
|
n_results=num_images |
|
) |
|
images = [results['metadatas'][0][i]['path'] |
|
for i in range(len(results['metadatas'][0]))] |
|
distances = [results['distances'][0][i] |
|
for i in range(len(results['metadatas'][0]))] |
|
if images: |
|
cols_per_row = 3 |
|
rows = (len(images)+cols_per_row-1)//cols_per_row |
|
for row in range(rows): |
|
cols = st.columns(cols_per_row) |
|
for col_idx, col in enumerate(cols): |
|
img_idx = row*cols_per_row+col_idx |
|
if img_idx < len(images): |
|
resized_img = resize_image( |
|
images[img_idx], size=(224, 224)) |
|
col.image(resized_img, |
|
caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True) |
|
else: |
|
st.write("No image found") |
|
else: |
|
st.write("Please try again later") |
|
logger.info(f"status code {response.status_code} returned") |
|
else: |
|
st.write("Please enter text in the text area") |
|
|
|
except Exception as e: |
|
logger.info(f"Exception occured: {e}") |
|
|