Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from fastapi.responses import JSONResponse | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b") | |
model = AutoModelForCausalLM.from_pretrained( | |
"chatdb/natural-sql-7b", | |
device_map="cpu", | |
torch_dtype=torch.float32, | |
) | |
schema = """ | |
CREATE TABLE users ( | |
id SERIAL PRIMARY KEY, | |
manager_id INTEGER, | |
first_name VARCHAR(100) NOT NULL, | |
last_name VARCHAR(100) NOT NULL, | |
designation VARCHAR(100), | |
email VARCHAR(100) UNIQUE NOT NULL, | |
phone VARCHAR(15) UNIQUE NOT NULL, | |
password TEXT NOT NULL, | |
role VARCHAR(50) NOT NULL, -- employee, manager, hr | |
country VARCHAR(50) NOT NULL, -- pakistan, uae, uk | |
fcm_token VARCHAR(255), | |
image VARCHAR(255) DEFAULT '', | |
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
); | |
CREATE TABLE leaves_balances ( | |
id SERIAL PRIMARY KEY, | |
sick_available FLOAT NOT NULL, | |
casual_available FLOAT NOT NULL, | |
wfh_available FLOAT NOT NULL, | |
sick_taken FLOAT NOT NULL, | |
casual_taken FLOAT NOT NULL, | |
wfh_taken FLOAT NOT NULL, | |
user_id INTEGER UNIQUE REFERENCES users(id) ON DELETE CASCADE, | |
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
); | |
CREATE TABLE leaves ( | |
id SERIAL PRIMARY KEY, | |
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
manager_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
username VARCHAR(100) NOT NULL, | |
type VARCHAR(50) NOT NULL, -- sick, casual, wfh | |
from_date TIMESTAMP NOT NULL, | |
to_date TIMESTAMP NOT NULL, | |
comments TEXT, | |
status VARCHAR(50) DEFAULT 'pending', -- pending, approved, rejected | |
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
); | |
CREATE TABLE user_otps ( | |
id SERIAL PRIMARY KEY, | |
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
otp INTEGER NOT NULL, | |
otp_expiry TIMESTAMP NOT NULL, | |
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
); | |
""" | |
# Define the request body model using Pydantic | |
class QuestionRequest(BaseModel): | |
question: str | |
async def root(): | |
return{"message":"Hello, World!"} | |
async def generate_sql(request: QuestionRequest): | |
""" | |
Endpoint to generate a SQL query based on a given question. | |
The schema is defined within the code (in the `schema` variable). | |
""" | |
question = request.question | |
if not question: | |
raise HTTPException(status_code=400, detail="No question provided") | |
prompt = f""" | |
### Task | |
Generate a SQL query to answer the following question: `{question}` | |
### PostgreSQL Database Schema | |
The query will run on a database with the following schema: | |
{schema} | |
### Answer | |
Here is the SQL query that answers the question: `{question}` | |
```sql | |
""" | |
# Generate SQL query | |
inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
generated_ids = model.generate( | |
**inputs, | |
num_return_sequences=1, | |
eos_token_id=100001, | |
pad_token_id=100001, | |
max_new_tokens=400, | |
do_sample=False, | |
num_beams=1, | |
) | |
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
sql_query = outputs[0].split("```sql")[-1].strip() | |
return JSONResponse(content={'sql_query': sql_query}) | |