File size: 3,794 Bytes
68491b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from fastapi import FastAPI, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import Ollama
from langchain.chains import LLMChain
import sqlite3
import os

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

# Initialize LLM
llm = Ollama(model="mannix/defog-llama3-sqlcoder-8b")

# LangChain Prompt
prompt = PromptTemplate(
    input_variables=["user_question", "create_table_statements"],
    template="""<|begin_of_text|><|start_header_id|>user<|end_header_id|>



Generate a SQL query to answer this question: {user_question}

Instructions:

- Use valid SQL syntax compatible with SQLite.

- Use the exact table and column names provided in the question.

- For date comparisons, do NOT use `INTERVAL`. Use SQLite-compatible expressions like `DATE('now', '-6 months')` or `DATE('now', '-30 days')`.

- Do not use `NULLS LAST` or `NULLS FIRST` in `ORDER BY` clauses unless explicitly required, as SQLite does not support it.

- Do not use `RANK()` or `OVER()` unless the question specifically requires ranking or window functions. SQLite has limited support for window functions.

- When using joins, ensure that the correct join condition is specified between the related tables. Use `INNER JOIN`, `LEFT JOIN`, etc., and ensure the conditions in `ON` are correct.

- If aggregating data (e.g., `SUM()`, `COUNT()`), ensure that grouping is done correctly using `GROUP BY` to avoid errors.

- Avoid complex SQL expressions that may not be supported in SQLite, such as `INTERVAL` for date calculations or subqueries that are not supported by SQLite.

- Return only the SQL query, no explanation.



DDL statements:

{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>



The following SQL query best answers the question {user_question}:

sql



"""
)

chain = LLMChain(llm=llm, prompt=prompt)

session_config = {
    "ddl": None,
    "db_path": None,
    "chat_history": []
}

@app.get("/", response_class=HTMLResponse)
def index(request: Request):
    return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": True, "chat_history": []})

@app.post("/setup", response_class=HTMLResponse)
async def setup(request: Request, ddl: str = Form(...), db_path: str = Form(...)):
    session_config["ddl"] = ddl
    session_config["db_path"] = db_path
    session_config["chat_history"] = []
    return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": False, "chat_history": session_config["chat_history"]})

@app.post("/ask", response_class=HTMLResponse)
async def ask(request: Request, user_question: str = Form(...)):
    generated_sql = chain.invoke({
        "user_question": user_question,
        "create_table_statements": session_config["ddl"]
    })

    result_rows = []
    error = None

    try:
        conn = sqlite3.connect(session_config["db_path"])
        cur = conn.cursor()
        cur.execute(generated_sql)
        result_rows = cur.fetchall()
    except sqlite3.Error as e:
        error = str(e)
    finally:
        if 'cur' in locals():
            cur.close()
        if 'conn' in locals():
            conn.close()

    session_config["chat_history"].append({
        "question": user_question,
        "sql": generated_sql,
        "result": result_rows,
        "error": error
    })

    return templates.TemplateResponse("chat.html", {
        "request": request,
        "ddl_mode": False,
        "chat_history": session_config["chat_history"]
    })