RAG_AI_BOT / app.py
willco-afk's picture
Update app.py
9eaaba5 verified
raw
history blame
5.56 kB
import gradio as gr
import chromadb
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import torch
# Load the pre-trained model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Initialize Chroma client
client = chromadb.Client()
# Create a Chroma collection
collection = client.create_collection(name="tree_images")
# Custom dataset of tree descriptions (both decorated and undecorated)
content = [
"Tree 1: Decorated with colorful lights and a star on top",
"Tree 2: Undecorated, only a bare tree with no lights or ornaments",
"Tree 3: Decorated with silver tinsel and baubles",
"Tree 4: Undecorated, only green branches",
"Tree 5: Decorated with red ribbons and golden ornaments",
"Tree 6: Undecorated, a plain pine tree",
"Tree 7: Decorated with multicolored lights and silver bells",
"Tree 8: Undecorated, just a tree with no decorations",
"Tree 9: Decorated with handmade ornaments and garlands",
"Tree 10: Undecorated, a simple tree without any adornment",
"Tree 11: Decorated with blue and white lights and a snowflake theme",
"Tree 12: Undecorated, only branches with no adornments",
"Tree 13: Decorated with red and green ornaments and candy canes",
"Tree 14: Undecorated, just a tall and natural-looking tree",
"Tree 15: Decorated with silver garlands and a star topper",
"Tree 16: Undecorated, a bare spruce tree",
"Tree 17: Decorated with gold and red ornaments, and a snowman figure",
"Tree 18: Undecorated, a simple fir tree with no extras",
"Tree 19: Decorated with colorful LED lights and a bow on top",
"Tree 20: Undecorated, just a bare tree with no lights or baubles",
"Tree 21: Decorated with small fairy lights and a red ribbon",
"Tree 22: Undecorated, just a simple tree with green branches",
"Tree 23: Decorated with golden stars and white snowflakes",
"Tree 24: Undecorated, just a natural green pine tree",
"Tree 25: Decorated with pink ornaments and a gold topper",
"Tree 26: Undecorated, no decorations, just a plain tree",
"Tree 27: Decorated with silver and blue ornaments and a festive ribbon",
"Tree 28: Undecorated, just a fresh pine tree",
"Tree 29: Decorated with white fairy lights and a Christmas angel on top",
"Tree 30: Undecorated, only the green foliage of the tree",
"Tree 31: Decorated with bright red ornaments and golden tinsel",
"Tree 32: Undecorated, no lights or decorations, just branches",
"Tree 33: Decorated with silver tinsel, green ribbons, and a star",
"Tree 34: Undecorated, plain with no adornments",
"Tree 35: Decorated with red and white ornaments, and a Santa figurine",
"Tree 36: Undecorated, just a bare tree",
"Tree 37: Decorated with rainbow lights and colorful ornaments",
"Tree 38: Undecorated, a simple evergreen tree with no additions",
"Tree 39: Decorated with small golden bells and a red bow",
"Tree 40: Undecorated, no ornaments or lights",
"Tree 41: Decorated with silver baubles, white snowflakes, and a red star",
"Tree 42: Undecorated, just a natural tree with no accessories",
"Tree 43: Decorated with multicolor ribbons and white angel decorations",
"Tree 44: Undecorated, no adornments, just the tree itself",
"Tree 45: Decorated with large silver baubles and a golden star",
"Tree 46: Undecorated, no decorations, just a green tree",
"Tree 47: Decorated with white snowflakes, red ribbons, and a Santa hat",
"Tree 48: Undecorated, a bare tree with no lights or ornaments",
"Tree 49: Decorated with small lights and star-shaped ornaments",
"Tree 50: Undecorated, only the tree, no adornments or lights"
]
# Function to generate embeddings using the pre-trained model
def generate_embeddings(texts):
embeddings = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
output = model(**inputs)
embeddings.append(output.last_hidden_state.mean(dim=1).squeeze().numpy())
return embeddings
# Generate embeddings for the content
embeddings = generate_embeddings(content)
# Add the embeddings to Chroma
for idx, text in enumerate(content):
collection.add_documents(
documents=[text], # the document (text) itself
metadatas=[{"id": idx}], # metadata associated with the document
embeddings=[embeddings[idx]] # the corresponding embeddings for the document
)
# Build FAISS index for efficient retrieval
embeddings_np = np.array(embeddings).astype('float32')
faiss_index = faiss.IndexFlatL2(embeddings_np.shape[1])
faiss_index.add(embeddings_np)
# Define the search function for Gradio interface
def search(query):
# Generate embedding for the query
query_embedding = generate_embeddings([query])[0].reshape(1, -1)
# FAISS-based search
distances, indices = faiss_index.search(query_embedding, 3)
faiss_results = [content[i] for i in indices[0]]
# Chroma-based search
chroma_results = collection.query(query_embeddings=query_embedding, n_results=3)["documents"]
# Return results
return "FAISS Results: " + ", ".join(faiss_results) + "\nChroma Results: " + ", ".join(chroma_results)
# Create the Gradio interface
interface = gr.Interface(fn=search, inputs="text", outputs="text")
# Launch the Gradio interface
interface.launch()