ashish-001's picture
Update app.py
50fa949 verified
import chromadb
from chromadb.config import Settings
import torchvision.models as models
import torch
from torchvision import transforms
from PIL import Image
import logging
import streamlit as st
import requests
import json
import uuid
import os
try:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@st.cache_resource
def load_mobilenet_model():
device = 'cpu'
model = models.mobilenet_v3_small(pretrained=False)
model.classifier[3] = torch.nn.Linear(1024, 768)
model.load_state_dict(torch.load(
'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device))
model.eval().to(device)
return model
@st.cache_resource
def load_chromadb():
chroma_client = chromadb.PersistentClient(
path='data', settings=Settings(anonymized_telemetry=False))
collection = chroma_client.get_collection(name='images')
return collection
model = load_mobilenet_model()
logger.info("MobileNet loaded")
collection = load_chromadb()
logger.info("ChromaDB loaded")
logger.info(
f"Connected to ChromaDB collection images with {collection.count()} items")
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225])
])
def get_image_embedding(image):
if isinstance(image, str):
img = Image.open(image).convert('RGB')
else:
img = Image.open(image).convert('RGB')
input_tensor = preprocess(img).unsqueeze(0).to('cpu')
with torch.no_grad():
student_embedding = model(input_tensor)
return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist()
def save_image(image_file):
unique_filename = f"{image_file.name}"
save_path = os.path.join('images', unique_filename)
with open(save_path, "wb") as f:
f.write(image_file.getbuffer())
return save_path
def resize_image(image_path, size=(224, 224)):
if isinstance(image_path, str):
img = Image.open(image_path).convert("RGB")
else:
# Handle uploaded file
img = Image.open(image_path).convert("RGB")
img_resized = img.resize(size, Image.LANCZOS) # High-quality resizing
return img_resized
st.sidebar.header("Upload Images")
image_files = st.sidebar.file_uploader(
"Upload images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
num_images = st.sidebar.slider(
"Number of results to return", min_value=1, max_value=10, value=3)
if image_files:
st.sidebar.subheader(
"Add Images to collection")
if st.sidebar.button("Add uploaded images"):
for idx, image_file in enumerate(image_files):
image_embedding = get_image_embedding(image_file)
saved_path = save_image(image_file)
unique_id = str(uuid.uuid4())
metadata = {
'path': f'images/{image_file.name}', "type": "photo"
}
collection.add(
embeddings=[image_embedding],
ids=[unique_id],
metadatas=[metadata]
)
st.sidebar.success(
f"Image {image_file.name} added to the collection")
st.title('Image Search Using Text')
st.write(
"The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).")
st.write('Enter the text to search for images with matching description')
text_input = st.text_input("Description", "Road")
if st.button("Search"):
if text_input.strip():
params = {'text': text_input}
response = requests.get(
'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
if response.status_code == 200:
logger.info("Embedding returned by API successfully")
data = json.loads(response.content)
embedding = data['embedding']
results = collection.query(
query_embeddings=[embedding],
n_results=num_images
)
images = [results['metadatas'][0][i]['path']
for i in range(len(results['metadatas'][0]))]
distances = [results['distances'][0][i]
for i in range(len(results['metadatas'][0]))]
if images:
cols_per_row = 3
rows = (len(images)+cols_per_row-1)//cols_per_row
for row in range(rows):
cols = st.columns(cols_per_row)
for col_idx, col in enumerate(cols):
img_idx = row*cols_per_row+col_idx
if img_idx < len(images):
resized_img = resize_image(
images[img_idx], size=(224, 224))
col.image(resized_img,
caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True)
else:
st.write("No image found")
else:
st.write("Please try again later")
logger.info(f"status code {response.status_code} returned")
else:
st.write("Please enter text in the text area")
except Exception as e:
logger.info(f"Exception occured: {e}")