File size: 5,797 Bytes
5e4b1fa 50fa949 5e4b1fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import chromadb
from chromadb.config import Settings
import torchvision.models as models
import torch
from torchvision import transforms
from PIL import Image
import logging
import streamlit as st
import requests
import json
import uuid
import os
try:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@st.cache_resource
def load_mobilenet_model():
device = 'cpu'
model = models.mobilenet_v3_small(pretrained=False)
model.classifier[3] = torch.nn.Linear(1024, 768)
model.load_state_dict(torch.load(
'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device))
model.eval().to(device)
return model
@st.cache_resource
def load_chromadb():
chroma_client = chromadb.PersistentClient(
path='data', settings=Settings(anonymized_telemetry=False))
collection = chroma_client.get_collection(name='images')
return collection
model = load_mobilenet_model()
logger.info("MobileNet loaded")
collection = load_chromadb()
logger.info("ChromaDB loaded")
logger.info(
f"Connected to ChromaDB collection images with {collection.count()} items")
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225])
])
def get_image_embedding(image):
if isinstance(image, str):
img = Image.open(image).convert('RGB')
else:
img = Image.open(image).convert('RGB')
input_tensor = preprocess(img).unsqueeze(0).to('cpu')
with torch.no_grad():
student_embedding = model(input_tensor)
return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist()
def save_image(image_file):
unique_filename = f"{image_file.name}"
save_path = os.path.join('images', unique_filename)
with open(save_path, "wb") as f:
f.write(image_file.getbuffer())
return save_path
def resize_image(image_path, size=(224, 224)):
if isinstance(image_path, str):
img = Image.open(image_path).convert("RGB")
else:
# Handle uploaded file
img = Image.open(image_path).convert("RGB")
img_resized = img.resize(size, Image.LANCZOS) # High-quality resizing
return img_resized
st.sidebar.header("Upload Images")
image_files = st.sidebar.file_uploader(
"Upload images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
num_images = st.sidebar.slider(
"Number of results to return", min_value=1, max_value=10, value=3)
if image_files:
st.sidebar.subheader(
"Add Images to collection")
if st.sidebar.button("Add uploaded images"):
for idx, image_file in enumerate(image_files):
image_embedding = get_image_embedding(image_file)
saved_path = save_image(image_file)
unique_id = str(uuid.uuid4())
metadata = {
'path': f'images/{image_file.name}', "type": "photo"
}
collection.add(
embeddings=[image_embedding],
ids=[unique_id],
metadatas=[metadata]
)
st.sidebar.success(
f"Image {image_file.name} added to the collection")
st.title('Image Search Using Text')
st.write(
"The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).")
st.write('Enter the text to search for images with matching description')
text_input = st.text_input("Description", "Road")
if st.button("Search"):
if text_input.strip():
params = {'text': text_input}
response = requests.get(
'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
if response.status_code == 200:
logger.info("Embedding returned by API successfully")
data = json.loads(response.content)
embedding = data['embedding']
results = collection.query(
query_embeddings=[embedding],
n_results=num_images
)
images = [results['metadatas'][0][i]['path']
for i in range(len(results['metadatas'][0]))]
distances = [results['distances'][0][i]
for i in range(len(results['metadatas'][0]))]
if images:
cols_per_row = 3
rows = (len(images)+cols_per_row-1)//cols_per_row
for row in range(rows):
cols = st.columns(cols_per_row)
for col_idx, col in enumerate(cols):
img_idx = row*cols_per_row+col_idx
if img_idx < len(images):
resized_img = resize_image(
images[img_idx], size=(224, 224))
col.image(resized_img,
caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True)
else:
st.write("No image found")
else:
st.write("Please try again later")
logger.info(f"status code {response.status_code} returned")
else:
st.write("Please enter text in the text area")
except Exception as e:
logger.info(f"Exception occured: {e}")
|