|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- laion/OIG |
|
language: |
|
- en |
|
pipeline_tag: text2text-generation |
|
tags: |
|
- nl2sql |
|
widget: |
|
- text: 'Given the following schema:\ntrack (Track_ID, Name, Location, Seating, Year_Opened)\nrace (Race_ID, Name, Class, Date, Track_ID)\nWrite a SQL query to count the number of tracks.' |
|
example_title: 'count' |
|
- text: 'Given the following schema:\nmountain (Mountain_ID, Name, Height, Prominence, Range, Country)\nclimber (Climber_ID, Name, Country, Time, Points, Mountain_ID)\nWrite a SQL query to list the countries that have more than one mountain.' |
|
example_title: 'having' |
|
--- |
|
|
|
# How to Use |
|
|
|
```python |
|
import torch |
|
from transformers import T5ForConditionalGeneration, AutoTokenizer |
|
|
|
device = torch.device("cuda:0") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("LarkAI/codet5p-770m_nl2sql_oig") |
|
model = T5ForConditionalGeneration.from_pretrained("LarkAI/codet5p-770m_nl2sql_oig").to(device) |
|
|
|
text = "Given the following schema:\ntrack (Track_ID, Name, Location, Seating, Year_Opened)\nrace (Race_ID, Name, Class, Date, Track_ID)\nWrite a SQL query to count the number of tracks." |
|
inputs = tokenizer.encode(text, return_tensors="pt").to(device) |
|
output_ids = model.generate(inputs, max_length=512) |
|
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
# SELECT COUNT( * ) FROM track |
|
``` |
|
|
|
# How to Train |
|
|
|
Dataset: |
|
- https://huggingface.co/datasets/laion/OIG#unified_sqlv1jsonl-17000 |
|
- https://huggingface.co/datasets/laion/OIG#unified_sqlv2jsonl24000 |
|
|
|
```json |
|
{ |
|
"text":"<human>: Given the following schema:\nlocation (restaurant_id, house_number, street_name, city_name)\nrestaurant (id, name, food_type, city_name, rating)\ngeographic (city_name, county, region)\nWrite a SQL query to give me some good arabic -s on buchanan in san francisco ?\n<bot>: SELECT location.house_number , restaurant.name FROM location , restaurant WHERE location.city_name = \"san francisco\" AND location.street_name = \"buchanan\" AND restaurant.food_type = \"arabic\" AND restaurant.id = location.restaurant_id AND restaurant.rating > 2.5 ;", |
|
"metadata":{ |
|
"source":"unified_sqlv1" |
|
} |
|
} |
|
``` |