Mustehson commited on
Commit
c27c631
·
1 Parent(s): c0e9411

Using Transformers

Browse files
Files changed (2) hide show
  1. app.py +57 -49
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,53 +1,43 @@
1
  import os
 
2
  import duckdb
3
  import spaces
4
  import gradio as gr
5
  import pandas as pd
6
- from llama_cpp import Llama
7
- # from dotenv import load_dotenv
8
- from huggingface_hub import hf_hub_download
9
- # load_dotenv()
10
  # Height of the Tabs Text Area
11
  TAB_LINES = 8
12
  # Load Token
13
  md_token = os.getenv('MD_TOKEN')
 
 
14
  # Connect to DB
15
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}")
16
 
17
- # Custom CSS styling
18
- custom_css = """
19
- .gradio-container {
20
- background-color: #f0f4f8;
21
- }
22
- .logo {
23
- max-width: 200px;
24
- margin: 20px auto;
25
- display: block;
26
- }
27
- .gr-button {
28
- background-color: #4a90e2 !important;
29
- }
30
- .gr-button:hover {
31
- background-color: #3a7bc8 !important;
32
- }
33
- """
34
  print('Loading Model...')
35
- # Load Model
36
- # @spaces.GPU
37
- # def load_model():
38
- llama = Llama(
39
- model_path=hf_hub_download(
40
- repo_id="motherduckdb/DuckDB-NSQL-7B-v0.1-GGUF",
41
- filename="DuckDB-NSQL-7B-v0.1-q8_0.gguf",
42
- local_dir='.'
43
- ),
44
- n_ctx=2048,
45
- n_gpu_layers=0
46
- )
47
- # return llama
48
-
49
- # llama = load_model()
50
  print('Model Loaded...')
 
51
 
52
  # Get Databases
53
  def get_databases():
@@ -76,7 +66,7 @@ def get_schema(table):
76
  def get_prompt(schema, query_input):
77
  text = f"""
78
  ### Instruction:
79
- Your task is to generate valid duckdb SQL to answer the following question.
80
  ### Input:
81
  Here is the database schema that the SQL query will run on:
82
  {schema}
@@ -88,12 +78,7 @@ def get_prompt(schema, query_input):
88
  return text
89
 
90
  # Generate SQL
91
- # @spaces.GPU
92
- def generate_sql(prompt):
93
-
94
- result = llama(prompt, temperature=0.1, max_tokens=1000)
95
- return result["choices"][0]["text"]
96
-
97
  def text2sql(table, query_input):
98
  if table is None:
99
  return {
@@ -102,11 +87,18 @@ def text2sql(table, query_input):
102
  generated_query: "",
103
  result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
104
  }
 
105
  schema = get_schema(table)
 
106
  prompt = get_prompt(schema, query_input)
107
-
108
  try:
109
- result = generate_sql(prompt)
 
 
 
 
 
110
  except Exception as e:
111
  return {
112
  table_schema: schema,
@@ -116,7 +108,6 @@ def text2sql(table, query_input):
116
  }
117
  try:
118
  query_result = conn.sql(result).df()
119
- conn.close()
120
 
121
  except Exception as e:
122
  return {
@@ -126,7 +117,6 @@ def text2sql(table, query_input):
126
  result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
127
  }
128
 
129
- conn.close()
130
  return {
131
  table_schema: schema,
132
  input_prompt: prompt,
@@ -137,6 +127,24 @@ def text2sql(table, query_input):
137
  # Load Databases Names
138
  databases = get_databases()
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
141
  gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
142
 
@@ -168,8 +176,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
168
  with gr.Tab("Schema"):
169
  table_schema = gr.Textbox(lines=TAB_LINES, label="Schema", value="", interactive=False)
170
 
171
- database_dropdown.change(update_tables, inputs=database_dropdown, outputs=tables_dropdown)
172
- generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
173
 
174
  if __name__ == "__main__":
175
  demo.launch()
 
1
  import os
2
+ import torch
3
  import duckdb
4
  import spaces
5
  import gradio as gr
6
  import pandas as pd
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
+
9
+
10
+
11
  # Height of the Tabs Text Area
12
  TAB_LINES = 8
13
  # Load Token
14
  md_token = os.getenv('MD_TOKEN')
15
+
16
+ print('Connecting to DB...')
17
  # Connect to DB
18
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}")
19
 
20
+ if torch.cuda.is_available():
21
+ device = torch.device("cuda")
22
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
23
+ else:
24
+ device = torch.device("cpu")
25
+ print("Using CPU")
26
+
 
 
 
 
 
 
 
 
 
 
27
  print('Loading Model...')
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
30
+
31
+ quantization_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_compute_dtype=torch.bfloat16,
34
+ bnb_4bit_use_double_quant=True,
35
+ bnb_4bit_quant_type= "nf4")
36
+
37
+ model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1", quantization_config=quantization_config,
38
+ device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
 
39
  print('Model Loaded...')
40
+ print(f'Model Device: {model.device}')
41
 
42
  # Get Databases
43
  def get_databases():
 
66
  def get_prompt(schema, query_input):
67
  text = f"""
