Mustehson commited on
Commit
cd66976
·
1 Parent(s): 11dc9b2

Added Langsmith

Browse files
Files changed (2) hide show
  1. app.py +9 -8
  2. requirements.txt +3 -1
app.py CHANGED
@@ -4,8 +4,9 @@ import duckdb
4
  import spaces
5
  import gradio as gr
6
  import pandas as pd
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
-
 
9
 
10
 
11
  # Height of the Tabs Text Area
@@ -36,6 +37,9 @@ quantization_config = BitsAndBytesConfig(
36
 
37
  model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1", quantization_config=quantization_config,
38
  device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
39
  print('Model Loaded...')
40
  print(f'Model Device: {model.device}')
41
 
@@ -88,13 +92,10 @@ def get_prompt(schema, query_input):
88
  return text
89
 
90
  @spaces.GPU(duration=60)
 
91
  def generate_sql(prompt):
92
-
93
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
94
- input_token_len = input_ids.shape[1]
95
- outputs = model.generate(input_ids.to(model.device), max_new_tokens=1024)
96
- result = tokenizer.decode(outputs[0][input_token_len:], skip_special_tokens=True)
97
- return result
98
 
99
  # Generate SQL
100
  def text2sql(table, query_input):
 
4
  import spaces
5
  import gradio as gr
6
  import pandas as pd
7
+ from langchain_huggingface.llms import HuggingFacePipeline
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
9
+ from langsmith import traceable
10
 
11
 
12
  # Height of the Tabs Text Area
 
37
 
38
  model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1", quantization_config=quantization_config,
39
  device_map="auto", torch_dtype=torch.bfloat16)
40
+
41
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
42
+ hf = HuggingFacePipeline(pipeline=pipe)
43
  print('Model Loaded...')
44
  print(f'Model Device: {model.device}')
45
 
 
92
  return text
93
 
94
  @spaces.GPU(duration=60)
95
+ @traceable()
96
  def generate_sql(prompt):
97
+ result = hf.invoke(prompt)
98
+ return result.strip()
 
 
 
 
99
 
100
  # Generate SQL
101
  def text2sql(table, query_input):
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  accelerate
2
  bitsandbytes
3
  transformers
4
- duckdb
 
 
 
1
  accelerate
2
  bitsandbytes
3
  transformers
4
+ duckdb
5
+ langsmith
6
+ langchain-huggingface