MatteoFasulo commited on
Commit
976b6b9
·
1 Parent(s): 65f1a6e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+ import nltk
4
+ from nltk import WordNetLemmatizer
5
+ from datasets import Dataset
6
+ from transformers import AutoTokenizer
7
+ from transformers import AutoModelForSequenceClassification
8
+ from transformers import XLMRobertaForSequenceClassification
9
+ from transformers import Trainer
10
+ import gradio as gr
11
+
12
+ def preprocess_text(text: str) -> str:
13
+ """
14
+ Preprocesses the input text by removing or replacing specific patterns.
15
+
16
+ Args:
17
+ text (str): The input text to be preprocessed.
18
+
19
+ Returns:
20
+ str: The preprocessed text with URLs, mentions, hashtags, emojis,
21
+ special characters removed, 'and' replaced, and extra spaces trimmed.
22
+ """
23
+ # Define patterns
24
+ URL_PATTERN_STR = r"""(?i)((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info
25
+ |int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|
26
+ bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|
27
+ cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|
28
+ gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|
29
+ la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|
30
+ nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|
31
+ sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|
32
+ uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()]
33
+ *?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’])|(?:(?<!@)
34
+ [a-z0-9]+(?:[.\-][a-z0-9]+)*[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name
35
+ |post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn
36
+ |bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg
37
+ |eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id
38
+ |ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|
39
+ md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|
40
+ ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|
41
+ sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|
42
+ za|zm|zw)\b/?(?!@)))"""
43
+ URL_PATTERN = re.compile(URL_PATTERN_STR, re.IGNORECASE)
44
+ HASHTAG_PATTERN = re.compile(r'#\w*')
45
+ MENTION_PATTERN = re.compile(r'@\w*')
46
+ PUNCT_REPEAT_PATTERN = re.compile(r'([!?.]){2,}')
47
+ ELONG_PATTERN = re.compile(r'\b(\S*?)(.)\2{2,}\b')
48
+ WORD_PATTERN = re.compile(r'[^\w<>\s]')
49
+ # Convert URL to <URL> so that GloVe will have a vector for it
50
+ text = re.sub(URL_PATTERN, ' <URL>', text)
51
+ # Add spaces around slashes
52
+ text = re.sub(r"/", " / ", text)
53
+ # Replace mentions with <USER>
54
+ text = re.sub(MENTION_PATTERN, ' <USER> ', text)
55
+ # Replace numbers with <NUMBER>
56
+ text = re.sub(r"[-+]?[.\d]*[\d]+[:,.\d]*", " <NUMBER> ", text)
57
+ # Replace hashtags with <HASHTAG>
58
+ text = re.sub(HASHTAG_PATTERN, ' <HASHTAG> ', text)
59
+ #text = self.AND_PATTERN.sub('and', text) # &amp; already in the Vocab of GloVe-twitter
60
+ # Replace multiple punctuation marks with <REPEAT>
61
+ text = re.sub(PUNCT_REPEAT_PATTERN, lambda match: f" {match.group(1)} <REPEAT> ", text)
62
+ # Replace elongated words with <ELONG>
63
+ text = re.sub(ELONG_PATTERN, lambda match: f" {match.group(1)}{match.group(2)} <ELONG> ", text)
64
+ #text = emoji.replace_emoji(text, replace='') # some emojis are in the vocab so we do not remove them, the others will be OOVs
65
+ text = text.strip()
66
+ # Get only words
67
+ text = re.sub(WORD_PATTERN, ' ', text)
68
+ text = text.strip()
69
+ # Convert stylized Unicode characters to plain text (removes bold text, etc.)
70
+ text = ''.join(c for c in unicodedata.normalize('NFKD', text) if not unicodedata.combining(c))
71
+ return text
72
+
73
+ def lemmatize_text(text: str) -> str:
74
+ """
75
+ Lemmatizes the input text using the WordNet lemmatizer.
76
+
77
+ This method attempts to lemmatize each word in the input text. If the WordNet
78
+ data is not available, it will download the necessary data and retry.
79
+
80
+ Args:
81
+ text (str): The input text to be lemmatized.
82
+
83
+ Returns:
84
+ str: The lemmatized text.
85
+ """
86
+ lemmatizer = WordNetLemmatizer()
87
+ downloaded = False
88
+ while not downloaded:
89
+ try:
90
+ lemmatizer.lemmatize(text)
91
+ downloaded = True
92
+ except LookupError:
93
+ print("Downloading WordNet...")
94
+ nltk.download('wordnet')
95
+ return ' '.join([lemmatizer.lemmatize(word) for word in text.split()])
96
+
97
+ def predict(phrase: str, finetuned_model: str):
98
+ phrase = preprocess_text(phrase)
99
+ phrase = lemmatize_text(phrase)
100
+ phrase = phrase.lower()
101
+
102
+ # Get the tokenizer and model
103
+ if 'xlm' in finetuned_model.lower():
104
+ tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
105
+ model = XLMRobertaForSequenceClassification.from_pretrained(finetuned_model)
106
+ else:
107
+ tokenizer = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-hate')
108
+ model = AutoModelForSequenceClassification.from_pretrained(finetuned_model)
109
+
110
+ # Get the trainer
111
+ trainer = Trainer(
112
+ model=model,
113
+ processing_class=tokenizer,
114
+ )
115
+
116
+ # Tokenize the phrase
117
+ tokens = tokenizer(
118
+ phrase,
119
+ return_tensors="pt"
120
+ )
121
+
122
+ # Create the dataset
123
+ phrase_dataset = Dataset.from_dict({
124
+ "input_ids": tokens["input_ids"],
125
+ "attention_mask": tokens["attention_mask"],
126
+ })
127
+
128
+ # Get the predictions
129
+ pred = trainer.predict(phrase_dataset)
130
+
131
+ # Check if it is sexist or not
132
+ sexist = "Sexist" if pred.predictions.argmax() == 1 else "Not sexist"
133
+ return sexist
134
+
135
+ demo = gr.Interface(
136
+ fn=predict,
137
+ inputs=[
138
+ "textbox",
139
+ gr.Dropdown([
140
+ "MatteoFasulo/twitter-roberta-base-hate_69",
141
+ "MatteoFasulo/twitter-roberta-base-hate_1337",
142
+ "MatteoFasulo/twitter-roberta-base-hate_42",
143
+ "MatteoFasulo/xlm-roberta-base_69",
144
+ "MatteoFasulo/xlm-roberta-base_1337",
145
+ "MatteoFasulo/xlm-roberta-base_42",
146
+ ],
147
+ label="Model",
148
+ info="Choose the model to use for prediction.",
149
+ )
150
+ ],
151
+ outputs="text",
152
+ )
153
+
154
+ demo.launch()