Ankitajadhav commited on
Commit
2fe908e
·
verified ·
1 Parent(s): 28fb49c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -26,10 +26,24 @@ class VectorStore:
26
  self.chroma_client = chromadb.Client()
27
  self.collection = self.chroma_client.create_collection(name=collection_name)
28
 
29
- def populate_vectors(self, texts, ids):
30
- embeddings = self.embedding_model.encode(texts, batch_size=32).tolist()
31
- for text, embedding, doc_id in zip(texts, embeddings, ids):
32
- self.collection.add(embeddings=[embedding], documents=[text], ids=[doc_id])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def search_context(self, query, n_results=1):
35
  query_embedding = self.embedding_model.encode([query]).tolist()
 
26
  self.chroma_client = chromadb.Client()
27
  self.collection = self.chroma_client.create_collection(name=collection_name)
28
 
29
+ # def populate_vectors(self, texts, ids):
30
+ # embeddings = self.embedding_model.encode(texts, batch_size=32).tolist()
31
+ # for text, embedding, doc_id in zip(texts, embeddings, ids):
32
+ # self.collection.add(embeddings=[embedding], documents=[text], ids=[doc_id])
33
+
34
+ # Method to populate the vector store with embeddings from a dataset
35
+ def populate_vectors(self, dataset):
36
+ # Select the text columns to concatenate
37
+ # title = dataset['train']['title_cleaned'][:1000] # Limiting to 100 examples for the demo
38
+ recipe = dataset['train']['recipe_new'][:1000]
39
+ allergy = dataset['train']['allergy_type'][:1000]
40
+ ingredients = dataset['train']['ingredients_alternatives'][:1000]
41
+
42
+ # Concatenate the text from both columns
43
+ texts = [f"{rep} {ingr} {alle}" for rep, ingr,alle in zip(recipe, ingredients,allergy)]
44
+ for i, item in enumerate(texts):
45
+ embeddings = self.embedding_model.encode(item).tolist()
46
+ self.collection.add(embeddings=[embeddings], documents=[item], ids=[str(i)])
47
 
48
  def search_context(self, query, n_results=1):
49
  query_embedding = self.embedding_model.encode([query]).tolist()