File size: 2,190 Bytes
261f61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc717b
be59b56
9ad10c7
 
be59b56
 
 
 
 
 
 
 
9ad10c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

from io import BytesIO
from typing import Any, List, Dict

from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import requests
import copy
import base64



class EndpointHandler():
    def __init__(self, path=""):
        # Use a pipeline as a high-level helper
        model_id = 'microsoft/Florence-2-large'
        model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        self.model = model
        self.processor = processor

    def run_example(self, image, task_prompt, text_input=None):
        if text_input is None:
            prompt = task_prompt
        else:
            prompt = task_prompt + text_input
        inputs = self.processor(text=prompt, images=image, return_tensors="pt")
        generated_ids = self.model.generate(
            input_ids=inputs["input_ids"].cuda(),
            pixel_values=inputs["pixel_values"].cuda(),
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
        )
        generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = self.processor.post_process_generation(
            generated_text,
            task=task_prompt,
            image_size=(image.width, image.height)
        )

        return parsed_answer

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        image = data['inputs'].pop("image", None)
        image_url = data['inputs'].pop("image_url", None)
        type = data['inputs'].pop("type", '<MORE_DETAILED_CAPTION>')
        text = data['inputs'].pop("text", None)
        if image:
            image = Image.open(BytesIO(base64.b64decode(image)))
        elif image_url:
            response = requests.get(image_url)
            if response.status_code == 200:
                image = Image.open(BytesIO(response.content))
            else:
                raise ValueError(f"Unable to download image from URL: {image_url}")
        return self.run_example(image, type, text_input=text)