nileshhanotia commited on
Commit
5e16d9f
·
verified ·
1 Parent(s): 39485d9

Create llm_service.py

Browse files
Files changed (1) hide show
  1. llm_service.py +184 -0
llm_service.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import logging
3
+ import requests
4
+ from typing import Dict, Any, List, Optional
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class TableInfo:
9
+ """Store table information including schema and relationships"""
10
+ name: str
11
+ columns: List[Dict[str, Any]]
12
+ relationships: List[Dict[str, str]]
13
+
14
+ class LLMService:
15
+ def __init__(self, api_key: str, db_path: str):
16
+ self.api_key = api_key
17
+ self.db_path = db_path
18
+ self.api_url = "https://api.groq.com/openai/v1/chat/completions"
19
+ self.headers = {
20
+ "Authorization": f"Bearer {api_key}",
21
+ "Content-Type": "application/json"
22
+ }
23
+ logging.basicConfig(level=logging.INFO)
24
+ self.logger = logging.getLogger(__name__)
25
+ self.table_info = self._load_database_schema()
26
+
27
+ def _load_database_schema(self) -> Dict[str, TableInfo]:
28
+ """Load complete database schema with relationships"""
29
+ schema_info = {}
30
+ try:
31
+ with sqlite3.connect(self.db_path) as conn:
32
+ cursor = conn.cursor()
33
+
34
+ # Get all tables
35
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
36
+ tables = cursor.fetchall()
37
+
38
+ for table in tables:
39
+ table_name = table[0]
40
+ # Get columns
41
+ cursor.execute(f"PRAGMA table_info({table_name});")
42
+ columns = [{
43
+ 'name': col[1],
44
+ 'type': col[2],
45
+ 'nullable': not col[3],
46
+ 'primary_key': bool(col[5])
47
+ } for col in cursor.fetchall()]
48
+
49
+ # Get foreign keys
50
+ cursor.execute(f"PRAGMA foreign_key_list({table_name});")
51
+ relationships = [{
52
+ 'from_table': table_name,
53
+ 'to_table': fk[2],
54
+ 'from_column': fk[3],
55
+ 'to_column': fk[4]
56
+ } for fk in cursor.fetchall()]
57
+
58
+ schema_info[table_name] = TableInfo(
59
+ name=table_name,
60
+ columns=columns,
61
+ relationships=relationships
62
+ )
63
+
64
+ return schema_info
65
+
66
+ except sqlite3.Error as e:
67
+ self.logger.error(f"Database error while loading schema: {e}")
68
+ return {}
69
+
70
+ def _prepare_schema_prompt(self) -> str:
71
+ """Prepare a comprehensive schema description for the LLM"""
72
+ prompt = "Database Schema:\n\n"
73
+
74
+ # Add table schemas
75
+ for table_name, info in self.table_info.items():
76
+ prompt += f"Table: {table_name}\n"
77
+ prompt += "Columns:\n"
78
+ for col in info.columns:
79
+ prompt += f"- {col['name']} ({col['type']})"
80
+ if col['primary_key']:
81
+ prompt += " PRIMARY KEY"
82
+ if not col['nullable']:
83
+ prompt += " NOT NULL"
84
+ prompt += "\n"
85
+
86
+ # Add relationships
87
+ if info.relationships:
88
+ prompt += "Relationships:\n"
89
+ for rel in info.relationships:
90
+ prompt += f"- {rel['from_table']}.{rel['from_column']} -> {rel['to_table']}.{rel['to_column']}\n"
91
+ prompt += "\n"
92
+
93
+ return prompt
94
+
95
+ def convert_to_sql_query(self, natural_query: str) -> Dict[str, Any]:
96
+ """Convert natural language to SQL query using Groq API"""
97
+ schema_prompt = self._prepare_schema_prompt()
98
+
99
+ system_prompt = f"""
100
+ You are an expert SQL query generator. Your task is to convert natural language queries into valid SQL queries based on the provided database schema.
101
+
102
+ {schema_prompt}
103
+
104
+ Rules for generating SQL queries:
105
+ 1. Use proper JOIN syntax when relating multiple tables
106
+ 2. Consider table relationships and use appropriate JOIN conditions
107
+ 3. Handle NULL values appropriately
108
+ 4. Use table aliases when necessary for clarity
109
+ 5. Return only the requested columns, use * only when specifically asked
110
+ 6. Include WHERE clauses based on the natural language conditions
111
+ 7. Use appropriate aggregation functions when needed (COUNT, SUM, AVG, etc.)
112
+
113
+ Generate a SQL query for the following natural language request:
114
+ {natural_query}
115
+
116
+ Return only the SQL query without any explanation.
117
+ """
118
+
119
+ payload = {
120
+ "model": "llama3-8b-8192",
121
+ "messages": [
122
+ {"role": "system", "content": system_prompt},
123
+ {"role": "user", "content": natural_query}
124
+ ],
125
+ "max_tokens": 500,
126
+ "temperature": 0.1
127
+ }
128
+
129
+ try:
130
+ self.logger.info(f"Sending request to Groq API for query: {natural_query}")
131
+ response = requests.post(
132
+ self.api_url,
133
+ headers=self.headers,
134
+ json=payload,
135
+ timeout=30
136
+ )
137
+ response.raise_for_status()
138
+
139
+ result = response.json()
140
+ self.logger.info(f"Received response from Groq API")
141
+
142
+ if 'choices' in result and result['choices']:
143
+ sql_query = self._extract_sql_query(result['choices'][0]['message']['content'])
144
+ if sql_query:
145
+ return {"success": True, "query": sql_query}
146
+
147
+ return {"success": False, "error": "Failed to generate SQL query"}
148
+
149
+ except requests.exceptions.RequestException as e:
150
+ self.logger.error(f"API request error: {str(e)}")
151
+ return {"success": False, "error": f"API request failed: {str(e)}"}
152
+
153
+ def _extract_sql_query(self, content: str) -> Optional[str]:
154
+ """Extract SQL query from LLM response"""
155
+ # Remove markdown code blocks if present
156
+ content = content.replace("```sql", "").replace("```", "").strip()
157
+
158
+ # Basic validation
159
+ if content.upper().startswith("SELECT"):
160
+ return content
161
+
162
+ return None
163
+
164
+ def execute_query(self, sql_query: str) -> Dict[str, Any]:
165
+ """Execute the SQL query and return results"""
166
+ try:
167
+ with sqlite3.connect(self.db_path) as conn:
168
+ cursor = conn.cursor()
169
+ cursor.execute(sql_query)
170
+ results = cursor.fetchall()
171
+ columns = [description[0] for description in cursor.description]
172
+
173
+ return {
174
+ "success": True,
175
+ "columns": columns,
176
+ "results": results
177
+ }
178
+
179
+ except sqlite3.Error as e:
180
+ self.logger.error(f"Database error: {e}")
181
+ return {
182
+ "success": False,
183
+ "error": str(e)
184
+ }