|
import os
|
|
from flask import Flask, request, jsonify, render_template
|
|
from transformers import pipeline
|
|
import mysql.connector
|
|
from groq import Groq
|
|
|
|
app = Flask(name)
|
|
|
|
|
|
pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2")
|
|
|
|
|
|
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
|
|
|
|
|
DB_CONFIG = {
|
|
'host': 'auth-db579.hstgr.io',
|
|
'user': 'u121769371_ki_aiml_test',
|
|
'password': os.environ.get("DB_PASSWORD"),
|
|
'database': 'u121769371_ki_aiml_test'
|
|
}
|
|
|
|
def generate_sql(text):
|
|
output = pipe(text, max_new_tokens=50)
|
|
return output[0]['generated_text']
|
|
|
|
def execute_query(query):
|
|
try:
|
|
connection = mysql.connector.connect(**DB_CONFIG)
|
|
cursor = connection.cursor()
|
|
cursor.execute(query)
|
|
results = cursor.fetchall()
|
|
cursor.close()
|
|
connection.close()
|
|
return results
|
|
except mysql.connector.Error as err:
|
|
print(f"Error: {err}")
|
|
return None
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
@app.route('/chatbot', methods=['POST'])
|
|
def chatbot():
|
|
data = request.json
|
|
user_query = data.get('text')
|
|
|
|
if not user_query:
|
|
return jsonify({"error": "No query provided"}), 400
|
|
|
|
try:
|
|
|
|
sql_query = generate_sql(user_query)
|
|
|
|
|
|
query_result = execute_query(sql_query)
|
|
|
|
if query_result is None:
|
|
return jsonify({"error": "Database query execution failed"}), 500
|
|
|
|
|
|
prompt = f"Original query: {user_query}\nSQL query: {sql_query}\nQuery result: {query_result}\nPlease provide a natural language summary of the query result."
|
|
|
|
chat_completion = groq_client.chat.completions.create(
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": prompt,
|
|
}
|
|
],
|
|
model="llama3-8b-8192",
|
|
)
|
|
|
|
natural_language_response = chat_completion.choices[0].message.content
|
|
|
|
return jsonify({"response": natural_language_response})
|
|
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
if name == 'main':
|
|
app.run(host='0.0.0.0', port=8000) |