hasnanmr commited on
Commit
c2bc8b9
·
1 Parent(s): 069c659

change model

Browse files
Files changed (1) hide show
  1. app.py +61 -104
app.py CHANGED
@@ -8,35 +8,43 @@ import sqlite3
8
  import pandas as pd
9
  from tqdm import tqdm
10
 
11
- # Get the Groq API key from environment variables (in Hugging Face, this is stored as a secret)
 
 
 
12
  client = Groq(
13
- # This is the default and can be omitted
14
- api_key=os.environ.get("GROQ_API_KEY"),
15
  )
16
 
17
 
 
 
 
 
 
 
 
 
 
18
  con = sqlite3.connect("file::memory:?cache=shared", check_same_thread=False)
19
  con.row_factory = sqlite3.Row
20
  cur = con.cursor()
21
 
22
  # create table if not exists
23
-
24
  cur.execute("""
25
  CREATE TABLE IF NOT EXISTS places (
26
- Place_Id INTEGER PRIMARY KEY, -- SQLite auto-increments INTEGER PRIMARY KEY automatically
27
- Place_Name TEXT NOT NULL, -- SQLite uses TEXT instead of VARCHAR
28
  Description TEXT,
29
  Category TEXT,
30
  City TEXT,
31
- Price REAL, -- SQLite uses REAL instead of DECIMAL or FLOAT
32
  Rating REAL,
33
  Embedding TEXT
34
  );
35
  """)
36
 
37
-
38
- data = pd.read_csv('tourism_place.csv')
39
-
40
 
41
  # check if the table is empty
42
  cur.execute("SELECT * FROM places")
@@ -45,171 +53,120 @@ if cur.fetchone() is None:
45
  # Store the places in the database
46
  for i in tqdm(range(len(data))):
47
  cur.execute("""
48
- INSERT INTO places (Place_Name, Description, Category, City, Price, Rating)
49
  VALUES (?, ?, ?, ?, ?, ?)
50
  """, (data['Place_Name'][i], data['Description'][i], data['Category'][i], data['City'][i], float(data['Price'][i]), float(data['Rating'][i]))
51
  )
52
-
53
- # Commit the changes to the database
54
  con.commit()
55
 
56
- # Compute and store embeddings
57
  def compute_and_store_embeddings():
58
- model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
59
-
60
- # Select all places from the database
61
  cur.execute("SELECT Place_Id, Place_Name, Category, Description, City FROM places")
62
  places = cur.fetchall()
63
-
64
  for place in places:
65
- # Combine PlaceName, Category, Description, and City into one string
66
  text = f"{place[1]} {place[2]} {place[3]} {place[4]}"
67
-
68
- # Generate embedding for the combined text
69
  embedding = model.encode(text)
70
-
71
- # Convert embedding to a string format to store in the database
72
  embedding_str = ','.join([str(x) for x in embedding])
73
-
74
- # Update the place in the database with the embedding
75
- cur.execute(
76
- "UPDATE places SET Embedding = ? WHERE Place_Id = ?",
77
- (embedding_str, place[0])
78
- )
79
-
80
- # Commit the changes to the database
81
  con.commit()
82
- # Run the function to compute and store embeddings
83
- compute_and_store_embeddings()
84
-
85
 
86
- # Load Hugging Face model for generating embeddings
87
- model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
88
 
89
  # Normalize user query using Groq VM
90
  def normalize_query(user_query):
91
  try:
92
  response = client.chat.completions.create(
93
- model="llama-3.1-70b-versatile",
94
- temperature=0.5,
95
  messages=[{
96
  "role": "user",
97
  "content": f"""
98
- Please analyze the query: \"{user_query}\", extract Place name, Category, Description, and City.
99
  Return the response as: "Place name, Category, Description, City".
100
  """
101
  }]
102
  )
103
  normalized_user_query = response.choices[0].message.content.split('\n')[-1].strip()
104
- print(f"Normalized Query: {normalized_user_query}")
105
-
106
- return normalized_user_query
107
-
108
  except Exception as e:
109
  print(f"Error normalizing query: {e}")
110
  return ""
111
 
112
- # Generate user embedding using Hugging Face model
113
  def get_user_embedding(query):
114
  try:
115
  return model.encode(query)
116
  except Exception as e:
117
  print(f"Error generating embedding: {e}")
118
- return np.zeros()
119
 
120
- # Find similar places based on cosine similarity
121
  def get_similar_places(user_embedding):
122
  similarities = []
123
- # Select all places from the database
124
  res = cur.execute("SELECT * FROM places").fetchall()
125
-
126
  for place in res:
127
- embedding_str = place['Embedding'] # Assuming embeddings are stored as comma-separated strings in the database
128
- embedding = np.array([float(x) for x in embedding_str.split(',')]) # Convert the string back to a numpy array
129
-
130
- # Compute cosine similarity
131
  similarity = cosine_similarity([user_embedding], [embedding])[0][0]
132
  similarities.append((place, similarity))
