ashmib commited on
Commit
197fd64
·
verified ·
1 Parent(s): e2b76be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from sentence_transformers import SentenceTransformer
3
+ import pymongo
4
+ import sys
5
+ from huggingface_hub import InferenceClient
6
+ import gradio as gr
7
+
8
+ sys.path.append("../")
9
+ from config import constants
10
+
11
+ HF_token = constants.HF_TOKEN
12
+
13
+
14
+ def get_embedding(text: str) -> list[float]:
15
+ embedding_model = SentenceTransformer("thenlper/gte-large")
16
+
17
+ if not text.strip():
18
+ print("Attempted to get embedding for empty text.")
19
+ return []
20
+
21
+ embedding = embedding_model.encode(text)
22
+
23
+ return embedding.tolist()
24
+
25
+
26
+ def get_mongo_client(mongo_url):
27
+ """Establish connection to the MongoDB."""
28
+ if not mongo_url:
29
+ print("MONGO_URI not set in environment variables")
30
+ try:
31
+ client = pymongo.MongoClient(mongo_url)
32
+ print("Connection to MongoDB successful")
33
+ return client
34
+ except pymongo.errors.ConnectionFailure as e:
35
+ print(f"Connection failed: {e}")
36
+ return None
37
+
38
+
39
+ def get_mongo_url():
40
+ username = constants.MONGO_USERNAME
41
+ password = constants.MONGO_PW
42
+ mongo_url = f"mongodb+srv://{username}:{password}@cluster0.62unmco.mongodb.net/"
43
+ return mongo_url
44
+
45
+
46
+ def query_results(query, mongo_url):
47
+ mongo_client = get_mongo_client(mongo_url)
48
+ db = mongo_client["EU_Cities"]
49
+
50
+ query_embedding = get_embedding(query)
51
+ results = db.EU_cities_collection.aggregate([
52
+ {
53
+ "$vectorSearch": {
54
+ "index": "vector_index",
55
+ "path": "embedding",
56
+ "queryVector": query_embedding,
57
+ "numCandidates": 150,
58
+ "limit": 5
59
+ }
60
+ }
61
+ ])
62
+ return results
63
+
64
+
65
+ def get_search_result(query, mongo_url):
66
+ get_knowledge = query_results(query, mongo_url)
67
+ print(get_knowledge)
68
+
69
+ search_result = ""
70
+ for result in get_knowledge:
71
+ search_result += f"City: {result.get('city', 'N/A')}, Abstract: {result.get('combined', 'N/A')}\n"
72
+
73
+ return search_result
74
+
75
+
76
+ def generate_text(query, model_name: Optional[str] = "google/gemma-2b-it"):
77
+ if model_name is None:
78
+ model_name = "google/gemma-2b-it"
79
+
80
+ mongo_url = get_mongo_url()
81
+ source_information = get_search_result(query, mongo_url)
82
+ combined_information = (
83
+ f"Query: {query}\nContinue to answer the query by using the Search Results:\n{source_information}."
84
+ )
85
+ client = InferenceClient(model_name, token=HF_token)
86
+
87
+ stream = client.text_generation(prompt=combined_information, details=True, stream=True, max_new_tokens=2048,
88
+ return_full_text=False)
89
+ output = ""
90
+
91
+ for response in stream:
92
+ output += response.token.text
93
+
94
+ if "<eos>" in output:
95
+ output = output.split("<eos>")[0]
96
+ return output
97
+
98
+
99
+ examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
100
+ "local cuisines to try?", None],
101
+ ["Recommend places that are similar to Istanbul in terms of architecture", None],
102
+ ]
103
+
104
+ demo = gr.Interface(
105
+ fn=generate_text,
106
+ inputs=["text",
107
+ gr.Dropdown(
108
+ ["google/gemma-2b-it","google/gemma-7b", "mistralai/Mixtral-8x7B-Instruct-v0.1"], label="Models", info="Will "
109
+ "add "
110
+ "more "
111
+ "models "
112
+ "later! "
113
+ ),
114
+ ],
115
+ title="🇪🇺 Euro TravelBot 🇪🇺",
116
+ description="Travel related queries for Europe.",
117
+ outputs=["text"],
118
+ examples=examples,
119
+ )
120
+
121
+ if __name__ == "__main__":
122
+ demo.launch()