from fastapi import FastAPI, File, UploadFile, Form, HTTPException from typing import List from pydantic import BaseModel from PIL import Image import io from transformers import AutoModel, AutoTokenizer import torch app = FastAPI() # Load model and tokenizer model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16) model = model.eval().cuda() tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True) class FewshotExample(BaseModel): image: bytes question: str answer: str class PredictRequest(BaseModel): fewshot_examples: List[FewshotExample] test_image: bytes test_question: str @app.post("/predict_with_fewshot") async def predict_with_fewshot( fewshot_images: List[UploadFile] = File(...), fewshot_questions: List[str] = Form(...), fewshot_answers: List[str] = Form(...), test_image: UploadFile = File(...), test_question: str = Form(...) ): # Validate input lengths if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers): raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.") msgs = [] try: for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers): img_content = await fs_img.read() img = Image.open(io.BytesIO(img_content)).convert('RGB') msgs.append({'role': 'user', 'content': [img, fs_q]}) msgs.append({'role': 'assistant', 'content': [fs_a]}) # Test example test_img_content = await test_image.read() test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB') msgs.append({'role': 'user', 'content': [test_img, test_question]}) # Get answer answer = model.chat( image=None, msgs=msgs, tokenizer=tokenizer ) return {"answer": answer} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")