133
-
134
- # Sort results based on similarity and then by rating
135
  ranked_results = sorted(similarities, key=lambda x: (x[1], x[0]['Rating']), reverse=True)
136
-
137
- # Return top places
138
  return ranked_results
139
 
140
- # Main function to get top 5 destinations
141
- def get_top_5_destinations(user_query):
142
  normalized_query = normalize_query(user_query)
143
  user_embedding = get_user_embedding(normalized_query)
144
  similar_places = get_similar_places(user_embedding)
145
-
146
  if not similar_places:
147
  return "Tidak ada tempat yang ditemukan."
 
148
 
149
- top_places = []
150
- for i, (place, similarity) in enumerate(similar_places[:10]):
151
- top_places.append({
152
- 'name': place['Place_Name'],
153
- 'city': place['City'],
154
- 'category': place['Category'],
155
- 'rating': place['Rating'],
156
- 'description': place['Description'],
157
- 'similarity': similarity
158
- })
159
- print(top_places)
160
- return top_places
161
-
162
- # Generate response to user using Groq VM
163
- def generate_response(user_query, top_places):
164
  try:
165
- # Prepare the destinations data in JSON format for the model to use directly
166
  destinations_data = ", ".join([
167
- f'{{"name": "{place["name"]}", "city": "{place["city"]}", "category": "{place["category"]}", "rating": {place["rating"]}, "description": "{place["description"]}"}}'
168
  for place in top_places
169
  ])
170
-
171
- # System prompt: Simplified and focused on returning only the recommendations
172
- system_prompt = """
173
- You are a tour guide assistant. Your task is to present the following tourism recommendations based on what user want and needs to the user in Bahasa Indonesia.
174
- - For each destination, include the name, city, category, rating, and a short description.
175
- - Do not provide any additional commentary.
176
- - Only and must only return 5 places that suitable what user wants and provided the data in a clear and concise format.
177
- """
178
-
179
- # Generate the response using the model
180
  response = client.chat.completions.create(
181
- model="llama-3.1-70b-versatile",
182
- temperature=0.2,
183
  messages=[
184
- {"role": "system", "content": system_prompt}, # System prompt defines behavior
185
  {"role": "user", "content": f"Berikut adalah rekomendasi berdasarkan data: {destinations_data}"}
186
- ]
 
187
  )
188
-
189
- # Return the response content generated by the model
190
  return response.choices[0].message.content
191
  except Exception as e:
192
  print(f"Error generating response: {e}")
193
  return "Maaf, terjadi kesalahan dalam menghasilkan rekomendasi."
194
 
195
- # Gradio Interface - User Input and Output
196
- def chatbot(user_query):
197
- # Step 1: Get the top 5 destinations
198
- top_places = get_top_5_destinations(user_query)
199
-
200
- if isinstance(top_places, str): # Error case, e.g. "No places found"
201
  return top_places
202
-
203
- # only the first 5 element of top_places
204
- response = generate_response(user_query, top_places)
205
-
206
  return response
207
 
208
  # Define Gradio Interface
209
  iface = gr.Interface(
210
  fn=chatbot,
211
- inputs="text",
212
- outputs="text",
 
 
 
 
 
 
 
 
 
213
  title="Tourism Recommendation Chatbot",
214
  description="Masukkan pertanyaan wisata Anda dan dapatkan rekomendasi tempat terbaik!"
215
  )
 
8
  import pandas as pd
9
  from tqdm import tqdm
10
 
11
+ # Define the SentenceTransformer model globally
12
+ model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
13
+
14
+ # Get the Groq API key from environment variables
15
  client = Groq(
16
+ api_key="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"
 
17
  )
18
 
19
 
20
+ # Generate user embedding using the globally defined model
21
+ def get_user_embedding(query):
22
+ try:
23
+ return model.encode(query)
24
+ except Exception as e:
25
+ print(f"Error generating embedding: {e}")
26
+ return np.zeros(384) # Return a zero-vector of the correct size if there is an error
27
+
28
+
29
  con = sqlite3.connect("file::memory:?cache=shared", check_same_thread=False)
30
  con.row_factory = sqlite3.Row
31
  cur = con.cursor()
32
 
33
  # create table if not exists
 
34
  cur.execute("""
35
  CREATE TABLE IF NOT EXISTS places (
36
+ Place_Id INTEGER PRIMARY KEY,
37
+ Place_Name TEXT NOT NULL,
38
  Description TEXT,
39
  Category TEXT,
40
  City TEXT,
41
+ Price REAL,
42
  Rating REAL,
43
  Embedding TEXT
44
  );
45
  """)
46
 
47
+ data = pd.read_csv('dataset/tourism_place.csv')
 
 
48
 
49
  # check if the table is empty
50
  cur.execute("SELECT * FROM places")
 
53
  # Store the places in the database
54
  for i in tqdm(range(len(data))):
55
  cur.execute("""
56
+ INSERT INTO places (Place_Name, Description, Category, City, Price, Rating)
57
  VALUES (?, ?, ?, ?, ?, ?)
58
  """, (data['Place_Name'][i], data['Description'][i], data['Category'][i], data['City'][i], float(data['Price'][i]), float(data['Rating'][i]))
59
  )
 
 
