Spaces:
Sleeping
Sleeping
File size: 3,885 Bytes
90d8ee3 df8873e 687f4c4 5a153bd 90d8ee3 df8873e fa87edc 5a153bd df8873e 08343a9 90d8ee3 b2d9f7b dd3a93a df8873e 3a05353 08343a9 cf16b57 08343a9 90d8ee3 1c4b940 b2d9f7b dd3a93a b2d9f7b dd3a93a b2d9f7b dd3a93a 1c4b940 b030295 5e1e57a 1c4b940 b030295 1c4b940 df8873e 90d8ee3 c802c0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import sqlparse
# from modelscope import snapshot_download
# 加载模型和分词器
model_name = "defog/llama-3-sqlcoder-8b" # 使用更新的模型以提高性能
# model_name = snapshot_download("stevie/llama-3-sqlcoder-8b")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
use_cache=True,
)
def generate_sql(user_question, instructions, create_table_statements):
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{user_question}`
{instructions}
DDL statements:
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question `{user_question}`:
```sql
"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
generated_ids = model.generate(
**inputs,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=400,
do_sample=False,
num_beams=1,
)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
torch.cuda.empty_cache()
torch.cuda.synchronize()
return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
question = f"What are our top 3 products by revenue in the New York region?"
instructions = f"""- if the question cannot be answered given the database schema, return "I do not know"
- recall that the current date in YYYY-MM-DD format is 2024-09-15
"""
schema = f"""CREATE TABLE products (
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
name VARCHAR(50), -- Name of the product
price DECIMAL(10,2), -- Price of each unit of the product
quantity INTEGER -- Current quantity in stock
);
CREATE TABLE customers (
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
name VARCHAR(50), -- Name of the customer
address VARCHAR(100) -- Mailing address of the customer
);
CREATE TABLE salespeople (
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
name VARCHAR(50), -- Name of the salesperson
region VARCHAR(50) -- Geographic sales region
);
CREATE TABLE sales (
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
product_id INTEGER, -- ID of product sold
customer_id INTEGER, -- ID of customer who made purchase
salesperson_id INTEGER, -- ID of salesperson who made the sale
sale_date DATE, -- Date the sale occurred
quantity INTEGER -- Quantity of product sold
);
CREATE TABLE product_suppliers (
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
product_id INTEGER, -- Product ID supplied
supply_price DECIMAL(10,2) -- Unit price charged by supplier
);
-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""
demo = gr.Interface(
fn=generate_sql,
title="SQLCoder-8b",
description="Defog's SQLCoder-8B is a state of the art-models for generating SQL queries from natural language. ",
inputs=[
gr.Textbox(label="User Question", placeholder="Enter your question here...", value=question),
gr.Textbox(label="Instructions (optional)", placeholder="Enter any additional instructions here...", value=instructions),
gr.Textbox(label="Create Table Statements", placeholder="Enter DDL statements here...", value=schema),
],
outputs="text",
)
if __name__ == "__main__":
demo.launch(share=True) |