osiria commited on
Commit
ee9cd4e
·
1 Parent(s): 58b2de4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import subprocess
4
+ import sys
5
+
6
+ def install(package):
7
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
8
+
9
+ install("numpy")
10
+ install("torch")
11
+ install("transformers")
12
+ install("unidecode")
13
+
14
+ import numpy as np
15
+ import torch
16
+ from transformers import AutoTokenizer
17
+ from transformers import BertForTokenClassification
18
+ from collections import Counter
19
+ from unidecode import unidecode
20
+ import string
21
+ import re
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("osiria/bert-italian-uncased-ner")
24
+ model = BertForTokenClassification.from_pretrained("osiria/bert-italian-uncased-ner", num_labels = 5)
25
+ device = torch.device("cpu")
26
+ model = model.to(device)
27
+ model.eval()
28
+
29
+ from transformers import pipeline
30
+ ner = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)
31
+
32
+
33
+ header = '''--------------------------------------------------------------------------------------------------
34
+ <style>
35
+ .vertical-text {
36
+ writing-mode: vertical-lr;
37
+ text-orientation: upright;
38
+ background-color:red;
39
+ }
40
+ </style>
41
+ <center>
42
+ <body>
43
+ <span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
44
+ <span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
45
+ <span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
46
+ <span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
47
+ <span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
48
+ <span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
49
+ </body>
50
+ </center>
51
+ <br>
52
+ <center>(BETA)</center>
53
+ '''
54
+
55
+ maps = {"O": "NONE", "PER": "PER", "LOC": "LOC", "ORG": "ORG", "MISC": "MISC", "DATE": "DATE"}
56
+ reg_month = "(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre|january|february|march|april|may|june|july|august|september|october|november|december)"
57
+ reg_date = "(?:\d{1,2}\°{0,1}|primo|\d{1,2}\º{0,1})" + " " + reg_month + " " + "\d{4}|"
58
+ reg_date = reg_date + reg_month + " " + "\d{4}|"
59
+ reg_date = reg_date + "\d{1,2}" + " " + reg_month
60
+ reg_date = reg_date + "\d{1,2}" + "(?:\/|\.)\d{1,2}(?:\/|\.)" + "\d{4}|"
61
+ reg_date = reg_date + "(?<=dal )\d{4}|(?<=al )\d{4}|(?<=nel )\d{4}|(?<=anno )\d{4}|(?<=del )\d{4}|"
62
+ reg_date = reg_date + "\d{1,5} a\.c\.|\d{1,5} d\.c\."
63
+ map_punct = {"’": "'", "«": '"', "»": '"', "”": '"', "“": '"', "–": "-", "$": ""}
64
+ unk_tok = 9005
65
+
66
+ merge_th_1 = 0.8
67
+ merge_th_2 = 0.4
68
+ min_th = 0.55
69
+
70
+ def extract(text):
71
+
72
+ text = text.strip()
73
+ for mp in map_punct:
74
+ text = text.replace(mp, map_punct[mp])
75
+ text = re.sub("\[\d+\]", "", text)
76
+
77
+ warn_flag = False
78
+
79
+ res_total = []
80
+ out_text = ""
81
+
82
+ for p_text in text.split("\n"):
83
+
84
+ if p_text:
85
+
86
+ toks = tokenizer.encode(p_text)
87
+ if unk_tok in toks:
88
+ warn_flag = True
89
+
90
+ res_orig = ner(p_text, aggregation_strategy = "first")
91
+ res_orig = [el for r, el in enumerate(res_orig) if len(el["word"].strip()) > 1]
92
+ res = []
93
+
94
+ for r, ent in enumerate(res_orig):
95
+ if r > 0 and ent["score"] < merge_th_1 and ent["start"] <= res[-1]["end"] + 1 and ent["score"] <= res[-1]["score"]:
96
+ res[-1]["word"] = res[-1]["word"] + " " + ent["word"]
97
+ res[-1]["score"] = merge_th_1*(res[-1]["score"] > merge_th_2)
98
+ res[-1]["end"] = ent["end"]
99
+ elif r < len(res_orig) - 1 and ent["score"] < merge_th_1 and res_orig[r+1]["start"] <= ent["end"] + 1 and res_orig[r+1]["score"] > ent["score"]:
100
+ res_orig[r+1]["word"] = ent["word"] + " " + res_orig[r+1]["word"]
101
+ res_orig[r+1]["score"] = merge_th_1*(res_orig[r+1]["score"] > merge_th_2)
102
+ res_orig[r+1]["start"] = ent["start"]
103
+ else:
104
+ res.append(ent)
105
+
106
+ res = [el for r, el in enumerate(res) if el["score"] >= min_th]
107
+
108
+ dates = [{"entity_group": "DATE", "score": 1.0, "word": p_text[el.span()[0]:el.span()[1]], "start": el.span()[0], "end": el.span()[1]} for el in re.finditer(reg_date, p_text, flags = re.IGNORECASE)]
109
+ res.extend(dates)
110
+ res = sorted(res, key = lambda t: t["start"])
111
+ res_total.extend(res)
112
+
113
+ chunks = [("", "", 0, "NONE")]
114
+
115
+ for el in res:
116
+ if maps[el["entity_group"]] != "NONE":
117
+ tag = maps[el["entity_group"]]
118
+ chunks.append((p_text[el["start"]: el["end"]], p_text[chunks[-1][2]:el["end"]], el["end"], tag))
119
+
120
+ if chunks[-1][2] < len(p_text):
121
+ chunks.append(("END", p_text[chunks[-1][2]:], -1, "NONE"))
122
+ chunks = chunks[1:]
123
+
124
+ n_text = []
125
+
126
+ for i, chunk in enumerate(chunks):
127
+
128
+ rep = chunk[0]
129
+
130
+ if chunk[3] == "PER":
131
+ rep = '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> ' + chunk[0] + '</span>'
132
+ elif chunk[3] == "LOC":
133
+ rep = '<span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> ' + chunk[0] + '</span>'
134
+ elif chunk[3] == "ORG":
135
+ rep = '<span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> ' + chunk[0] + '</span>'
136
+ elif chunk[3] == "MISC":
137
+ rep = '<span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> ' + chunk[0] + '</span>'
138
+ elif chunk[3] == "DATE":
139
+ rep = '<span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> ' + chunk[0] + '</span>'
140
+
141
+ n_text.append(chunk[1].replace(chunk[0], rep))
142
+
143
+ n_text = "".join(n_text)
144
+ if out_text:
145
+ out_text = out_text + "<br>" + n_text
146
+ else:
147
+ out_text = n_text
148
+
149
+
150
+ tags = [el["word"] for el in res_total if el["entity_group"] not in ['DATE', None]]
151
+ cnt = Counter(tags)
152
+ tags = sorted(list(set([el for el in tags if cnt[el] > 1])), key = lambda t: cnt[t]*np.exp(-tags.index(t)))[::-1]
153
+ tags = [" ".join(re.sub("[^A-Za-z0-9\s]", "", unidecode(tag)).split()) for tag in tags]
154
+ tags = ['<span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ </b> ' + el + '</span>' for el in tags]
155
+ tags = " ".join(tags)
156
+
157
+ if tags:
158
+ out_text = out_text + "<br><br><b>Tags:</b> " + tags
159
+
160
+ if warn_flag:
161
+ out_text = out_text + "<br><br><b>Warning ⚠️:</b> Unknown tokens detected in text. The model might behave erratically"
162
+
163
+ return out_text
164
+
165
+
166
+
167
+ init_text = '''L'Agenzia spaziale europea, nota internazionalmente con l'acronimo ESA dalla denominazione inglese European Space Agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 Paesi europei. Il suo quartier generale si trova a Parigi in Francia, con uffici a Mosca, Bruxelles, Washington e Houston. Il personale dell'ESA del 2016 ammontava a 2 200 persone (esclusi sub-appaltatori e le agenzie nazionali) e il budget del 2022 è di 7,15 miliardi di euro. Attualmente il direttore generale dell'agenzia è l'austriaco Josef Aschbacher, il quale ha sostituito il tedesco Johann-Dietrich Wörner il primo marzo 2021.
168
+ Lo spazioporto dell'ESA è il Centre Spatial Guyanais a Kourou, nella Guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. Durante gli ultimi anni il lanciatore Ariane 5 ha consentito all'ESA di raggiungere una posizione di primo piano nei lanci commerciali e l'ESA è il principale concorrente della NASA nell'esplorazione spaziale.
169
+ Le missioni scientifiche dell'ESA hanno le loro basi al Centro europeo per la ricerca e la tecnologia spaziale (ESTEC) di Noordwijk, nei Paesi Bassi. Il Centro europeo per le operazioni spaziali (ESOC), di Darmstadt in Germania, è responsabile del controllo dei satelliti ESA in orbita. Le responsabilità del Centro europeo per l'osservazione della Terra (ESRIN) di Frascati, in Italia, includono la raccolta, l'archiviazione e la distribuzione di dati satellitari ai partner dell'ESA; oltre a ciò, la struttura agisce come centro di informazione tecnologica per l'intera agenzia. [...]
170
+ L'Agenzia Spaziale Italiana (ASI) venne fondata nel 1988 per promuovere, coordinare e condurre le attività spaziali in Italia. Opera in collaborazione con il Ministero dell'università e della ricerca scientifica e coopera in numerosi progetti con entità attive nella ricerca scientifica e nelle attività commerciali legate allo spazio. Internazionalmente l'ASI fornisce la delegazione italiana per l'Agenzia Spaziale Europea e le sue sussidiarie.'''
171
+
172
+ init_output = extract(init_text)
173
+
174
+
175
+
176
+
177
+ with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
178
+
179
+ with gr.Row():
180
+ gr.Markdown(header)
181
+ with gr.Row():
182
+ text = gr.Text(label="Extract entities", lines = 10, value = init_text)
183
+ with gr.Row():
184
+ with gr.Column():
185
+ button = gr.Button("Extract").style(full_width=False)
186
+ with gr.Row():
187
+ with gr.Column():
188
+ entities = gr.Markdown(init_output)
189
+
190
+ with gr.Row():
191
+ with gr.Column():
192
+ gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>")
193
+
194
+ button.click(extract, inputs=[text], outputs = [entities])
195
+
196
+
197
+ interface.launch()