sql-lite-chat / main.py
redfernstech's picture
Upload 3 files
68491b8 verified
raw
history blame
3.79 kB
from fastapi import FastAPI, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import Ollama
from langchain.chains import LLMChain
import sqlite3
import os
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Initialize LLM
llm = Ollama(model="mannix/defog-llama3-sqlcoder-8b")
# LangChain Prompt
prompt = PromptTemplate(
input_variables=["user_question", "create_table_statements"],
template="""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: {user_question}
Instructions:
- Use valid SQL syntax compatible with SQLite.
- Use the exact table and column names provided in the question.
- For date comparisons, do NOT use `INTERVAL`. Use SQLite-compatible expressions like `DATE('now', '-6 months')` or `DATE('now', '-30 days')`.
- Do not use `NULLS LAST` or `NULLS FIRST` in `ORDER BY` clauses unless explicitly required, as SQLite does not support it.
- Do not use `RANK()` or `OVER()` unless the question specifically requires ranking or window functions. SQLite has limited support for window functions.
- When using joins, ensure that the correct join condition is specified between the related tables. Use `INNER JOIN`, `LEFT JOIN`, etc., and ensure the conditions in `ON` are correct.
- If aggregating data (e.g., `SUM()`, `COUNT()`), ensure that grouping is done correctly using `GROUP BY` to avoid errors.
- Avoid complex SQL expressions that may not be supported in SQLite, such as `INTERVAL` for date calculations or subqueries that are not supported by SQLite.
- Return only the SQL query, no explanation.
DDL statements:
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question {user_question}:
sql
"""
)
chain = LLMChain(llm=llm, prompt=prompt)
session_config = {
"ddl": None,
"db_path": None,
"chat_history": []
}
@app.get("/", response_class=HTMLResponse)
def index(request: Request):
return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": True, "chat_history": []})
@app.post("/setup", response_class=HTMLResponse)
async def setup(request: Request, ddl: str = Form(...), db_path: str = Form(...)):
session_config["ddl"] = ddl
session_config["db_path"] = db_path
session_config["chat_history"] = []
return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": False, "chat_history": session_config["chat_history"]})
@app.post("/ask", response_class=HTMLResponse)
async def ask(request: Request, user_question: str = Form(...)):
generated_sql = chain.invoke({
"user_question": user_question,
"create_table_statements": session_config["ddl"]
})
result_rows = []
error = None
try:
conn = sqlite3.connect(session_config["db_path"])
cur = conn.cursor()
cur.execute(generated_sql)
result_rows = cur.fetchall()
except sqlite3.Error as e:
error = str(e)
finally:
if 'cur' in locals():
cur.close()
if 'conn' in locals():
conn.close()
session_config["chat_history"].append({
"question": user_question,
"sql": generated_sql,
"result": result_rows,
"error": error
})
return templates.TemplateResponse("chat.html", {
"request": request,
"ddl_mode": False,
"chat_history": session_config["chat_history"]
})