Spaces:
Sleeping
Sleeping
import gradio as gr | |
import chromadb | |
from transformers import AutoTokenizer, AutoModel | |
import faiss | |
import numpy as np | |
import torch | |
# Load the pre-trained model and tokenizer | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
# Initialize Chroma client | |
client = chromadb.Client() | |
# Create a Chroma collection | |
collection = client.create_collection(name="tree_images") | |
# Custom dataset of tree descriptions (both decorated and undecorated) | |
content = [ | |
"Tree 1: Decorated with colorful lights and a star on top", | |
"Tree 2: Undecorated, only a bare tree with no lights or ornaments", | |
"Tree 3: Decorated with silver tinsel and baubles", | |
"Tree 4: Undecorated, only green branches", | |
"Tree 5: Decorated with red ribbons and golden ornaments", | |
"Tree 6: Undecorated, a plain pine tree", | |
"Tree 7: Decorated with multicolored lights and silver bells", | |
"Tree 8: Undecorated, just a tree with no decorations", | |
"Tree 9: Decorated with handmade ornaments and garlands", | |
"Tree 10: Undecorated, a simple tree without any adornment", | |
"Tree 11: Decorated with blue and white lights and a snowflake theme", | |
"Tree 12: Undecorated, only branches with no adornments", | |
"Tree 13: Decorated with red and green ornaments and candy canes", | |
"Tree 14: Undecorated, just a tall and natural-looking tree", | |
"Tree 15: Decorated with silver garlands and a star topper", | |
"Tree 16: Undecorated, a bare spruce tree", | |
"Tree 17: Decorated with gold and red ornaments, and a snowman figure", | |
"Tree 18: Undecorated, a simple fir tree with no extras", | |
"Tree 19: Decorated with colorful LED lights and a bow on top", | |
"Tree 20: Undecorated, just a bare tree with no lights or baubles", | |
"Tree 21: Decorated with small fairy lights and a red ribbon", | |
"Tree 22: Undecorated, just a simple tree with green branches", | |
"Tree 23: Decorated with golden stars and white snowflakes", | |
"Tree 24: Undecorated, just a natural green pine tree", | |
"Tree 25: Decorated with pink ornaments and a gold topper", | |
"Tree 26: Undecorated, no decorations, just a plain tree", | |
"Tree 27: Decorated with silver and blue ornaments and a festive ribbon", | |
"Tree 28: Undecorated, just a fresh pine tree", | |
"Tree 29: Decorated with white fairy lights and a Christmas angel on top", | |
"Tree 30: Undecorated, only the green foliage of the tree", | |
"Tree 31: Decorated with bright red ornaments and golden tinsel", | |
"Tree 32: Undecorated, no lights or decorations, just branches", | |
"Tree 33: Decorated with silver tinsel, green ribbons, and a star", | |
"Tree 34: Undecorated, plain with no adornments", | |
"Tree 35: Decorated with red and white ornaments, and a Santa figurine", | |
"Tree 36: Undecorated, just a bare tree", | |
"Tree 37: Decorated with rainbow lights and colorful ornaments", | |
"Tree 38: Undecorated, a simple evergreen tree with no additions", | |
"Tree 39: Decorated with small golden bells and a red bow", | |
"Tree 40: Undecorated, no ornaments or lights", | |
"Tree 41: Decorated with silver baubles, white snowflakes, and a red star", | |
"Tree 42: Undecorated, just a natural tree with no accessories", | |
"Tree 43: Decorated with multicolor ribbons and white angel decorations", | |
"Tree 44: Undecorated, no adornments, just the tree itself", | |
"Tree 45: Decorated with large silver baubles and a golden star", | |
"Tree 46: Undecorated, no decorations, just a green tree", | |
"Tree 47: Decorated with white snowflakes, red ribbons, and a Santa hat", | |
"Tree 48: Undecorated, a bare tree with no lights or ornaments", | |
"Tree 49: Decorated with small lights and star-shaped ornaments", | |
"Tree 50: Undecorated, only the tree, no adornments or lights" | |
] | |
# Function to generate embeddings using the pre-trained model | |
def generate_embeddings(texts): | |
embeddings = [] | |
for text in texts: | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
output = model(**inputs) | |
embeddings.append(output.last_hidden_state.mean(dim=1).squeeze().numpy()) | |
return embeddings | |
# Generate embeddings for the content | |
embeddings = generate_embeddings(content) | |
# Add the embeddings to Chroma | |
for idx, text in enumerate(content): | |
collection.add_documents( | |
documents=[text], # the document (text) itself | |
metadatas=[{"id": idx}], # metadata associated with the document | |
embeddings=[embeddings[idx]] # the corresponding embeddings for the document | |
) | |
# Build FAISS index for efficient retrieval | |
embeddings_np = np.array(embeddings).astype('float32') | |
faiss_index = faiss.IndexFlatL2(embeddings_np.shape[1]) | |
faiss_index.add(embeddings_np) | |
# Define the search function for Gradio interface | |
def search(query): | |
# Generate embedding for the query | |
query_embedding = generate_embeddings([query])[0].reshape(1, -1) | |
# FAISS-based search | |
distances, indices = faiss_index.search(query_embedding, 3) | |
faiss_results = [content[i] for i in indices[0]] | |
# Chroma-based search | |
chroma_results = collection.query(query_embeddings=query_embedding, n_results=3)["documents"] | |
# Return results | |
return "FAISS Results: " + ", ".join(faiss_results) + "\nChroma Results: " + ", ".join(chroma_results) | |
# Create the Gradio interface | |
interface = gr.Interface(fn=search, inputs="text", outputs="text") | |
# Launch the Gradio interface | |
interface.launch() |