Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from example_strings import example1, example2, example3 | |
# tokenizer6B = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-6B") | |
# model6B = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-6B") | |
# tokenizer2B = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-2B") | |
# model2B = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-2B") | |
# tokenizer350M = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-2B") | |
# model350M = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-2B") | |
def load_model(model_name: str): | |
tokenizer = AutoTokenizer.from_pretrained(f"NumbersStation/{model_name}") | |
model = AutoModelForCausalLM.from_pretrained(f"NumbersStation/{model_name}") | |
return tokenizer, model | |
def infer(input_text, model_choice): | |
tokenizer, model = load_model(model_choice) | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
generated_ids = model.generate(input_ids, max_length=500) | |
return (tokenizer.decode(generated_ids[0], skip_special_tokens=True)) | |
description = """The NSQL model family was published by [Numbers Station](https://www.numbersstation.ai/) and is available in three flavors: | |
- [nsql-6B](https://huggingface.co/NumbersStation/nsql-6B) | |
- [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B) | |
- [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M)) | |
This demo let's you choose which one you want to use and provides the three examples you can also find in their model cards. | |
In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question. | |
The model will then generate a SQL query that you can run against your database. | |
""" | |
iface = gr.Interface( | |
title="Text to SQL with NSQL", | |
description=description, | |
fn=infer, | |
inputs=["text", | |
gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B")], | |
outputs="text", | |
examples=[[example1, "nsql-350M"], | |
[example2, "nsql-2B"], | |
[example3, "nsql-350M"]]) | |
iface.launch() |