File size: 2,350 Bytes
a51d495
 
b0095ef
 
 
 
 
 
 
a51d495
 
 
 
 
 
 
b0095ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os

from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from typing import List
from pydantic import BaseModel
from PIL import Image
import io
from transformers import AutoModel, AutoTokenizer
import torch
from huggingface_hub import login
from dotenv import load_dotenv
load_dotenv()

api_key = os.getenv("HF_TOKEN")

login(token = api_key)

app = FastAPI()

# Load model and tokenizer
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
                                  attn_implementation='sdpa', torch_dtype=torch.bfloat16)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)

class FewshotExample(BaseModel):
    image: bytes
    question: str
    answer: str

class PredictRequest(BaseModel):
    fewshot_examples: List[FewshotExample]
    test_image: bytes
    test_question: str

@app.post("/predict_with_fewshot")
async def predict_with_fewshot(
    fewshot_images: List[UploadFile] = File(...),
    fewshot_questions: List[str] = Form(...),
    fewshot_answers: List[str] = Form(...),
    test_image: UploadFile = File(...),
    test_question: str = Form(...)
):
    # Validate input lengths
    if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers):
        raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.")
    
    msgs = []
    try:
        for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers):
            img_content = await fs_img.read()
            img = Image.open(io.BytesIO(img_content)).convert('RGB')
            msgs.append({'role': 'user', 'content': [img, fs_q]})
            msgs.append({'role': 'assistant', 'content': [fs_a]})
        
        # Test example
        test_img_content = await test_image.read()
        test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB')
        msgs.append({'role': 'user', 'content': [test_img, test_question]})
        
        # Get answer
        answer = model.chat(
            image=None,
            msgs=msgs,
            tokenizer=tokenizer
        )
        
        return {"answer": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")