anilajax commited on
Commit
2e751e0
·
verified ·
1 Parent(s): cdb0e6a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -1
README.md CHANGED
@@ -15,4 +15,43 @@ tags:
15
  metrics:
16
  - accuracy
17
  - character
18
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  metrics:
16
  - accuracy
17
  - character
18
+ library_name: transformers
19
+ ---
20
+
21
+ Industry standard text to sql generation with high accuracy.
22
+
23
+ sample code to start with:
24
+
25
+ import torch
26
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
27
+
28
+ # Initialize the tokenizer from Hugging Face Transformers library
29
+ tokenizer = T5Tokenizer.from_pretrained('anilajax/text2sql_industry_standard')
30
+
31
+ # Load the model
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model = T5ForConditionalGeneration.from_pretrained('anilajax/text2sql_industry_standard')
34
+ model = model.to(device)
35
+ model.eval()
36
+
37
+ def generate_sql(input_prompt):
38
+ # Tokenize the input prompt
39
+ inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
40
+
41
+ # Forward pass
42
+ with torch.no_grad():
43
+ outputs = model.generate(**inputs, max_length=512)
44
+
45
+ # Decode the output IDs to a string (SQL query in this case)
46
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+
48
+ return generated_sql
49
+
50
+
51
+ input_prompt = "provide count of students where class = 10"
52
+
53
+ generated_sql = generate_sql(input_prompt)
54
+
55
+ print(f"The generated SQL query is: {generated_sql}")
56
+ ## expected output - SELECT COUNT(*) FROM students WHERE class = 10
57
+