File size: 2,540 Bytes
701388d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import logging
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Global variables for model, tokenizer, and pipeline
model = None
tokenizer = None
pipe = None

@app.on_event("startup")
async def load_model():
    global model, tokenizer, pipe
    
    try:
        # Get Hugging Face token from environment variable
        hf_token = os.environ.get("HUGGINGFACE_TOKEN")
        
        logger.info("Loading PEFT configuration...")
        config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
        
        logger.info("Loading base model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.3", 
            token=hf_token if hf_token else None,
            use_auth_token=True if not hf_token else None
        )
        
        logger.info("Loading PEFT model...")
        model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")

        logger.info("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.3", 
            token=hf_token if hf_token else None,
            use_auth_token=True if not hf_token else None
        )

        logger.info("Creating pipeline...")
        pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
        
        logger.info("Model, tokenizer, and pipeline loaded successfully.")
    except Exception as e:
        logger.error(f"Error loading model or creating pipeline: {e}")
        raise

@app.get("/")
def home():
    return {"message": "Hello World"}

@app.get("/generate")
async def generate(text: str):
    if not pipe:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        output = pipe(text, max_length=100, num_return_sequences=1)
        return {"output": output[0]['generated_text']}
    except Exception as e:
        logger.error(f"Error during text generation: {e}")
        raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)