File size: 1,154 Bytes
5dae26f
222b8b3
5dae26f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import PIL.Image
import torch
from huggingface_hub import login
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces

hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)

class PaliGemmaModel:
    def __init__(self):
        self.model_id = "google/paligemma-3b-mix-448"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
        self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)

    @spaces.GPU
    def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
        inputs = self.processor(text=text, images=image, return_tensors="pt").to(self.device)
        with torch.inference_mode():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )
        result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        return result[0][len(text):].lstrip("\n")