Spaces:
Sleeping
Sleeping
import { HfInference } from "@huggingface/inference"; | |
import { NextApiRequest, NextApiResponse } from "next"; | |
import { createConnection, executeQuery } from "@/utils/database"; | |
export default async function handler( | |
req: NextApiRequest, | |
res: NextApiResponse | |
) { | |
if (req.method !== "POST") { | |
return res.status(405).json({ message: "Method not allowed" }); | |
} | |
const { dbUri, userPrompt } = req.body; | |
if (!dbUri || !userPrompt) { | |
return res.status(400).json({ | |
message: "Missing required fields", | |
details: { | |
dbUri: !dbUri ? "Database URI is required" : null, | |
userPrompt: !userPrompt ? "Query prompt is required" : null | |
} | |
}); | |
} | |
try { | |
const apiKey = process.env.HUGGINGFACE_API_KEY; | |
if (!apiKey) { | |
return res.status(500).json({ | |
message: "Server configuration error", | |
details: "API key is not configured" | |
}); | |
} | |
let hf; | |
try { | |
hf = new HfInference(apiKey); | |
} catch (error: any) { | |
return res.status(500).json({ | |
message: "Failed to initialize AI model", | |
details: error.message | |
}); | |
} | |
let response; | |
const prompt = `You are a SQL expert. Convert the following text to a SQL query. | |
Rules: | |
- Return a JSON object with exactly this format: {"query": "YOUR SQL QUERY HERE", "chartType": "CHART TYPE HERE"} | |
- For chartType use one of: "bar", "pie", "line", "doughnut", or null | |
- The query should be safe and only return the requested data | |
- Keep table names exactly as provided | |
- Do not include any explanations or comments | |
Example input: "Show me sales data as a pie chart" | |
Example output: {"query": "SELECT * FROM sales LIMIT 10", "chartType": "pie"} | |
Text: ${userPrompt}`; | |
try { | |
response = await hf.chatCompletion({ | |
model: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", | |
messages: [{ role: "user", content: prompt }], | |
temperature: 0.1, | |
max_tokens: 500, | |
}); | |
console.log(response); | |
} catch (error: any) { | |
console.log(error); | |
return res.status(500).json({ | |
message: "AI model error", | |
details: error.message || "Failed to generate SQL query" | |
}); | |
} | |
let sqlQuery = ''; | |
let requestedChartType = null; | |
try { | |
const content = response?.choices?.[0]?.message?.content?.trim() || ''; | |
const jsonMatch = content.match(/\{.*\}/s); | |
if (jsonMatch) { | |
const parsedResponse = JSON.parse(jsonMatch[0]); | |
sqlQuery = parsedResponse.query?.trim(); | |
requestedChartType = parsedResponse.chartType; | |
} else { | |
sqlQuery = content | |
.replace(/```sql/gi, '') | |
.replace(/```/gi, '') | |
.replace(/sql query:?\s*/gi, '') | |
.replace(/query:?\s*/gi, '') | |
.trim(); | |
} | |
if (!sqlQuery) { | |
throw new Error('No valid SQL query found in response'); | |
} | |
} catch (error: any) { | |
return res.status(500).json({ | |
message: "Failed to parse AI response", | |
details: error.message | |
}); | |
} | |
let connection; | |
try { | |
connection = await createConnection(dbUri); | |
} catch (error: any) { | |
return res.status(500).json({ | |
message: "Database connection error", | |
details: error.message | |
}); | |
} | |
try { | |
const results = await executeQuery(connection, sqlQuery || ''); | |
let visualization = null; | |
if (Array.isArray(results) && results.length > 0) { | |
const firstRow = results[0]; | |
const columns = Object.keys(firstRow); | |
const dataAnalysis = { | |
totalColumns: columns.length, | |
numericColumns: columns.filter(col => | |
typeof firstRow[col] === 'number' && | |
!col.toLowerCase().includes('id') && | |
!col.toLowerCase().includes('_id') | |
), | |
dateColumns: columns.filter(col => firstRow[col] instanceof Date), | |
stringColumns: columns.filter(col => | |
typeof firstRow[col] === 'string' || | |
col.toLowerCase().includes('name') || | |
col.toLowerCase().includes('title') | |
), | |
rowCount: results.length | |
}; | |
if (requestedChartType === 'pie' || requestedChartType === 'doughnut') { | |
const preferredNumericColumns = dataAnalysis.numericColumns.filter(col => | |
col.toLowerCase().includes('status') || | |
col.toLowerCase().includes('count') || | |
col.toLowerCase().includes('amount') || | |
col.toLowerCase().includes('total') | |
); | |
if (preferredNumericColumns.length > 0) { | |
dataAnalysis.numericColumns = preferredNumericColumns; | |
} | |
} | |
if (requestedChartType) { | |
switch (requestedChartType) { | |
case 'pie': | |
case 'doughnut': | |
if (dataAnalysis.numericColumns.length > 0) { | |
visualization = { | |
type: requestedChartType, | |
config: { | |
labels: results.map(row => | |
dataAnalysis.stringColumns[0] | |
? String(row[dataAnalysis.stringColumns[0]]) | |
: `Row ${results.indexOf(row) + 1}` | |
), | |
datasets: [{ | |
data: results.map(row => row[dataAnalysis.numericColumns[0]]), | |
backgroundColor: results.map(() => | |
`hsla(${Math.random() * 360}, 70%, 50%, 0.6)` | |
) | |
}] | |
} | |
}; | |
} | |
break; | |
case 'line': | |
if (dataAnalysis.dateColumns.length > 0 || dataAnalysis.numericColumns.length > 0) { | |
visualization = { | |
type: 'line', | |
config: { | |
labels: dataAnalysis.dateColumns.length > 0 | |
? results.map(row => new Date(row[dataAnalysis.dateColumns[0]]).toLocaleDateString()) | |
: results.map((_, idx) => `Point ${idx + 1}`), | |
datasets: dataAnalysis.numericColumns.map(col => ({ | |
label: col, | |
data: results.map(row => row[col]), | |
borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`, | |
tension: 0.1 | |
})) | |
} | |
}; | |
} | |
break; | |
case 'bar': | |
default: | |
visualization = { | |
type: 'bar', | |
config: { | |
labels: dataAnalysis.stringColumns.length > 0 | |
? results.map(row => String(row[dataAnalysis.stringColumns[0]])) | |
: results.map((_, idx) => `Row ${idx + 1}`), | |
datasets: dataAnalysis.numericColumns.map(col => ({ | |
label: col, | |
data: results.map(row => row[col]), | |
backgroundColor: `hsla(${Math.random() * 360}, 70%, 50%, 0.6)`, | |
borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`, | |
borderWidth: 1 | |
})) | |
} | |
}; | |
break; | |
} | |
} | |
} | |
await connection.end(); | |
return res.status(200).json({ | |
results, | |
query: sqlQuery, | |
visualization | |
}); | |
} catch (error: any) { | |
await connection?.end(); | |
return res.status(500).json({ | |
message: "Query execution error", | |
details: error.message | |
}); | |
} | |
} catch (error: any) { | |
console.error('Unexpected API Error:', error); | |
return res.status(500).json({ | |
message: "Unexpected error occurred", | |
details: error.message || "Unknown error" | |
}); | |
} | |
} | |