60
  con.commit()
61
 
62
+ # Compute and store embeddings for places using the same model
63
  def compute_and_store_embeddings():
 
 
 
64
  cur.execute("SELECT Place_Id, Place_Name, Category, Description, City FROM places")
65
  places = cur.fetchall()
66
+
67
  for place in places:
 
68
  text = f"{place[1]} {place[2]} {place[3]} {place[4]}"
 
 
69
  embedding = model.encode(text)
 
 
70
  embedding_str = ','.join([str(x) for x in embedding])
71
+ cur.execute("UPDATE places SET Embedding = ? WHERE Place_Id = ?", (embedding_str, place[0]))
 
 
 
 
 
 
 
72
  con.commit()
 
 
 
73
 
74
+ compute_and_store_embeddings()
 
75
 
76
  # Normalize user query using Groq VM
77
  def normalize_query(user_query):
78
  try:
79
  response = client.chat.completions.create(
80
+ model="llama-3.1-8b-instant",
 
81
  messages=[{
82
  "role": "user",
83
  "content": f"""
84
+ Please analyze the query: \"{user_query}\", extract Place name, Category, Description, and City.
85
  Return the response as: "Place name, Category, Description, City".
86
  """
87
  }]
88
  )
89
  normalized_user_query = response.choices[0].message.content.split('\n')[-1].strip()
90
+ return normalized_user_query + str(user_query)
 
 
 
91
  except Exception as e:
92
  print(f"Error normalizing query: {e}")
93
  return ""
94
 
95
+ # Generate user embedding
96
  def get_user_embedding(query):
97
  try:
98
  return model.encode(query)
99
  except Exception as e:
100
  print(f"Error generating embedding: {e}")
101
+ return np.zeros(512)
102
 
103
+ # Find similar places
104
  def get_similar_places(user_embedding):
105
  similarities = []
 
106
  res = cur.execute("SELECT * FROM places").fetchall()
 
107
  for place in res:
108
+ embedding_str = place['Embedding']
109
+ embedding = np.array([float(x) for x in embedding_str.split(',')])
 
 
110
  similarity = cosine_similarity([user_embedding], [embedding])[0][0]
111
  similarities.append((place, similarity))
 
 
112
  ranked_results = sorted(similarities, key=lambda x: (x[1], x[0]['Rating']), reverse=True)
 
 
113
  return ranked_results
114
 
115
+ # Get top 10 destinations
116
+ def get_top_10_destinations(user_query):
117
  normalized_query = normalize_query(user_query)
118
  user_embedding = get_user_embedding(normalized_query)
119
  similar_places = get_similar_places(user_embedding)
 
120
  if not similar_places:
121
  return "Tidak ada tempat yang ditemukan."
122
+ return similar_places[:10]
123
 
124
+ # Generate response using Groq VM
125
+ def generate_response(user_query, top_places, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  try:
 
127
  destinations_data = ", ".join([
128
+ f'{{"name": "{place[0]["Place_Name"]}", "city": "{place[0]["City"]}", "category": "{place[0]["Category"]}", "rating": {place[0]["Rating"]}, "description": "{place[0]["Description"]}"}}'
129
  for place in top_places
130
  ])
131
+ system_prompt = f"""
132
+ You are a tour guide assistant. Present the tourism recommendations to the user in Bahasa Indonesia.
133
+ Only return maximum 5 places that suitable what user wants and provided the data in a clear and concise format. Only return the city that mentioned in \"{user_query}\".
134
+ """
 
 
 
 
 
 
135
  response = client.chat.completions.create(
136
+ model="llama-3.1-8b-instant",
 
137
  messages=[
138
+ {"role": "system", "content": system_prompt},
139
  {"role": "user", "content": f"Berikut adalah rekomendasi berdasarkan data: {destinations_data}"}
140
+ ],
141
+ temperature=temperature
142
  )
 
 
143
  return response.choices[0].message.content
144
  except Exception as e:
145
  print(f"Error generating response: {e}")
146
  return "Maaf, terjadi kesalahan dalam menghasilkan rekomendasi."
147
 
148
+ # Main chatbot function
149
+ def chatbot(user_query, temperature):
150
+ top_places = get_top_10_destinations(user_query)
151
+ if isinstance(top_places, str):
 
 
152
  return top_places
153
+ response = generate_response(user_query, top_places[:5], temperature)
 
 
 
154
  return response
155
 
156
  # Define Gradio Interface
157
  iface = gr.Interface(
158
  fn=chatbot,
159
+ inputs=[
160
+ "text",
161
+ gr.Slider(
162
+ minimum=0,
163
+ maximum=1,
164
+ step=0.1,
165
+ value=0.8,
166
+ label="Temperature"
167
+ )
168
+ ],
169
+ outputs="text",
170
  title="Tourism Recommendation Chatbot",
171
  description="Masukkan pertanyaan wisata Anda dan dapatkan rekomendasi tempat terbaik!"
172
  )