Spaces:
Runtime error
Runtime error
Add initial attempt of a code framework.
Browse files- README.md +5 -5
- app.py +342 -0
- models.py +112 -0
- requirements.txt +8 -0
- scrollbar.css +30 -0
- utils.py +161 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
1 |
---
|
2 |
+
title: Athena's Lens
|
3 |
+
emoji: 🦉
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.3.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List, Union, Dict, Mapping
|
2 |
+
import base64
|
3 |
+
import os
|
4 |
+
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
import gradio as gr
|
7 |
+
from spacy import displacy
|
8 |
+
from transformers import (
|
9 |
+
AutoTokenizer,
|
10 |
+
AutoModelForTokenClassification,
|
11 |
+
BatchEncoding,
|
12 |
+
AutoModelForSeq2SeqLM,
|
13 |
+
DataCollatorForTokenClassification,
|
14 |
+
)
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from utils import get_dependencies, preprocess_text
|
18 |
+
from models import (
|
19 |
+
DependencyRobertaForTokenClassification,
|
20 |
+
LabelRobertaForTokenClassification,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
DEFAULT_TEXT = "τίω δέ μιν ἐν καρὸς αἴσῃ."
|
25 |
+
BUTTON_CSS = "float: right; --tw-border-opacity: 1; border-color: rgb(229 231 235 / var(--tw-border-opacity)); --tw-gradient-from: rgb(243 244 246 / 0.7); --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to, rgb(243 244 246 / 0)); --tw-gradient-to: rgb(229 231 235 / 0.8); --tw-text-opacity: 1; color: rgb(55 65 81 / var(--tw-text-opacity)); border-width: 1px; --tw-bg-opacity: 1; background-color: rgb(255 255 255 / var(--tw-bg-opacity)); background-image: linear-gradient(to bottom right, var(--tw-gradient-stops)); display: inline-flex; flex: 1 1 0%; align-items: center; justify-content: center; --tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); --tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color); box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -webkit-appearance: button; border-radius: 0.5rem; padding-top: 0.5rem; padding-bottom: 0.5rem; padding-left: 1rem; padding-right: 1rem; font-size: 1rem; line-height: 1.5rem; font-weight: 600;"
|
26 |
+
DEFAULT_COLOR = "white"
|
27 |
+
|
28 |
+
MODEL_PATHS = {
|
29 |
+
"POS": "bowphs/testid",
|
30 |
+
"LEMMATIZATION": "bowphs/lemmatization-demo",
|
31 |
+
"DEPENDENCY": "bowphs/depenBERTa_perseus",
|
32 |
+
"LABELS": "bowphs/depenBERTa_labler_perseus",
|
33 |
+
}
|
34 |
+
MODEL_MAX_LENGTH = 512
|
35 |
+
|
36 |
+
AUTH_TOKEN = os.environ.get("TOKEN") or True
|
37 |
+
# PoS
|
38 |
+
pos_tokenizer = AutoTokenizer.from_pretrained(
|
39 |
+
MODEL_PATHS["POS"], model_max_length=MODEL_MAX_LENGTH, use_auth_token=AUTH_TOKEN
|
40 |
+
)
|
41 |
+
pos_model = AutoModelForTokenClassification.from_pretrained(
|
42 |
+
MODEL_PATHS["POS"], use_auth_token=AUTH_TOKEN
|
43 |
+
)
|
44 |
+
|
45 |
+
# Lemmatization
|
46 |
+
lemmatizer_tokenizer = AutoTokenizer.from_pretrained(
|
47 |
+
MODEL_PATHS["LEMMATIZATION"],
|
48 |
+
model_max_length=MODEL_MAX_LENGTH,
|
49 |
+
use_auth_token=AUTH_TOKEN,
|
50 |
+
)
|
51 |
+
lemmatizer_model = AutoModelForSeq2SeqLM.from_pretrained(
|
52 |
+
MODEL_PATHS["LEMMATIZATION"], use_auth_token=AUTH_TOKEN
|
53 |
+
)
|
54 |
+
|
55 |
+
# Dependency Parsing
|
56 |
+
dependency_tokenizer = AutoTokenizer.from_pretrained(
|
57 |
+
MODEL_PATHS["DEPENDENCY"],
|
58 |
+
model_max_length=MODEL_MAX_LENGTH,
|
59 |
+
use_auth_token=AUTH_TOKEN,
|
60 |
+
)
|
61 |
+
arcs_model = DependencyRobertaForTokenClassification.from_pretrained(
|
62 |
+
MODEL_PATHS["DEPENDENCY"], use_auth_token=AUTH_TOKEN
|
63 |
+
)
|
64 |
+
labels_model = LabelRobertaForTokenClassification.from_pretrained(
|
65 |
+
MODEL_PATHS["LABELS"], use_auth_token=AUTH_TOKEN
|
66 |
+
)
|
67 |
+
|
68 |
+
data_collator = DataCollatorForTokenClassification(dependency_tokenizer)
|
69 |
+
|
70 |
+
|
71 |
+
def is_valid_selection(col_arcs, col_labels) -> bool:
|
72 |
+
if not col_arcs and col_labels:
|
73 |
+
return False
|
74 |
+
return True
|
75 |
+
|
76 |
+
|
77 |
+
def get_pos_predictions(inputs) -> torch.Tensor:
|
78 |
+
"""Get part of speech predictions."""
|
79 |
+
return pos_model(inputs["input_ids"]).logits.argmax(-1) # type: ignore
|
80 |
+
|
81 |
+
|
82 |
+
def execute_parse(
|
83 |
+
text_input: str,
|
84 |
+
col_pos: bool,
|
85 |
+
col_arcs: bool,
|
86 |
+
col_labels: bool,
|
87 |
+
col_lemmata: bool,
|
88 |
+
compact: bool,
|
89 |
+
bg: str,
|
90 |
+
text: str,
|
91 |
+
) -> Tuple[str, str]:
|
92 |
+
if is_valid_selection(col_arcs, col_labels):
|
93 |
+
return parse(
|
94 |
+
text_input, col_pos, col_arcs, col_labels, col_lemmata, compact, bg, text
|
95 |
+
)
|
96 |
+
return "Please check 'Dependency Arcs' before checking 'Dependency Labels'", ""
|
97 |
+
|
98 |
+
|
99 |
+
def lemmatize(tokens: List[str]) -> List[str]:
|
100 |
+
def construct_task(word_idx: int) -> str:
|
101 |
+
return f"lemmatize: {' '.join(tokens[:word_idx])} <extra_id_0> {tokens[word_idx]} <extra_id_1> {' '.join(list(tokens[word_idx]))} <extra_id_2> {' '.join(tokens[word_idx+1:])}"
|
102 |
+
|
103 |
+
predictions = [
|
104 |
+
lemmatizer_tokenizer.decode(
|
105 |
+
lemmatizer_model.generate(
|
106 |
+
lemmatizer_tokenizer(construct_task(word_idx), return_tensors="pt")[
|
107 |
+
"input_ids"
|
108 |
+
],
|
109 |
+
max_length=20,
|
110 |
+
num_beams=5,
|
111 |
+
num_return_sequences=1,
|
112 |
+
early_stopping=True,
|
113 |
+
)[0],
|
114 |
+
skip_special_tokens=True,
|
115 |
+
)
|
116 |
+
for word_idx in range(len(tokens))
|
117 |
+
]
|
118 |
+
|
119 |
+
return predictions
|
120 |
+
|
121 |
+
|
122 |
+
def add_lemma_visualization(soup, lemmata: List[str], col_arcs: bool) -> str:
|
123 |
+
for token, lemma in zip(soup.find_all(class_="displacy-token")[col_arcs:], lemmata):
|
124 |
+
pos_tag = token.find(class_="displacy-tag")
|
125 |
+
lemma_tag = soup.new_tag(
|
126 |
+
"tspan",
|
127 |
+
class_="displacy-lemma",
|
128 |
+
dy="2em",
|
129 |
+
fill="currentColor",
|
130 |
+
x=pos_tag.attrs["x"],
|
131 |
+
)
|
132 |
+
lemma_tag.string = lemma
|
133 |
+
pos_tag.insert_after(lemma_tag)
|
134 |
+
return str(soup)
|
135 |
+
|
136 |
+
|
137 |
+
def download_svg(svg):
|
138 |
+
encode = base64.b64encode(bytes(svg, "utf-8"))
|
139 |
+
img = "data:image/svg+xml;base64," + str(encode)[2:-1]
|
140 |
+
html = f'<a download="displacy.svg" href="{img}" style="{BUTTON_CSS}">Download as SVG</a>'
|
141 |
+
return html
|
142 |
+
|
143 |
+
|
144 |
+
def prepare_doc(
|
145 |
+
tokens: List[str], col_pos: bool, pos_outputs: torch.Tensor, inputs: BatchEncoding,
|
146 |
+
) -> Dict[str, List[Dict[str, str]]]:
|
147 |
+
doc: Dict[str, List[Dict[str, str]]] = {
|
148 |
+
"words": [], #[{"text": "ROOT", "tag": ""}],
|
149 |
+
"arcs": [],
|
150 |
+
}
|
151 |
+
word_ids = inputs.word_ids()
|
152 |
+
previous_word_idx = None
|
153 |
+
|
154 |
+
for idx, word_idx in enumerate(word_ids):
|
155 |
+
if word_idx != previous_word_idx and word_idx is not None:
|
156 |
+
tag_repr = (
|
157 |
+
pos_model.config.id2label[pos_outputs[0][idx].item()] if col_pos else ""
|
158 |
+
)
|
159 |
+
doc["words"].append({"text": tokens[word_idx], "tag": tag_repr})
|
160 |
+
previous_word_idx = word_idx
|
161 |
+
|
162 |
+
return doc
|
163 |
+
|
164 |
+
|
165 |
+
def parse(
|
166 |
+
text_input: str,
|
167 |
+
col_pos: bool,
|
168 |
+
col_arcs: bool,
|
169 |
+
col_labels: bool,
|
170 |
+
col_lemmata: bool,
|
171 |
+
compact: bool,
|
172 |
+
bg: str,
|
173 |
+
text: str,
|
174 |
+
) -> Tuple[str, str]:
|
175 |
+
tokens = preprocess_text(text_input)
|
176 |
+
inputs = pos_tokenizer(
|
177 |
+
tokens,
|
178 |
+
return_tensors="pt",
|
179 |
+
truncation=True,
|
180 |
+
padding=True,
|
181 |
+
is_split_into_words=True,
|
182 |
+
)
|
183 |
+
pos_outputs = get_pos_predictions(inputs)
|
184 |
+
|
185 |
+
doc = prepare_doc(tokens, col_pos, pos_outputs, inputs)
|
186 |
+
|
187 |
+
if col_arcs:
|
188 |
+
doc["words"].insert(0, {"text": "ROOT", "tag": ""})
|
189 |
+
doc["arcs"] = get_dependencies(
|
190 |
+
arcs_model,
|
191 |
+
labels_model,
|
192 |
+
dependency_tokenizer,
|
193 |
+
data_collator,
|
194 |
+
col_labels,
|
195 |
+
tokens,
|
196 |
+
)["arcs"]
|
197 |
+
|
198 |
+
options = {"compact": compact, "bg": bg, "color": text}
|
199 |
+
svg = displacy.render(doc, manual=True, style="dep", options=options)
|
200 |
+
|
201 |
+
if col_lemmata:
|
202 |
+
soup = BeautifulSoup(svg, "lxml-xml")
|
203 |
+
lemmata = lemmatize(tokens)
|
204 |
+
svg = add_lemma_visualization(soup, lemmata, col_arcs)
|
205 |
+
|
206 |
+
download_link = download_svg(svg)
|
207 |
+
|
208 |
+
return svg, download_link
|
209 |
+
|
210 |
+
|
211 |
+
def setup_parser_ui():
|
212 |
+
demo = gr.Blocks(css="scrollbar.css")
|
213 |
+
with demo:
|
214 |
+
with gr.Box():
|
215 |
+
with gr.Row():
|
216 |
+
with gr.Column():
|
217 |
+
gr.Markdown("# Athena's Lens")
|
218 |
+
gr.Markdown(
|
219 |
+
"### From Ἀlkaios to Ὠrigen: A Modern Lens on Timeless Texts"
|
220 |
+
)
|
221 |
+
with gr.Box():
|
222 |
+
with gr.Column():
|
223 |
+
gr.Markdown(" ## Enter some text")
|
224 |
+
with gr.Row():
|
225 |
+
with gr.Column(scale=0.5):
|
226 |
+
text_input = gr.Textbox(
|
227 |
+
value=DEFAULT_TEXT, interactive=True, label="Input Text"
|
228 |
+
)
|
229 |
+
with gr.Row():
|
230 |
+
with gr.Column(scale=0.25):
|
231 |
+
button = gr.Button("Update", variant="primary").style(
|
232 |
+
full_width=False
|
233 |
+
)
|
234 |
+
with gr.Box():
|
235 |
+
with gr.Column():
|
236 |
+
with gr.Row():
|
237 |
+
with gr.Column():
|
238 |
+
gr.Markdown("## Parser")
|
239 |
+
with gr.Row():
|
240 |
+
with gr.Column():
|
241 |
+
col_pos = gr.Checkbox(label="PoS Labels", value=True)
|
242 |
+
col_arcs = gr.Checkbox(label="Dependency Arcs", value=False)
|
243 |
+
col_labels = gr.Checkbox(label="Dependency Labels", value=False)
|
244 |
+
col_lemmata = gr.Checkbox(label="Lemmata", value=False)
|
245 |
+
compact = gr.Checkbox(label="Compact", value=False)
|
246 |
+
with gr.Column():
|
247 |
+
bg = gr.Textbox(label="Background Color", value=DEFAULT_COLOR)
|
248 |
+
with gr.Column():
|
249 |
+
text = gr.Textbox(label="Text Color", value="black")
|
250 |
+
with gr.Row():
|
251 |
+
dep_output = gr.HTML(
|
252 |
+
value=parse(
|
253 |
+
DEFAULT_TEXT,
|
254 |
+
True,
|
255 |
+
False,
|
256 |
+
False,
|
257 |
+
False,
|
258 |
+
False,
|
259 |
+
DEFAULT_COLOR,
|
260 |
+
"black",
|
261 |
+
)[0]
|
262 |
+
)
|
263 |
+
with gr.Row():
|
264 |
+
with gr.Column(scale=0.25):
|
265 |
+
dep_button = gr.Button(
|
266 |
+
"Update Parser", variant="primary"
|
267 |
+
).style(full_width=False)
|
268 |
+
with gr.Column():
|
269 |
+
dep_download_button = gr.HTML(
|
270 |
+
value=download_svg(dep_output.value)
|
271 |
+
)
|
272 |
+
|
273 |
+
with gr.Box():
|
274 |
+
with gr.Column():
|
275 |
+
with gr.Row():
|
276 |
+
with gr.Column():
|
277 |
+
gr.Markdown("## Contact")
|
278 |
+
gr.Markdown(
|
279 |
+
"If you have any questions, suggestions, comments, or problems, feel free to [reach out](mailto:[email protected])."
|
280 |
+
)
|
281 |
+
gr.Markdown("## Citation")
|
282 |
+
gr.Markdown(
|
283 |
+
"This space uses models from [this](https://aclanthology.org/2023.acl-long.846.pdf) paper."
|
284 |
+
)
|
285 |
+
gr.Markdown(
|
286 |
+
"""```bibtex
|
287 |
+
@incollection{riemenschneider-frank-2023-exploring,
|
288 |
+
title = "Exploring Large Language Models for Classical Philology",
|
289 |
+
author = "Riemenschneider, Frederick and Frank, Anette",
|
290 |
+
booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
|
291 |
+
month = jul,
|
292 |
+
year = "2023",
|
293 |
+
address = "Toronto, Canada",
|
294 |
+
publisher = "Association for Computational Linguistics",
|
295 |
+
url = "https://aclanthology.org/2023.acl-long.846",
|
296 |
+
doi = "10.18653/v1/2023.acl-long.846",
|
297 |
+
pages = "15181--15199",
|
298 |
+
}
|
299 |
+
```
|
300 |
+
"""
|
301 |
+
)
|
302 |
+
|
303 |
+
button.click(
|
304 |
+
execute_parse,
|
305 |
+
inputs=[
|
306 |
+
text_input,
|
307 |
+
col_pos,
|
308 |
+
col_arcs,
|
309 |
+
col_labels,
|
310 |
+
col_lemmata,
|
311 |
+
compact,
|
312 |
+
bg,
|
313 |
+
text,
|
314 |
+
],
|
315 |
+
outputs=[dep_output, dep_download_button],
|
316 |
+
)
|
317 |
+
|
318 |
+
dep_button.click(
|
319 |
+
execute_parse,
|
320 |
+
inputs=[
|
321 |
+
text_input,
|
322 |
+
col_pos,
|
323 |
+
col_arcs,
|
324 |
+
col_labels,
|
325 |
+
col_lemmata,
|
326 |
+
compact,
|
327 |
+
bg,
|
328 |
+
text,
|
329 |
+
],
|
330 |
+
outputs=[dep_output, dep_download_button],
|
331 |
+
)
|
332 |
+
|
333 |
+
demo.launch()
|
334 |
+
|
335 |
+
|
336 |
+
def main():
|
337 |
+
demo = setup_parser_ui()
|
338 |
+
demo.launch()
|
339 |
+
|
340 |
+
|
341 |
+
if __name__ == "__main__":
|
342 |
+
main()
|
models.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers import RobertaPreTrainedModel
|
4 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
5 |
+
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel
|
6 |
+
|
7 |
+
from utils import batched_index_select
|
8 |
+
|
9 |
+
|
10 |
+
class DependencyRobertaForTokenClassification(RobertaPreTrainedModel):
|
11 |
+
config_class = RobertaConfig # type: ignore
|
12 |
+
|
13 |
+
def __init__(self, config):
|
14 |
+
super().__init__(config)
|
15 |
+
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
16 |
+
self.u_a = nn.Linear(768, 768)
|
17 |
+
self.w_a = nn.Linear(768, 768)
|
18 |
+
self.v_a_inv = nn.Linear(768, 1, bias=False)
|
19 |
+
self.criterion = nn.NLLLoss()
|
20 |
+
self.init_weights()
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
input_ids=None,
|
25 |
+
attention_mask=None,
|
26 |
+
token_type_ids=None,
|
27 |
+
labels=None,
|
28 |
+
**kwargs,
|
29 |
+
):
|
30 |
+
loss = 0.0
|
31 |
+
output = self.roberta(
|
32 |
+
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
33 |
+
)[0]
|
34 |
+
batch_size, seq_len, _ = output.size()
|
35 |
+
|
36 |
+
parent_prob_table = []
|
37 |
+
for i in range(0, seq_len):
|
38 |
+
target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1)
|
39 |
+
mask = output.eq(target)[:, :, 0].unsqueeze(2)
|
40 |
+
p_head = self.attention(output, target, mask)
|
41 |
+
if labels is not None:
|
42 |
+
current_loss = self.criterion(p_head.squeeze(-1), labels[:, i])
|
43 |
+
if not torch.all(labels[:, i] == -100):
|
44 |
+
loss += current_loss
|
45 |
+
parent_prob_table.append(torch.exp(p_head))
|
46 |
+
|
47 |
+
parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2)
|
48 |
+
prob, topi = parent_prob_table.topk(k=1, dim=2)
|
49 |
+
preds = topi.squeeze(-1)
|
50 |
+
loss = loss / seq_len
|
51 |
+
output = TokenClassifierOutput(loss=loss, logits=preds)
|
52 |
+
|
53 |
+
if labels is not None:
|
54 |
+
return output, preds, parent_prob_table, labels
|
55 |
+
else:
|
56 |
+
return output, preds, parent_prob_table
|
57 |
+
|
58 |
+
def attention(self, source, target, mask=None):
|
59 |
+
function_g = self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
|
60 |
+
if mask is not None:
|
61 |
+
function_g.masked_fill_(mask, -1e4)
|
62 |
+
return nn.functional.log_softmax(function_g, dim=1)
|
63 |
+
|
64 |
+
|
65 |
+
class LabelRobertaForTokenClassification(RobertaPreTrainedModel):
|
66 |
+
config_class = RobertaConfig # type: ignore
|
67 |
+
|
68 |
+
def __init__(self, config):
|
69 |
+
super().__init__(config)
|
70 |
+
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
71 |
+
self.num_labels = 33
|
72 |
+
self.hidden = nn.Linear(768 * 2, 768)
|
73 |
+
self.relu = nn.ReLU()
|
74 |
+
self.out = nn.Linear(768, self.num_labels)
|
75 |
+
self.loss_fct = nn.CrossEntropyLoss()
|
76 |
+
|
77 |
+
def forward(
|
78 |
+
self,
|
79 |
+
input_ids=None,
|
80 |
+
attention_mask=None,
|
81 |
+
token_type_ids=None,
|
82 |
+
labels=None,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
loss = 0.0
|
86 |
+
output = self.roberta(
|
87 |
+
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
88 |
+
)[0]
|
89 |
+
batch_size, seq_len, _ = output.size()
|
90 |
+
logits = []
|
91 |
+
for i in range(seq_len):
|
92 |
+
current_token = output[:, i, :]
|
93 |
+
connected_with_index = kwargs["head_labels"][:, i]
|
94 |
+
connected_with_index[connected_with_index == -100] = 0
|
95 |
+
connected_with_embedding = batched_index_select(
|
96 |
+
output.clone(), 1, connected_with_index.clone()
|
97 |
+
)
|
98 |
+
combined_embeddings = torch.cat(
|
99 |
+
(current_token, connected_with_embedding.squeeze(1)), -1
|
100 |
+
)
|
101 |
+
pred = self.out(self.relu(self.hidden(combined_embeddings)))
|
102 |
+
pred = pred.view(-1, self.num_labels)
|
103 |
+
logits.append(pred)
|
104 |
+
if labels is not None:
|
105 |
+
current_loss = self.loss_fct(pred, labels[:, i].view(-1))
|
106 |
+
if not torch.all(labels[:, i] == -100):
|
107 |
+
loss += current_loss
|
108 |
+
|
109 |
+
loss = loss / seq_len
|
110 |
+
logits = torch.stack(logits, dim=1)
|
111 |
+
output = TokenClassifierOutput(loss=loss, logits=logits)
|
112 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas==1.4.2
|
2 |
+
gradio==3.3.1
|
3 |
+
beautifulsoup4
|
4 |
+
lxml
|
5 |
+
ufal.chu-liu-edmonds
|
6 |
+
spacy
|
7 |
+
transformers
|
8 |
+
torch
|
scrollbar.css
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.output-html {
|
2 |
+
overflow-x: auto;
|
3 |
+
}
|
4 |
+
|
5 |
+
.output-html::-webkit-scrollbar {
|
6 |
+
-webkit-appearance: none;
|
7 |
+
}
|
8 |
+
|
9 |
+
.output-html::-webkit-scrollbar:vertical {
|
10 |
+
width: 0px;
|
11 |
+
}
|
12 |
+
|
13 |
+
.output-html::-webkit-scrollbar:horizontal {
|
14 |
+
height: 11px;
|
15 |
+
}
|
16 |
+
|
17 |
+
.output-html::-webkit-scrollbar-thumb {
|
18 |
+
border-radius: 8px;
|
19 |
+
border: 2px solid white;
|
20 |
+
background-color: rgba(0, 0, 0, .5);
|
21 |
+
}
|
22 |
+
|
23 |
+
.output-html::-webkit-scrollbar-track {
|
24 |
+
background-color: #fff;
|
25 |
+
border-radius: 8px;
|
26 |
+
}
|
27 |
+
|
28 |
+
.spans {
|
29 |
+
min-height: 75px;
|
30 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List, Dict, Set
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from ufal.chu_liu_edmonds import chu_liu_edmonds
|
6 |
+
|
7 |
+
DEPENDENCY_RELATIONS = [
|
8 |
+
"acl",
|
9 |
+
"advcl",
|
10 |
+
"advmod",
|
11 |
+
"amod",
|
12 |
+
"appos",
|
13 |
+
"aux",
|
14 |
+
"case",
|
15 |
+
"cc",
|
16 |
+
"ccomp",
|
17 |
+
"conj",
|
18 |
+
"cop",
|
19 |
+
"csubj",
|
20 |
+
"det",
|
21 |
+
"iobj",
|
22 |
+
"mark",
|
23 |
+
"nmod",
|
24 |
+
"nsubj",
|
25 |
+
"nummod",
|
26 |
+
"obj",
|
27 |
+
"obl",
|
28 |
+
"parataxis",
|
29 |
+
"punct",
|
30 |
+
"root",
|
31 |
+
"vocative",
|
32 |
+
"xcomp",
|
33 |
+
]
|
34 |
+
INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
|
35 |
+
TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_text(text: str) -> List[str]:
|
39 |
+
text = text.strip()
|
40 |
+
text = re.sub("(?<! )(?=[.,!?()·;:])|(?<=[.,!?()·;:])(?! )", r" ", text)
|
41 |
+
return text.split()
|
42 |
+
|
43 |
+
|
44 |
+
def batched_index_select(
|
45 |
+
input: torch.Tensor, dim: int, index: torch.Tensor
|
46 |
+
) -> torch.Tensor:
|
47 |
+
views = [input.shape[0]] + [
|
48 |
+
1 if i != dim else -1 for i in range(1, len(input.shape))
|
49 |
+
]
|
50 |
+
expanse = list(input.shape)
|
51 |
+
expanse[0] = -1
|
52 |
+
expanse[dim] = -1
|
53 |
+
index = index.view(views).expand(expanse)
|
54 |
+
return torch.gather(input, dim, index)
|
55 |
+
|
56 |
+
|
57 |
+
def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]:
|
58 |
+
return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids]
|
59 |
+
|
60 |
+
|
61 |
+
def resolve(
|
62 |
+
edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor
|
63 |
+
) -> torch.Tensor:
|
64 |
+
multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0]
|
65 |
+
if len(multiple_roots) > 1:
|
66 |
+
main_root = max(multiple_roots, key=edmonds_head.count)
|
67 |
+
secondary_roots = set(multiple_roots) - {main_root}
|
68 |
+
for root in secondary_roots:
|
69 |
+
parent_probs_table[0][word_ids.index(root)][0] = 0
|
70 |
+
return parent_probs_table
|
71 |
+
|
72 |
+
|
73 |
+
def apply_chu_liu_edmonds(
|
74 |
+
parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int]
|
75 |
+
) -> List[int]:
|
76 |
+
parent_probs_table = (
|
77 |
+
parent_probs_table
|
78 |
+
if parent_probs_table.shape[1] == parent_probs_table.shape[2]
|
79 |
+
else parent_probs_table[:, :, 1:]
|
80 |
+
)
|
81 |
+
edmonds_heads, _ = chu_liu_edmonds(
|
82 |
+
parent_probs_table.squeeze(0).cpu().numpy().astype("double")
|
83 |
+
)
|
84 |
+
edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0)
|
85 |
+
edmonds_heads[edmonds_heads == -1] = 0
|
86 |
+
tokenized_input["head_labels"] = edmonds_heads
|
87 |
+
return get_relevant_tokens(edmonds_heads[0], start_ids)
|
88 |
+
|
89 |
+
|
90 |
+
def get_word_endings(tokenized_input):
|
91 |
+
word_ids = tokenized_input.word_ids(batch_index=0)
|
92 |
+
start_ids = set()
|
93 |
+
word_endings = {0: (1, 0)}
|
94 |
+
for word_id in word_ids:
|
95 |
+
if word_id is not None:
|
96 |
+
start, end = tokenized_input.word_to_tokens(
|
97 |
+
batch_or_word_index=0, word_index=word_id
|
98 |
+
)
|
99 |
+
start_ids.add(start)
|
100 |
+
word_endings[start] = (end, word_id + 1)
|
101 |
+
for a in range(start + 1, end + 1):
|
102 |
+
word_endings[a] = (end, word_id + 1)
|
103 |
+
return word_endings, start_ids, word_ids
|
104 |
+
|
105 |
+
|
106 |
+
def get_dependencies(
|
107 |
+
dependency_parser,
|
108 |
+
label_parser,
|
109 |
+
tokenizer,
|
110 |
+
collator,
|
111 |
+
labels: bool,
|
112 |
+
sentence: List[str],
|
113 |
+
) -> Dict:
|
114 |
+
tokenized_input = tokenizer(
|
115 |
+
sentence, truncation=True, is_split_into_words=True, add_special_tokens=True
|
116 |
+
)
|
117 |
+
dep_dict: Dict[str, List[Dict[str, str]]] = {
|
118 |
+
"words": [{"text": "ROOT", "tag": ""}],
|
119 |
+
"arcs": [],
|
120 |
+
}
|
121 |
+
|
122 |
+
word_endings, start_ids, word_ids = get_word_endings(tokenized_input)
|
123 |
+
tokenized_input = collator([tokenized_input])
|
124 |
+
_, _, parent_probs_table = dependency_parser(**tokenized_input)
|
125 |
+
|
126 |
+
irrelevant = torch.tensor(
|
127 |
+
[
|
128 |
+
idx.item()
|
129 |
+
for idx in torch.arange(parent_probs_table.size(1))
|
130 |
+
if idx.item() not in start_ids and idx.item() != 0
|
131 |
+
]
|
132 |
+
)
|
133 |
+
if irrelevant.nelement() > 0:
|
134 |
+
parent_probs_table.index_fill_(1, irrelevant, torch.nan)
|
135 |
+
parent_probs_table.index_fill_(2, irrelevant, torch.nan)
|
136 |
+
|
137 |
+
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
|
138 |
+
parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table)
|
139 |
+
edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
|
140 |
+
|
141 |
+
if labels:
|
142 |
+
predictions_labels = np.argmax(
|
143 |
+
label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1
|
144 |
+
)
|
145 |
+
predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids)
|
146 |
+
predicted_relations = [
|
147 |
+
INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence))
|
148 |
+
]
|
149 |
+
else:
|
150 |
+
predicted_relations = [""] * len(sentence)
|
151 |
+
|
152 |
+
for idx, head in enumerate(edmonds_head):
|
153 |
+
arc = {
|
154 |
+
"start": min(idx + 1, word_endings[head][1]),
|
155 |
+
"end": max(idx + 1, word_endings[head][1]),
|
156 |
+
"label": predicted_relations[idx],
|
157 |
+
"dir": "left" if idx + 1 < word_endings[head][1] else "right",
|
158 |
+
}
|
159 |
+
dep_dict["arcs"].append(arc)
|
160 |
+
|
161 |
+
return dep_dict
|