MatteoFasulo
commited on
Commit
·
976b6b9
1
Parent(s):
65f1a6e
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import unicodedata
|
3 |
+
import nltk
|
4 |
+
from nltk import WordNetLemmatizer
|
5 |
+
from datasets import Dataset
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from transformers import AutoModelForSequenceClassification
|
8 |
+
from transformers import XLMRobertaForSequenceClassification
|
9 |
+
from transformers import Trainer
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
def preprocess_text(text: str) -> str:
|
13 |
+
"""
|
14 |
+
Preprocesses the input text by removing or replacing specific patterns.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
text (str): The input text to be preprocessed.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: The preprocessed text with URLs, mentions, hashtags, emojis,
|
21 |
+
special characters removed, 'and' replaced, and extra spaces trimmed.
|
22 |
+
"""
|
23 |
+
# Define patterns
|
24 |
+
URL_PATTERN_STR = r"""(?i)((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info
|
25 |
+
|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|
|
26 |
+
bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|
|
27 |
+
cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|
|
28 |
+
gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|
|
29 |
+
la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|
|
30 |
+
nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|
|
31 |
+
sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|
|
32 |
+
uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()]
|
33 |
+
*?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’])|(?:(?<!@)
|
34 |
+
[a-z0-9]+(?:[.\-][a-z0-9]+)*[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name
|
35 |
+
|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn
|
36 |
+
|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg
|
37 |
+
|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id
|
38 |
+
|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|
|
39 |
+
md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|
|
40 |
+
ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|
|
41 |
+
sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|
|
42 |
+
za|zm|zw)\b/?(?!@)))"""
|
43 |
+
URL_PATTERN = re.compile(URL_PATTERN_STR, re.IGNORECASE)
|
44 |
+
HASHTAG_PATTERN = re.compile(r'#\w*')
|
45 |
+
MENTION_PATTERN = re.compile(r'@\w*')
|
46 |
+
PUNCT_REPEAT_PATTERN = re.compile(r'([!?.]){2,}')
|
47 |
+
ELONG_PATTERN = re.compile(r'\b(\S*?)(.)\2{2,}\b')
|
48 |
+
WORD_PATTERN = re.compile(r'[^\w<>\s]')
|
49 |
+
# Convert URL to <URL> so that GloVe will have a vector for it
|
50 |
+
text = re.sub(URL_PATTERN, ' <URL>', text)
|
51 |
+
# Add spaces around slashes
|
52 |
+
text = re.sub(r"/", " / ", text)
|
53 |
+
# Replace mentions with <USER>
|
54 |
+
text = re.sub(MENTION_PATTERN, ' <USER> ', text)
|
55 |
+
# Replace numbers with <NUMBER>
|
56 |
+
text = re.sub(r"[-+]?[.\d]*[\d]+[:,.\d]*", " <NUMBER> ", text)
|
57 |
+
# Replace hashtags with <HASHTAG>
|
58 |
+
text = re.sub(HASHTAG_PATTERN, ' <HASHTAG> ', text)
|
59 |
+
#text = self.AND_PATTERN.sub('and', text) # & already in the Vocab of GloVe-twitter
|
60 |
+
# Replace multiple punctuation marks with <REPEAT>
|
61 |
+
text = re.sub(PUNCT_REPEAT_PATTERN, lambda match: f" {match.group(1)} <REPEAT> ", text)
|
62 |
+
# Replace elongated words with <ELONG>
|
63 |
+
text = re.sub(ELONG_PATTERN, lambda match: f" {match.group(1)}{match.group(2)} <ELONG> ", text)
|
64 |
+
#text = emoji.replace_emoji(text, replace='') # some emojis are in the vocab so we do not remove them, the others will be OOVs
|
65 |
+
text = text.strip()
|
66 |
+
# Get only words
|
67 |
+
text = re.sub(WORD_PATTERN, ' ', text)
|
68 |
+
text = text.strip()
|
69 |
+
# Convert stylized Unicode characters to plain text (removes bold text, etc.)
|
70 |
+
text = ''.join(c for c in unicodedata.normalize('NFKD', text) if not unicodedata.combining(c))
|
71 |
+
return text
|
72 |
+
|
73 |
+
def lemmatize_text(text: str) -> str:
|
74 |
+
"""
|
75 |
+
Lemmatizes the input text using the WordNet lemmatizer.
|
76 |
+
|
77 |
+
This method attempts to lemmatize each word in the input text. If the WordNet
|
78 |
+
data is not available, it will download the necessary data and retry.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
text (str): The input text to be lemmatized.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The lemmatized text.
|
85 |
+
"""
|
86 |
+
lemmatizer = WordNetLemmatizer()
|
87 |
+
downloaded = False
|
88 |
+
while not downloaded:
|
89 |
+
try:
|
90 |
+
lemmatizer.lemmatize(text)
|
91 |
+
downloaded = True
|
92 |
+
except LookupError:
|
93 |
+
print("Downloading WordNet...")
|
94 |
+
nltk.download('wordnet')
|
95 |
+
return ' '.join([lemmatizer.lemmatize(word) for word in text.split()])
|
96 |
+
|
97 |
+
def predict(phrase: str, finetuned_model: str):
|
98 |
+
phrase = preprocess_text(phrase)
|
99 |
+
phrase = lemmatize_text(phrase)
|
100 |
+
phrase = phrase.lower()
|
101 |
+
|
102 |
+
# Get the tokenizer and model
|
103 |
+
if 'xlm' in finetuned_model.lower():
|
104 |
+
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
|
105 |
+
model = XLMRobertaForSequenceClassification.from_pretrained(finetuned_model)
|
106 |
+
else:
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-hate')
|
108 |
+
model = AutoModelForSequenceClassification.from_pretrained(finetuned_model)
|
109 |
+
|
110 |
+
# Get the trainer
|
111 |
+
trainer = Trainer(
|
112 |
+
model=model,
|
113 |
+
processing_class=tokenizer,
|
114 |
+
)
|
115 |
+
|
116 |
+
# Tokenize the phrase
|
117 |
+
tokens = tokenizer(
|
118 |
+
phrase,
|
119 |
+
return_tensors="pt"
|
120 |
+
)
|
121 |
+
|
122 |
+
# Create the dataset
|
123 |
+
phrase_dataset = Dataset.from_dict({
|
124 |
+
"input_ids": tokens["input_ids"],
|
125 |
+
"attention_mask": tokens["attention_mask"],
|
126 |
+
})
|
127 |
+
|
128 |
+
# Get the predictions
|
129 |
+
pred = trainer.predict(phrase_dataset)
|
130 |
+
|
131 |
+
# Check if it is sexist or not
|
132 |
+
sexist = "Sexist" if pred.predictions.argmax() == 1 else "Not sexist"
|
133 |
+
return sexist
|
134 |
+
|
135 |
+
demo = gr.Interface(
|
136 |
+
fn=predict,
|
137 |
+
inputs=[
|
138 |
+
"textbox",
|
139 |
+
gr.Dropdown([
|
140 |
+
"MatteoFasulo/twitter-roberta-base-hate_69",
|
141 |
+
"MatteoFasulo/twitter-roberta-base-hate_1337",
|
142 |
+
"MatteoFasulo/twitter-roberta-base-hate_42",
|
143 |
+
"MatteoFasulo/xlm-roberta-base_69",
|
144 |
+
"MatteoFasulo/xlm-roberta-base_1337",
|
145 |
+
"MatteoFasulo/xlm-roberta-base_42",
|
146 |
+
],
|
147 |
+
label="Model",
|
148 |
+
info="Choose the model to use for prediction.",
|
149 |
+
)
|
150 |
+
],
|
151 |
+
outputs="text",
|
152 |
+
)
|
153 |
+
|
154 |
+
demo.launch()
|