File size: 12,384 Bytes
0efb4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import sqlite3
import inspect
import pandas as pd
import json
import re
import streamlit as st
def log_groq_token_usage(response, prompt=None, function_name=None, filename="efficiency_log.txt"):
    usage = response.usage
    log_message = (
        f"Function: {function_name or 'unknown'}\n"
        f"Prompt tokens: {usage.prompt_tokens}\n"
        f"Completion tokens: {usage.completion_tokens}\n"
        f"Total tokens: {usage.total_tokens}\n"
        f"Prompt: {prompt}\n"
        "---\n"
    )
    with open(filename, "a", encoding="utf-8") as f:  # ← THIS LINE
        f.write(log_message)

import pandas as pd
# --- Database Execution ---    
def execute_transaction(sql_statements):
    txn_conn = None
    try:
        txn_conn = sqlite3.connect("db/restaurant_reservation.db")
        cursor = txn_conn.cursor()
        for stmt in sql_statements:
            cursor.execute(stmt)
        txn_conn.commit()
        return "βœ… Booking Executed"
    except Exception as e:
        if txn_conn:
            txn_conn.rollback()
        return f"❌ Booking failed: {e}"
    finally:
        if txn_conn:
            txn_conn.close()


def execute_query(sql_query, db_path="db/restaurant_reservation.db"):
    conn = None
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(sql_query)
        rows = cursor.fetchall()
        columns = [desc[0] for desc in cursor.description] if cursor.description else []
        return pd.DataFrame(rows, columns=columns)
    except Exception as e:
        return f"❌ Error executing query: {e}"
    finally:
        if conn:
            conn.close()
def generate_sql_query_v2(user_input,SCHEMA_DESCRIPTIONS,history_prompt, vector_db, client, use_cache=False):
    # Get relevant schema elements
    relevant_tables = vector_db.get_relevant_schema(user_input)
    schema_prompt = "\n".join([f"Table {table}:\n{SCHEMA_DESCRIPTIONS[table]}" for table in relevant_tables])
    # Cache check
    cache_key = f"query:{user_input[:50]}"
    if use_cache and (cached := cache.get(cache_key)):
        return cached.decode()
    # Generate SQL with Groq
    prompt = f"""Based on these tables:

{schema_prompt}

Previous assistant reply:

{history_prompt}

Convert this request to SQL: {user_input}



Only return the SQL query, nothing else."""
    response = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that only returns SQL queries."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.3,
        max_tokens=200
    )
    log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)
    sql = response.choices[0].message.content.strip()
    if use_cache:
        cache.setex(cache_key, 3600, sql)
    return sql
