dsivakumar commited on
Commit
73a545c
·
1 Parent(s): 5c0597d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -0
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import transformers
2
+ from transformers import (
3
+ T5ForConditionalGeneration,
4
+ T5Tokenizer,
5
+ )
6
+
7
+ #load model
8
+ model = T5ForConditionalGeneration.from_pretrained('dsivakumar/text2sql')
9
+ tokenizer = T5Tokenizer.from_pretrained('dsivakumar/text2sql')
10
+
11
+ #predict function
12
+ def get_sql(query,tokenizer,model):
13
+ source_text= "English to SQL: "+query
14
+ source_text = ' '.join(source_text.split())
15
+ source = tokenizer.batch_encode_plus([source_text],max_length= 128, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
16
+ source_ids = source['input_ids'] #.squeeze()
17
+ source_mask = source['attention_mask']#.squeeze()
18
+ generated_ids = model.generate(
19
+ input_ids = source_ids.to(dtype=torch.long),
20
+ attention_mask = source_mask.to(dtype=torch.long),
21
+ max_length=150,
22
+ num_beams=2,
23
+ repetition_penalty=2.5,
24
+ length_penalty=1.0,
25
+ early_stopping=True
26
+ )
27
+ preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
28
+ return preds
29
+
30
+ #test
31
+ query="Show me the average age of of wines in Italy by provinces"
32
+ sql = get_sql(query,tokenizer,model)
33
+ print(sql)