Canstralian commited on
Commit
cc1cc95
·
verified ·
1 Parent(s): 975ac6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -57
app.py CHANGED
@@ -1,75 +1,82 @@
1
- ## https://www.kaggle.com/code/unravel/fine-tuning-of-a-sql-model
2
-
3
- import spaces
4
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
- import gradio as gr
6
  import torch
 
7
  from transformers.utils import logging
8
- from example_queries import small_query, long_query
9
 
 
10
  logging.set_verbosity_info()
11
  logger = logging.get_logger("transformers")
12
 
13
- model_name='t5-small'
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
16
-
17
- ft_model_name="daljeetsingh/sql_ft_t5small_kag" #"cssupport/t5-small-awesome-text-to-sql"
18
- ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
19
 
20
- original_model.to('cuda')
21
- ft_model.to('cuda')
 
 
22
 
23
- @spaces.GPU
24
- def translate_text(text):
25
- prompt = f"{text}"
26
- inputs = tokenizer(prompt, return_tensors='pt')
27
- inputs = inputs.to('cuda')
28
 
29
- try:
30
- output = tokenizer.decode(
31
- original_model.generate(
32
- inputs["input_ids"],
33
- max_new_tokens=200,
34
- )[0],
 
 
 
 
 
 
 
35
  skip_special_tokens=True
36
  )
37
- ft_output = tokenizer.decode(
38
- ft_model.generate(
39
- inputs["input_ids"],
40
- max_new_tokens=200,
41
- )[0],
 
 
 
42
  skip_special_tokens=True
43
  )
44
- return [output, ft_output]
45
- except Exception as e:
46
- return f"Error: {str(e)}"
47
 
 
 
 
 
48
 
49
- with gr.Blocks() as demo:
50
- with gr.Row():
51
- with gr.Column():
52
- prompt = gr.Textbox(
53
- value=small_query,
54
- lines=8,
55
- placeholder="Enter prompt...",
56
- label="Prompt"
57
- )
58
- submit_btn = gr.Button(value="Generate")
59
- with gr.Column():
60
- orig_output = gr.Textbox(label="OriginalModel", lines=2)
61
- ft_output = gr.Textbox(label="FTModel", lines=8)
62
 
63
- submit_btn.click(
64
- translate_text, inputs=[prompt], outputs=[orig_output, ft_output], api_name=False
65
- )
66
- examples = gr.Examples(
67
- examples=[
68
- [small_query],
69
- [long_query],
70
- ],
71
- inputs=[prompt],
72
- )
73
 
74
- demo.launch(show_api=False, share=True, debug=True)
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
1
+ import streamlit as st
 
 
 
 
2
  import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from transformers.utils import logging
 
5
 
6
+ # Set up logging
7
  logging.set_verbosity_info()
8
  logger = logging.get_logger("transformers")
9
 
10
+ # Model names
11
+ original_model_name = 't5-small'
12
+ fine_tuned_model_name = 'daljeetsingh/sql_ft_t5small_kag'
 
 
 
13
 
14
+ # Load models and tokenizer
15
+ tokenizer = AutoTokenizer.from_pretrained(original_model_name)
16
+ original_model = AutoModelForSeq2SeqLM.from_pretrained(original_model_name, torch_dtype=torch.bfloat16)
17
+ fine_tuned_model = AutoModelForSeq2SeqLM.from_pretrained(fine_tuned_model_name, torch_dtype=torch.bfloat16)
18
 
19
+ # Move models to GPU
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ original_model.to(device)
22
+ fine_tuned_model.to(device)
 
23
 
24
+ def generate_sql_query(prompt):
25
+ """
26
+ Generate SQL queries using both the original and fine-tuned models.
27
+ """
28
+ inputs = tokenizer(prompt, return_tensors='pt').to(device)
29
+ try:
30
+ # Generate output from the original model
31
+ original_output = original_model.generate(
32
+ inputs["input_ids"],
33
+ max_new_tokens=200,
34
+ )
35
+ original_sql = tokenizer.decode(
36
+ original_output[0],
37
  skip_special_tokens=True
38
  )
39
+
40
+ # Generate output from the fine-tuned model
41
+ fine_tuned_output = fine_tuned_model.generate(
42
+ inputs["input_ids"],
43
+ max_new_tokens=200,
44
+ )
45
+ fine_tuned_sql = tokenizer.decode(
46
+ fine_tuned_output[0],
47
  skip_special_tokens=True
48
  )
 
 
 
49
 
50
+ return original_sql, fine_tuned_sql
51
+ except Exception as e:
52
+ logger.error(f"Error: {str(e)}")
53
+ return f"Error: {str(e)}", None
54
 
55
+ # Streamlit App Interface
56
+ st.title("SQL Query Generation")
57
+ st.markdown("This application generates SQL queries based on your input prompt.")
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Input prompt
60
+ prompt = st.text_area(
61
+ "Enter your prompt here...",
62
+ value="Find all employees who joined after 2020.",
63
+ height=150
64
+ )
 
 
 
 
65
 
66
+ # Generate button
67
+ if st.button("Generate"):
68
+ if prompt:
69
+ original_sql, fine_tuned_sql = generate_sql_query(prompt)
70
+ st.subheader("Original Model Output")
71
+ st.text_area("Original SQL Query", value=original_sql, height=200)
72
+ st.subheader("Fine-Tuned Model Output")
73
+ st.text_area("Fine-Tuned SQL Query", value=fine_tuned_sql, height=200)
74
+ else:
75
+ st.warning("Please enter a prompt to generate SQL queries.")
76
 
77
+ # Examples
78
+ st.sidebar.title("Examples")
79
+ st.sidebar.markdown("""
80
+ - **Example 1**: Find all employees who joined after 2020.
81
+ - **Example 2**: Retrieve the names of customers who purchased product X in the last month.
82
+ """)