rxavier commited on
Commit
7fdac21
1 Parent(s): d0490bb

Update off_topic.py

Browse files
Files changed (1) hide show
  1. off_topic.py +57 -10
off_topic.py CHANGED
@@ -10,28 +10,71 @@ import numpy as np
10
  import torch
11
  import PIL
12
  import imagehash
13
- from transformers import CLIPModel, CLIPProcessor
14
  from PIL import Image
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class OffTopicDetector:
18
- def __init__(self, model_id: str, device: Optional[str] = None, image_size: str = "E"):
19
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
20
  self.processor = CLIPProcessor.from_pretrained(model_id)
21
  self.model = CLIPModel.from_pretrained(model_id).to(self.device)
22
  self.image_size = image_size
 
23
 
24
  def predict_probas(self, images: List[PIL.Image.Image], domain: str,
 
25
  valid_templates: Optional[List[str]] = None,
26
  invalid_classes: Optional[List[str]] = None,
27
  autocast: bool = True):
 
 
28
  if valid_templates:
29
  valid_classes = [template.format(domain) for template in valid_templates]
30
  else:
31
- valid_classes = [f"a photo of {domain}", f"brochure with {domain} image", f"instructions for {domain}", f"{domain} diagram",
32
- f"{domain} packaging", f"box of {domain}"]
 
 
 
 
 
 
 
 
 
33
  if not invalid_classes:
34
  invalid_classes = ["promotional ad with store information", "promotional text", "google maps screenshot", "business card", "qr code"]
 
35
  n_valid = len(valid_classes)
36
  classes = valid_classes + invalid_classes
37
  print(f"Valid classes: {valid_classes}", f"Invalid classes: {invalid_classes}", sep="\n")
@@ -59,18 +102,21 @@ class OffTopicDetector:
59
  return probas, valid_probas, invalid_probas
60
 
61
  def predict_probas_url(self, img_urls: List[str], domain: str,
 
62
  valid_templates: Optional[List[str]] = None,
63
  invalid_classes: Optional[List[str]] = None,
64
  autocast: bool = True):
65
  images = self.get_images(img_urls)
66
  dedup_images = self._filter_dups(images)
67
- return self.predict_probas(images, domain, valid_templates, invalid_classes, autocast)
68
 
69
  def predict_probas_item(self, url_or_id: str,
 
70
  valid_templates: Optional[List[str]] = None,
71
  invalid_classes: Optional[List[str]] = None):
72
- images, domain = self.get_item_data(url_or_id)
73
- probas, valid_probas, invalid_probas = self.predict_probas(images, domain, valid_templates,
 
74
  invalid_classes)
75
  return images, domain, probas, valid_probas, invalid_probas
76
 
@@ -84,7 +130,8 @@ class OffTopicDetector:
84
  item_id = re.sub("-", "", url_or_id)
85
  start = time.time()
86
  response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
87
- domain = re.sub("_", " ", response["domain_id"].split("-")[-1]).lower()
 
88
  img_urls = [x["url"] for x in response["pictures"]]
89
  img_urls = [x.replace("-O.jpg", f"-{self.image_size}.jpg") for x in img_urls]
90
  end = time.time()
@@ -92,7 +139,7 @@ class OffTopicDetector:
92
  print(f"Items API time: {round(duration * 1000, 0)} ms")
93
  images = self.get_images(img_urls)
94
  dedup_images = self._filter_dups(images)
95
- return dedup_images, domain
96
 
97
  def _filter_dups(self, images: List):
98
  if len(images) > 1:
@@ -166,4 +213,4 @@ class OffTopicDetector:
166
  if title:
167
  fig.suptitle(title)
168
  fig.tight_layout()
169
- return
 
10
  import torch
11
  import PIL
12
  import imagehash
13
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, CLIPModel, CLIPProcessor
14
  from PIL import Image
15
 
16
 
