File size: 3,648 Bytes
55e1199
 
f5feb33
 
 
55e1199
f5feb33
 
55e1199
f5feb33
55e1199
 
 
bce251d
 
55e1199
f5feb33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc7e00e
 
 
f5feb33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78246a8
f5feb33
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse

# Initialize FastAPI app
app = FastAPI()

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b")
model = AutoModelForCausalLM.from_pretrained(
    "chatdb/natural-sql-7b",
    device_map="cpu",
    torch_dtype=torch.float32,
)

schema = """
CREATE TABLE users (
    id SERIAL PRIMARY KEY,
    manager_id INTEGER,
    first_name VARCHAR(100) NOT NULL,
    last_name VARCHAR(100) NOT NULL,
    designation VARCHAR(100),
    email VARCHAR(100) UNIQUE NOT NULL,
    phone VARCHAR(15) UNIQUE NOT NULL,
    password TEXT NOT NULL,
    role VARCHAR(50) NOT NULL, -- employee, manager, hr
    country VARCHAR(50) NOT NULL, -- pakistan, uae, uk
    fcm_token VARCHAR(255),
    image VARCHAR(255) DEFAULT '',
    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE leaves_balances (
    id SERIAL PRIMARY KEY,
    sick_available FLOAT NOT NULL,
    casual_available FLOAT NOT NULL,
    wfh_available FLOAT NOT NULL,
    sick_taken FLOAT NOT NULL,
    casual_taken FLOAT NOT NULL,
    wfh_taken FLOAT NOT NULL,
    user_id INTEGER UNIQUE REFERENCES users(id) ON DELETE CASCADE,
    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE leaves (
    id SERIAL PRIMARY KEY,
    user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
    manager_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
    username VARCHAR(100) NOT NULL,
    type VARCHAR(50) NOT NULL, -- sick, casual, wfh
    from_date TIMESTAMP NOT NULL,
    to_date TIMESTAMP NOT NULL,
    comments TEXT,
    status VARCHAR(50) DEFAULT 'pending', -- pending, approved, rejected
    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE user_otps (
    id SERIAL PRIMARY KEY,
    user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
    otp INTEGER NOT NULL,
    otp_expiry TIMESTAMP NOT NULL,
    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
"""

# Define the request body model using Pydantic
class QuestionRequest(BaseModel):
    question: str

@app.get("/")
async def root():    
    return{"message":"Hello, World!"}

@app.post('/generate-sql')
async def generate_sql(request: QuestionRequest):
    """
    Endpoint to generate a SQL query based on a given question.
    The schema is defined within the code (in the `schema` variable).
    """
    question = request.question
    if not question:
        raise HTTPException(status_code=400, detail="No question provided")

    prompt = f"""
    ### Task

    Generate a SQL query to answer the following question: `{question}`

    ### PostgreSQL Database Schema
    The query will run on a database with the following schema:
    {schema}

    ### Answer
    Here is the SQL query that answers the question: `{question}`
    ```sql
    """

    # Generate SQL query
    inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=100001,
        pad_token_id=100001,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )

    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sql_query = outputs[0].split("```sql")[-1].strip()

    return JSONResponse(content={'sql_query': sql_query})