from fastapi import FastAPI, Form, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from langchain_core.prompts import PromptTemplate from langchain_community.llms import Ollama from langchain.chains import LLMChain import sqlite3 import os app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Initialize LLM llm = Ollama(model="mannix/defog-llama3-sqlcoder-8b") # LangChain Prompt prompt = PromptTemplate( input_variables=["user_question", "create_table_statements"], template="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Generate a SQL query to answer this question: {user_question} Instructions: - Use valid SQL syntax compatible with SQLite. - Use the exact table and column names provided in the question. - For date comparisons, do NOT use `INTERVAL`. Use SQLite-compatible expressions like `DATE('now', '-6 months')` or `DATE('now', '-30 days')`. - Do not use `NULLS LAST` or `NULLS FIRST` in `ORDER BY` clauses unless explicitly required, as SQLite does not support it. - Do not use `RANK()` or `OVER()` unless the question specifically requires ranking or window functions. SQLite has limited support for window functions. - When using joins, ensure that the correct join condition is specified between the related tables. Use `INNER JOIN`, `LEFT JOIN`, etc., and ensure the conditions in `ON` are correct. - If aggregating data (e.g., `SUM()`, `COUNT()`), ensure that grouping is done correctly using `GROUP BY` to avoid errors. - Avoid complex SQL expressions that may not be supported in SQLite, such as `INTERVAL` for date calculations or subqueries that are not supported by SQLite. - Return only the SQL query, no explanation. DDL statements: {create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|> The following SQL query best answers the question {user_question}: sql """ ) chain = LLMChain(llm=llm, prompt=prompt) session_config = { "ddl": None, "db_path": None, "chat_history": [] } @app.get("/", response_class=HTMLResponse) def index(request: Request): return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": True, "chat_history": []}) @app.post("/setup", response_class=HTMLResponse) async def setup(request: Request, ddl: str = Form(...), db_path: str = Form(...)): session_config["ddl"] = ddl session_config["db_path"] = db_path session_config["chat_history"] = [] return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": False, "chat_history": session_config["chat_history"]}) @app.post("/ask", response_class=HTMLResponse) async def ask(request: Request, user_question: str = Form(...)): generated_sql = chain.invoke({ "user_question": user_question, "create_table_statements": session_config["ddl"] }) result_rows = [] error = None try: conn = sqlite3.connect(session_config["db_path"]) cur = conn.cursor() cur.execute(generated_sql) result_rows = cur.fetchall() except sqlite3.Error as e: error = str(e) finally: if 'cur' in locals(): cur.close() if 'conn' in locals(): conn.close() session_config["chat_history"].append({ "question": user_question, "sql": generated_sql, "result": result_rows, "error": error }) return templates.TemplateResponse("chat.html", { "request": request, "ddl_mode": False, "chat_history": session_config["chat_history"] })