|
--- |
|
license: llama2 |
|
datasets: |
|
- bigcode/the-stack |
|
- NumbersStation/NSText2SQL |
|
language: |
|
- en |
|
--- |
|
# nova-nsql-Llama-2-70B |
|
|
|
## Model Description |
|
|
|
NSQL is a family of autoregressive open-source large foundation models (FMs) designed specifically for SQL generation tasks. |
|
|
|
In this repository we are introducing a new member of NSQL, NSQL-Llama-2-70B. It's based on Meta's original [Llama-2 70B model](https://huggingface.co/meta-llama/Llama-2-70b) and further pre-trained on a dataset of general SQL queries and then fine-tuned on a dataset composed of text-to-SQL pairs. |
|
|
|
Use of this model is governed by the Meta’s Llama 2 Community License Agreement. Please review and accept the license before downloading the model weights and tokenizer |
|
|
|
### Basic Information |
|
|
|
<!-- Provide the basic links for the model. --> |
|
- **Blog Post**: [Link](TBA) |
|
- **HF Hosting**: [Chat with me!](TBA) |
|
|
|
## Training Data |
|
|
|
The general SQL queries are the SQL subset from [The Stack](https://huggingface.co/datasets/bigcode/the-stack), containing 1M training samples. The labeled text-to-SQL pairs come from the NSText2SQL dataset (https://huggingface.co/datasets/NumbersStation/NSText2SQL). |
|
|
|
## Evaluation Data |
|
|
|
We evaluate our models on three text-to-SQL benchmarks: Spider, Bird, and text2sql. |
|
|
|
## Training Procedure |
|
|
|
NSQL was trained using cross-entropy loss to maximize the likelihood of sequential inputs. For finetuning on text-to-SQL pairs, we only compute the loss over the SQL portion of the pair. The model is trained using SambaNova's in-house Reconfigurable Dataflow Unit (RDU), leveraging data and model parallelism. We pre-trained for 2 epochs and fine-tuned for 10 epochs. |
|
|
|
### Hyperparameters |
|
|
|
**Continous pretraining on Stack-SQL dataset** |
|
|
|
- Hardware: SambaNova Reconfigurable Dataflow Unit (RDU) |
|
- Optimizer: AdamW |
|
- Epochs: 2 |
|
- Global Batch size: 256 |
|
- Batch tokens: 256 * 4096 = 1,048,576 tokens |
|
- Learning Rate: 1e-5 |
|
- Learning Rate Scheduler: Fixed |
|
- Warmup Steps: 0 |
|
- Weight decay: 0.1 |
|
|
|
**Finetuning on NSText2SQL dataset** |
|
|
|
- Hardware: SambaNova Reconfigurable Dataflow Unit (RDU) |
|
- Optimizer: AdamW |
|
- Epochs: 10 |
|
- Global Batch size: 64 |
|
- Batch tokens: 64 * 4096 = 262,144 tokens |
|
- Learning Rate: 1e-5 |
|
- Learning Rate Scheduler: Cosine Schedule with Warmup |
|
- Warmup Steps: 0 |
|
- End Learning Ratio: 0.1 |
|
- Weight decay: 0.1 |
|
## Intended Use and Limitations |
|
|
|
The model was designed for text-to-SQL generation tasks from given table schema and natural language prompts. The model works best with the prompt format defined below and outputting `SELECT` queries. |
|
|
|
## How to Use |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
tokenizer = AutoTokenizer.from_pretrained("sambanovasystems/nova-nsql-Llama-2-70B") |
|
model = AutoModelForCausalLM.from_pretrained("sambanovasystems/nova-nsql-Llama-2-70B", torch_dtype=torch.bfloat16) |
|
text = "CREATE TABLE stadium ( |
|
stadium_id number, |
|
location text, |
|
name text, |
|
capacity number, |
|
highest number, |
|
lowest number, |
|
average number |
|
) |
|
|
|
CREATE TABLE singer ( |
|
singer_id number, |
|
name text, |
|
country text, |
|
song_name text, |
|
song_release_year text, |
|
age number, |
|
is_male others |
|
) |
|
|
|
CREATE TABLE concert ( |
|
concert_id number, |
|
concert_name text, |
|
theme text, |
|
stadium_id text, |
|
year text |
|
) |
|
|
|
CREATE TABLE singer_in_concert ( |
|
concert_id number, |
|
singer_id text |
|
) |
|
|
|
|
|
-- Using valid SQLite, answer the following questions for the tables provided above. |
|
|
|
-- What is the average, minimum, and maximum age of all singers from France? |
|
SELECT" |
|
input_ids = tokenizer(text, return_tensors="pt").input_ids |
|
|
|
generated_ids = model.generate(input_ids, max_length=500) |
|
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) |
|
``` |