Spaces:
Sleeping
Sleeping
File size: 5,236 Bytes
d7728be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import transformers
DOWNLOAD_URL = "https://github.com/unitaryai/detoxify/releases/download/"
MODEL_URLS = {
"original": DOWNLOAD_URL + "v0.1-alpha/toxic_original-c1212f89.ckpt",
"unbiased": DOWNLOAD_URL + "v0.3-alpha/toxic_debiased-c7548aa0.ckpt",
"multilingual": DOWNLOAD_URL + "v0.4-alpha/multilingual_debiased-0b549669.ckpt",
"original-small": DOWNLOAD_URL + "v0.1.2/original-albert-0e1d6498.ckpt",
"unbiased-small": DOWNLOAD_URL + "v0.1.2/unbiased-albert-c8519128.ckpt",
}
PRETRAINED_MODEL = None
def get_model_and_tokenizer(
model_type, model_name, tokenizer_name, num_classes, state_dict, huggingface_config_path=None
):
model_class = getattr(transformers, model_name)
model = model_class.from_pretrained(
pretrained_model_name_or_path=None,
config=huggingface_config_path or model_type,
num_labels=num_classes,
state_dict=state_dict,
local_files_only=huggingface_config_path is not None,
)
tokenizer = getattr(transformers, tokenizer_name).from_pretrained(
huggingface_config_path or model_type,
local_files_only=huggingface_config_path is not None,
# TODO: may be needed to let it work with Kaggle competition
# model_max_length=512,
)
return model, tokenizer
def load_checkpoint(model_type="original", checkpoint=None, device="cpu", huggingface_config_path=None):
if checkpoint is None:
checkpoint_path = MODEL_URLS[model_type]
loaded = torch.hub.load_state_dict_from_url(checkpoint_path, map_location=device)
else:
loaded = torch.load(checkpoint, map_location=device)
if "config" not in loaded or "state_dict" not in loaded:
raise ValueError(
"Checkpoint needs to contain the config it was trained \
with as well as the state dict"
)
class_names = loaded["config"]["dataset"]["args"]["classes"]
# standardise class names between models
change_names = {
"toxic": "toxicity",
"identity_hate": "identity_attack",
"severe_toxic": "severe_toxicity",
}
class_names = [change_names.get(cl, cl) for cl in class_names]
model, tokenizer = get_model_and_tokenizer(
**loaded["config"]["arch"]["args"],
state_dict=loaded["state_dict"],
huggingface_config_path=huggingface_config_path,
)
return model, tokenizer, class_names
def load_model(model_type, checkpoint=None):
if checkpoint is None:
model, _, _ = load_checkpoint(model_type=model_type)
else:
model, _, _ = load_checkpoint(checkpoint=checkpoint)
return model
class Detoxify:
"""Detoxify
Easily predict if a comment or list of comments is toxic.
Can initialize 5 different model types from model type or checkpoint path:
- original:
model trained on data from the Jigsaw Toxic Comment
Classification Challenge
- unbiased:
model trained on data from the Jigsaw Unintended Bias in
Toxicity Classification Challenge
- multilingual:
model trained on data from the Jigsaw Multilingual
Toxic Comment Classification Challenge
- original-small:
lightweight version of the original model
- unbiased-small:
lightweight version of the unbiased model
Args:
model_type(str): model type to be loaded, can be either original,
unbiased or multilingual
checkpoint(str): checkpoint path, defaults to None
device(str or torch.device): accepts any torch.device input or
torch.device object, defaults to cpu
huggingface_config_path: path to HF config and tokenizer files needed for offline model loading
Returns:
results(dict): dictionary of output scores for each class
"""
def __init__(self, model_type="original", checkpoint=PRETRAINED_MODEL, device="cpu", huggingface_config_path=None):
super().__init__()
self.model, self.tokenizer, self.class_names = load_checkpoint(
model_type=model_type,
checkpoint=checkpoint,
device=device,
huggingface_config_path=huggingface_config_path,
)
self.device = device
self.model.to(self.device)
@torch.no_grad()
def predict(self, text):
self.model.eval()
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
out = self.model(**inputs)[0]
scores = torch.sigmoid(out).cpu().detach().numpy()
results = {}
for i, cla in enumerate(self.class_names):
results[cla] = (
scores[0][i] if isinstance(text, str) else [scores[ex_i][i].tolist() for ex_i in range(len(scores))]
)
return results
def toxic_bert():
return load_model("original")
def toxic_albert():
return load_model("original-small")
def unbiased_toxic_roberta():
return load_model("unbiased")
def unbiased_albert():
return load_model("unbiased-small")
def multilingual_toxic_xlm_r():
return load_model("multilingual")
|