Spaces:
Sleeping
Sleeping
import torch | |
import transformers | |
DOWNLOAD_URL = "https://github.com/unitaryai/detoxify/releases/download/" | |
MODEL_URLS = { | |
"original": DOWNLOAD_URL + "v0.1-alpha/toxic_original-c1212f89.ckpt", | |
"unbiased": DOWNLOAD_URL + "v0.3-alpha/toxic_debiased-c7548aa0.ckpt", | |
"multilingual": DOWNLOAD_URL + "v0.4-alpha/multilingual_debiased-0b549669.ckpt", | |
"original-small": DOWNLOAD_URL + "v0.1.2/original-albert-0e1d6498.ckpt", | |
"unbiased-small": DOWNLOAD_URL + "v0.1.2/unbiased-albert-c8519128.ckpt", | |
} | |
PRETRAINED_MODEL = None | |
def get_model_and_tokenizer( | |
model_type, model_name, tokenizer_name, num_classes, state_dict, huggingface_config_path=None | |
): | |
model_class = getattr(transformers, model_name) | |
model = model_class.from_pretrained( | |
pretrained_model_name_or_path=None, | |
config=huggingface_config_path or model_type, | |
num_labels=num_classes, | |
state_dict=state_dict, | |
local_files_only=huggingface_config_path is not None, | |
) | |
tokenizer = getattr(transformers, tokenizer_name).from_pretrained( | |
huggingface_config_path or model_type, | |
local_files_only=huggingface_config_path is not None, | |
# TODO: may be needed to let it work with Kaggle competition | |
# model_max_length=512, | |
) | |
return model, tokenizer | |
def load_checkpoint(model_type="original", checkpoint=None, device="cpu", huggingface_config_path=None): | |
if checkpoint is None: | |
checkpoint_path = MODEL_URLS[model_type] | |
loaded = torch.hub.load_state_dict_from_url(checkpoint_path, map_location=device) | |
else: | |
loaded = torch.load(checkpoint, map_location=device) | |
if "config" not in loaded or "state_dict" not in loaded: | |
raise ValueError( | |
"Checkpoint needs to contain the config it was trained \ | |
with as well as the state dict" | |
) | |
class_names = loaded["config"]["dataset"]["args"]["classes"] | |
# standardise class names between models | |
change_names = { | |
"toxic": "toxicity", | |
"identity_hate": "identity_attack", | |
"severe_toxic": "severe_toxicity", | |
} | |
class_names = [change_names.get(cl, cl) for cl in class_names] | |
model, tokenizer = get_model_and_tokenizer( | |
**loaded["config"]["arch"]["args"], | |
state_dict=loaded["state_dict"], | |
huggingface_config_path=huggingface_config_path, | |
) | |
return model, tokenizer, class_names | |
def load_model(model_type, checkpoint=None): | |
if checkpoint is None: | |
model, _, _ = load_checkpoint(model_type=model_type) | |
else: | |
model, _, _ = load_checkpoint(checkpoint=checkpoint) | |
return model | |
class Detoxify: | |
"""Detoxify | |
Easily predict if a comment or list of comments is toxic. | |
Can initialize 5 different model types from model type or checkpoint path: | |
- original: | |
model trained on data from the Jigsaw Toxic Comment | |
Classification Challenge | |
- unbiased: | |
model trained on data from the Jigsaw Unintended Bias in | |
Toxicity Classification Challenge | |
- multilingual: | |
model trained on data from the Jigsaw Multilingual | |
Toxic Comment Classification Challenge | |
- original-small: | |
lightweight version of the original model | |
- unbiased-small: | |
lightweight version of the unbiased model | |
Args: | |
model_type(str): model type to be loaded, can be either original, | |
unbiased or multilingual | |
checkpoint(str): checkpoint path, defaults to None | |
device(str or torch.device): accepts any torch.device input or | |
torch.device object, defaults to cpu | |
huggingface_config_path: path to HF config and tokenizer files needed for offline model loading | |
Returns: | |
results(dict): dictionary of output scores for each class | |
""" | |
def __init__(self, model_type="original", checkpoint=PRETRAINED_MODEL, device="cpu", huggingface_config_path=None): | |
super().__init__() | |
self.model, self.tokenizer, self.class_names = load_checkpoint( | |
model_type=model_type, | |
checkpoint=checkpoint, | |
device=device, | |
huggingface_config_path=huggingface_config_path, | |
) | |
self.device = device | |
self.model.to(self.device) | |
def predict(self, text): | |
self.model.eval() | |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.model.device) | |
out = self.model(**inputs)[0] | |
scores = torch.sigmoid(out).cpu().detach().numpy() | |
results = {} | |
for i, cla in enumerate(self.class_names): | |
results[cla] = ( | |
scores[0][i] if isinstance(text, str) else [scores[ex_i][i].tolist() for ex_i in range(len(scores))] | |
) | |
return results | |
def toxic_bert(): | |
return load_model("original") | |
def toxic_albert(): | |
return load_model("original-small") | |
def unbiased_toxic_roberta(): | |
return load_model("unbiased") | |
def unbiased_albert(): | |
return load_model("unbiased-small") | |
def multilingual_toxic_xlm_r(): | |
return load_model("multilingual") | |