import os from fastapi import FastAPI, File, UploadFile, Form, HTTPException from typing import List from pydantic import BaseModel from PIL import Image import io from transformers import AutoModel, AutoTokenizer import torch from huggingface_hub import login from dotenv import load_dotenv load_dotenv() api_key = os.getenv("HF_TOKEN") login(token = api_key) app = FastAPI() # Load model and tokenizer model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16) model = model.eval().cuda() tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True) class FewshotExample(BaseModel): image: bytes question: str answer: str class PredictRequest(BaseModel): fewshot_examples: List[FewshotExample] test_image: bytes test_question: str @app.post("/predict_with_fewshot") async def predict_with_fewshot( fewshot_images: List[UploadFile] = File(...), fewshot_questions: List[str] = Form(...), fewshot_answers: List[str] = Form(...), test_image: UploadFile = File(...), test_question: str = Form(...) ): # Validate input lengths if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers): raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.") msgs = [] try: for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers): img_content = await fs_img.read() img = Image.open(io.BytesIO(img_content)).convert('RGB') msgs.append({'role': 'user', 'content': [img, fs_q]}) msgs.append({'role': 'assistant', 'content': [fs_a]}) # Test example test_img_content = await test_image.read() test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB') msgs.append({'role': 'user', 'content': [test_img, test_question]}) # Get answer answer = model.chat( image=None, msgs=msgs, tokenizer=tokenizer ) return {"answer": answer} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")