Spaces:
Sleeping
Sleeping
First upload
Browse files- app.py +230 -0
- requirements.txt +67 -0
app.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, RobertaForQuestionAnswering
|
5 |
+
|
6 |
+
# 0.モデルのロード, Examplesの準備
|
7 |
+
# TODO 評価対象の要約モデルをロード
|
8 |
+
tokenizer_sum = AutoTokenizer.from_pretrained("tsmatz/mt5_summarize_japanese")
|
9 |
+
model_sum = AutoModelForSeq2SeqLM.from_pretrained("tsmatz/mt5_summarize_japanese")
|
10 |
+
|
11 |
+
# TODO 質問文の生成モデルをロード
|
12 |
+
tokenizer_gen_q = T5Tokenizer.from_pretrained("sonoisa/t5-base-japanese-question-generation")
|
13 |
+
model_gen_q = T5ForConditionalGeneration.from_pretrained("sonoisa/t5-base-japanese-question-generation")
|
14 |
+
|
15 |
+
# TODO 回答の生成モデルをロード
|
16 |
+
tokenizer_qa = AutoTokenizer.from_pretrained("tsmatz/roberta_qa_japanese")
|
17 |
+
model_qa = RobertaForQuestionAnswering.from_pretrained("tsmatz/roberta_qa_japanese")
|
18 |
+
|
19 |
+
# Example 1
|
20 |
+
eg_text_1 = """
|
21 |
+
ポケットモンスターの原点は、1996年2月27日に発売されたゲームボーイ用ソフト『ポケットモンスター 赤・緑』である。
|
22 |
+
開発元はゲームフリーク。コンセプトメーカーにしてディレクターを務めたのは、同社代表取締役でもある田尻智。
|
23 |
+
この作品が小学生を中心に、口コミから火が点き大ヒットとなり、以降も多くの続編が発売されている(詳しくは「ポケットモンスター(ゲーム)」を参照)。
|
24 |
+
ゲーム本編作品だけでなく、派生作品や関連作品が数多く発売されている(詳しくはポケットモンスターの関連ゲームを参照)。
|
25 |
+
|
26 |
+
ポケモンはゲームのみならず、アニメ化、キャラクター商品化、カードゲーム、アーケードゲームと様々なメディアミックス展開がなされ、日本国外でも人気を獲得している。
|
27 |
+
|
28 |
+
ポケモン関連ゲームソフトの累計出荷数は、全世界で2017年11月時点で3億本以上[1]、2022年3月時点で4億4000万本以上に達している[2]。
|
29 |
+
その中で、メインシリーズの累計販売本数は2016年2月時点での最新作、ニンテンドー3DS『オメガルビー・アルファサファイア』までの25作品で2億100万本となる[3]。
|
30 |
+
"""
|
31 |
+
eg_ans_1_1 = "2月27日"
|
32 |
+
eg_ans_1_2 = "ポケットモンスター 赤・緑"
|
33 |
+
|
34 |
+
# Example 2
|
35 |
+
eg_text_2 = """
|
36 |
+
アンパンマンの生みの親であるやなせたかしの作品で1968年に「バラの花とジョー」、
|
37 |
+
「チリンの鈴」の絵本や映画にいち早くアンパンマンが登場しているが、この時はまだ人間の姿。
|
38 |
+
この童話は一年間連載された。[5]アンパンマン、やなせたかしの作品としての、「アンパンマン」は、
|
39 |
+
PHP研究所が発行する青年向け雑誌『PHP』の通巻第257号に当たる、『こどものえほん』の1969年10月号[6](同年10月1日刊行)に掲載された青年向け読物、
|
40 |
+
やなせたかし(絵と文)「アンパンマン」という形が初出である[7][8][9]。
|
41 |
+
この時期、やなせが『こどものえほん』のために執筆した読物は連載12本の短編で、「アンパンマン」はその6本目の作品であった。
|
42 |
+
これら12篇は、株式会社山梨シルクセンター(※3年後、株式会社サンリオへ社名変更)より単行本『十二の真珠』名義で1970年に刊行された。
|
43 |
+
|
44 |
+
空腹に喘ぐ人の所へ駆け付けて、自らの大事な持ち物であるパンを差し出して食べるよう勧めるという、のちのアンパンマンに通じる物語の骨組みが、
|
45 |
+
この作品のおいて早くも整えられている[10][6]。
|
46 |
+
絵本・漫画・アニメなど、のちに描かれるアンパンマンとの大きな違いと言えば、第一に主人公のアンパンマンが普通の人間のおじさんであり[10][6]、
|
47 |
+
パンは所有物に過ぎなかったことである。
|
48 |
+
"""
|
49 |
+
eg_ans_2_1 = "アンパンマン"
|
50 |
+
eg_ans_2_2 = "やなせたかし"
|
51 |
+
|
52 |
+
"""#### イベント用の関数の実装
|
53 |
+
- 要約生成の関数作成
|
54 |
+
- 質問生成の関数作成
|
55 |
+
- 回答生成の関数作成
|
56 |
+
"""
|
57 |
+
|
58 |
+
# 1. イベント用の関数
|
59 |
+
def summy(text):
|
60 |
+
"""要約
|
61 |
+
|
62 |
+
Args
|
63 |
+
text: str
|
64 |
+
要約対象のテキスト
|
65 |
+
|
66 |
+
Returns
|
67 |
+
summarize_text: str
|
68 |
+
要約結果のテキスト
|
69 |
+
|
70 |
+
TODO
|
71 |
+
処理の実装
|
72 |
+
"""
|
73 |
+
inputs = tokenizer_sum("summarize: " + text, return_tensors="pt")
|
74 |
+
outputs = model_sum.generate(
|
75 |
+
inputs["input_ids"],
|
76 |
+
max_new_tokens=300,
|
77 |
+
min_length=150,
|
78 |
+
num_beams=5
|
79 |
+
)
|
80 |
+
summarize_text = tokenizer_sum.decode(outputs[0], skip_special_tokens=True)
|
81 |
+
return summarize_text
|
82 |
+
|
83 |
+
def generate_questions(answer_1, answer_2, text):
|
84 |
+
"""質問生成
|
85 |
+
|
86 |
+
Args
|
87 |
+
answers: list[str]
|
88 |
+
質問生成のための正解単語のリスト
|
89 |
+
text: str
|
90 |
+
質問文を生成する際に参照するテキスト
|
91 |
+
|
92 |
+
Returns
|
93 |
+
generated_questions: list[str]
|
94 |
+
生成された質問文のリスト
|
95 |
+
|
96 |
+
TODO
|
97 |
+
処理の実装
|
98 |
+
"""
|
99 |
+
# 質問文の生成
|
100 |
+
answer_context_list = [(answer_1, text), (answer_2, text)]
|
101 |
+
|
102 |
+
generated_questions = []
|
103 |
+
|
104 |
+
for answer, context in answer_context_list:
|
105 |
+
input = tokenizer_gen_q(f"answer: {answer} context: {context}", return_tensors="pt")
|
106 |
+
|
107 |
+
# 質問文を生成する
|
108 |
+
output = model_gen_q.generate(
|
109 |
+
input['input_ids'],
|
110 |
+
max_new_tokens=100,
|
111 |
+
num_beams=4
|
112 |
+
)
|
113 |
+
|
114 |
+
# 生成された問題文のトークン列を文字列に変換する。
|
115 |
+
output = tokenizer_gen_q.decode(output[0], skip_special_tokens=True)
|
116 |
+
|
117 |
+
generated_questions.append(output)
|
118 |
+
|
119 |
+
return generated_questions
|
120 |
+
|
121 |
+
def extract_answer(question, text):
|
122 |
+
"""質問応答
|
123 |
+
|
124 |
+
Args
|
125 |
+
question: str
|
126 |
+
質問文のテキスト
|
127 |
+
text: str
|
128 |
+
質問に回答するために参照するテキスト
|
129 |
+
|
130 |
+
Returns
|
131 |
+
answer: str
|
132 |
+
回答のテキスト
|
133 |
+
|
134 |
+
TODO
|
135 |
+
処理の実装
|
136 |
+
"""
|
137 |
+
inputs = tokenizer_qa(question, text, return_tensors="pt")
|
138 |
+
|
139 |
+
outputs = model_qa(**inputs)
|
140 |
+
answer_start_scores = outputs.start_logits
|
141 |
+
answer_end_scores = outputs.end_logits
|
142 |
+
|
143 |
+
answer_start = torch.argmax(answer_start_scores)
|
144 |
+
answer_end = torch.argmax(answer_end_scores) + 1
|
145 |
+
|
146 |
+
input_ids = inputs["input_ids"].tolist()[0]
|
147 |
+
|
148 |
+
answer = tokenizer_qa.decode(input_ids[answer_start:answer_end])
|
149 |
+
|
150 |
+
return answer
|
151 |
+
|
152 |
+
def extract_answer_all(gen_q_1, gen_q_2, source_text, sum_text):
|
153 |
+
"""extract_answer()をまとめて実行する
|
154 |
+
TODO
|
155 |
+
処理の実装
|
156 |
+
"""
|
157 |
+
a_source_1 = extract_answer(gen_q_1, source_text)
|
158 |
+
a_sum_1 = extract_answer(gen_q_1, sum_text)
|
159 |
+
a_source_2 = extract_answer(gen_q_2, source_text)
|
160 |
+
a_sum_2 = extract_answer(gen_q_2, sum_text)
|
161 |
+
|
162 |
+
return a_source_1, a_sum_1, a_source_2, a_sum_2
|
163 |
+
|
164 |
+
"""#### UIの実装
|
165 |
+
- `gr.Blocks()`を使ったUIの実装
|
166 |
+
- 要約生成のUI作成
|
167 |
+
- 質問生成のUI作成
|
168 |
+
- 回答生成のUI作成
|
169 |
+
- イベントの実装(btn.clickなど)
|
170 |
+
- 要約生成の実行ボタン作成
|
171 |
+
- 質問生成の実行ボタン作成
|
172 |
+
- 回答生成の実行ボタン作成
|
173 |
+
"""
|
174 |
+
|
175 |
+
# 2. UIの定義
|
176 |
+
with gr.Blocks() as demo:
|
177 |
+
gr.Markdown("### 1. 要約生成")
|
178 |
+
# TODO 要約のための入出力UIの作成
|
179 |
+
text_source = gr.Textbox(label="要約対象")
|
180 |
+
btn_summy = gr.Button("要約生成")
|
181 |
+
text_summy = gr.Textbox(label="要約結果")
|
182 |
+
|
183 |
+
gr.Markdown("### 2. 質問生成")
|
184 |
+
# TODO 質問文生成のための入力UIの作成
|
185 |
+
with gr.Row():
|
186 |
+
text_q_1 = gr.Textbox(label="正解1")
|
187 |
+
text_q_2 = gr.Textbox(label="正解2")
|
188 |
+
btn_generate_questions = gr.Button("質問生成")
|
189 |
+
|
190 |
+
gr.Markdown("### 3. 回答生成")
|
191 |
+
# TODO 質問文を表示するUIの作成
|
192 |
+
with gr.Row():
|
193 |
+
text_gq_1 = gr.Textbox(label="1番目の質問")
|
194 |
+
text_gq_2 = gr.Textbox(label="2番目の質問")
|
195 |
+
btn_extract_answer = gr.Button("回答生成")
|
196 |
+
# TODO それぞれの回答を表示するUIの作成
|
197 |
+
with gr.Row():
|
198 |
+
with gr.Column():
|
199 |
+
text_asrc_1 = gr.Textbox(label="sourceからの答え1")
|
200 |
+
text_asum_1 = gr.Textbox(label="sumからの答え1")
|
201 |
+
with gr.Column():
|
202 |
+
text_asrc_2 = gr.Textbox(label="sourceからの答え2")
|
203 |
+
text_asum_2 = gr.Textbox(label="sumからの答え2")
|
204 |
+
|
205 |
+
# 2. イベント発火
|
206 |
+
btn_summy.click(
|
207 |
+
summy,
|
208 |
+
inputs=text_source, # TODO 定義したUIのコンポーネントを与える
|
209 |
+
outputs=text_summy # TODO 定義したUIのコンポーネントを与える
|
210 |
+
)
|
211 |
+
btn_generate_questions.click(
|
212 |
+
generate_questions,
|
213 |
+
inputs=[text_q_1, text_q_2, text_summy], # TODO 定義したUIのコンポーネントを与える
|
214 |
+
outputs=[text_gq_1, text_gq_2] # TODO 定義したUIのコンポーネントを与える
|
215 |
+
)
|
216 |
+
btn_extract_answer.click(extract_answer_all,
|
217 |
+
inputs=[text_gq_1, text_gq_2, text_source, text_summy], # TODO 定義したUIのコンポーネントを与える
|
218 |
+
outputs=[text_asrc_1, text_asum_1, text_asrc_2, text_asum_2] # TODO 定義したUIのコンポーネントを与える
|
219 |
+
)
|
220 |
+
|
221 |
+
# Examplesの定義
|
222 |
+
gr.Markdown("## Examples")
|
223 |
+
gr.Examples(
|
224 |
+
examples = [[eg_text_1, eg_ans_1_1, eg_ans_1_2], [eg_text_2, eg_ans_2_1, eg_ans_2_2]],
|
225 |
+
inputs = [text_source, text_q_1, text_q_2]
|
226 |
+
) # TODO Exampleにデータを与えて、表示させる
|
227 |
+
|
228 |
+
"""#### demoの起動"""
|
229 |
+
|
230 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==22.1.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.2
|
5 |
+
anyio==3.6.2
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==22.2.0
|
8 |
+
certifi==2022.12.7
|
9 |
+
charset-normalizer==2.1.1
|
10 |
+
click==8.1.3
|
11 |
+
contourpy==1.0.7
|
12 |
+
cycler==0.11.0
|
13 |
+
entrypoints==0.4
|
14 |
+
fastapi==0.89.1
|
15 |
+
ffmpy==0.3.0
|
16 |
+
filelock==3.9.0
|
17 |
+
fonttools==4.38.0
|
18 |
+
frozenlist==1.3.3
|
19 |
+
fsspec==2023.1.0
|
20 |
+
gradio==3.17.1
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==0.16.3
|
23 |
+
httpx==0.23.3
|
24 |
+
huggingface-hub==0.12.0
|
25 |
+
idna==3.4
|
26 |
+
Jinja2==3.1.2
|
27 |
+
jsonschema==4.17.3
|
28 |
+
kiwisolver==1.4.4
|
29 |
+
linkify-it-py==1.0.3
|
30 |
+
markdown-it-py==2.1.0
|
31 |
+
MarkupSafe==2.1.2
|
32 |
+
matplotlib==3.6.3
|
33 |
+
mdit-py-plugins==0.3.3
|
34 |
+
mdurl==0.1.2
|
35 |
+
multidict==6.0.4
|
36 |
+
numpy==1.24.2
|
37 |
+
orjson==3.8.5
|
38 |
+
packaging==23.0
|
39 |
+
pandas==1.5.3
|
40 |
+
Pillow==9.4.0
|
41 |
+
pycryptodome==3.17
|
42 |
+
pydantic==1.10.4
|
43 |
+
pydub==0.25.1
|
44 |
+
pyparsing==3.0.9
|
45 |
+
pyrsistent==0.19.3
|
46 |
+
python-dateutil==2.8.2
|
47 |
+
python-multipart==0.0.5
|
48 |
+
pytz==2022.7.1
|
49 |
+
PyYAML==6.0
|
50 |
+
regex==2022.10.31
|
51 |
+
requests==2.28.2
|
52 |
+
rfc3986==1.5.0
|
53 |
+
sentencepiece==0.1.97
|
54 |
+
six==1.16.0
|
55 |
+
sniffio==1.3.0
|
56 |
+
starlette==0.22.0
|
57 |
+
tokenizers==0.13.2
|
58 |
+
toolz==0.12.0
|
59 |
+
torch==1.13.1
|
60 |
+
tqdm==4.64.1
|
61 |
+
transformers==4.26.0
|
62 |
+
typing_extensions==4.4.0
|
63 |
+
uc-micro-py==1.0.1
|
64 |
+
urllib3==1.26.14
|
65 |
+
uvicorn==0.20.0
|
66 |
+
websockets==10.4
|
67 |
+
yarl==1.8.2
|