adbcode commited on
Commit
f087194
·
verified ·
1 Parent(s): 0e57a25

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import pymongo
4
+ import spaces
5
+
6
+
7
+ from huggingface_hub import login
8
+ from sentence_transformers import SentenceTransformer
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+
12
+ def get_embedding(text: str) -> list[float]:
13
+ if not text.strip():
14
+ print("Attempted to get embedding for empty text.")
15
+ return []
16
+
17
+ embedding = embedding_model.encode(text)
18
+
19
+ return embedding.tolist()
20
+
21
+
22
+ def get_mongo_client(mongo_uri):
23
+ """Establish connection to the MongoDB."""
24
+ try:
25
+ client = pymongo.MongoClient(mongo_uri)
26
+ print("Connection to MongoDB successful")
27
+ return client
28
+ except pymongo.errors.ConnectionFailure as e:
29
+ print(f"Connection failed: {e}")
30
+ return None
31
+
32
+
33
+ def vector_search(user_query, collection):
34
+
35
+ # Generate embedding for the user query
36
+ query_embedding = get_embedding(user_query)
37
+
38
+ if query_embedding is None:
39
+ return "Invalid query or embedding generation failed."
40
+
41
+ # Define the vector search pipeline
42
+ pipeline = [
43
+ {
44
+ "$vectorSearch": {
45
+ "index": "vector_index",
46
+ "queryVector": query_embedding,
47
+ "path": "embedding",
48
+ "numCandidates": 150, # Number of candidate matches to consider
49
+ "limit": 4, # Return top 4 matches
50
+ }
51
+ },
52
+ {
53
+ "$project": {
54
+ "_id": 0,
55
+ "title": 1,
56
+ "ingredients": 1,
57
+ "directions": 1,
58
+ "score": {"$meta": "vectorSearchScore"}, # Include the search score
59
+ }
60
+ },
61
+ ]
62
+
63
+ # Execute the search
64
+ results = collection.aggregate(pipeline)
65
+ return list(results)
66
+
67
+
68
+ def get_search_result(query, collection):
69
+
70
+ get_knowledge = vector_search(query, collection)
71
+
72
+ search_result = ""
73
+ for result in get_knowledge:
74
+ search_result += f"Recipe Name: {result.get('title', 'N/A')}, Ingredients: {result.get('ingredients', 'N/A')}, Directions: {result.get('directions', 'N/A')}\n"
75
+
76
+ return search_result, get_knowledge
77
+
78
+
79
+ @spaces.GPU
80
+ def process_response(message, history):
81
+ source_information, matches = get_search_result(message, collection)
82
+ recipe_dict = {}
83
+ for x in matches:
84
+ name = x.pop("title")
85
+ recipe_dict[name] = x
86
+
87
+ combined_information = f"Query: {message}\nContinue to answer the query by using the Search Results:\n{source_information}."
88
+ input_ids = tokenizer(combined_information, return_tensors="pt").to("cuda")
89
+ response = model.generate(**input_ids, max_new_tokens=500)
90
+ response_text = tokenizer.decode(response[0]).split("\n.\n")[-1].split("<eos>")[0].strip()
91
+
92
+ matched_recipe = ""
93
+ for title in recipe_dict.keys():
94
+ if title in response_text:
95
+ matched_recipe = title
96
+ break
97
+ if not matched_recipe:
98
+ matched_recipe = next(iter(recipe_dict))
99
+ recipe = recipe_dict[matched_recipe]
100
+
101
+ response_text += f"\n\nRecipe for **{matched_recipe}**:"
102
+ response_text += "\n### List of ingredients:\n- {0}".format("\n- ".join(recipe["ingredients"].split(", ")))
103
+ response_text += "\n### Directions:\n- {0}".format(".\n- ".join(recipe["directions"].split(". ")))
104
+
105
+ return response_text
106
+
107
+
108
+ if __name__ == "__main__":
109
+
110
+ # https://huggingface.co/thenlper/gte-large
111
+ embedding_model = SentenceTransformer("thenlper/gte-large")
112
+
113
+ mongo_uri = os.getenv("MONGO_URI")
114
+ if not mongo_uri:
115
+ raise ValueError("MONGO_URI not set in environment variables")
116
+
117
+ mongo_client = get_mongo_client(mongo_uri)
118
+
119
+ # Ingest data into MongoDB
120
+ db = mongo_client["recipe"]
121
+ collection = db["recipe_collection"]
122
+
123
+ # login(token=os.getenv("HF_TOKEN"))
124
+
125
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
126
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto")
127
+
128
+ gr.ChatInterface(process_response).launch()