redfernstech commited on
Commit
68491b8
·
verified ·
1 Parent(s): ff9773c

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. company.db +3 -0
  3. main.py +97 -0
  4. requirements.txt +7 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ company.db filter=lfs diff=lfs merge=lfs -text
company.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dd011e031b3ac60d7444f703c8b9c251a151fc7d846a5226b291a60a14e0843
3
+ size 184320
main.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Form, Request
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ from fastapi.staticfiles import StaticFiles
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_community.llms import Ollama
7
+ from langchain.chains import LLMChain
8
+ import sqlite3
9
+ import os
10
+
11
+ app = FastAPI()
12
+ app.mount("/static", StaticFiles(directory="static"), name="static")
13
+ templates = Jinja2Templates(directory="templates")
14
+
15
+ # Initialize LLM
16
+ llm = Ollama(model="mannix/defog-llama3-sqlcoder-8b")
17
+
18
+ # LangChain Prompt
19
+ prompt = PromptTemplate(
20
+ input_variables=["user_question", "create_table_statements"],
21
+ template="""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
22
+
23
+ Generate a SQL query to answer this question: {user_question}
24
+ Instructions:
25
+ - Use valid SQL syntax compatible with SQLite.
26
+ - Use the exact table and column names provided in the question.
27
+ - For date comparisons, do NOT use `INTERVAL`. Use SQLite-compatible expressions like `DATE('now', '-6 months')` or `DATE('now', '-30 days')`.
28
+ - Do not use `NULLS LAST` or `NULLS FIRST` in `ORDER BY` clauses unless explicitly required, as SQLite does not support it.
29
+ - Do not use `RANK()` or `OVER()` unless the question specifically requires ranking or window functions. SQLite has limited support for window functions.
30
+ - When using joins, ensure that the correct join condition is specified between the related tables. Use `INNER JOIN`, `LEFT JOIN`, etc., and ensure the conditions in `ON` are correct.
31
+ - If aggregating data (e.g., `SUM()`, `COUNT()`), ensure that grouping is done correctly using `GROUP BY` to avoid errors.
32
+ - Avoid complex SQL expressions that may not be supported in SQLite, such as `INTERVAL` for date calculations or subqueries that are not supported by SQLite.
33
+ - Return only the SQL query, no explanation.
34
+
35
+ DDL statements:
36
+ {create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
37
+
38
+ The following SQL query best answers the question {user_question}:
39
+ sql
40
+
41
+ """
42
+ )
43
+
44
+ chain = LLMChain(llm=llm, prompt=prompt)
45
+
46
+ session_config = {
47
+ "ddl": None,
48
+ "db_path": None,
49
+ "chat_history": []
50
+ }
51
+
52
+ @app.get("/", response_class=HTMLResponse)
53
+ def index(request: Request):
54
+ return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": True, "chat_history": []})
55
+
56
+ @app.post("/setup", response_class=HTMLResponse)
57
+ async def setup(request: Request, ddl: str = Form(...), db_path: str = Form(...)):
58
+ session_config["ddl"] = ddl
59
+ session_config["db_path"] = db_path
60
+ session_config["chat_history"] = []
61
+ return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": False, "chat_history": session_config["chat_history"]})
62
+
63
+ @app.post("/ask", response_class=HTMLResponse)
64
+ async def ask(request: Request, user_question: str = Form(...)):
65
+ generated_sql = chain.invoke({
66
+ "user_question": user_question,
67
+ "create_table_statements": session_config["ddl"]
68
+ })
69
+
70
+ result_rows = []
71
+ error = None
72
+
73
+ try:
74
+ conn = sqlite3.connect(session_config["db_path"])
75
+ cur = conn.cursor()
76
+ cur.execute(generated_sql)
77
+ result_rows = cur.fetchall()
78
+ except sqlite3.Error as e:
79
+ error = str(e)
80
+ finally:
81
+ if 'cur' in locals():
82
+ cur.close()
83
+ if 'conn' in locals():
84
+ conn.close()
85
+
86
+ session_config["chat_history"].append({
87
+ "question": user_question,
88
+ "sql": generated_sql,
89
+ "result": result_rows,
90
+ "error": error
91
+ })
92
+
93
+ return templates.TemplateResponse("chat.html", {
94
+ "request": request,
95
+ "ddl_mode": False,
96
+ "chat_history": session_config["chat_history"]
97
+ })
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ jinja2
4
+ langchain
5
+ langchain-community
6
+ ollama
7
+ python-multipart