from typing import Tuple, List, Union, Dict, Mapping import base64 import os from bs4 import BeautifulSoup import gradio as gr from spacy import displacy from transformers import ( AutoTokenizer, AutoModelForTokenClassification, BatchEncoding, AutoModelForSeq2SeqLM, DataCollatorForTokenClassification, ) import torch from utils import get_dependencies, preprocess_text from models import ( DependencyRobertaForTokenClassification, LabelRobertaForTokenClassification, ) DEFAULT_TEXT = "τίω δέ μιν ἐν καρὸς αἴσῃ." BUTTON_CSS = "float: right; --tw-border-opacity: 1; border-color: rgb(229 231 235 / var(--tw-border-opacity)); --tw-gradient-from: rgb(243 244 246 / 0.7); --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to, rgb(243 244 246 / 0)); --tw-gradient-to: rgb(229 231 235 / 0.8); --tw-text-opacity: 1; color: rgb(55 65 81 / var(--tw-text-opacity)); border-width: 1px; --tw-bg-opacity: 1; background-color: rgb(255 255 255 / var(--tw-bg-opacity)); background-image: linear-gradient(to bottom right, var(--tw-gradient-stops)); display: inline-flex; flex: 1 1 0%; align-items: center; justify-content: center; --tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); --tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color); box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -webkit-appearance: button; border-radius: 0.5rem; padding-top: 0.5rem; padding-bottom: 0.5rem; padding-left: 1rem; padding-right: 1rem; font-size: 1rem; line-height: 1.5rem; font-weight: 600;" DEFAULT_COLOR = "white" MODEL_PATHS = { "POS": "bowphs/testid", "LEMMATIZATION": "bowphs/lemmatization-demo", "DEPENDENCY": "bowphs/depenBERTa_perseus", "LABELS": "bowphs/depenBERTa_labler_perseus", } MODEL_MAX_LENGTH = 512 AUTH_TOKEN = os.environ.get("TOKEN") or True # PoS pos_tokenizer = AutoTokenizer.from_pretrained( MODEL_PATHS["POS"], model_max_length=MODEL_MAX_LENGTH, use_auth_token=AUTH_TOKEN, revision="8bd84df2bcaee089307fd604c80139a34ac71f12", ) pos_model = AutoModelForTokenClassification.from_pretrained( MODEL_PATHS["POS"], use_auth_token=AUTH_TOKEN, revision="8bd84df2bcaee089307fd604c80139a34ac71f12", ) # Lemmatization lemmatizer_tokenizer = AutoTokenizer.from_pretrained( MODEL_PATHS["LEMMATIZATION"], model_max_length=MODEL_MAX_LENGTH, use_auth_token=AUTH_TOKEN, ) lemmatizer_model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_PATHS["LEMMATIZATION"], use_auth_token=AUTH_TOKEN ) # Dependency Parsing dependency_tokenizer = AutoTokenizer.from_pretrained( MODEL_PATHS["DEPENDENCY"], model_max_length=MODEL_MAX_LENGTH, use_auth_token=AUTH_TOKEN, ) arcs_model = DependencyRobertaForTokenClassification.from_pretrained( MODEL_PATHS["DEPENDENCY"], use_auth_token=AUTH_TOKEN ) labels_model = LabelRobertaForTokenClassification.from_pretrained( MODEL_PATHS["LABELS"], use_auth_token=AUTH_TOKEN ) data_collator = DataCollatorForTokenClassification(dependency_tokenizer) def is_valid_selection(col_arcs, col_labels) -> bool: if not col_arcs and col_labels: return False return True def get_pos_predictions(inputs) -> torch.Tensor: """Get part of speech predictions.""" return pos_model(inputs["input_ids"]).logits.argmax(-1) # type: ignore def execute_parse( text_input: str, col_pos: bool, col_arcs: bool, col_labels: bool, col_lemmata: bool, compact: bool, bg: str, text: str, ) -> Tuple[str, str]: if is_valid_selection(col_arcs, col_labels): return parse( text_input, col_pos, col_arcs, col_labels, col_lemmata, compact, bg, text ) return "Please check 'Dependency Arcs' before checking 'Dependency Labels'", "" def lemmatize(tokens: List[str]) -> List[str]: def construct_task(word_idx: int) -> str: return f"lemmatize: {' '.join(tokens[:word_idx])} {tokens[word_idx]} {' '.join(list(tokens[word_idx]))} {' '.join(tokens[word_idx+1:])}" predictions = [ lemmatizer_tokenizer.decode( lemmatizer_model.generate( lemmatizer_tokenizer(construct_task(word_idx), return_tensors="pt")[ "input_ids" ], max_length=20, num_beams=5, num_return_sequences=1, early_stopping=True, )[0], skip_special_tokens=True, ) for word_idx in range(len(tokens)) ] return predictions def add_lemma_visualization(soup, lemmata: List[str], col_arcs: bool) -> str: for token, lemma in zip(soup.find_all(class_="displacy-token")[col_arcs:], lemmata): pos_tag = token.find(class_="displacy-tag") lemma_tag = soup.new_tag( "tspan", class_="displacy-lemma", dy="2em", fill="currentColor", x=pos_tag.attrs["x"], ) lemma_tag.string = lemma pos_tag.insert_after(lemma_tag) return str(soup) def download_svg(svg): encode = base64.b64encode(bytes(svg, "utf-8")) img = "data:image/svg+xml;base64," + str(encode)[2:-1] html = f'Download as SVG' return html def prepare_doc( tokens: List[str], col_pos: bool, pos_outputs: torch.Tensor, inputs: BatchEncoding, ) -> Dict[str, List[Dict[str, str]]]: doc: Dict[str, List[Dict[str, str]]] = { "words": [], #[{"text": "ROOT", "tag": ""}], "arcs": [], } word_ids = inputs.word_ids() previous_word_idx = None for idx, word_idx in enumerate(word_ids): if word_idx != previous_word_idx and word_idx is not None: tag_repr = ( pos_model.config.id2label[pos_outputs[0][idx].item()] if col_pos else "" ) doc["words"].append({"text": tokens[word_idx], "tag": tag_repr}) previous_word_idx = word_idx return doc def parse( text_input: str, col_pos: bool, col_arcs: bool, col_labels: bool, col_lemmata: bool, compact: bool, bg: str, text: str, ) -> Tuple[str, str]: tokens = preprocess_text(text_input) inputs = pos_tokenizer( tokens, return_tensors="pt", truncation=True, padding=True, is_split_into_words=True, ) pos_outputs = get_pos_predictions(inputs) doc = prepare_doc(tokens, col_pos, pos_outputs, inputs) if col_arcs: doc["words"].insert(0, {"text": "ROOT", "tag": ""}) doc["arcs"] = get_dependencies( arcs_model, labels_model, dependency_tokenizer, data_collator, col_labels, tokens, )["arcs"] options = {"compact": compact, "bg": bg, "color": text} svg = displacy.render(doc, manual=True, style="dep", options=options) if col_lemmata: soup = BeautifulSoup(svg, "lxml-xml") lemmata = lemmatize(tokens) svg = add_lemma_visualization(soup, lemmata, col_arcs) download_link = download_svg(svg) return svg, download_link def setup_parser_ui(): theme = gr.themes.Monochrome() with gr.Blocks(theme=theme) as demo: with gr.Group(): gr.Markdown("# Athena's Lens") gr.Markdown( "### From Ἀlkaios to Ὠrigen: A Modern Lens on Timeless Texts" ) with gr.Group(): gr.Markdown("## Enter some text") with gr.Row(): with gr.Column(scale=0.5): text_input = gr.Textbox( value=DEFAULT_TEXT, interactive=True, label="Input Text" ) with gr.Row(): with gr.Column(scale=0.25): button = gr.Button("Update", variant="primary") with gr.Group(): with gr.Column(): with gr.Row(): with gr.Column(): gr.Markdown("## Parser") with gr.Row(): with gr.Column(): col_pos = gr.Checkbox(label="PoS Labels", value=True) col_arcs = gr.Checkbox(label="Dependency Arcs", value=False) col_labels = gr.Checkbox(label="Dependency Labels", value=False) col_lemmata = gr.Checkbox(label="Lemmata", value=False) compact = gr.Checkbox(label="Compact", value=False) with gr.Column(): bg = gr.Textbox(label="Background Color", value=DEFAULT_COLOR) with gr.Column(): text = gr.Textbox(label="Text Color", value="black") with gr.Group(): with gr.Row(): dep_output = gr.HTML( value=parse( DEFAULT_TEXT, True, False, False, False, False, DEFAULT_COLOR, "black", )[0] ) with gr.Row(): with gr.Column(): dep_button = gr.Button("Update Parser", variant="primary") with gr.Column(): dep_download_button = gr.HTML( value=download_svg(dep_output.value) ) with gr.Group(): with gr.Column(): with gr.Row(): with gr.Column(): gr.Markdown("## Contact") gr.Markdown( "If you have any questions, suggestions, comments, or problems, feel free to [reach out](mailto:riemenschneider@cl.uni-heidelberg.de)." ) gr.Markdown("## Citation") gr.Markdown( "This space uses models from [this](https://aclanthology.org/2023.acl-long.846.pdf) paper." ) gr.Markdown( """```bibtex @incollection{riemenschneider-frank-2023-exploring, title = "Exploring Large Language Models for Classical Philology", author = "Riemenschneider, Frederick and Frank, Anette", booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", month = jul, year = "2023", address = "Toronto, Canada", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2023.acl-long.846", doi = "10.18653/v1/2023.acl-long.846", pages = "15181--15199", } ``` """ ) button.click( execute_parse, inputs=[ text_input, col_pos, col_arcs, col_labels, col_lemmata, compact, bg, text, ], outputs=[dep_output, dep_download_button], ) dep_button.click( execute_parse, inputs=[ text_input, col_pos, col_arcs, col_labels, col_lemmata, compact, bg, text, ], outputs=[dep_output, dep_download_button], ) return demo def main(): demo = setup_parser_ui() demo.launch() if __name__ == "__main__": main()