import { HfInference } from "@huggingface/inference"; import { NextApiRequest, NextApiResponse } from "next"; import { createConnection, executeQuery } from "@/utils/database"; export default async function handler( req: NextApiRequest, res: NextApiResponse ) { if (req.method !== "POST") { return res.status(405).json({ message: "Method not allowed" }); } const { dbUri, userPrompt } = req.body; if (!dbUri || !userPrompt) { return res.status(400).json({ message: "Missing required fields", details: { dbUri: !dbUri ? "Database URI is required" : null, userPrompt: !userPrompt ? "Query prompt is required" : null } }); } try { const apiKey = process.env.HUGGINGFACE_API_KEY; if (!apiKey) { return res.status(500).json({ message: "Server configuration error", details: "API key is not configured" }); } let hf; try { hf = new HfInference(apiKey); } catch (error: any) { return res.status(500).json({ message: "Failed to initialize AI model", details: error.message }); } let response; const prompt = `You are a SQL expert. Convert the following text to a SQL query. Rules: - Return a JSON object with exactly this format: {"query": "YOUR SQL QUERY HERE", "chartType": "CHART TYPE HERE"} - For chartType use one of: "bar", "pie", "line", "doughnut", or null - The query should be safe and only return the requested data - Keep table names exactly as provided - Do not include any explanations or comments Example input: "Show me sales data as a pie chart" Example output: {"query": "SELECT * FROM sales LIMIT 10", "chartType": "pie"} Text: ${userPrompt}`; try { response = await hf.chatCompletion({ model: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", messages: [{ role: "user", content: prompt }], temperature: 0.1, max_tokens: 500, }); console.log(response); } catch (error: any) { console.log(error); return res.status(500).json({ message: "AI model error", details: error.message || "Failed to generate SQL query" }); } let sqlQuery = ''; let requestedChartType = null; try { const content = response?.choices?.[0]?.message?.content?.trim() || ''; const jsonMatch = content.match(/\{.*\}/s); if (jsonMatch) { const parsedResponse = JSON.parse(jsonMatch[0]); sqlQuery = parsedResponse.query?.trim(); requestedChartType = parsedResponse.chartType; } else { sqlQuery = content .replace(/```sql/gi, '') .replace(/```/gi, '') .replace(/sql query:?\s*/gi, '') .replace(/query:?\s*/gi, '') .trim(); } if (!sqlQuery) { throw new Error('No valid SQL query found in response'); } } catch (error: any) { return res.status(500).json({ message: "Failed to parse AI response", details: error.message }); } let connection; try { connection = await createConnection(dbUri); } catch (error: any) { return res.status(500).json({ message: "Database connection error", details: error.message }); } try { const results = await executeQuery(connection, sqlQuery || ''); let visualization = null; if (Array.isArray(results) && results.length > 0) { const firstRow = results[0]; const columns = Object.keys(firstRow); const dataAnalysis = { totalColumns: columns.length, numericColumns: columns.filter(col => typeof firstRow[col] === 'number' && !col.toLowerCase().includes('id') && !col.toLowerCase().includes('_id') ), dateColumns: columns.filter(col => firstRow[col] instanceof Date), stringColumns: columns.filter(col => typeof firstRow[col] === 'string' || col.toLowerCase().includes('name') || col.toLowerCase().includes('title') ), rowCount: results.length }; if (requestedChartType === 'pie' || requestedChartType === 'doughnut') { const preferredNumericColumns = dataAnalysis.numericColumns.filter(col => col.toLowerCase().includes('status') || col.toLowerCase().includes('count') || col.toLowerCase().includes('amount') || col.toLowerCase().includes('total') ); if (preferredNumericColumns.length > 0) { dataAnalysis.numericColumns = preferredNumericColumns; } } if (requestedChartType) { switch (requestedChartType) { case 'pie': case 'doughnut': if (dataAnalysis.numericColumns.length > 0) { visualization = { type: requestedChartType, config: { labels: results.map(row => dataAnalysis.stringColumns[0] ? String(row[dataAnalysis.stringColumns[0]]) : `Row ${results.indexOf(row) + 1}` ), datasets: [{ data: results.map(row => row[dataAnalysis.numericColumns[0]]), backgroundColor: results.map(() => `hsla(${Math.random() * 360}, 70%, 50%, 0.6)` ) }] } }; } break; case 'line': if (dataAnalysis.dateColumns.length > 0 || dataAnalysis.numericColumns.length > 0) { visualization = { type: 'line', config: { labels: dataAnalysis.dateColumns.length > 0 ? results.map(row => new Date(row[dataAnalysis.dateColumns[0]]).toLocaleDateString()) : results.map((_, idx) => `Point ${idx + 1}`), datasets: dataAnalysis.numericColumns.map(col => ({ label: col, data: results.map(row => row[col]), borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`, tension: 0.1 })) } }; } break; case 'bar': default: visualization = { type: 'bar', config: { labels: dataAnalysis.stringColumns.length > 0 ? results.map(row => String(row[dataAnalysis.stringColumns[0]])) : results.map((_, idx) => `Row ${idx + 1}`), datasets: dataAnalysis.numericColumns.map(col => ({ label: col, data: results.map(row => row[col]), backgroundColor: `hsla(${Math.random() * 360}, 70%, 50%, 0.6)`, borderColor: `hsl(${Math.random() * 360}, 70%, 50%)`, borderWidth: 1 })) } }; break; } } } await connection.end(); return res.status(200).json({ results, query: sqlQuery, visualization }); } catch (error: any) { await connection?.end(); return res.status(500).json({ message: "Query execution error", details: error.message }); } } catch (error: any) { console.error('Unexpected API Error:', error); return res.status(500).json({ message: "Unexpected error occurred", details: error.message || "Unknown error" }); } }