sushruthsam commited on
Commit
bf9ae34
·
verified ·
1 Parent(s): 7cdb8e4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -15
main.py CHANGED
@@ -1,28 +1,31 @@
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from ctransformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # Model loading
6
  llm = AutoModelForCausalLM.from_pretrained("sqlcoder-7b.Q4_K_S.gguf")
7
- tokenizer = AutoTokenizer.from_pretrained("sqlcoder-7b.Q4_K_S.gguf")
8
 
9
- # Pydantic object for request validation
10
  class Validation(BaseModel):
11
- prompt: str
12
 
13
- # Initialize FastAPI app
14
  app = FastAPI()
15
 
16
- # Endpoint for SQL query generation
17
  @app.post("/generate_sql")
18
  async def generate_sql(item: Validation):
19
- # Tokenize the input prompt
20
- input_ids = tokenizer.encode(item.prompt, return_tensors="pt")
 
 
21
 
22
- # Use the tokenized prompt for model completion
23
- completion = llm.generate(input_ids)
 
24
 
25
- # Decode the generated SQL query
26
- generated_sql = tokenizer.decode(completion[0], skip_special_tokens=True)
27
-
28
- return {"generated_sql": generated_sql}
 
 
 
 
 
1
+ from ctransformers import AutoModelForCausalLM
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
 
4
 
5
+ # Model loading with the new model name
6
  llm = AutoModelForCausalLM.from_pretrained("sqlcoder-7b.Q4_K_S.gguf")
 
7
 
 
8
  class Validation(BaseModel):
9
+ prompt: str # Assuming this includes both user_question and table_metadata_string
10
 
 
11
  app = FastAPI()
12
 
 
13
  @app.post("/generate_sql")
14
  async def generate_sql(item: Validation):
15
+ # Updated system prompt
16
+ system_prompt = """### Task
17
+ Generate a SQL query to answer the following question:
18
+ `{question}`
19
 
20
+ ### Database Schema
21
+ The query will run on a database with the following schema:
22
+ {schema}
23
 
24
+ ### Answer
25
+ Given the database schema, here is the SQL query that answers `{question}`:
26
+ ```sql
27
+ """
28
+ # Format the actual prompt using item.prompt
29
+ prompt = system_prompt.format(user_question="Your question here", table_metadata_string="Your schema here")
30
+ completion = llm(prompt)
31
+ return completion