File size: 6,240 Bytes
92be70e 5da0eba 92be70e c09ae13 92be70e 5da0eba 92be70e 7bbe89d a985ed6 7bbe89d 92be70e 70ebd67 92be70e 70ebd67 92be70e 70ebd67 92be70e 5da0eba 70ebd67 92be70e 7dd13e4 92be70e 7dd13e4 df36680 7bbe89d 92be70e b3cae23 92be70e 8ac7f12 92be70e 1c9f7e0 92be70e 20062d7 7dd13e4 37fe09a 1fd8303 ff69146 1fd8303 ab91900 1fd8303 92be70e 72681cf 92be70e 7dd13e4 7bbe89d 7dd13e4 92be70e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import datetime
import gradio as gr
from huggingface_hub import hf_hub_download
import fasttext, torch, clip
from sentence_transformers import SentenceTransformer, util
model_en, _ = clip.load("ViT-B/32")
model_multi = SentenceTransformer("sentence-transformers/clip-ViT-B-32-multilingual-v1")
fasttext_model = fasttext.load_model(hf_hub_download("julien-c/fasttext-language-id", "lid.176.bin"))
def prep_examples():
example_text1 = "Coronavirus disease (COVID-19) is an infectious disease caused by the SARS-CoV-2 virus. Most \
people who fall sick with COVID-19 will experience mild to moderate symptoms and recover without special treatment. \
However, some will become seriously ill and require medical attention."
example_labels1 = "business;;health related;;politics;;climate change"
example_text2 = "Elephants are"
example_labels2 = "big;;small;;strong;;fast;;carnivorous"
example_text3 = "Elephants"
example_labels3 = "are big;;can be very small;;generally not strong enough;;are faster than you think"
example_text4 = "Dogs are man's best friend"
example_labels4 = "positive;;negative;;neutral"
example_text5 = "Şampiyonlar Ligi’nde 5. hafta oynanan karşılaşmaların ardından sona erdi. Real Madrid, \
Inter ve Sporting oynadıkları mücadeleler sonrasında Son 16 turuna yükselmeyi başardı. \
Gecenin dev mücadelesinde ise Manchester City, PSG’yi yenerek liderliği garantiledi."
example_labels5 = "dünya;;ekonomi;;kültür;;siyaset;;spor;;teknoloji"
example_text6 = "Letzte Woche gab es einen Selbstmord in einer nahe gelegenen kolonie"
example_labels6 = "verbrechen;;tragödie;;stehlen"
example_text7 = "El autor se perfila, a los 50 años de su muerte, como uno de los grandes de su siglo"
example_labels7 = "cultura;;sociedad;;economia;;salud;;deportes"
example_text8 = "Россия в среду заявила, что военные учения в аннексированном Москвой Крыму закончились \
и что солдаты возвращаются в свои гарнизоны, на следующий день после того, как она объявила о первом выводе \
войск от границ Украины."
example_labels8 = "новости;;комедия"
example_text9 = "I quattro registi - Federico Fellini, Pier Paolo Pasolini, Bernardo Bertolucci e Vittorio De Sica - \
hanno utilizzato stili di ripresa diversi, ma hanno fortemente influenzato le giovani generazioni di registi."
example_labels9 = "cinema;;politica;;cibo"
example_text10 = "Ja, vi elsker dette landet,\
som det stiger frem,\
furet, værbitt over vannet,\
med de tusen hjem.\
Og som fedres kamp har hevet\
det av nød til seir"
example_labels10 = "helse;;sport;;religion;;mat;;patriotisme og nasjonalisme"
example_text11 = "Amar sonar bangla ami tomay bhalobasi"
example_labels11 = "bhalo;;kharap"
examples = [
[example_text1, example_labels1, ""],
[example_text2, example_labels2, ""],
[example_text3, example_labels3, ""],
[example_text4, example_labels4, ""],
[example_text5, example_labels5, ""],
[example_text6, example_labels6, ""],
[example_text7, example_labels7, ""],
[example_text8, example_labels8, ""],
[example_text9, example_labels9, ""],
[example_text10, example_labels10, ""],
[example_text11, example_labels11, ""]]
return examples
def detect_lang(text):
seq_lang = 'en'
text = text.replace('\n', ' ')
try:
seq_lang = fasttext_model.predict(text, k=1)[0][0].split("__label__")[1]
except:
print("Language detection failed!",
"Date:{}, Sequence: {}".format(
str(datetime.datetime.now()),
text))
return seq_lang
def sequence_to_classify(text, labels, hypothesis_template):
lang = detect_lang(text)
if lang == 'en':
model = model_en
else:
model = model_multi
hypothesis_template += " {}"
labels = [hypothesis_template.format(label) for label in labels.split(";;")]
if str(type(model)) == "<class 'clip.model.CLIP'>":
text_tokens = clip.tokenize(text)
text_features = model.encode_text(text_tokens)
label_tokens = clip.tokenize(labels)
labels_features = model.encode_text(label_tokens)
else:
text_features = torch.tensor(model.encode(text))
labels_features = torch.tensor(self.model.encode(labels))
sim_scores = util.cos_sim(text_features, labels_features)
preds = []
for textlet, sim_score in zip([text], sim_scores):
out = []
pred = {}
for raw_score in sim_score:
out.append(raw_score.item() * 100)
probs = torch.tensor([out])
probs = probs.softmax(dim=-1).cpu().numpy()
scores = list(probs.flatten())
sorted_sl = sorted(zip(scores, labels), key=lambda t:t[0], reverse=True)
pred["text"] = textlet
pred["scores"], pred["labels"] = zip(*sorted_sl)
preds.append(pred)
print(preds)
if len(preds) == 1:
preds = preds[0]
predicted_labels = list(preds['labels'])
predicted_scores = list(preds['scores'])
print(predicted_labels)
print(predicted_scores)
output = {idx: float(predicted_scores.pop(0)) for idx in predicted_labels}
print("Date:{}, Sequence:{}, Labels: {}".format(
str(datetime.datetime.now()),
text,
predicted_labels))
return output
iface = gr.Interface(
title="Light-weight Zero-shot NLP Classifier",
description="Multi-label Multilingual classifier which uses Sentence Transformer / OpenAI CLIP.",
fn=sequence_to_classify,
inputs=[gr.inputs.Textbox(lines=10,
label="Please enter the text you would like to classify...",
placeholder="Text here..."),
gr.inputs.Textbox(lines=2,
label="Please enter the candidate labels (separated by 2 consecutive semicolons)...",
placeholder="Labels here separated by ;;"),
gr.inputs.Textbox(lines=2,
label="Please enter the text for hypothesis template (optional)...",
placeholder="Text here...")],
outputs=gr.outputs.Label(num_top_classes=5),
#interpretation="default",
examples=prep_examples())
iface.launch()
|