File size: 2,304 Bytes
9b164d1
 
f6721ff
9b164d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6721ff
 
9b164d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44df4d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Dict, List, Any
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
from io import BytesIO
import re

class EndpointHandler():
    def __init__(self, path=""):
        # Configuració de la quantització
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )

        # Carrega el processador i model de forma global
        self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
        self.model = LlavaNextForConditionalGeneration.from_pretrained(
            "llava-hf/llava-v1.6-mistral-7b-hf",
            quantization_config=quantization_config,
            device_map="auto"
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        image_url = data.get("url")
        prompt = data.get("prompt")
        
        try:
            response = requests.get(image_url, stream=True)
            image = Image.open(response.raw)
            
            if image.format == 'PNG':
                image = image.convert('RGB')
                buffer = BytesIO()
                image.save(buffer, format="JPEG")
                buffer.seek(0)
                image = Image.open(buffer)

        except Exception as e:
            return {"error": str(e)}

        inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
        output = self.model.generate(**inputs, max_new_tokens=100)
        result = self.processor.decode(output[0], skip_special_tokens=True)
        
        scores = self.extract_scores(result)
        sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
        return sorted_scores

    def extract_scores(self, response):
        scores = {}
        result_part = response.split("[/INST]")[-1].strip()
        pattern = re.compile(r'(\d+)\.\s*(.*?):\s*(\d+)')
        matches = pattern.findall(result_part)
        for match in matches:
            category_number = int(match[0])
            category_name = match[1].strip()
            score = int(match[2])
            scores[category_name] = score
        return scores