fullstuckdev's picture
Update src/pages/api/sql.ts
75db0a2 verified
import { HfInference } from "@huggingface/inference";
import { NextApiRequest, NextApiResponse } from "next";
import { createConnection, executeQuery } from "@/utils/database";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse
) {
if (req.method !== "POST") {
return res.status(405).json({ message: "Method not allowed" });
}
const { dbUri, userPrompt } = req.body;
if (!dbUri || !userPrompt) {
return res.status(400).json({
message: "Missing required fields",
details: {
dbUri: !dbUri ? "Database URI is required" : null,
userPrompt: !userPrompt ? "Query prompt is required" : null
}
});
}
try {
const apiKey = process.env.API_TOKEN;
if (!apiKey) {
return res.status(500).json({
message: "Server configuration error",
details: "API key is not configured"
});
}
let hf;
try {
hf = new HfInference(apiKey);
} catch (error: any) {
return res.status(500).json({
message: "Failed to initialize AI model",
details: error.message
});
}
let response;
const prompt = `You are a SQL expert. Convert the following text to a SQL query.
Rules:
- Return a JSON object with exactly this format: {"query": "YOUR SQL QUERY HERE", "chartType": "CHART TYPE HERE"}
- For chartType use one of: "bar", "pie", "line", "doughnut", or null
- The query should be safe and only return the requested data
- Keep table names exactly as provided
- Do not include any explanations or comments
Example input: "Show me sales data as a pie chart"
Example output: {"query": "SELECT * FROM sales LIMIT 10", "chartType": "pie"}
Text: ${userPrompt}`;
try {
response = await hf.chatCompletion({
model: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
messages: [{ role: "user", content: prompt }],
temperature: 0.1,
max_tokens: 500,
});
console.log(response);
} catch (error: any) {
console.log(error);
return res.status(500).json({
message: "AI model error",
details: error.message || "Failed to generate SQL query"
});
}
let sqlQuery = '';
let requestedChartType = null;
try {
const content = response?.choices?.[0]?.message?.content?.trim() || '';
const jsonMatch = content.match(/\{[\s\S]*\}/);
if (jsonMatch) {
const parsedResponse = JSON.parse(jsonMatch[0]);
sqlQuery = parsedResponse.query?.trim();
requestedChartType = parsedResponse.chartType;
} else {
sqlQuery = content
.replace(/```sql/gi, '')
.replace(/```/gi, '')
.replace(/sql query:?\s*/gi, '')
.replace(/query:?\s*/gi, '')
.trim();
}
if (!sqlQuery) {
throw new Error('No valid SQL query found in response');
}
} catch (error: any) {
return res.status(500).json({
message: "Failed to parse AI response",
details: error.message
});
}
let connection;
try {
connection = await createConnection(dbUri);
} catch (error: any) {
return res.status(500).json({
message: "Database connection error",
details: error.message
});
}
try {
const results = await executeQuery(connection, sqlQuery || '');
let visualization = null;
if (Array.isArray(results) && results.length > 0) {
const firstRow = results[0];
const columns = Object.keys(firstRow);
const dataAnalysis = {
totalColumns: columns.length,
numericColumns: columns.filter((col: string) =>
typeof (firstRow as any)[col] === 'number' &&
!col.toLowerCase().includes('id') &&
!col.toLowerCase().includes('_id')
),
dateColumns: columns.filter((col: string) => (firstRow as any)[col] instanceof Date),
stringColumns: columns.filter((col: string) =>
typeof (firstRow as any)[col] === 'string' ||
col.toLowerCase().includes('name') ||
col.toLowerCase().includes('title')
),
rowCount: results.length
};
if (requestedChartType === 'pie' || requestedChartType === 'doughnut') {
const preferredNumericColumns = dataAnalysis.numericColumns.filter(col =>
col.toLowerCase().includes('status') ||
col.toLowerCase().includes('count') ||
col.toLowerCase().includes('amount') ||
col.toLowerCase().includes('total')
);
if (preferredNumericColumns.length > 0) {
dataAnalysis.numericColumns = preferredNumericColumns;
}
}
if (requestedChartType) {
switch (requestedChartType) {
case 'pie':
case 'doughnut':
if (dataAnalysis.numericColumns.length > 0) {
visualization = {
type: requestedChartType,
config: {
labels: results.map((row: any) =>
dataAnalysis.stringColumns[0]
? String(row[dataAnalysis.stringColumns[0]])
: `Row ${results.indexOf(row) + 1}`
),
datasets: [{
data: results.map((row: any) => row[dataAnalysis.numericColumns[0]]),
backgroundColor: results.map(() =>
`hsla(${Math.random() * 360}, 70%, 50%, 0.6)`
)
}]
}
};
}
break;
case 'line':
if (dataAnalysis.dateColumns.length > 0 || dataAnalysis.numericColumns.length > 0) {
visualization = {
type: 'line',
config: {
labels: dataAnalysis.dateColumns.length > 0
? results.map((row: any) => new Date(row[dataAnalysis.dateColumns[0]]).toLocaleDateString())
: results.map((_, idx) => `Point ${idx + 1}`),
datasets: dataAnalysis.numericColumns.map((col: string) => ({
label: col,
data: results.map((row: any) => row[col]),
borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`,
tension: 0.1
}))
}
};
}
break;
case 'bar':
default:
visualization = {
type: 'bar',
config: {
labels: dataAnalysis.stringColumns.length > 0
? results.map((row: any) => String(row[dataAnalysis.stringColumns[0]]))
: results.map((_, idx) => `Row ${idx + 1}`),
datasets: dataAnalysis.numericColumns.map((col: string) => ({
label: col,
data: results.map((row: any) => row[col]),
backgroundColor: `hsla(${Math.random() * 360}, 70%, 50%, 0.6)`,
borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`,
borderWidth: 1
}))
}
};
break;
}
}
}
await connection.end();
return res.status(200).json({
results,
query: sqlQuery,
visualization
});
} catch (error: any) {
await connection?.end();
return res.status(500).json({
message: "Query execution error",
details: error.message
});
}
} catch (error: any) {
console.error('Unexpected API Error:', error);
return res.status(500).json({
message: "Unexpected error occurred",
details: error.message || "Unknown error"
});
}
}