Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,10 @@ from sentence_transformers import SentenceTransformer
|
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
# Load a small subset (10,000 rows)
|
8 |
-
dataset = load_dataset("wiki40b", "en", split="train[:
|
9 |
|
10 |
# Extract only text
|
11 |
docs = [d["text"] for d in dataset]
|
@@ -18,19 +19,29 @@ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
18 |
# Convert texts to embeddings
|
19 |
embeddings = embed_model.encode(docs, show_progress_bar=True)
|
20 |
|
21 |
-
# Store in FAISS index
|
22 |
-
dimension = embeddings.shape[1]
|
23 |
-
index = faiss.IndexFlatL2(dimension)
|
24 |
-
index.add(np.array(embeddings))
|
25 |
|
26 |
-
|
|
|
|
|
27 |
|
28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def search_wikipedia(query, top_k=3):
|
30 |
-
query_embedding = embed_model.encode([query])
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
# Gradio Interface
|
36 |
iface = gr.Interface(
|
|
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
+
import chromadb
|
7 |
|
8 |
# Load a small subset (10,000 rows)
|
9 |
+
dataset = load_dataset("wiki40b", "en", split="train[:1000]")
|
10 |
|
11 |
# Extract only text
|
12 |
docs = [d["text"] for d in dataset]
|
|
|
19 |
# Convert texts to embeddings
|
20 |
embeddings = embed_model.encode(docs, show_progress_bar=True)
|
21 |
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Initialize ChromaDB client
|
24 |
+
chroma_client = chromadb.PersistentClient(path="./chroma_db") # Stores data persistently
|
25 |
+
collection = chroma_client.get_or_create_collection(name="wikipedia_docs")
|
26 |
|
27 |
+
# Store embeddings in ChromaDB
|
28 |
+
for i, (doc, embedding) in enumerate(zip(docs, embeddings)):
|
29 |
+
collection.add(
|
30 |
+
ids=[str(i)], # Unique ID for each doc
|
31 |
+
embeddings=[embedding.tolist()], # Convert numpy array to list
|
32 |
+
documents=[doc]
|
33 |
+
)
|
34 |
+
|
35 |
+
print("Stored embeddings in ChromaDB!")
|
36 |
+
|
37 |
+
# Search function using ChromaDB
|
38 |
def search_wikipedia(query, top_k=3):
|
39 |
+
query_embedding = embed_model.encode([query]).tolist()
|
40 |
+
results = collection.query(
|
41 |
+
query_embeddings=query_embedding,
|
42 |
+
n_results=top_k
|
43 |
+
)
|
44 |
+
return "\n\n".join(results["documents"][0]) # Return top results
|
45 |
|
46 |
# Gradio Interface
|
47 |
iface = gr.Interface(
|