Phoenix21's picture
Update app.py
8087bbe verified
raw
history blame
1.34 kB
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
import uvicorn
class Query(BaseModel):
text: str
app = FastAPI(title="Financial Chatbot API")
# Load base model
base_model_name = "meta-llama/Meta-Llama-3-8B" # Update this if different base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
trust_remote_code=True
)
# Load adapter from your checkpoint
peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
model = PeftModel.from_pretrained(model, peft_model_id)
# Load tokenizer from base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Rest of your code remains the same...
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
)
@app.post("/generate")
def generate(query: Query):
prompt = f"Question: {query.text}\nAnswer: "
response = chat_pipe(prompt)[0]["generated_text"]
return {"response": response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)