Spaces:
Sleeping
Sleeping
Create llm_service.py
Browse files- llm_service.py +184 -0
llm_service.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import logging
|
3 |
+
import requests
|
4 |
+
from typing import Dict, Any, List, Optional
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class TableInfo:
|
9 |
+
"""Store table information including schema and relationships"""
|
10 |
+
name: str
|
11 |
+
columns: List[Dict[str, Any]]
|
12 |
+
relationships: List[Dict[str, str]]
|
13 |
+
|
14 |
+
class LLMService:
|
15 |
+
def __init__(self, api_key: str, db_path: str):
|
16 |
+
self.api_key = api_key
|
17 |
+
self.db_path = db_path
|
18 |
+
self.api_url = "https://api.groq.com/openai/v1/chat/completions"
|
19 |
+
self.headers = {
|
20 |
+
"Authorization": f"Bearer {api_key}",
|
21 |
+
"Content-Type": "application/json"
|
22 |
+
}
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
+
self.logger = logging.getLogger(__name__)
|
25 |
+
self.table_info = self._load_database_schema()
|
26 |
+
|
27 |
+
def _load_database_schema(self) -> Dict[str, TableInfo]:
|
28 |
+
"""Load complete database schema with relationships"""
|
29 |
+
schema_info = {}
|
30 |
+
try:
|
31 |
+
with sqlite3.connect(self.db_path) as conn:
|
32 |
+
cursor = conn.cursor()
|
33 |
+
|
34 |
+
# Get all tables
|
35 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
36 |
+
tables = cursor.fetchall()
|
37 |
+
|
38 |
+
for table in tables:
|
39 |
+
table_name = table[0]
|
40 |
+
# Get columns
|
41 |
+
cursor.execute(f"PRAGMA table_info({table_name});")
|
42 |
+
columns = [{
|
43 |
+
'name': col[1],
|
44 |
+
'type': col[2],
|
45 |
+
'nullable': not col[3],
|
46 |
+
'primary_key': bool(col[5])
|
47 |
+
} for col in cursor.fetchall()]
|
48 |
+
|
49 |
+
# Get foreign keys
|
50 |
+
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
|
51 |
+
relationships = [{
|
52 |
+
'from_table': table_name,
|
53 |
+
'to_table': fk[2],
|
54 |
+
'from_column': fk[3],
|
55 |
+
'to_column': fk[4]
|
56 |
+
} for fk in cursor.fetchall()]
|
57 |
+
|
58 |
+
schema_info[table_name] = TableInfo(
|
59 |
+
name=table_name,
|
60 |
+
columns=columns,
|
61 |
+
relationships=relationships
|
62 |
+
)
|
63 |
+
|
64 |
+
return schema_info
|
65 |
+
|
66 |
+
except sqlite3.Error as e:
|
67 |
+
self.logger.error(f"Database error while loading schema: {e}")
|
68 |
+
return {}
|
69 |
+
|
70 |
+
def _prepare_schema_prompt(self) -> str:
|
71 |
+
"""Prepare a comprehensive schema description for the LLM"""
|
72 |
+
prompt = "Database Schema:\n\n"
|
73 |
+
|
74 |
+
# Add table schemas
|
75 |
+
for table_name, info in self.table_info.items():
|
76 |
+
prompt += f"Table: {table_name}\n"
|
77 |
+
prompt += "Columns:\n"
|
78 |
+
for col in info.columns:
|
79 |
+
prompt += f"- {col['name']} ({col['type']})"
|
80 |
+
if col['primary_key']:
|
81 |
+
prompt += " PRIMARY KEY"
|
82 |
+
if not col['nullable']:
|
83 |
+
prompt += " NOT NULL"
|
84 |
+
prompt += "\n"
|
85 |
+
|
86 |
+
# Add relationships
|
87 |
+
if info.relationships:
|
88 |
+
prompt += "Relationships:\n"
|
89 |
+
for rel in info.relationships:
|
90 |
+
prompt += f"- {rel['from_table']}.{rel['from_column']} -> {rel['to_table']}.{rel['to_column']}\n"
|
91 |
+
prompt += "\n"
|
92 |
+
|
93 |
+
return prompt
|
94 |
+
|
95 |
+
def convert_to_sql_query(self, natural_query: str) -> Dict[str, Any]:
|
96 |
+
"""Convert natural language to SQL query using Groq API"""
|
97 |
+
schema_prompt = self._prepare_schema_prompt()
|
98 |
+
|
99 |
+
system_prompt = f"""
|
100 |
+
You are an expert SQL query generator. Your task is to convert natural language queries into valid SQL queries based on the provided database schema.
|
101 |
+
|
102 |
+
{schema_prompt}
|
103 |
+
|
104 |
+
Rules for generating SQL queries:
|
105 |
+
1. Use proper JOIN syntax when relating multiple tables
|
106 |
+
2. Consider table relationships and use appropriate JOIN conditions
|
107 |
+
3. Handle NULL values appropriately
|
108 |
+
4. Use table aliases when necessary for clarity
|
109 |
+
5. Return only the requested columns, use * only when specifically asked
|
110 |
+
6. Include WHERE clauses based on the natural language conditions
|
111 |
+
7. Use appropriate aggregation functions when needed (COUNT, SUM, AVG, etc.)
|
112 |
+
|
113 |
+
Generate a SQL query for the following natural language request:
|
114 |
+
{natural_query}
|
115 |
+
|
116 |
+
Return only the SQL query without any explanation.
|
117 |
+
"""
|
118 |
+
|
119 |
+
payload = {
|
120 |
+
"model": "llama3-8b-8192",
|
121 |
+
"messages": [
|
122 |
+
{"role": "system", "content": system_prompt},
|
123 |
+
{"role": "user", "content": natural_query}
|
124 |
+
],
|
125 |
+
"max_tokens": 500,
|
126 |
+
"temperature": 0.1
|
127 |
+
}
|
128 |
+
|
129 |
+
try:
|
130 |
+
self.logger.info(f"Sending request to Groq API for query: {natural_query}")
|
131 |
+
response = requests.post(
|
132 |
+
self.api_url,
|
133 |
+
headers=self.headers,
|
134 |
+
json=payload,
|
135 |
+
timeout=30
|
136 |
+
)
|
137 |
+
response.raise_for_status()
|
138 |
+
|
139 |
+
result = response.json()
|
140 |
+
self.logger.info(f"Received response from Groq API")
|
141 |
+
|
142 |
+
if 'choices' in result and result['choices']:
|
143 |
+
sql_query = self._extract_sql_query(result['choices'][0]['message']['content'])
|
144 |
+
if sql_query:
|
145 |
+
return {"success": True, "query": sql_query}
|
146 |
+
|
147 |
+
return {"success": False, "error": "Failed to generate SQL query"}
|
148 |
+
|
149 |
+
except requests.exceptions.RequestException as e:
|
150 |
+
self.logger.error(f"API request error: {str(e)}")
|
151 |
+
return {"success": False, "error": f"API request failed: {str(e)}"}
|
152 |
+
|
153 |
+
def _extract_sql_query(self, content: str) -> Optional[str]:
|
154 |
+
"""Extract SQL query from LLM response"""
|
155 |
+
# Remove markdown code blocks if present
|
156 |
+
content = content.replace("```sql", "").replace("```", "").strip()
|
157 |
+
|
158 |
+
# Basic validation
|
159 |
+
if content.upper().startswith("SELECT"):
|
160 |
+
return content
|
161 |
+
|
162 |
+
return None
|
163 |
+
|
164 |
+
def execute_query(self, sql_query: str) -> Dict[str, Any]:
|
165 |
+
"""Execute the SQL query and return results"""
|
166 |
+
try:
|
167 |
+
with sqlite3.connect(self.db_path) as conn:
|
168 |
+
cursor = conn.cursor()
|
169 |
+
cursor.execute(sql_query)
|
170 |
+
results = cursor.fetchall()
|
171 |
+
columns = [description[0] for description in cursor.description]
|
172 |
+
|
173 |
+
return {
|
174 |
+
"success": True,
|
175 |
+
"columns": columns,
|
176 |
+
"results": results
|
177 |
+
}
|
178 |
+
|
179 |
+
except sqlite3.Error as e:
|
180 |
+
self.logger.error(f"Database error: {e}")
|
181 |
+
return {
|
182 |
+
"success": False,
|
183 |
+
"error": str(e)
|
184 |
+
}
|