Phoenix21's picture
Update app.py
b6332c3 verified
raw
history blame
2.7 kB
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import uvicorn
import json
from huggingface_hub import hf_hub_download, login
# Authenticate with Hugging Face Hub
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
# Define a Pydantic model for request validation
class Query(BaseModel):
text: str
app = FastAPI(title="Financial Chatbot API")
# Load the base model
base_model_name = "meta-llama/Llama-3.2-3B" # Update if using a different base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
trust_remote_code=True
)
# Load adapter from your checkpoint with a fix for the 'eva_config' issue
peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
# Manually download and load the adapter config to filter out problematic fields
try:
# Download the adapter_config.json file
config_file = hf_hub_download(
repo_id=peft_model_id,
filename="adapter_config.json",
token=HF_TOKEN
)
# Load and clean the config
with open(config_file, 'r') as f:
config_dict = json.load(f)
# Remove problematic fields if they exist
if "eva_config" in config_dict:
del config_dict["eva_config"]
# Load the adapter directly with the cleaned config
model = PeftModel.from_pretrained(
model,
peft_model_id,
config=config_dict
)
except Exception as e:
print(f"Error loading adapter: {e}")
# Fallback to direct loading if the above fails
model = PeftModel.from_pretrained(
model,
peft_model_id,
# Use this config parameter to ignore unknown parameters
config=None
)
# Load tokenizer from the base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Create a text-generation pipeline
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
)
@app.post("/generate")
def generate(query: Query):
prompt = f"Question: {query.text}\nAnswer: "
response = chat_pipe(prompt)[0]["generated_text"]
# Extract only the answer part from the response
answer = response.split("Answer: ")[-1].strip()
return {"response": answer}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)