Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,13 @@ import os
|
|
9 |
import sys
|
10 |
from datetime import datetime
|
11 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Enable GPU if available
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -31,6 +38,7 @@ GLOBAL_TOKENIZER = None
|
|
31 |
def initialize_model():
|
32 |
"""Initialize model and tokenizer globally"""
|
33 |
global GLOBAL_MODEL, GLOBAL_TOKENIZER
|
|
|
34 |
st.write("Initializing model and tokenizer...")
|
35 |
start_time = time.time()
|
36 |
|
@@ -44,11 +52,12 @@ def initialize_model():
|
|
44 |
# Set model to evaluation mode
|
45 |
GLOBAL_MODEL.eval()
|
46 |
|
47 |
-
|
48 |
|
49 |
def test_db_connection():
|
50 |
"""Test database connection with timeout"""
|
51 |
try:
|
|
|
52 |
connection = mysql.connector.connect(
|
53 |
**DB_CONFIG,
|
54 |
connect_timeout=10
|
@@ -60,33 +69,41 @@ def test_db_connection():
|
|
60 |
db_name = cursor.fetchone()[0]
|
61 |
cursor.close()
|
62 |
connection.close()
|
|
|
63 |
return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}"
|
64 |
except Error as e:
|
|
|
65 |
return False, f"Error connecting to MySQL database: {e}"
|
66 |
return False, "Unable to establish database connection"
|
67 |
|
68 |
def get_db_connection():
|
69 |
"""Get database connection from pool"""
|
|
|
70 |
return mysql.connector.connect(**DB_CONFIG)
|
71 |
|
72 |
def execute_query(query):
|
73 |
"""Execute SQL query with timeout and connection pooling"""
|
|
|
74 |
connection = None
|
75 |
try:
|
76 |
connection = get_db_connection()
|
77 |
cursor = connection.cursor(dictionary=True, buffered=True)
|
78 |
cursor.execute(query)
|
79 |
results = cursor.fetchall()
|
|
|
80 |
return results
|
81 |
except Error as e:
|
|
|
82 |
return f"Error executing query: {e}"
|
83 |
finally:
|
84 |
if connection and connection.is_connected():
|
85 |
cursor.close()
|
86 |
connection.close()
|
|
|
87 |
|
88 |
def generate_sql(natural_language_query):
|
89 |
"""Generate SQL query with performance optimizations"""
|
|
|
90 |
try:
|
91 |
start_time = time.time()
|
92 |
|
@@ -138,18 +155,21 @@ def generate_sql(natural_language_query):
|
|
138 |
generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
139 |
sql_query = generated_query.split("### SQL Query:")[-1].strip()
|
140 |
|
141 |
-
|
142 |
return sql_query
|
143 |
|
144 |
except Exception as e:
|
|
|
145 |
return f"Error generating SQL query: {str(e)}"
|
146 |
|
147 |
def format_result(query_result):
|
148 |
"""Format query results efficiently"""
|
149 |
if isinstance(query_result, str) and "Error" in query_result:
|
|
|
150 |
return query_result
|
151 |
|
152 |
if not query_result:
|
|
|
153 |
return "No results found."
|
154 |
|
155 |
# Use list comprehension for better performance
|
@@ -177,6 +197,7 @@ def main():
|
|
177 |
st.write(db_message)
|
178 |
|
179 |
if not db_success:
|
|
|
180 |
st.write("Could not connect to the database. Exiting.")
|
181 |
return
|
182 |
|
@@ -202,6 +223,7 @@ def main():
|
|
202 |
st.write("Human-Readable Response:")
|
203 |
st.text(formatted_result)
|
204 |
else:
|
|
|
205 |
st.write("Please enter a query.")
|
206 |
|
207 |
if __name__ == "__main__":
|
|
|
9 |
import sys
|
10 |
from datetime import datetime
|
11 |
import time
|
12 |
+
import logging
|
13 |
+
|
14 |
+
# Set up logging
|
15 |
+
logging.basicConfig(
|
16 |
+
level=logging.INFO,
|
17 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
18 |
+
)
|
19 |
|
20 |
# Enable GPU if available
|
21 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
38 |
def initialize_model():
|
39 |
"""Initialize model and tokenizer globally"""
|
40 |
global GLOBAL_MODEL, GLOBAL_TOKENIZER
|
41 |
+
logging.info("Initializing model and tokenizer...")
|
42 |
st.write("Initializing model and tokenizer...")
|
43 |
start_time = time.time()
|
44 |
|
|
|
52 |
# Set model to evaluation mode
|
53 |
GLOBAL_MODEL.eval()
|
54 |
|
55 |
+
logging.info(f"Model initialization took {time.time() - start_time:.2f} seconds")
|
56 |
|
57 |
def test_db_connection():
|
58 |
"""Test database connection with timeout"""
|
59 |
try:
|
60 |
+
logging.info("Testing database connection...")
|
61 |
connection = mysql.connector.connect(
|
62 |
**DB_CONFIG,
|
63 |
connect_timeout=10
|
|
|
69 |
db_name = cursor.fetchone()[0]
|
70 |
cursor.close()
|
71 |
connection.close()
|
72 |
+
logging.info(f"Successfully connected to MySQL Server version {db_info} - Database: {db_name}")
|
73 |
return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}"
|
74 |
except Error as e:
|
75 |
+
logging.error(f"Error connecting to MySQL database: {e}")
|
76 |
return False, f"Error connecting to MySQL database: {e}"
|
77 |
return False, "Unable to establish database connection"
|
78 |
|
79 |
def get_db_connection():
|
80 |
"""Get database connection from pool"""
|
81 |
+
logging.info("Getting database connection from pool...")
|
82 |
return mysql.connector.connect(**DB_CONFIG)
|
83 |
|
84 |
def execute_query(query):
|
85 |
"""Execute SQL query with timeout and connection pooling"""
|
86 |
+
logging.info(f"Executing query: {query}")
|
87 |
connection = None
|
88 |
try:
|
89 |
connection = get_db_connection()
|
90 |
cursor = connection.cursor(dictionary=True, buffered=True)
|
91 |
cursor.execute(query)
|
92 |
results = cursor.fetchall()
|
93 |
+
logging.info(f"Query executed successfully, retrieved {len(results)} records.")
|
94 |
return results
|
95 |
except Error as e:
|
96 |
+
logging.error(f"Error executing query: {e}")
|
97 |
return f"Error executing query: {e}"
|
98 |
finally:
|
99 |
if connection and connection.is_connected():
|
100 |
cursor.close()
|
101 |
connection.close()
|
102 |
+
logging.info("Database connection closed.")
|
103 |
|
104 |
def generate_sql(natural_language_query):
|
105 |
"""Generate SQL query with performance optimizations"""
|
106 |
+
logging.info(f"Generating SQL for query: {natural_language_query}")
|
107 |
try:
|
108 |
start_time = time.time()
|
109 |
|
|
|
155 |
generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
156 |
sql_query = generated_query.split("### SQL Query:")[-1].strip()
|
157 |
|
158 |
+
logging.info(f"SQL generation took {time.time() - start_time:.2f} seconds")
|
159 |
return sql_query
|
160 |
|
161 |
except Exception as e:
|
162 |
+
logging.error(f"Error generating SQL query: {str(e)}")
|
163 |
return f"Error generating SQL query: {str(e)}"
|
164 |
|
165 |
def format_result(query_result):
|
166 |
"""Format query results efficiently"""
|
167 |
if isinstance(query_result, str) and "Error" in query_result:
|
168 |
+
logging.warning(f"Query result contains an error: {query_result}")
|
169 |
return query_result
|
170 |
|
171 |
if not query_result:
|
172 |
+
logging.info("No results found.")
|
173 |
return "No results found."
|
174 |
|
175 |
# Use list comprehension for better performance
|
|
|
197 |
st.write(db_message)
|
198 |
|
199 |
if not db_success:
|
200 |
+
logging.error("Could not connect to the database. Exiting.")
|
201 |
st.write("Could not connect to the database. Exiting.")
|
202 |
return
|
203 |
|
|
|
223 |
st.write("Human-Readable Response:")
|
224 |
st.text(formatted_result)
|
225 |
else:
|
226 |
+
logging.warning("User did not enter a query.")
|
227 |
st.write("Please enter a query.")
|
228 |
|
229 |
if __name__ == "__main__":
|