hero2002 commited on
Commit
438b9ae
·
1 Parent(s): 4a1b4ed

Add application file

Browse files
Files changed (1) hide show
  1. app.py +25 -0
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from typing import List
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("juierror/text-to-sql-with-table-schema")
6
+ model = AutoModelForSeq2SeqLM.from_pretrained("juierror/text-to-sql-with-table-schema")
7
+
8
+ t = st.text_input('enter tables')
9
+ q = st.text_input('enter question')
10
+
11
+ def prepare_input(question: str, table: str):
12
+ table_prefix = "table:"
13
+ question_prefix = "question:"
14
+ inputs = f"{question_prefix} {question} {table_prefix} {table}"
15
+ input_ids = tokenizer(inputs, max_length=700, return_tensors="pt").input_ids
16
+ return input_ids
17
+
18
+ def inference(question: str, table: str) -> str:
19
+ input_data = prepare_input(question=question, table=table)
20
+ input_data = input_data.to(model.device)
21
+ outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
22
+ result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
23
+ return result
24
+
25
+ st.write(inference(q,t))