Mustehson commited on
Commit
4aef500
Β·
1 Parent(s): 317a551
Files changed (4) hide show
  1. README.md +1 -3
  2. app.py +156 -49
  3. logo.png +0 -0
  4. requirements.txt +10 -1
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Datajoi Sql Agent
3
- emoji: πŸ’¬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
@@ -9,5 +9,3 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
-
13
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
  title: Datajoi Sql Agent
3
+ emoji: 🐣
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
9
  pinned: false
10
  license: mit
11
  ---
 
 
app.py CHANGED
@@ -1,63 +1,170 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
1
+ import os
2
+ import duckdb
3
+ import spaces
4
  import gradio as gr
5
+ import pandas as pd
6
+ from llama_cpp import Llama
7
+ # from dotenv import load_dotenv
8
+ from huggingface_hub import hf_hub_download
9
+ # load_dotenv()
10
+ # Height of the Tabs Text Area
11
+ TAB_LINES = 8
12
+ # Load Token
13
+ md_token = os.getenv('MD_TOKEN')
14
+ # Connect to DB
15
+ conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}")
16
 
17
+ # Custom CSS styling
18
+ custom_css = """
19
+ .gradio-container {
20
+ background-color: #f0f4f8;
21
+ }
22
+ .logo {
23
+ max-width: 200px;
24
+ margin: 20px auto;
25
+ display: block;
26
+ }
27
+ .gr-button {
28
+ background-color: #4a90e2 !important;
29
+ }
30
+ .gr-button:hover {
31
+ background-color: #3a7bc8 !important;
32
+ }
33
  """
34
+ print('Loading Model...')
35
+ # Load Model
36
+ llama = Llama(
37
+ model_path=hf_hub_download(
38
+ repo_id="motherduckdb/DuckDB-NSQL-7B-v0.1-GGUF",
39
+ filename="DuckDB-NSQL-7B-v0.1-q8_0.gguf",
40
+ local_dir='.'
41
+ ),
42
+ n_ctx=2048,
43
+ n_gpu_layers=-1
44
+ )
45
+ print('Model Loaded...')
46
 
47
+ # Get Databases
48
+ def get_databases():
49
+ databases = conn.execute("PRAGMA show_databases").fetchall()
50
+ return [item[0] for item in databases]
51
 
52
+ # Get Tables
53
+ def get_tables(database):
54
+ conn.execute(f"USE {database}")
55
+ tables = conn.execute("SHOW TABLES").fetchall()
56
+ return [table[0] for table in tables]
 
 
 
 
57
 
58
+ # Update Tables
59
+ def update_tables(selected_db):
60
+ tables = get_tables(selected_db)
61
+ return gr.update(choices=tables)
 
62
 
63
+ # Get Schema
64
+ def get_schema(table):
65
+ conn.execute(f"SELECT * FROM '{table}' LIMIT 1;")
66
+ result = conn.sql(f"SELECT sql FROM duckdb_tables() where table_name ='{table}';").df()
67
+ ddl_create = result.iloc[0,0]
68
+ return ddl_create
69
 
70
+ # Get Prompt
71
+ def get_prompt(schema, query_input):
72
+ text = f"""
73
+ ### Instruction:
74
+ Your task is to generate valid duckdb SQL to answer the following question.
75
+ ### Input:
76
+ Here is the database schema that the SQL query will run on:
77
+ {schema}
78
+
79
+ ### Question:
80
+ {query_input}
81
+ ### Response (use duckdb shorthand if possible):
82
+ """
83
+ return text
84
 
85
+ # Generate SQL
86
+ @spaces.GPU
87
+ def generate_sql(prompt):
88
+ result = llama(prompt, temperature=0.1, max_tokens=1000)
89
+ return result["choices"][0]["text"]
 
 
 
90
 
91
+ def text2sql(table, query_input):
92
+ if table is None:
93
+ return {
94
+ table_schema: "",
95
+ input_prompt: "",
96
+ generated_query: "",
97
+ result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
98
+ }
99
+ schema = get_schema(table)
100
+ prompt = get_prompt(schema, query_input)
101
+
102
+ try:
103
+ result = generate_sql(prompt)
104
+ except Exception as e:
105
+ return {
106
+ table_schema: schema,
107
+ input_prompt: prompt,
108
+ generated_query: "",
109
+ result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
110
+ }
111
+ try:
112
+ query_result = conn.sql(result).df()
113
+ conn.close()
114
+
115
+ except Exception as e:
116
+ return {
117
+ table_schema: schema,
118
+ input_prompt: prompt,
119
+ generated_query: result,
120
+ result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
121
+ }
122
+
123
+ conn.close()
124
+ return {
125
+ table_schema: schema,
126
+ input_prompt: prompt,
127
+ generated_query: result,
128
+ result_output:query_result
129
+ }
130
 
131
+ # Load Databases Names
132
+ databases = get_databases()
133
+
134
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
135
+ gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
136
+
137
+ gr.Markdown("""
138
+ <div style='text-align: center;'>
139
+ <strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
140
+ <br>
141
+ <span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span>
142
+ </div>
143
+ """)
144
+
145
+ with gr.Row():
146
+
147
+ with gr.Column(scale=1, variant='panel'):
148
+ database_dropdown = gr.Dropdown(choices=databases, label="Select Database", interactive=True)
149
+ tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
150
+
151
+ with gr.Column(scale=2):
152
+ query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
153
+ generate_query_button = gr.Button("Run Query", variant="primary")
154
 
155
+ with gr.Tabs():
156
+ with gr.Tab("Result"):
157
+ result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
158
+ with gr.Tab("SQL Query"):
159
+ generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
160
+ with gr.Tab("Prompt"):
161
+ input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
162
+ with gr.Tab("Schema"):
163
+ table_schema = gr.Textbox(lines=TAB_LINES, label="Schema", value="", interactive=False)
164
+
165
+ database_dropdown.change(update_tables, inputs=database_dropdown, outputs=tables_dropdown)
166
+ generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
167
 
168
  if __name__ == "__main__":
169
+ demo.launch()
170
+
logo.png ADDED
requirements.txt CHANGED
@@ -1 +1,10 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
 
 
1
+ gradio_huggingfacehub_search==0.0.7
2
+ pandas<=2.1.4
3
+ numpy<=1.26.4
4
+ httpx
5
+ huggingface_hub
6
+ python-dotenv
7
+ duckdb
8
+ scikit-build-core
9
+ duckdb
10
+ https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.82-cu124/llama_cpp_python-0.2.82-cp310-cp310-linux_x86_64.whl