Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Form, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from langchain_core.prompts import PromptTemplate | |
from langchain_community.llms import Ollama | |
from langchain.chains import LLMChain | |
import sqlite3 | |
import os | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
# Initialize LLM | |
llm = Ollama(model="mannix/defog-llama3-sqlcoder-8b") | |
# LangChain Prompt | |
prompt = PromptTemplate( | |
input_variables=["user_question", "create_table_statements"], | |
template="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
Generate a SQL query to answer this question: {user_question} | |
Instructions: | |
- Use valid SQL syntax compatible with SQLite. | |
- Use the exact table and column names provided in the question. | |
- For date comparisons, do NOT use `INTERVAL`. Use SQLite-compatible expressions like `DATE('now', '-6 months')` or `DATE('now', '-30 days')`. | |
- Do not use `NULLS LAST` or `NULLS FIRST` in `ORDER BY` clauses unless explicitly required, as SQLite does not support it. | |
- Do not use `RANK()` or `OVER()` unless the question specifically requires ranking or window functions. SQLite has limited support for window functions. | |
- When using joins, ensure that the correct join condition is specified between the related tables. Use `INNER JOIN`, `LEFT JOIN`, etc., and ensure the conditions in `ON` are correct. | |
- If aggregating data (e.g., `SUM()`, `COUNT()`), ensure that grouping is done correctly using `GROUP BY` to avoid errors. | |
- Avoid complex SQL expressions that may not be supported in SQLite, such as `INTERVAL` for date calculations or subqueries that are not supported by SQLite. | |
- Return only the SQL query, no explanation. | |
DDL statements: | |
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
The following SQL query best answers the question {user_question}: | |
sql | |
""" | |
) | |
chain = LLMChain(llm=llm, prompt=prompt) | |
session_config = { | |
"ddl": None, | |
"db_path": None, | |
"chat_history": [] | |
} | |
def index(request: Request): | |
return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": True, "chat_history": []}) | |
async def setup(request: Request, ddl: str = Form(...), db_path: str = Form(...)): | |
session_config["ddl"] = ddl | |
session_config["db_path"] = db_path | |
session_config["chat_history"] = [] | |
return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": False, "chat_history": session_config["chat_history"]}) | |
async def ask(request: Request, user_question: str = Form(...)): | |
generated_sql = chain.invoke({ | |
"user_question": user_question, | |
"create_table_statements": session_config["ddl"] | |
}) | |
result_rows = [] | |
error = None | |
try: | |
conn = sqlite3.connect(session_config["db_path"]) | |
cur = conn.cursor() | |
cur.execute(generated_sql) | |
result_rows = cur.fetchall() | |
except sqlite3.Error as e: | |
error = str(e) | |
finally: | |
if 'cur' in locals(): | |
cur.close() | |
if 'conn' in locals(): | |
conn.close() | |
session_config["chat_history"].append({ | |
"question": user_question, | |
"sql": generated_sql, | |
"result": result_rows, | |
"error": error | |
}) | |
return templates.TemplateResponse("chat.html", { | |
"request": request, | |
"ddl_mode": False, | |
"chat_history": session_config["chat_history"] | |
}) | |