ky2k's picture
Upload folder using huggingface_hub
d7728be
import torch
import transformers
DOWNLOAD_URL = "https://github.com/unitaryai/detoxify/releases/download/"
MODEL_URLS = {
"original": DOWNLOAD_URL + "v0.1-alpha/toxic_original-c1212f89.ckpt",
"unbiased": DOWNLOAD_URL + "v0.3-alpha/toxic_debiased-c7548aa0.ckpt",
"multilingual": DOWNLOAD_URL + "v0.4-alpha/multilingual_debiased-0b549669.ckpt",
"original-small": DOWNLOAD_URL + "v0.1.2/original-albert-0e1d6498.ckpt",
"unbiased-small": DOWNLOAD_URL + "v0.1.2/unbiased-albert-c8519128.ckpt",
}
PRETRAINED_MODEL = None
def get_model_and_tokenizer(
model_type, model_name, tokenizer_name, num_classes, state_dict, huggingface_config_path=None
):
model_class = getattr(transformers, model_name)
model = model_class.from_pretrained(
pretrained_model_name_or_path=None,
config=huggingface_config_path or model_type,
num_labels=num_classes,
state_dict=state_dict,
local_files_only=huggingface_config_path is not None,
)
tokenizer = getattr(transformers, tokenizer_name).from_pretrained(
huggingface_config_path or model_type,
local_files_only=huggingface_config_path is not None,
# TODO: may be needed to let it work with Kaggle competition
# model_max_length=512,
)
return model, tokenizer
def load_checkpoint(model_type="original", checkpoint=None, device="cpu", huggingface_config_path=None):
if checkpoint is None:
checkpoint_path = MODEL_URLS[model_type]
loaded = torch.hub.load_state_dict_from_url(checkpoint_path, map_location=device)
else:
loaded = torch.load(checkpoint, map_location=device)
if "config" not in loaded or "state_dict" not in loaded:
raise ValueError(
"Checkpoint needs to contain the config it was trained \
with as well as the state dict"
)
class_names = loaded["config"]["dataset"]["args"]["classes"]
# standardise class names between models
change_names = {
"toxic": "toxicity",
"identity_hate": "identity_attack",
"severe_toxic": "severe_toxicity",
}
class_names = [change_names.get(cl, cl) for cl in class_names]
model, tokenizer = get_model_and_tokenizer(
**loaded["config"]["arch"]["args"],
state_dict=loaded["state_dict"],
huggingface_config_path=huggingface_config_path,
)
return model, tokenizer, class_names
def load_model(model_type, checkpoint=None):
if checkpoint is None:
model, _, _ = load_checkpoint(model_type=model_type)
else:
model, _, _ = load_checkpoint(checkpoint=checkpoint)
return model
class Detoxify:
"""Detoxify
Easily predict if a comment or list of comments is toxic.
Can initialize 5 different model types from model type or checkpoint path:
- original:
model trained on data from the Jigsaw Toxic Comment
Classification Challenge
- unbiased:
model trained on data from the Jigsaw Unintended Bias in
Toxicity Classification Challenge
- multilingual:
model trained on data from the Jigsaw Multilingual
Toxic Comment Classification Challenge
- original-small:
lightweight version of the original model
- unbiased-small:
lightweight version of the unbiased model
Args:
model_type(str): model type to be loaded, can be either original,
unbiased or multilingual
checkpoint(str): checkpoint path, defaults to None
device(str or torch.device): accepts any torch.device input or
torch.device object, defaults to cpu
huggingface_config_path: path to HF config and tokenizer files needed for offline model loading
Returns:
results(dict): dictionary of output scores for each class
"""
def __init__(self, model_type="original", checkpoint=PRETRAINED_MODEL, device="cpu", huggingface_config_path=None):
super().__init__()
self.model, self.tokenizer, self.class_names = load_checkpoint(
model_type=model_type,
checkpoint=checkpoint,
device=device,
huggingface_config_path=huggingface_config_path,
)
self.device = device
self.model.to(self.device)
@torch.no_grad()
def predict(self, text):
self.model.eval()
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
out = self.model(**inputs)[0]
scores = torch.sigmoid(out).cpu().detach().numpy()
results = {}
for i, cla in enumerate(self.class_names):
results[cla] = (
scores[0][i] if isinstance(text, str) else [scores[ex_i][i].tolist() for ex_i in range(len(scores))]
)
return results
def toxic_bert():
return load_model("original")
def toxic_albert():
return load_model("original-small")
def unbiased_toxic_roberta():
return load_model("unbiased")
def unbiased_albert():
return load_model("unbiased-small")
def multilingual_toxic_xlm_r():
return load_model("multilingual")