minicpm-fastapi / app.py
acharyaaditya26's picture
changes
b0095ef
raw
history blame
2.2 kB
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from typing import List
from pydantic import BaseModel
from PIL import Image
import io
from transformers import AutoModel, AutoTokenizer
import torch
app = FastAPI()
# Load model and tokenizer
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
class FewshotExample(BaseModel):
image: bytes
question: str
answer: str
class PredictRequest(BaseModel):
fewshot_examples: List[FewshotExample]
test_image: bytes
test_question: str
@app.post("/predict_with_fewshot")
async def predict_with_fewshot(
fewshot_images: List[UploadFile] = File(...),
fewshot_questions: List[str] = Form(...),
fewshot_answers: List[str] = Form(...),
test_image: UploadFile = File(...),
test_question: str = Form(...)
):
# Validate input lengths
if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers):
raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.")
msgs = []
try:
for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers):
img_content = await fs_img.read()
img = Image.open(io.BytesIO(img_content)).convert('RGB')
msgs.append({'role': 'user', 'content': [img, fs_q]})
msgs.append({'role': 'assistant', 'content': [fs_a]})
# Test example
test_img_content = await test_image.read()
test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB')
msgs.append({'role': 'user', 'content': [test_img, test_question]})
# Get answer
answer = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer
)
return {"answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")