Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from typing import List | |
from pydantic import BaseModel | |
from PIL import Image | |
import io | |
from transformers import AutoModel, AutoTokenizer | |
import torch | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
load_dotenv() | |
api_key = os.getenv("HF_TOKEN") | |
login(token = api_key) | |
app = FastAPI() | |
# Load model and tokenizer | |
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True, | |
attn_implementation='sdpa', torch_dtype=torch.bfloat16) | |
model = model.eval().cuda() | |
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True) | |
class FewshotExample(BaseModel): | |
image: bytes | |
question: str | |
answer: str | |
class PredictRequest(BaseModel): | |
fewshot_examples: List[FewshotExample] | |
test_image: bytes | |
test_question: str | |
async def predict_with_fewshot( | |
fewshot_images: List[UploadFile] = File(...), | |
fewshot_questions: List[str] = Form(...), | |
fewshot_answers: List[str] = Form(...), | |
test_image: UploadFile = File(...), | |
test_question: str = Form(...) | |
): | |
# Validate input lengths | |
if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers): | |
raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.") | |
msgs = [] | |
try: | |
for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers): | |
img_content = await fs_img.read() | |
img = Image.open(io.BytesIO(img_content)).convert('RGB') | |
msgs.append({'role': 'user', 'content': [img, fs_q]}) | |
msgs.append({'role': 'assistant', 'content': [fs_a]}) | |
# Test example | |
test_img_content = await test_image.read() | |
test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB') | |
msgs.append({'role': 'user', 'content': [test_img, test_question]}) | |
# Get answer | |
answer = model.chat( | |
image=None, | |
msgs=msgs, | |
tokenizer=tokenizer | |
) | |
return {"answer": answer} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |