File size: 2,620 Bytes
f5e4c93
517b429
f5e4c93
fc5590b
f5e4c93
517b429
f5e4c93
 
 
 
5c19521
 
 
 
f5e4c93
 
 
 
5c19521
517b429
 
 
 
 
f5e4c93
 
fc5590b
f5e4c93
 
 
517b429
f5e4c93
 
517b429
f5e4c93
 
517b429
f5e4c93
 
517b429
f5e4c93
 
 
 
 
292c995
 
 
f5e4c93
 
 
5c19521
 
 
 
 
 
f5e4c93
fc5590b
f5e4c93
 
a3e7329
fc5590b
 
a3e7329
f5e4c93
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import logging
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Global variables for model, tokenizer, and pipeline
model = None
tokenizer = None
pipe = None

# Get the Hugging Face token from environment variable
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
    raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")

@app.on_event("startup")
async def load_model():
    global model, tokenizer, pipe
    
    try:
        logger.info("Loading PEFT configuration...")
        config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval", token=hf_token)
        
        logger.info("Loading base model...")
        base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", token=hf_token)
        
        logger.info("Loading PEFT model...")
        model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval", token=hf_token)

        logger.info("Loading tokenizer...")
        tokenizer = MistralTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", token=hf_token)

        logger.info("Creating pipeline...")
        pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
        
        logger.info("Model, tokenizer, and pipeline loaded successfully.")
    except ImportError as e:
        logger.error(f"Error importing required modules. Please check your installation: {e}")
        raise
    except Exception as e:
        logger.error(f"Error loading model or creating pipeline: {e}")
        raise

@app.get("/")
def home():
    return {"message": "Hello World"}

@app.get("/generate")
async def generate(text: str):
    if not pipe:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        output = pipe(text, max_length=100, num_return_sequences=1)
        return {"output": output[0]['generated_text']}
    except Exception as e:
        logger.error(f"Error during text generation: {e}")
        raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)