Mustehson commited on
Commit
2bcd76f
·
1 Parent(s): f603f74
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +51 -4
  3. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ app2.py
app.py CHANGED
@@ -2,12 +2,15 @@ import os
2
  import torch
3
  import duckdb
4
  import spaces
 
5
  import gradio as gr
6
  import pandas as pd
 
 
 
 
7
  from langchain_huggingface.llms import HuggingFacePipeline
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
9
- from langsmith import traceable
10
- from langchain import hub
11
 
12
  # Height of the Tabs Text Area
13
  TAB_LINES = 8
@@ -16,8 +19,8 @@ TAB_LINES = 8
16
  #----------CONNECT TO DATABASE----------
17
  md_token = os.getenv('MD_TOKEN')
18
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
19
-
20
  #---------------------------------------
 
21
  if torch.cuda.is_available():
22
  device = torch.device("cuda")
23
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
@@ -26,6 +29,25 @@ else:
26
  print("Using CPU")
27
  #---------------------------------------
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  #-------LOAD HUGGINGFACE PIPELINE-------
30
  tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
31
 
@@ -46,7 +68,9 @@ hf = HuggingFacePipeline(pipeline=pipe)
46
  prompt = hub.pull("sql-agent-prompt")
47
  #---------------------------------------
48
 
49
-
 
 
50
 
51
  #--------------ALL UTILS----------------
52
  # Get Databases
@@ -91,6 +115,20 @@ def get_prompt(schema, query_input):
91
  def generate_sql(prompt):
92
  result = hf.invoke(prompt)
93
  return result.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  #---------------------------------------
95
 
96
 
@@ -108,6 +146,7 @@ def text2sql(table, query_input):
108
  print(f'Schema Generated...')
109
  prompt = get_prompt(schema, query_input)
110
  print(f'Prompt Generated...')
 
111
  try:
112
  print(f'Generating SQL... {model.device}')
113
  result = generate_sql(prompt)
@@ -119,6 +158,14 @@ def text2sql(table, query_input):
119
  generated_query: "",
120
  result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
121
  }
 
 
 
 
 
 
 
 
122
  try:
123
  query_result = conn.sql(result).df()
124
 
 
2
  import torch
3
  import duckdb
4
  import spaces
5
+ import lancedb
6
  import gradio as gr
7
  import pandas as pd
8
+ import pyarrow as pa
9
+ from langchain import hub
10
+ from langsmith import traceable
11
+ from sentence_transformers import SentenceTransformer
12
  from langchain_huggingface.llms import HuggingFacePipeline
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
 
 
14
 
15
  # Height of the Tabs Text Area
16
  TAB_LINES = 8
 
19
  #----------CONNECT TO DATABASE----------
20
  md_token = os.getenv('MD_TOKEN')
21
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
 
22
  #---------------------------------------
23
+
24
  if torch.cuda.is_available():
25
  device = torch.device("cuda")
26
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
 
29
  print("Using CPU")
30
  #---------------------------------------
31
 
32
+ #--------------LanceDB-------------
33
+
34
+ lance_db = lancedb.connect(
35
+ uri=os.getenv('lancedb_uri'),
36
+ api_key=os.getenv('lancedb_api_key'),
37
+ region=os.getenv('lancedb_region')
38
+ )
39
+
40
+ lance_schema = pa.schema([
41
+ pa.field("vector", pa.list_(pa.float32())),
42
+ pa.field("sql-query", pa.utf8())
43
+ ])
44
+
45
+ try:
46
+ table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
47
+ except:
48
+ table = lance_db.open_table(name="SQL-Queries")
49
+ #---------------------------------------
50
+
51
  #-------LOAD HUGGINGFACE PIPELINE-------
52
  tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
53
 
 
68
  prompt = hub.pull("sql-agent-prompt")
69
  #---------------------------------------
70
 
71
+ #-----LOAD EMBEDDING MODEL-----
72
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
73
+ #---------------------------------------
74
 
75
  #--------------ALL UTILS----------------
76
  # Get Databases
 
115
  def generate_sql(prompt):
116
  result = hf.invoke(prompt)
117
  return result.strip()
118
+ @spaces.GPU(duration=10)
119
+ def embed_query(sql_query):
120
+ print(f'Creating Emebeddings {sql_query}')
121
+ if sql_query is not None:
122
+ embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist()
123
+ return embeddings
124
+
125
+ def log2lancedb(embeddings, sql_query):
126
+ data = [{
127
+ "sql-query": sql_query,
128
+ "vector": embeddings
129
+ }]
130
+ table.add(data)
131
+ print(f'Added to Lance DB.')
132
  #---------------------------------------
133
 
134
 
 
146
  print(f'Schema Generated...')
147
  prompt = get_prompt(schema, query_input)
148
  print(f'Prompt Generated...')
149
+
150
  try:
151
  print(f'Generating SQL... {model.device}')
152
  result = generate_sql(prompt)
 
158
  generated_query: "",
159
  result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
160
  }
161
+
162
+ try:
163
+ embeddings = embed_query(result)
164
+ log2lancedb(embeddings, result)
165
+ except Exception as e:
166
+ print("Error Generating and Logging Embeddings...")
167
+ print(e)
168
+
169
  try:
170
  query_result = conn.sql(result).df()
171
 
requirements.txt CHANGED
@@ -4,4 +4,7 @@ transformers==4.44.2
4
  duckdb==1.1.1
5
  langsmith==0.1.135
6
  langchain==0.3.4
 
 
 
7
  langchain-huggingface
 
4
  duckdb==1.1.1
5
  langsmith==0.1.135
6
  langchain==0.3.4
7
+ lancedb==0.15.0
8
+ sentence-transformers==3.2.1
9
+ pyarrow==17.0.0
10
  langchain-huggingface