from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import librosa import torch import base64 import io import logging import numpy as np from transformers import AutoModel, AutoTokenizer # from transformers import AutoModel, AutoTokenizer # from transformers import AutoModel, AutoTokenizer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AudioRequest(BaseModel): audio_data: str sample_rate: int class AudioResponse(BaseModel): audio_data: str text: str = "" # Model initialization status INITIALIZATION_STATUS = { "model_loaded": False, "error": None } # Global model and tokenizer instances class Model: def __init__(self): self.model = model = AutoModel.from_pretrained( './models/checkpoint', trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='sdpa' ) model = model.eval().cuda() self.tokenizer = AutoTokenizer.from_pretrained( './models/checkpoint', trust_remote_code=True ) # Initialize TTS model.init_tts() model.tts.float() # Convert TTS to float32 if needed self.model_in_sr = 16000 self.model_out_sr = 24000 self.ref_audio, _ = librosa.load('./ref_audios/female_example.wav', sr=self.model_in_sr, mono=True) # load the reference audio self.sys_prompt = model.get_sys_prompt(ref_audio=self.ref_audio, mode='audio_assistant', language='en') # warmup audio_data = librosa.load('./ref_audios/male_example.wav', sr=self.model_in_sr, mono=True)[0] _ = self.inference(audio_data, self.model_in_sr) def inference(self, audio_np, input_audio_sr): if input_audio_sr != self.model_in_sr: audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr) user_question = {'role': 'user', 'content': [audio_np]} # round one msgs = [self.sys_prompt, user_question] res = self.model.chat( msgs=msgs, tokenizer=self.tokenizer, sampling=True, max_new_tokens=128, use_tts_template=True, generate_audio=True, temperature=0.3, ) audio = res["audio_wav"].cpu().numpy() if self.model_out_sr != input_audio_sr: audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr) return audio, res["text"] def initialize_model(): """Initialize the MiniCPM model""" global model, INITIALIZATION_STATUS try: logger.info("Initializing model...") model = Model() INITIALIZATION_STATUS["model_loaded"] = True logger.info("MiniCPM model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "model_loaded": INITIALIZATION_STATUS["model_loaded"], "error": INITIALIZATION_STATUS["error"] } return status @app.post("/api/v1/inference") async def inference(request: AudioRequest) -> AudioResponse: """Run inference with MiniCPM model""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: # Decode audio data from base64 audio_bytes = base64.b64decode(request.audio_data) audio_np = np.load(io.BytesIO(audio_bytes)).flatten() # Generate response import time start = time.time() print(f"starting inference with audio length {audio_np.shape}") audio_response, text_response = model.inference(audio_np, request.sample_rate) print(f"inference took {time.time() - start} seconds") # If we got audio, save it and encode to base64 buffer = io.BytesIO() np.save(buffer, audio_response) audio_b64 = base64.b64encode(buffer.getvalue()).decode() return AudioResponse( audio_data=audio_b64, text=text_response ) except Exception as e: logger.error(f"Inference failed: {str(e)}") raise HTTPException( status_code=500, detail=str(e) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)