minicpm-fastapi / app.py
acharyaaditya26's picture
Update app.py
a51d495 verified
import os
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from typing import List
from pydantic import BaseModel
from PIL import Image
import io
from transformers import AutoModel, AutoTokenizer
import torch
from huggingface_hub import login
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("HF_TOKEN")
login(token = api_key)
app = FastAPI()
# Load model and tokenizer
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
class FewshotExample(BaseModel):
image: bytes
question: str
answer: str
class PredictRequest(BaseModel):
fewshot_examples: List[FewshotExample]
test_image: bytes
test_question: str
@app.post("/predict_with_fewshot")
async def predict_with_fewshot(
fewshot_images: List[UploadFile] = File(...),
fewshot_questions: List[str] = Form(...),
fewshot_answers: List[str] = Form(...),
test_image: UploadFile = File(...),
test_question: str = Form(...)
):
# Validate input lengths
if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers):
raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.")
msgs = []
try:
for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers):
img_content = await fs_img.read()
img = Image.open(io.BytesIO(img_content)).convert('RGB')
msgs.append({'role': 'user', 'content': [img, fs_q]})
msgs.append({'role': 'assistant', 'content': [fs_a]})
# Test example
test_img_content = await test_image.read()
test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB')
msgs.append({'role': 'user', 'content': [test_img, test_question]})
# Get answer
answer = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer
)
return {"answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")