import os import torch import json from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel, PeftConfig import uvicorn from huggingface_hub import login, hf_hub_download # Authenticate with Hugging Face Hub using the HF_TOKEN environment variable HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) else: raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.") # Define a Pydantic model for request validation class Query(BaseModel): text: str app = FastAPI(title="Financial Chatbot API") # Load the base model from Meta-Llama base_model_name = "meta-llama/Llama-3.2-3B" base_model = AutoModelForCausalLM.from_pretrained( base_model_name, device_map="auto", trust_remote_code=True ) # Load the tokenizer from the base model tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # Load the finetuned adapter using PEFT with handling for eva_config peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2" try: # Try direct loading first model = PeftModel.from_pretrained(base_model, peft_model_id) except TypeError as e: if "eva_config" in str(e): print("Handling eva_config compatibility issue...") # Download config but handle it manually config_path = hf_hub_download(repo_id=peft_model_id, filename="adapter_config.json") with open(config_path, 'r') as f: config_dict = json.load(f) # Remove the problematic parameter if 'eva_config' in config_dict: del config_dict['eva_config'] # Save modified config modified_config_path = "modified_adapter_config.json" with open(modified_config_path, 'w') as f: json.dump(config_dict, f) # Load the config from the modified file config = PeftConfig.from_json_file(modified_config_path) # Ensure the config has the correct path config._name_or_path = peft_model_id # Now load with the modified config model = PeftModel.from_pretrained(base_model, peft_model_id, config=config) else: # If it's a different error, raise it raise # Create a text-generation pipeline using the loaded model and tokenizer chat_pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, temperature=0.7, top_p=0.95, ) @app.post("/generate") def generate(query: Query): prompt = f"Question: {query.text}\nAnswer: " response = chat_pipe(prompt)[0]["generated_text"] # Extract just the answer part from the response if "Answer: " in response: response = response.split("Answer: ", 1)[1] return {"response": response} if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)