Kalyani8 commited on
Commit
b20fcd1
·
verified ·
1 Parent(s): e25de72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -3,9 +3,10 @@ from sentence_transformers import SentenceTransformer
3
  import faiss
4
  import numpy as np
5
  import gradio as gr
 
6
 
7
  # Load a small subset (10,000 rows)
8
- dataset = load_dataset("wiki40b", "en", split="train[:10000]")
9
 
10
  # Extract only text
11
  docs = [d["text"] for d in dataset]
@@ -18,19 +19,29 @@ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
18
  # Convert texts to embeddings
19
  embeddings = embed_model.encode(docs, show_progress_bar=True)
20
 
21
- # Store in FAISS index
22
- dimension = embeddings.shape[1]
23
- index = faiss.IndexFlatL2(dimension)
24
- index.add(np.array(embeddings))
25
 
26
- print("Stored embeddings in FAISS!")
 
 
27
 
28
- # Search function
 
 
 
 
 
 
 
 
 
 
29
  def search_wikipedia(query, top_k=3):
30
- query_embedding = embed_model.encode([query])
31
- distances, indices = index.search(np.array(query_embedding), top_k)
32
- results = [docs[i] for i in indices[0]]
33
- return "\n\n".join(results)
 
 
34
 
35
  # Gradio Interface
36
  iface = gr.Interface(
 
3
  import faiss
4
  import numpy as np
5
  import gradio as gr
6
+ import chromadb
7
 
8
  # Load a small subset (10,000 rows)
9
+ dataset = load_dataset("wiki40b", "en", split="train[:1000]")
10
 
11
  # Extract only text
12
  docs = [d["text"] for d in dataset]
 
19
  # Convert texts to embeddings
20
  embeddings = embed_model.encode(docs, show_progress_bar=True)
21
 
 
 
 
 
22
 
23
+ # Initialize ChromaDB client
24
+ chroma_client = chromadb.PersistentClient(path="./chroma_db") # Stores data persistently
25
+ collection = chroma_client.get_or_create_collection(name="wikipedia_docs")
26
 
27
+ # Store embeddings in ChromaDB
28
+ for i, (doc, embedding) in enumerate(zip(docs, embeddings)):
29
+ collection.add(
30
+ ids=[str(i)], # Unique ID for each doc
31
+ embeddings=[embedding.tolist()], # Convert numpy array to list
32
+ documents=[doc]
33
+ )
34
+
35
+ print("Stored embeddings in ChromaDB!")
36
+
37
+ # Search function using ChromaDB
38
  def search_wikipedia(query, top_k=3):
39
+ query_embedding = embed_model.encode([query]).tolist()
40
+ results = collection.query(
41
+ query_embeddings=query_embedding,
42
+ n_results=top_k
43
+ )
44
+ return "\n\n".join(results["documents"][0]) # Return top results
45
 
46
  # Gradio Interface
47
  iface = gr.Interface(