|
import os |
|
import subprocess |
|
import sys |
|
import gradio as gr |
|
|
|
def install(package): |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
install("numpy") |
|
install("torch") |
|
install("transformers") |
|
install("unidecode") |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import DebertaV2TokenizerFast, DebertaV2ForQuestionAnswering |
|
from transformers.pipelines import QuestionAnsweringPipeline |
|
from transformers import pipeline |
|
from collections import Counter |
|
from unidecode import unidecode |
|
import re |
|
import string |
|
|
|
tokenizer = DebertaV2TokenizerFast.from_pretrained("osiria/deberta-italian-question-answering", revision="liteqa") |
|
model = DebertaV2ForQuestionAnswering.from_pretrained("osiria/deberta-italian-question-answering", revision="liteqa") |
|
|
|
class OsiriaQA(QuestionAnsweringPipeline): |
|
|
|
def __init__(self, punctuation = ',;.:!?()[\]{}', **kwargs): |
|
|
|
QuestionAnsweringPipeline.__init__(self, **kwargs) |
|
self.post_regex_left = "^[\s" + punctuation + "]+" |
|
self.post_regex_right = "[\s" + punctuation + "]+$" |
|
|
|
def postprocess(self, output): |
|
|
|
output = QuestionAnsweringPipeline.postprocess(self, model_outputs=output) |
|
output_length = len(output["answer"]) |
|
output["answer"] = re.sub(self.post_regex_left, "", output["answer"]) |
|
output["start"] = output["start"] + (output_length - len(output["answer"])) |
|
output_length = len(output["answer"]) |
|
output["answer"] = re.sub(self.post_regex_right, "", output["answer"]) |
|
output["end"] = output["end"] - (output_length - len(output["answer"])) |
|
|
|
return output |
|
|
|
|
|
device = torch.device("cpu") |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
pipeline_qa = OsiriaQA(model = model, tokenizer = tokenizer) |
|
|
|
|
|
header = '''-------------------------------------------------------------------------------------------------- |
|
<style> |
|
.vertical-text { |
|
writing-mode: vertical-lr; |
|
text-orientation: upright; |
|
background-color:red; |
|
} |
|
</style> |
|
<center> |
|
<body> |
|
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span> |
|
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span> |
|
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;"> E</span> |
|
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;"> M</span> |
|
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span> |
|
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span> |
|
</body> |
|
</center> |
|
<br> |
|
''' |
|
|
|
def extract(question, context): |
|
|
|
res = pipeline_qa(context = context, |
|
question = question) |
|
|
|
out_text = context[0:res["start"]] + '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴀɴs </b> ' + context[res["start"]:res["end"]] + '</span>' + context[res["end"]:] |
|
|
|
return out_text |
|
|
|
|
|
init_question= "Cos'è l'Agenzia Spaziale Italiana?" |
|
|
|
init_context = '''L'Agenzia Spaziale Italiana (ASI) è un ente governativo italiano, istituito nel 1988, che ha il compito di predisporre e attuare la politica aerospaziale italiana. Dipende e utilizza i fondi ricevuti dal Governo italiano per finanziare il progetto, lo sviluppo e la gestione operativa di missioni spaziali, con obiettivi scientifici e applicativi. |
|
Gestisce missioni spaziali in proprio e in collaborazione con i maggiori organismi spaziali internazionali, prima tra tutte l'Agenzia Spaziale Europea (dove l'Italia è il terzo maggior contribuente dopo Francia e Germania, e a cui l'ASI corrisponde una parte del proprio budget), quindi la NASA e le altre agenzie spaziali nazionali. Per la realizzazione di satelliti e strumenti scientifici, l'ASI stipula contratti con le imprese, italiane e non, operanti nel settore aerospaziale. |
|
Ha la sede principale a Roma e centri operativi a Matera (sede del Centro di geodesia spaziale Giuseppe Colombo) e Malindi, Kenya (sede del Centro spaziale Luigi Broglio). Il centro di Trapani-Milo, usato per i lanci di palloni stratosferici dal 1975, non è più operativo dal 2010.''' |
|
|
|
init_output = extract(question = init_question, context = init_context) |
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface: |
|
|
|
with gr.Row(): |
|
gr.Markdown(header) |
|
with gr.Row(): |
|
context = gr.Text(label="Context", lines = 10, value = init_context) |
|
with gr.Row(): |
|
question = gr.Text(label="Question", lines = 1, value = init_question) |
|
with gr.Row(): |
|
gr.Examples([["Cosa fa l'Agenzia Spaziale Italiana?"], |
|
["Qual è la sigla dell'Agenzia Spaziale Italiana?"], |
|
["Quando è stata fondata l'ASI?"], |
|
["Chi finanzia l'ASI?"], |
|
["Chi altro contribuisce all'Agenzia Spaziale Europea oltre all'Italia?"], |
|
["Dove ha sede l'Agenzia Spaziale Italiana?"], |
|
["Dove si trova il centro spaziale Giuseppe Colombo?"], |
|
["Dove si trova il centro spaziale Luigi Broglio?"], |
|
["Il centro di Trapani-Milo è ancora in funzione?"]], |
|
inputs=[question]) |
|
with gr.Row(): |
|
with gr.Column(): |
|
button = gr.Button("Ask") |
|
with gr.Row(): |
|
with gr.Column(): |
|
output = gr.Markdown(init_output) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>") |
|
|
|
button.click(extract, inputs=[question, context], outputs = [output]) |
|
|
|
|
|
interface.launch() |