JacksonLark's picture
Update README.md
dd58119
|
raw
history blame
2.14 kB
metadata
license: apache-2.0
datasets:
  - laion/OIG
language:
  - en
pipeline_tag: text2text-generation
tags:
  - nl2sql
widget:
  - text: >-
      Given the following schema:\ntrack (Track_ID, Name, Location, Seating,
      Year_Opened)\nrace (Race_ID, Name, Class, Date, Track_ID)\nWrite a SQL
      query to count the number of tracks.
    example_title: count
  - text: >-
      Given the following schema:\nmountain (Mountain_ID, Name, Height,
      Prominence, Range, Country)\nclimber (Climber_ID, Name, Country, Time,
      Points, Mountain_ID)\nWrite a SQL query to list the countries that have
      more than one mountain.
    example_title: having

How to Use

import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer

device = torch.device("cuda:0")

tokenizer = AutoTokenizer.from_pretrained("LarkAI/codet5p-770m_nl2sql_oig")
model = T5ForConditionalGeneration.from_pretrained("LarkAI/codet5p-770m_nl2sql_oig").to(device)

text = "Given the following schema:\ntrack (Track_ID, Name, Location, Seating, Year_Opened)\nrace (Race_ID, Name, Class, Date, Track_ID)\nWrite a SQL query to count the number of tracks."
inputs = tokenizer.encode(text, return_tensors="pt").to(device)
output_ids = model.generate(inputs, max_length=512)
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# SELECT COUNT( * ) FROM track

How to Train

Dataset:

{
    "text":"<human>: Given the following schema:\nlocation (restaurant_id, house_number, street_name, city_name)\nrestaurant (id, name, food_type, city_name, rating)\ngeographic (city_name, county, region)\nWrite a SQL query to give me some good arabic -s on buchanan in san francisco ?\n<bot>: SELECT location.house_number , restaurant.name FROM location , restaurant WHERE location.city_name = \"san francisco\" AND location.street_name = \"buchanan\" AND restaurant.food_type = \"arabic\" AND restaurant.id = location.restaurant_id AND restaurant.rating > 2.5 ;",
    "metadata":{
        "source":"unified_sqlv1"
    }
}