import torch from transformers import AutoModelForCausalLM, AutoTokenizer from fastapi import FastAPI, HTTPException from pydantic import BaseModel from fastapi.responses import JSONResponse # Initialize FastAPI app app = FastAPI() # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b") model = AutoModelForCausalLM.from_pretrained( "chatdb/natural-sql-7b", device_map="cpu", torch_dtype=torch.float32, ) schema = """ CREATE TABLE users ( id SERIAL PRIMARY KEY, manager_id INTEGER, first_name VARCHAR(100) NOT NULL, last_name VARCHAR(100) NOT NULL, designation VARCHAR(100), email VARCHAR(100) UNIQUE NOT NULL, phone VARCHAR(15) UNIQUE NOT NULL, password TEXT NOT NULL, role VARCHAR(50) NOT NULL, -- employee, manager, hr country VARCHAR(50) NOT NULL, -- pakistan, uae, uk fcm_token VARCHAR(255), image VARCHAR(255) DEFAULT '', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE leaves_balances ( id SERIAL PRIMARY KEY, sick_available FLOAT NOT NULL, casual_available FLOAT NOT NULL, wfh_available FLOAT NOT NULL, sick_taken FLOAT NOT NULL, casual_taken FLOAT NOT NULL, wfh_taken FLOAT NOT NULL, user_id INTEGER UNIQUE REFERENCES users(id) ON DELETE CASCADE, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE leaves ( id SERIAL PRIMARY KEY, user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, manager_id INTEGER REFERENCES users(id) ON DELETE CASCADE, username VARCHAR(100) NOT NULL, type VARCHAR(50) NOT NULL, -- sick, casual, wfh from_date TIMESTAMP NOT NULL, to_date TIMESTAMP NOT NULL, comments TEXT, status VARCHAR(50) DEFAULT 'pending', -- pending, approved, rejected created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE user_otps ( id SERIAL PRIMARY KEY, user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, otp INTEGER NOT NULL, otp_expiry TIMESTAMP NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); """ # Define the request body model using Pydantic class QuestionRequest(BaseModel): question: str @app.get("/") async def root(): return{"message":"Hello, World!"} @app.post('/generate-sql') async def generate_sql(request: QuestionRequest): """ Endpoint to generate a SQL query based on a given question. The schema is defined within the code (in the `schema` variable). """ question = request.question if not question: raise HTTPException(status_code=400, detail="No question provided") prompt = f""" ### Task Generate a SQL query to answer the following question: `{question}` ### PostgreSQL Database Schema The query will run on a database with the following schema: {schema} ### Answer Here is the SQL query that answers the question: `{question}` ```sql """ # Generate SQL query inputs = tokenizer(prompt, return_tensors="pt").to("cpu") generated_ids = model.generate( **inputs, num_return_sequences=1, eos_token_id=100001, pad_token_id=100001, max_new_tokens=400, do_sample=False, num_beams=1, ) outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) sql_query = outputs[0].split("```sql")[-1].strip() return JSONResponse(content={'sql_query': sql_query})