def interpret_result_v2(result, user_query, sql_query,client):
    if isinstance(result, str):
        return result
    try:
        # Compress to essential columns if possible
        cols = [c for c in result.columns if c in ['name', 'cuisine', 'location', 'seating_capacity', 'rating', 'address', 'contact', 'price_range', 'special_features', 'capacity', 'date', 'hour']]
        if cols:
            compressed = result[cols]
        else:
            compressed = result
        json_data = compressed.to_json(orient='records', indent=2)
        # Summarize with Groq
        prompt = f"""User query: {user_query}

SQL query: {sql_query}

Result data (JSON): {json_data}



Summarize the results for the user."""
        response = client.chat.completions.create(
            model="llama3-8b-8192",
            messages=[
                {"role": "system", "content": "Summarize database query results for a restaurant reservation assistant."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.3,
            max_tokens=300
        )
        log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error interpreting results: {e}"
    
def handle_query(user_input, vector_db, client):
    try:
        # First try semantic search
        semantic_results = {}
        
        # Search across all collections
        restaurant_results = vector_db.semantic_search(user_input, "restaurants")
        table_results = vector_db.semantic_search(user_input, "tables")
        slot_results = vector_db.semantic_search(user_input, "slots")
        
        if any([restaurant_results, table_results, slot_results]):
            semantic_results = {
                "restaurants": restaurant_results,
                "tables": table_results,
                "slots": slot_results
            }
            
            # Format semantic results
            summary = []
            for category, items in semantic_results.items():
                if items:
                    summary.append(f"Found {len(items)} relevant {category}:")
                    summary.extend([f"- {item['name']}" if 'name' in item else f"- {item}" 
                                  for item in items[:3]])
            
            return "\n".join(summary)
        else:
            # Fall back to SQL generation
            sql = generate_sql_query_v2(user_input, vector_db, client)
            result = execute_query(sql)
            return interpret_result_v2(result, user_input, sql,client)
            
    except Exception as e:
        return f"Error: {e}"


def is_large_output_request(query):
    query = query.lower()
    # List of single words and multi-word phrases (as lists)
    triggers = [
        ['all'], ['every'], ['entire'], ['complete'], ['full'], ['each'],
        ['list'], ['show'], ['display'], ['give', 'me'], ['get'],
        ['every', 'single'], ['each', 'and', 'every'],
        ['whole'], ['total'], ['collection'], ['set'],
        ['no', 'filters'], ['without', 'filters'],
        ['everything'], ['entirety'],
        ['comprehensive'], ['exhaustive'], ['record'],
        ['don\'t', 'filter'], ['without', 'limitations']
    ]
    query_words = query.split()
    for trigger in triggers:
        if all(word in query_words for word in trigger):
            return True
    return False


def generate_reservation_conversation(user_query, history_prompt, sql_summary, user_data,generate_reservation_conversation_prompt,client):
    words = history_prompt.split() if history_prompt else []
    if len(words) > 25:
        history_prompt_snippet = " ".join(words[:15]) + " ... " + " ".join(words[-10:])
    else:
        history_prompt_snippet = " ".join(words)

    # Serialize user_data as pretty JSON for readability in prompt
    user_data_json = json.dumps(user_data, indent=2)

    prompt = generate_reservation_conversation_prompt.format(
        user_query=user_query,
        user_data=user_data_json,
        sql_summary=sql_summary,
        history_prompt_snippet=history_prompt_snippet
    )

    response = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[
            {"role": "system", "content": "You are a helpful restaurant reservation assistant."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.4
    )

    if not response.choices:
        return "Sorry, I couldn't generate a response right now."
    log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)

    return response.choices[0].message.content.strip()

 
# --- Helper Functions ---

def determine_intent(user_input,determine_intent_prompt,client):
    prompt = determine_intent_prompt.format(user_input=user_input)
    response = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[
            {"role": "system", "content": "Classify user intent into SELECT, STORE, BOOK, GREET, or RUBBISH based on message content."},
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )
    log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)
    return response.choices[0].message.content.strip().upper()



def store_user_info(user_input,history_prompt,store_user_info_prompt, client):
    # words = history_prompt.split()
    # if len(words) > 25:
    #     history_prompt_snippet = " ".join(words[:15]) + " ... " + " ".join(words[-10:])
    # else:
    #     history_prompt_snippet = " ".join(words)
    previous_info = json.dumps(st.session_state.user_data)
    # st.json(previous_info)
    prompt = store_user_info_prompt.format(previous_info=previous_info,user_input=user_input)
    response = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[{"role": "system", "content": "Extract or update user booking info in JSON."},
                  {"role": "user", "content": prompt}],
        temperature=0.3
    )
    log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)

    try:
        # Print raw LLM output for inspection
        raw_output = response.choices[0].message.content
        # st.subheader("🧠 Raw LLM Response")
        # st.write(raw_output)

        # Extract JSON substring from anywhere in the response
        json_match = re.search(r'{[\s\S]*?}', raw_output)
        if not json_match:
            return None
            # raise ValueError("No JSON object found in response.")

        json_str = json_match.group()

        # Show the extracted JSON string
        # st.subheader("πŸ“¦ Extracted JSON String")
        # st.code(json_str, language="json")

        # Safely parse using json.loads
        parsed = json.loads(json_str)

        # Display the parsed result
        # st.subheader("βœ… Parsed JSON Object")
        # st.json(parsed)

        return parsed

    except Exception as e:
        st.error(f"⚠️ Failed to parse JSON: {e}")
        return {}
    
def generate_sql_query(user_input,restaurant_name,party_size,time, history_prompt, schema_prompt, client):
    words = history_prompt.split()
    if len(words) > 25:
        history_prompt_snippet = " ".join(words[:15]) + " ... " + " ".join(words[-10:])
    else:
        history_prompt_snippet = " ".join(words)
    prompt = schema_prompt.format(
        history_prompt=history_prompt,
        user_input=user_input
    )

    response = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that only returns SQL queries."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.3
    )
    log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)
    raw_sql = response.choices[0].message.content.strip()
    extracted_sql = re.findall(r"(SELECT[\s\S]+?)(?:;|$)", raw_sql, re.IGNORECASE)
    sql_query = extracted_sql[0].strip() + ";" if extracted_sql else raw_sql
       
    return sql_query
    
def interpret_sql_result(user_query, sql_query, result,interpret_sql_result_prompt, client):
    if isinstance(result, pd.DataFrame):
        # Convert DataFrame to list of dicts
        result_dict = result.to_dict(orient="records")
    else:
        # Fall back to raw string if not a DataFrame
        result_dict = result

    prompt = interpret_sql_result_prompt.format(
        user_query=user_query,
        sql_query=sql_query,
        result_str=json.dumps(result_dict, indent=2)  # Pass as formatted JSON string
    )
    
    response = client.chat.completions.create(
        model="llama3-8b-8192",
        messages=[
            {"role": "system", "content": "You summarize database query results for a restaurant reservation assistant."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.3
    )
    log_groq_token_usage(response,prompt, function_name=inspect.currentframe().f_code.co_name)
    return response.choices[0].message.content.strip()