import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch # 加载模型和分词器 model_name = "defog/sqlcoder-7b-2" # 使用更新的模型以提高性能 tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # 降低内存占用 def generate_sql(user_question, instructions, create_table_statements): prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Generate a SQL query to answer this question: `{user_question}` {instructions} DDL statements: {create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|> The following SQL query best answers the question `{user_question}`: ```sql """ inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_length=150) sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) return sql_query question = f"What are our top 3 products by revenue in the New York region?" instructions = f"""- if the question cannot be answered given the database schema, return "I do not know" - recall that the current date in YYYY-MM-DD format is 2024-09-15 """ schema = f"""CREATE TABLE products ( product_id INTEGER PRIMARY KEY, -- Unique ID for each product name VARCHAR(50), -- Name of the product price DECIMAL(10,2), -- Price of each unit of the product quantity INTEGER -- Current quantity in stock ); CREATE TABLE customers ( customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer name VARCHAR(50), -- Name of the customer address VARCHAR(100) -- Mailing address of the customer ); CREATE TABLE salespeople ( salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson name VARCHAR(50), -- Name of the salesperson region VARCHAR(50) -- Geographic sales region ); CREATE TABLE sales ( sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale product_id INTEGER, -- ID of product sold customer_id INTEGER, -- ID of customer who made purchase salesperson_id INTEGER, -- ID of salesperson who made the sale sale_date DATE, -- Date the sale occurred quantity INTEGER -- Quantity of product sold ); CREATE TABLE product_suppliers ( supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier product_id INTEGER, -- Product ID supplied supply_price DECIMAL(10,2) -- Unit price charged by supplier ); -- sales.product_id can be joined with products.product_id -- sales.customer_id can be joined with customers.customer_id -- sales.salesperson_id can be joined with salespeople.salesperson_id -- product_suppliers.product_id can be joined with products.product_id """ demo = gr.Interface( fn=generate_sql, title="SQLCoder-8b", description="Defog's SQLCoder-8B is a state of the art-models for generating SQL queries from natural language. ", inputs=[ gr.Textbox(label="User Question", placeholder="Enter your question here...", value=question), gr.Textbox(label="Instructions (optional)", placeholder="Enter any additional instructions here...", value=instructions), gr.Textbox(label="Create Table Statements", placeholder="Enter DDL statements here...", value=schema), ], outputs="text", ) if __name__ == "__main__": demo.launch(share=True)