File size: 1,309 Bytes
80d5bbb 960e156 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import sqlite3
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
# Load fine-tuned text-to-SQL model
MODEL_NAME = "budecosystem/sql-millennials-13b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) #AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
def generate_sql(query):
print(query)
inputs = tokenizer(query, return_tensors="pt")
outputs = model.generate(**inputs)
print(outputs)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("======>", sql_query)
return sql_query
def execute_sql(sql_query):
conn = sqlite3.connect("./ecommerce.db")
cursor = conn.cursor()
try:
cursor.execute(sql_query)
result = cursor.fetchall()
conn.commit()
except Exception as e:
result = str(e)
conn.close()
return result
class QueryRequest(BaseModel):
text: str
@app.post("/generate_sql/")
def get_sql(query: QueryRequest):
sql_query = generate_sql(query.text)
result = execute_sql(sql_query)
return {"sql": sql_query, "result": result}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |