rroset commited on
Commit
266aad6
1 Parent(s): 3aad73d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -24
handler.py CHANGED
@@ -1,37 +1,37 @@
1
- from typing import Dict, List, Any
2
  from PIL import Image
 
3
  from io import BytesIO
4
  import base64
5
- import torch
6
- import open_clip
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
- self.model, self.preprocess, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
11
- self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
 
12
 
13
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
- image_base64 = data.get("inputs", None)
15
- parameters = data.get("parameters", None)
16
- if image_base64 is None or parameters is None:
17
- raise ValueError("Input data or parameters not provided")
18
 
19
- candidate_labels = parameters.get("candidate_labels", None)
20
- if candidate_labels is None:
21
- raise ValueError("Candidate labels not provided")
22
 
23
- image = Image.open(BytesIO(base64.b64decode(image_base64)))
24
- image = self.preprocess(image).unsqueeze(0)
25
- text = self.tokenizer(candidate_labels)
26
 
27
- with torch.no_grad():
28
- image_features = self.model.encode_image(image)
29
- text_features = self.model.encode_text(text)
30
- image_features /= image_features.norm(dim=-1, keepdim=True)
31
- text_features /= text_features.norm(dim=-1, keepdim=True)
 
 
 
32
 
33
- text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 
34
 
35
- results = [{"label": label, "score": score.item()} for label, score in zip(candidate_labels, text_probs[0])]
36
  return results
37
-
 
1
+ from transformers import CLIPProcessor, CLIPModel, pipeline
2
  from PIL import Image
3
+ import requests
4
  from io import BytesIO
5
  import base64
6
+ from typing import Dict, List, Any
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
+ # Carrega el model i el processor específics de CLIP
11
+ self.model = CLIPModel.from_pretrained(path)
12
+ self.processor = CLIPProcessor.from_pretrained(path)
13
 
14
+ # Crea la pipeline de classificació d'imatges zero-shot
15
+ self.classifier = pipeline("zero-shot-image-classification", model=self.model, tokenizer=self.processor)
 
 
 
16
 
17
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
+ image_input = data.get("inputs", None)
19
+ candidate_labels = data.get("candidate_labels", None)
20
 
21
+ if image_input is None or candidate_labels is None:
22
+ raise ValueError("Image input or candidate labels not provided")
 
23
 
24
+ # Determina si l'input és una URL o una cadena base64
25
+ if image_input.startswith("http"):
26
+ response = requests.get(image_input)
27
+ image = Image.open(BytesIO(response.content))
28
+ else:
29
+ # Suposa que l'input és base64 i decodifica-la
30
+ image_data = base64.b64decode(image_input)
31
+ image = Image.open(BytesIO(image_data))
32
 
33
+ # Realitza la classificació zero-shot
34
+ results = self.classifier(images=image, candidate_labels=candidate_labels)
35
 
36
+ # Torna els resultats processats
37
  return results