sushruthsam commited on
Commit
ea3fea4
Β·
verified Β·
1 Parent(s): 9c386dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import gradio as gr
3
+ import sqlparse
4
+ import torch
5
+
6
+ model_name = "defog/sqlcoder-7b"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ # Update the model loading process with potential disk offloading
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ trust_remote_code=True,
13
+ torch_dtype=torch.float16, # Use reduced precision
14
+ device_map="auto", # Automatically distribute model layers
15
+ use_cache=True,
16
+ # Specify an offload folder if your setup requires offloading to disk
17
+ offload_folder="text_to_sql_defog_7b/offfolder", # Uncomment and set path as necessary
18
+ offload_state_dict=True, # Uncomment if offloading state dict is needed
19
+ )
20
+
21
+ def generate_response(prompt):
22
+ # Ensure the tokenizer and model are on the correct device
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model.to(device)
25
+
26
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
27
+ generated_ids = model.generate(
28
+ **inputs,
29
+ num_return_sequences=1,
30
+ max_new_tokens=400,
31
+ do_sample=False,
32
+ num_beams=1,
33
+ )
34
+
35
+ outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
36
+ formatted_sql = sqlparse.format(outputs[0], reindent=True)
37
+
38
+ return formatted_sql
39
+
40
+ iface = gr.Interface(
41
+ fn=generate_response,
42
+ inputs=gr.Textbox(lines=7, label="Input Prompt", placeholder="Enter your prompt here..."),
43
+ outputs=gr.Textbox(label="Generated SQL"),
44
+ title="SQL Query Generator",
45
+ description="Generates SQL queries based on the provided natural language prompt. Powered by the 'defog/sqlcoder-7b' model."
46
+ )
47
+
48
+ iface.launch(share=True)