import os import logging from fastapi import FastAPI, HTTPException from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel, PeftConfig # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Global variables for model, tokenizer, and pipeline model = None tokenizer = None pipe = None @app.on_event("startup") async def load_model(): global model, tokenizer, pipe try: # Get Hugging Face token from environment variable hf_token = os.environ.get("HUGGINGFACE_TOKEN") logger.info("Loading PEFT configuration...") config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval") # Debugging: Print the configuration logger.info(f"Configuration: {config}") logger.info("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.3", use_auth_token=hf_token ) logger.info("Loading PEFT model...") model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval") logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.3", use_auth_token=hf_token ) logger.info("Creating pipeline...") pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) logger.info("Model, tokenizer, and pipeline loaded successfully.") except Exception as e: logger.error(f"Error loading model or creating pipeline: {e}") raise @app.get("/") def home(): return {"message": "Hello World"} @app.get("/generate") async def generate(text: str): if not pipe: raise HTTPException(status_code=503, detail="Model not loaded") try: output = pipe(text, max_length=100, num_return_sequences=1) return {"output": output[0]['generated_text']} except Exception as e: logger.error(f"Error during text generation: {e}") raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)