rroset commited on
Commit
9b164d1
1 Parent(s): b6b5f88

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +62 -0
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ import re
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, path=""):
11
+ # Configuració de la quantització
12
+ quantization_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.float16,
16
+ )
17
+
18
+ # Carrega el processador i model de forma global
19
+ self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
20
+ self.model = LlavaNextForConditionalGeneration.from_pretrained(
21
+ "llava-hf/llava-v1.6-mistral-7b-hf",
22
+ quantization_config=quantization_config,
23
+ device_map="auto"
24
+ )
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
+ image_url = data.get("url")
28
+ prompt = data.get("prompt")
29
+
30
+ try:
31
+ response = requests.get(image_url, stream=True)
32
+ image = Image.open(response.raw)
33
+
34
+ if image.format == 'PNG':
35
+ image = image.convert('RGB')
36
+ buffer = BytesIO()
37
+ image.save(buffer, format="JPEG")
38
+ buffer.seek(0)
39
+ image = Image.open(buffer)
40
+
41
+ except Exception as e:
42
+ return {"error": str(e)}
43
+
44
+ inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
45
+ output = self.model.generate(**inputs, max_new_tokens=100)
46
+ result = self.processor.decode(output[0], skip_special_tokens=True)
47
+
48
+ scores = self.extract_scores(result)
49
+ sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
50
+ return sorted_scores
51
+
52
+ def extract_scores(self, response):
53
+ scores = {}
54
+ result_part = response.split("[/INST]")[-1].strip()
55
+ pattern = re.compile(r'(\d+)\.\s*(.*?):\s*(\d+)')
56
+ matches = pattern.findall(result_part)
57
+ for match in matches:
58
+ category_number = int(match[0])
59
+ category_name = match[1].strip()
60
+ score = int(match[2])
61
+ scores[category_name] = score
62
+ return scores