JeetSuthar commited on
Commit
10ee73c
·
verified ·
1 Parent(s): 9a3b033

reverted back to basic code

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -1,33 +1,39 @@
1
  from fastapi import FastAPI, HTTPException
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import sqlite3
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
- # Load Model & Tokenizer
9
  MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-instruct"
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
14
 
15
- def generate_sql_query(user_input):
16
- """ Convert natural language input into an SQL query """
17
- inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True).to(device)
18
- outputs = model.generate(**inputs, max_length=600, do_sample=False, num_beams=2)
19
- return tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
20
 
21
- @app.post("/chat")
22
- def chat(request: dict):
23
- user_input = request.get("message", "")
24
- if not user_input:
25
- raise HTTPException(status_code=400, detail="Message cannot be empty")
 
 
 
 
 
 
26
 
 
 
 
 
 
27
  sql_query = generate_sql_query(user_input)
28
  print(f"Generated SQL Query: {sql_query}")
 
29
  return {"response": sql_query}
30
 
31
  @app.get("/")
32
  def home():
33
- return {"message": "DeepSeek SQL Query API is running"}
 
1
  from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import sqlite3
5
  import torch
6
 
7
  app = FastAPI()
8
 
9
+ # Load the DeepSeek model and tokenizer
10
  MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-instruct"
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to("cpu") # Use "cuda" if available
13
 
 
 
 
 
 
14
 
15
+ class ChatRequest(BaseModel):
16
+ message: str
17
+
18
+ def generate_sql_query(user_input: str) -> str:
19
+ """
20
+ Generate an SQL query from a natural language query using the DeepSeek model.
21
+ """
22
+ inputs = tokenizer(user_input, return_tensors="pt")
23
+ outputs = model.generate(**inputs, max_length=600)
24
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
+ return sql_query
26
 
27
+
28
+ @app.post("/chat")
29
+ def chat(request: ChatRequest):
30
+ user_input = request.message
31
+
32
  sql_query = generate_sql_query(user_input)
33
  print(f"Generated SQL Query: {sql_query}")
34
+
35
  return {"response": sql_query}
36
 
37
  @app.get("/")
38
  def home():
39
+ return {"message": "DeepSeek SQL Query API is running"}