Spaces:
Running
Running
import argparse | |
import json as js | |
import os | |
import re | |
from pathlib import Path | |
from typing import List, Tuple | |
import fasttext | |
import gradio as gr | |
import joblib | |
import omikuji | |
from huggingface_hub import snapshot_download | |
from prepare_everything import download_model | |
download_model( | |
"https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin", | |
Path("lid.176.bin")) | |
# Download the model files from Hugging Face | |
model_names = [ | |
"omikuji-bonsai-parliament-spacy-de-all_topics-input_long", | |
"omikuji-bonsai-parliament-spacy-fr-all_topics-input_long", | |
"omikuji-bonsai-parliament-spacy-it-all_topics-input_long", | |
] | |
for repo_id in model_names: | |
if not os.path.exists(repo_id): | |
os.makedirs(repo_id) | |
model_dir = snapshot_download(repo_id=f"kapllan/{repo_id}", local_dir=f"kapllan/{repo_id}") | |
lang_model = fasttext.load_model("lid.176.bin") | |
with open(Path("label2id.json"), "r") as f: | |
label2id = js.load(f) | |
id2label = {} | |
for key, value in label2id.items(): | |
id2label[str(value)] = key | |
with open(Path("topics_hierarchy.json"), "r") as f: | |
topics_hierarchy = js.load(f) | |
def map_language(language: str) -> str: | |
language_mapping = {"de": "German", "it": "Italian", "fr": "French"} | |
if language in language_mapping.keys(): | |
return language_mapping[language] | |
else: | |
return language | |
def find_model(language: str): | |
vectorizer, model = None, None | |
if language in ["de", "fr", "it"]: | |
path_to_vectorizer = ( | |
f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/vectorizer" | |
) | |
path_to_model = ( | |
f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/omikuji-model" | |
) | |
vectorizer = joblib.load(path_to_vectorizer) | |
model = omikuji.Model.load(path_to_model) | |
return vectorizer, model | |
def predict_lang(text: str) -> str: | |
text = re.sub( | |
r"\n", "", text | |
) # Remove linebreaks because fasttext cannot process that otherwise | |
predictions = lang_model.predict(text, k=1) # returns top 2 matching languages | |
language = predictions[0][0] # returns top 2 matching languages | |
language = re.sub(r"__label__", "", language) # returns top 2 matching languages | |
return language | |
def predict_topic(text: str) -> [List[str], str]: | |
results = [] | |
language = predict_lang(text) | |
vectorizer, model = find_model(language) | |
language = map_language(language) | |
if vectorizer is not None: | |
texts = [text] | |
vector = vectorizer.transform(texts) | |
for row in vector: | |
if row.nnz == 0: # All zero vector, empty result | |
continue | |
feature_values = [(col, row[0, col]) for col in row.nonzero()[1]] | |
for subj_id, score in model.predict(feature_values, top_k=1000): | |
score = round(score, 2) | |
results.append((id2label[str(subj_id)], score)) | |
return results, language | |
def get_row_color(type: str): | |
if "main" in type.lower(): | |
return "background-color: darkgrey;" | |
if "sub" in type.lower(): | |
return "background-color: lightgrey;" | |
def generate_html_table(topics: List[Tuple[str, str, float]]): | |
html = '<table style="width:100%; border: 1px solid black; border-collapse: collapse;">' | |
html += "<tr><th>Type</th><th>Topic</th><th>Score</th></tr>" | |
for type, topic, score in topics: | |
color = get_row_color(type) | |
topic = f"<strong>{topic}</strong>" if "main" in type.lower() else topic | |
type = f"<strong>{type}</strong>" if "main" in type.lower() else type | |
score = f"<strong>{score}</strong>" if "main" in type.lower() else score | |
html += ( | |
f'<tr style="{color}"><td>{type}</td><td>{topic}</td><td>{score}</td></tr>' | |
) | |
html += "</table>" | |
return html | |
def restructure_topics(topics: List[Tuple[str, float]]) -> List[Tuple[str, str, float]]: | |
topics = [(str(x[0]).lower(), x[1]) for x in topics] | |
topics_as_dict = {} | |
for predicted_topic, score in topics: | |
if str(predicted_topic).lower() in topics_hierarchy.keys(): | |
topics_as_dict[str(predicted_topic).lower()] = [] | |
for predicted_topic, score in topics: | |
for main_topic, sub_topics in topics_hierarchy.items(): | |
if ( | |
main_topic in topics_as_dict.keys() | |
and predicted_topic != main_topic | |
and predicted_topic in sub_topics | |
): | |
topics_as_dict[main_topic].append(predicted_topic) | |
topics_restructured = [] | |
for predicted_main_topic, predicted_sub_topics in topics_as_dict.items(): | |
if len(predicted_sub_topics) > 0: | |
score = [t for t in topics if t[0] == predicted_main_topic][0][1] | |
predicted_main_topic = predicted_main_topic.replace("hauptthema: ", "") | |
topics_restructured.append(("Main Topic", predicted_main_topic, score)) | |
predicted_sub_topics_with_scores = [] | |
for pst in predicted_sub_topics: | |
score = [t for t in topics if t[0] == pst][0][1] | |
pst = pst.replace("unterthema: ", "") | |
entry = ("Sub Topic", pst, score) | |
if entry not in predicted_sub_topics_with_scores: | |
predicted_sub_topics_with_scores.append(entry) | |
for x in predicted_sub_topics_with_scores: | |
topics_restructured.append(x) | |
return topics_restructured | |
def topic_modeling(text: str, threshold: float) -> [List[str], str]: | |
# Prepare labels and scores for the plot | |
sorted_topics, language = predict_topic(text) | |
if len(sorted_topics) > 0 and language in ["German", "French", "Italian"]: | |
sorted_topics = [t for t in sorted_topics if t[1] >= threshold] | |
else: | |
sorted_topics = [] | |
sorted_topics = restructure_topics(sorted_topics) | |
sorted_topics = generate_html_table(sorted_topics) | |
return sorted_topics, language | |
with gr.Blocks() as iface: | |
gr.Markdown("# Topic Modeling") | |
gr.Markdown("Enter a document and get each topic along with its score.") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(lines=10, placeholder="Enter a document") | |
submit_button = gr.Button("Submit") | |
threshold_slider = gr.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, label="Score Threshold", value=0.0 | |
) | |
language_text = gr.Textbox( | |
lines=1, | |
placeholder="Detected language will be shown here...", | |
interactive=False, | |
label="Detected Language", | |
) | |
with gr.Column(): | |
output_data = gr.HTML() | |
submit_button.click( | |
topic_modeling, | |
inputs=[input_text, threshold_slider], | |
outputs=[output_data, language_text], | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-ipa", | |
"--ip_address", | |
default=None, | |
type=str, | |
help="Specify the IP address of your computer.", | |
) | |
args = parser.parse_args() | |
# Launch the app | |
if args.ip_address is None: | |
_, public_url = iface.launch(share=True) | |
print(f"The app runs here: {public_url}") | |
else: | |
iface.launch(server_name=args.ip_address, server_port=8080, show_error=True) | |