RAG_AI_BOT / app.py
willco-afk's picture
Update app.py
1764725 verified
raw
history blame
2.2 kB
import gradio as gr
import chromadb
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import torch
# Load the pre-trained model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Initialize Chroma client
client = chromadb.Client()
# Create a Chroma collection
collection = client.create_collection(name="tree_images")
# Example data (you can replace this with your actual content or dataset)
content = ["Tree 1: Decorated with lights", "Tree 2: Undecorated", "Tree 3: Decorated with ornaments"]
# Function to generate embeddings using the pre-trained model
def generate_embeddings(texts):
embeddings = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
output = model(**inputs)
embeddings.append(output.last_hidden_state.mean(dim=1).squeeze().numpy())
return embeddings
# Generate embeddings for the content
embeddings = generate_embeddings(content)
# Add the embeddings to Chroma
for idx, text in enumerate(content):
collection.add(embedding=embeddings[idx], document=text, metadatas={"id": idx})
# Build FAISS index for efficient retrieval
embeddings_np = np.array(embeddings).astype('float32')
faiss_index = faiss.IndexFlatL2(embeddings_np.shape[1])
faiss_index.add(embeddings_np)
# Define the search function for Gradio interface
def search(query):
# Generate embedding for the query
query_embedding = generate_embeddings([query])[0].reshape(1, -1)
# FAISS-based search
distances, indices = faiss_index.search(query_embedding, 3)
faiss_results = [content[i] for i in indices[0]]
# Chroma-based search
chroma_results = collection.query(query_embeddings=query_embedding, n_results=3)["documents"]
# Return results
return "FAISS Results: " + ", ".join(faiss_results) + "\nChroma Results: " + ", ".join(chroma_results)
# Create the Gradio interface
interface = gr.Interface(fn=search, inputs="text", outputs="text")
# Launch the Gradio interface
interface.launch()