osiria commited on
Commit
9bc0963
·
1 Parent(s): eaa395a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -6
app.py CHANGED
@@ -1,18 +1,173 @@
1
  import os
2
  import gradio as gr
 
3
  import subprocess
4
  import sys
5
 
6
  def install(package):
7
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
8
 
9
- install("datasets")
 
 
 
 
 
10
 
11
- from datasets import load_dataset
 
 
 
 
 
 
 
12
 
13
- auth_token=os.environ.get("AUTH-TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- code = load_dataset("osiria/classifier-zero-shot-italian", data_files="classifier-zero-shot-italian-code.txt", use_auth_token=auth_token)
16
- code = "\n".join(code["train"]["text"][0:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- exec(code)
 
1
  import os
2
  import gradio as gr
3
+ from gradio.components import Label
4
  import subprocess
5
  import sys
6
 
7
  def install(package):
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
9
 
10
+ install("numpy")
11
+ install("transformers")
12
+ install("torch")
13
+ install("tensorflow")
14
+ install("tensorflow-text")
15
+ install("tensorflow-hub")
16
 
17
+ import tensorflow_hub as hub
18
+ import tensorflow_text
19
+ import tensorflow as tf
20
+ import torch
21
+ from transformers import AutoTokenizer
22
+ from transformers import BertForTokenClassification
23
+ import numpy as np
24
+ import re
25
 
26
+ auth_token = os.environ.get("AUTH-TOKEN")
27
+
28
+
29
+ header = '''--------------------------------------------------------------------------------------------------
30
+ <style>
31
+ .vertical-text {
32
+ writing-mode: vertical-lr;
33
+ text-orientation: upright;
34
+ background-color:red;
35
+ }
36
+ </style>
37
+ <center>
38
+ <body>
39
+ <span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
40
+ <span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
41
+ <span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
42
+ <span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
43
+ <span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
44
+ <span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
45
+ </body>
46
+ </center>
47
+ <br>
48
+ <center>(BETA)</center>
49
+
50
+ --------------------------------------------------------------------------------------------------'''
51
+
52
+ model1 = hub.load("https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3")
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained("osiria/mbert-base-cased-pos-it", use_auth_token=auth_token)
55
+ model = BertForTokenClassification.from_pretrained("osiria/mbert-base-cased-pos-it", num_labels = 17, use_auth_token=auth_token)
56
+ model.eval()
57
+
58
+ from transformers import pipeline
59
+ pos = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)
60
+
61
+
62
+ def classify(text, classes):
63
+
64
+ text = text[:10000]
65
+ res = pos(text, aggregation_strategy = "first")
66
+ text = " ".join([r["word"] for r in res if r["entity_group"] in ["AGGETTIVO", "NOME", "NOME PROPRIO"]])
67
+ classes = {el.split(":")[0].strip(): el.split(":")[1].strip() for el in classes.split("\n")}
68
+
69
+ t_vec = model1(text).numpy()
70
+ t_vec = t_vec/np.linalg.norm(t_vec)
71
+ t_vec = t_vec.reshape(-1, 1)
72
+ cl_vecs = model1(["L'argomento di cui parliamo è quindi: " + re.sub("\s+", " ", classes[cl].lower().replace(",", " ")).strip() for cl in classes]).numpy()
73
+ cl_vecs = cl_vecs/np.sqrt(np.sum(cl_vecs**2, axis = 1).reshape(-1,1))
74
+ scores1 = [el[0] for el in np.dot(cl_vecs, t_vec).tolist()]
75
+
76
+ scores = np.array([s if s > 0 else 0 for s in scores1])
77
+ scores = (scores/np.sum(scores)).tolist()
78
+
79
+ classes = list(classes.keys())
80
+
81
+ output = sorted(classes, key = lambda cl: scores[classes.index(cl)], reverse = True)
82
+ scores = sorted(scores, reverse=True)
83
+ out = {tpl[0].capitalize(): tpl[1] for tpl in list(zip(output, scores))}
84
 
85
+ return out
86
+
87
+
88
+
89
+ init_text = '''L'Agenzia spaziale europea, nota internazionalmente con l'acronimo ESA dalla denominazione inglese European Space Agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 Paesi europei. Il suo quartier generale si trova a Parigi in Francia, con uffici a Mosca, Bruxelles, Washington e Houston. Il personale dell'ESA del 2016 ammontava a 2 200 persone (esclusi sub-appaltatori e le agenzie nazionali) e il budget del 2022 è di 7,15 miliardi di euro. Attualmente il direttore generale dell'agenzia è l'austriaco Josef Aschbacher, il quale ha sostituito il tedesco Johann-Dietrich Wörner il primo marzo 2021.
90
+
91
+ Lo spazioporto dell'ESA è il Centre Spatial Guyanais a Kourou, nella Guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. Durante gli ultimi anni il lanciatore Ariane 5 ha consentito all'ESA di raggiungere una posizione di primo piano nei lanci commerciali e l'ESA è il principale concorrente della NASA nell'esplorazione spaziale.
92
+
93
+ Le missioni scientifiche dell'ESA hanno le loro basi al Centro europeo per la ricerca e la tecnologia spaziale (ESTEC) di Noordwijk, nei Paesi Bassi. Il Centro europeo per le operazioni spaziali (ESOC), di Darmstadt in Germania, è responsabile del controllo dei satelliti ESA in orbita. Le responsabilità del Centro europeo per l'osservazione della Terra (ESRIN) di Frascati, in Italia, includono la raccolta, l'archiviazione e la distribuzione di dati satellitari ai partner dell'ESA; oltre a ciò, la struttura agisce come centro di informazione tecnologica per l'intera agenzia.
94
+ '''
95
+
96
+ init_classes = '''alimentazione: alimentazione, cibo, agricoltura, allevamento, nutrizione
97
+ arte: arte, pittura, scultura, moda
98
+ animali: animali, zoologia, botanica, piante
99
+ ambiente: ambiente, clima, sostenibilità, ecologia, inquinamento
100
+ economia: aziende, banche, economia, finanza, borsa
101
+ filosofia: etica, filosofia, religione, teologia
102
+ geografia: città, regioni, nazioni, geografia, geologia
103
+ giustizia: giustizia, magistratura, reati, criminalità
104
+ musica: musica, cantanti, gruppi musicali, generi musicali
105
+ cinema: cinema, film, televisione, spettacolo
106
+ intrattenimento: intrattenimento, tempo libero, svago, videogiochi
107
+ letteratura: letteratura, romanzi, narrativa, poesia
108
+ medicina: medicina, salute, farmaci, malattie, patologie
109
+ governo: governo, legge, politica, partiti, settore pubblico
110
+ scienza: scienza, ingegneria, tecnologia
111
+ sport: competizioni, sport
112
+ guerra: guerra, conflitti, battaglie, tematiche militari
113
+ storia: eventi, storia
114
+ società: tematiche sociali, tematiche internazionali
115
+ trasporti: automobili, treni, aerei, trasporti, veicoli
116
+ informatica: computer, smartphone, applicazioni, internet, social networks'''
117
+
118
+ init_output = classify(init_text, init_classes)
119
+
120
+ with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
121
+
122
+ with gr.Row():
123
+ gr.Markdown(header)
124
+ with gr.Row():
125
+ text = gr.Text(label="Write or paste a text", lines = 5, value = init_text)
126
+ with gr.Row():
127
+ gr.Examples([["Alessandro Manzoni, nome completo Alessandro Francesco Tommaso Antonio Manzoni (Milano, 7 marzo 1785 – Milano, 22 maggio 1873), è stato uno scrittore, poeta e drammaturgo italiano. Considerato uno dei maggiori romanzieri italiani di tutti i tempi per il suo celebre romanzo I promessi sposi, caposaldo della letteratura italiana, Manzoni ebbe il merito principale di aver gettato le basi per il romanzo moderno e di aver così patrocinato l'unità linguistica italiana, sulla scia di quella letteratura moralmente e civilmente impegnata propria dell'Illuminismo italiano."],
128
+ ["Oggi sto male perchè ho la febbre"],
129
+ ["Mi sono fatto un profilo su Facebook"],
130
+ ["Stasera mi guardo Netflix"],
131
+ ["La battaglia delle Termòpili, o delle Termòpile, fu una battaglia combattuta da un'alleanza di poleis greche, guidata dal re di Sparta Leonida I, contro l'Impero persiano governato da Serse I. Si svolse in tre giorni, durante la seconda invasione persiana della Grecia, nell'agosto o nel settembre del 480 a.C. presso lo stretto passaggio delle Termopili (o, più correttamente, Termopile, 'Le porte calde'), contemporaneamente alla battaglia navale di Capo Artemisio."],
132
+ ["Ieri ho comprato l'Xbox One"],
133
+ ["Domani per pranzo preparo la pasta alle vongole"],
134
+ ["Ho appena ascoltato l'ultimo album dei Green Day"],
135
+ ["Sono chiamati gas serra quei gas presenti nell'atmosfera che riescono a trattenere, in maniera consistente, una parte considerevole della componente nell'infrarosso della radiazione solare che colpisce la Terra ed è emessa dalla superficie terrestre, dall'atmosfera e dalle nuvole. Tale proprietà causa il fenomeno noto come 'effetto serra' ed è verificabile da un'analisi spettroscopica in laboratorio."]],
136
+ inputs=[text])
137
+ with gr.Row():
138
+ classes = gr.Text(label="Classes (write a few classes in the form 'class_name: word1, word2, word3...' using 1 to 5 descriptive words for each class)", lines = 1, value = '''alimentazione: alimentazione, cibo, agricoltura, allevamento, nutrizione
139
+ arte: arte, pittura, scultura, moda
140
+ animali: animali, zoologia, botanica, piante
141
+ ambiente: ambiente, clima, sostenibilità, ecologia, inquinamento
142
+ economia: aziende, banche, economia, finanza, borsa
143
+ filosofia: etica, filosofia, religione, teologia
144
+ geografia: città, regioni, nazioni, geografia, geologia
145
+ giustizia: giustizia, magistratura, reati, criminalità
146
+ musica: musica, cantanti, gruppi musicali, generi musicali
147
+ cinema: cinema, film, televisione, spettacolo
148
+ intrattenimento: intrattenimento, tempo libero, svago, videogiochi
149
+ letteratura: letteratura, romanzi, narrativa, poesia
150
+ medicina: medicina, salute, farmaci, malattie, patologie
151
+ governo: governo, legge, politica, partiti, settore pubblico
152
+ scienza: scienza, ingegneria, tecnologia
153
+ sport: competizioni, sport
154
+ guerra: guerra, conflitti, battaglie, tematiche militari
155
+ storia: eventi, storia
156
+ società: tematiche sociali, tematiche internazionali
157
+ trasporti: automobili, treni, aerei, trasporti, veicoli
158
+ informatica: computer, smartphone, applicazioni, internet, social networks''')
159
+ with gr.Row():
160
+ button = gr.Button("Classify").style(full_width=False)
161
+
162
+ with gr.Row():
163
+ with gr.Column():
164
+ output = Label(label="Result")
165
+
166
+ with gr.Row():
167
+ with gr.Column():
168
+ footer = gr.Markdown("<center>A few examples in this demo are extracted from Wikipedia</center>")
169
+
170
+ button.click(classify, inputs=[text, classes], outputs = [output])
171
+
172
 
173
+ interface.launch()