17
+ class Translator:
18
+ def __init__(self, model_id: str, device: Optional[str] = None):
19
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.model_id = model_id
21
+ self.tokenizer = AutoTokenizer.from_pretrained(
22
+ model_id)
23
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
24
+ self.bos_token_map = self.tokenizer.get_lang_id if hasattr(self.tokenizer, "get_lang_id") else self.tokenizer.lang_code_to_id
25
+
26
+ @property
27
+ def _language_code_mapper(self):
28
+ if "nllb" in self.model_id.lower():
29
+ return {"en": "eng_Latn",
30
+ "es": "spa_Latn",
31
+ "pt": "por_Latn"}
32
+ elif "m2m" in self.model_id.lower():
33
+ return {"en": "en",
34
+ "es": "es",
35
+ "pt": "pt"}
36
+
37
+ def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
38
+ self.tokenizer.src_lang = self._language_code_mapper[src_lang]
39
+ inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
40
+ translated_tokens = self.model.generate(
41
+ **inputs, forced_bos_token_id=self.bos_token_map["eng_Latn"], max_length=max_length
42
+ )
43
+ return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
44
+
45
+
46
  class OffTopicDetector:
47
+ def __init__(self, model_id: str, device: Optional[str] = None, image_size: str = "E", translator: Optional[Translator] = None):
48
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
49
  self.processor = CLIPProcessor.from_pretrained(model_id)
50
  self.model = CLIPModel.from_pretrained(model_id).to(self.device)
51
  self.image_size = image_size
52
+ self.translator = translator
53
 
54
  def predict_probas(self, images: List[PIL.Image.Image], domain: str,
55
+ title: Optional[str] = None,
56
  valid_templates: Optional[List[str]] = None,
57
  invalid_classes: Optional[List[str]] = None,
58
  autocast: bool = True):
59
+ site, domain = domain.split("-")
60
+ domain = re.sub("_", " ", domain).lower()
61
  if valid_templates:
62
  valid_classes = [template.format(domain) for template in valid_templates]
63
  else:
64
+ valid_classes = [f"a photo of {domain}", f"brochure with {domain} image", f"instructions for {domain}", f"{domain} diagram"]
65
+ if title:
66
+ if site == "CBT":
67
+ translated_title = title
68
+ else:
69
+ if site == "MLB":
70
+ src_lang = "pt"
71
+ else:
72
+ src_lang = "es"
73
+ translated_title = self.translator.translate(title, src_lang=src_lang, dest_lang="en", max_length=100)[0]
74
+ valid_classes.append(translated_title)
75
  if not invalid_classes:
76
  invalid_classes = ["promotional ad with store information", "promotional text", "google maps screenshot", "business card", "qr code"]
77
+
78
  n_valid = len(valid_classes)
79
  classes = valid_classes + invalid_classes
80
  print(f"Valid classes: {valid_classes}", f"Invalid classes: {invalid_classes}", sep="\n")
 
102
  return probas, valid_probas, invalid_probas
103
 
104
  def predict_probas_url(self, img_urls: List[str], domain: str,
105
+ title: Optional[str] = None,
106
  valid_templates: Optional[List[str]] = None,
107
  invalid_classes: Optional[List[str]] = None,
108
  autocast: bool = True):
109
  images = self.get_images(img_urls)
110
  dedup_images = self._filter_dups(images)
111
+ return self.predict_probas(images, domain, title, valid_templates, invalid_classes, autocast)
112
 
113
  def predict_probas_item(self, url_or_id: str,
114
+ use_title: bool = False,
115
  valid_templates: Optional[List[str]] = None,
116
  invalid_classes: Optional[List[str]] = None):
117
+ images, domain, title = self.get_item_data(url_or_id)
118
+ title = title if use_title else None
119
+ probas, valid_probas, invalid_probas = self.predict_probas(images, domain, title, valid_templates,
120
  invalid_classes)
121
  return images, domain, probas, valid_probas, invalid_probas
122
 
 
130
  item_id = re.sub("-", "", url_or_id)
131
  start = time.time()
132
  response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
133
+ domain = response["domain_id"]
134
+ title = response["title"]
135
  img_urls = [x["url"] for x in response["pictures"]]
136
  img_urls = [x.replace("-O.jpg", f"-{self.image_size}.jpg") for x in img_urls]
137
  end = time.time()
 
139
  print(f"Items API time: {round(duration * 1000, 0)} ms")
140
  images = self.get_images(img_urls)
141
  dedup_images = self._filter_dups(images)
142
+ return dedup_images, domain, title
143
 
144
  def _filter_dups(self, images: List):
145
  if len(images) > 1:
 
213
  if title:
214
  fig.suptitle(title)
215
  fig.tight_layout()
216
+ return