68
  ### Instruction:
69
+ Your task is to generate valid duckdb SQL query to answer the following question.
70
  ### Input:
71
  Here is the database schema that the SQL query will run on:
72
  {schema}
 
78
  return text
79
 
80
  # Generate SQL
81
+ @spaces.GPU
 
 
 
 
 
82
  def text2sql(table, query_input):
83
  if table is None:
84
  return {
 
87
  generated_query: "",
88
  result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
89
  }
90
+
91
  schema = get_schema(table)
92
+ print(f'Schema Generated...')
93
  prompt = get_prompt(schema, query_input)
94
+ print(f'Prompt Generated...')
95
  try:
96
+ print(f'Generating SQL... {model.device}')
97
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
98
+ input_token_len = input_ids.shape[1]
99
+ outputs = model.generate(input_ids.to(model.device), max_new_tokens=1024)
100
+ result = tokenizer.decode(outputs[0][input_token_len:], skip_special_tokens=True)
101
+ print('SQL Generated...')
102
  except Exception as e:
103
  return {
104
  table_schema: schema,
 
108
  }
109
  try:
110
  query_result = conn.sql(result).df()
 
111
 
112
  except Exception as e:
113
  return {
 
117
  result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
118
  }
119
 
 
120
  return {
121
  table_schema: schema,
122
  input_prompt: prompt,
 
127
  # Load Databases Names
128
  databases = get_databases()
129
 
130
+ # Custom CSS styling
131
+ custom_css = """
132
+ .gradio-container {
133
+ background-color: #f0f4f8;
134
+ }
135
+ .logo {
136
+ max-width: 200px;
137
+ margin: 20px auto;
138
+ display: block;
139
+ }
140
+ .gr-button {
141
+ background-color: #4a90e2 !important;
142
+ }
143
+ .gr-button:hover {
144
+ background-color: #3a7bc8 !important;
145
+ }
146
+ """
147
+
148
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
149
  gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
150
 
 
176
  with gr.Tab("Schema"):
177
  table_schema = gr.Textbox(lines=TAB_LINES, label="Schema", value="", interactive=False)
178
 
179
+ database_dropdown.change(update_tables, inputs=database_dropdown, outputs=tables_dropdown)
180
+ generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
181
 
182
  if __name__ == "__main__":
183
  demo.launch()
requirements.txt CHANGED
@@ -5,5 +5,7 @@ huggingface_hub
5
  python-dotenv
6
  scikit-build-core
7
  duckdb
8
- https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.82-cu124/llama_cpp_python-0.2.82-cp310-cp310-linux_x86_64.whl
9
  gradio
 
 
 
 
5
  python-dotenv
6
  scikit-build-core
7
  duckdb
 
8
  gradio
9
+ transformers
10
+ bitsandbytes
11
+ torch