Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from pipelines.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline | |
from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline | |
import orjson | |
from annotated_text.util import get_annotated_html | |
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode | |
import re | |
import string | |
import numpy as np | |
def load_pipeline(chosen_model): | |
if "keyphrase-extraction" in chosen_model: | |
return KeyphraseExtractionPipeline(chosen_model) | |
elif "keyphrase-generation" in chosen_model: | |
return KeyphraseGenerationPipeline(chosen_model) | |
def extract_keyphrases(): | |
st.session_state.keyphrases = pipe(st.session_state.input_text) | |
st.session_state.data_frame = pd.concat( | |
[ | |
st.session_state.data_frame, | |
pd.DataFrame( | |
data=[ | |
np.concatenate( | |
( | |
[ | |
st.session_state.chosen_model, | |
st.session_state.input_text, | |
], | |
st.session_state.keyphrases, | |
) | |
) | |
], | |
columns=["model", "text"] | |
+ [str(i) for i in range(len(st.session_state.keyphrases))], | |
), | |
], | |
ignore_index=True, | |
axis=0, | |
).fillna("") | |
def get_annotated_text(text, keyphrases): | |
for keyphrase in keyphrases: | |
text = re.sub( | |
rf"({keyphrase})([^A-Za-z])", | |
rf"$K:{keyphrases.index(keyphrase)}\2", | |
text, | |
flags=re.I, | |
count=1, | |
) | |
result = [] | |
for i, word in enumerate(text.split(" ")): | |
if "$K" in word and re.search( | |
"(\d+)$", word.translate(str.maketrans("", "", string.punctuation)) | |
): | |
result.append( | |
( | |
re.sub( | |
r"\$K:\d+", | |
keyphrases[ | |
int( | |
re.search( | |
"(\d+)$", | |
word.translate( | |
str.maketrans("", "", string.punctuation) | |
), | |
).group(1) | |
) | |
], | |
word, | |
), | |
"KEY", | |
"#21c354", | |
) | |
) | |
else: | |
if i == len(st.session_state.input_text.split(" ")) - 1: | |
result.append(f" {word}") | |
elif i == 0: | |
result.append(f"{word} ") | |
else: | |
result.append(f" {word} ") | |
return result | |
def rerender_output(layout): | |
layout.write("βοΈ Output") | |
if ( | |
len(st.session_state.keyphrases) > 0 | |
and len(st.session_state.selected_rows) == 0 | |
): | |
text, keyphrases = st.session_state.input_text, st.session_state.keyphrases | |
else: | |
text, keyphrases = ( | |
st.session_state.selected_rows["text"].values[0], | |
[ | |
keyphrase | |
for keyphrase in st.session_state.selected_rows.loc[ | |
:, | |
st.session_state.selected_rows.columns.difference( | |
["model", "text"] | |
), | |
] | |
.astype(str) | |
.values.tolist()[0] | |
if keyphrase != "" | |
], | |
) | |
result = get_annotated_text(text, list(keyphrases)) | |
layout.markdown( | |
get_annotated_html(*result), | |
unsafe_allow_html=True, | |
) | |
if "generation" in st.session_state.chosen_model: | |
abstractive_keyphrases = [ | |
keyphrase | |
for keyphrase in keyphrases | |
if keyphrase.lower() not in text.lower() | |
] | |
layout.write(", ".join(abstractive_keyphrases)) | |
if "config" not in st.session_state: | |
with open("config.json", "r") as f: | |
content = f.read() | |
st.session_state.config = orjson.loads(content) | |
st.session_state.data_frame = pd.DataFrame(columns=["model"]) | |
st.session_state.keyphrases = [] | |
if "select_rows" not in st.session_state: | |
st.session_state.selected_rows = [] | |
st.set_page_config( | |
page_icon="π", | |
page_title="Keyphrase extraction/generation with Transformers", | |
layout="centered", | |
) | |
with open("css/style.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
st.header("π Keyphrase extraction/generation with Transformers") | |
description = """ | |
Keyphrase extraction is a technique in text analysis where you extract the important keyphrases | |
from a text. Since this is a time-consuming process, Artificial Intelligence is used to automate it. | |
Currently, classical machine learning methods, that use statistics and linguistics, are widely used | |
for the extraction process. The fact that these methods have been widely used in the community has | |
the advantage that there are many easy-to-use libraries. Now with the recent innovations in | |
deep learning methods (such as recurrent neural networks and transformers, GANS, β¦), | |
keyphrase extraction can be improved. These new methods also focus on the semantics and | |
context of a document, which is quite an improvement. | |
This space gives you the ability to test around with some keyphrase extraction and generation models. | |
Keyphrase extraction models are transformers models fine-tuned as a token classification problem where | |
the tokens in a text are annotated as: | |
* B: Beginning of a keyphrase | |
* I: Inside a keyphrases | |
* O: Outside a keyhprase. | |
While keyphrase extraction can only extract keyphrases from a given text. Keyphrase generation models | |
work a bit differently. Here you use an encoder-decoder model like BART to generate keyphrases from a given text. | |
These models also have the ability to generate keyphrases, which are not present in the text π€―. | |
Do you want to see some magic π§ββοΈ? Try it out yourself! π | |
""" | |
st.write(description) | |
with st.form("test"): | |
chosen_model = st.selectbox( | |
"Choose your model:", | |
st.session_state.config.get("models"), | |
) | |
st.session_state.chosen_model = chosen_model | |
st.markdown( | |
f"For more information about the chosen model, please be sure to check it out the [π€ Model Card](https://huggingface.co/DeDeckerThomas/{chosen_model})." | |
) | |
with st.spinner("Loading pipeline..."): | |
pipe = load_pipeline( | |
f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}" | |
) | |
st.session_state.input_text = st.text_area( | |
"β Input", st.session_state.config.get("example_text"), height=300 | |
).replace("\n", " ") | |
with st.spinner("Extracting keyphrases..."): | |
pressed = st.form_submit_button("Extract", on_click=extract_keyphrases) | |
if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0: | |
rerender_output(st) | |
if len(st.session_state.data_frame.columns) > 0: | |
st.subheader("π History") | |
builder = GridOptionsBuilder.from_dataframe( | |
st.session_state.data_frame, sortable=False | |
) | |
builder.configure_selection(selection_mode="single", use_checkbox=True) | |
builder.configure_column("text", hide=True) | |
go = builder.build() | |
data = AgGrid( | |
st.session_state.data_frame, | |
gridOptions=go, | |
update_mode=GridUpdateMode.SELECTION_CHANGED, | |
) | |
st.session_state.selected_rows = pd.DataFrame(data["selected_rows"]) | |