kapllan's picture
Update app.py
1a6e249 verified
import argparse
import json as js
import os
import re
from pathlib import Path
from typing import List, Tuple
import fasttext
import gradio as gr
import joblib
import omikuji
from huggingface_hub import snapshot_download
from prepare_everything import download_model
download_model(
"https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin",
Path("lid.176.bin"))
# Download the model files from Hugging Face
model_names = [
"omikuji-bonsai-parliament-spacy-de-all_topics-input_long",
"omikuji-bonsai-parliament-spacy-fr-all_topics-input_long",
"omikuji-bonsai-parliament-spacy-it-all_topics-input_long",
]
for repo_id in model_names:
if not os.path.exists(repo_id):
os.makedirs(repo_id)
model_dir = snapshot_download(repo_id=f"kapllan/{repo_id}", local_dir=f"kapllan/{repo_id}")
lang_model = fasttext.load_model("lid.176.bin")
with open(Path("label2id.json"), "r") as f:
label2id = js.load(f)
id2label = {}
for key, value in label2id.items():
id2label[str(value)] = key
with open(Path("topics_hierarchy.json"), "r") as f:
topics_hierarchy = js.load(f)
def map_language(language: str) -> str:
language_mapping = {"de": "German", "it": "Italian", "fr": "French"}
if language in language_mapping.keys():
return language_mapping[language]
else:
return language
def find_model(language: str):
vectorizer, model = None, None
if language in ["de", "fr", "it"]:
path_to_vectorizer = (
f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/vectorizer"
)
path_to_model = (
f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/omikuji-model"
)
vectorizer = joblib.load(path_to_vectorizer)
model = omikuji.Model.load(path_to_model)
return vectorizer, model
def predict_lang(text: str) -> str:
text = re.sub(
r"\n", "", text
) # Remove linebreaks because fasttext cannot process that otherwise
predictions = lang_model.predict(text, k=1) # returns top 2 matching languages
language = predictions[0][0] # returns top 2 matching languages
language = re.sub(r"__label__", "", language) # returns top 2 matching languages
return language
def predict_topic(text: str) -> [List[str], str]:
results = []
language = predict_lang(text)
vectorizer, model = find_model(language)
language = map_language(language)
if vectorizer is not None:
texts = [text]
vector = vectorizer.transform(texts)
for row in vector:
if row.nnz == 0: # All zero vector, empty result
continue
feature_values = [(col, row[0, col]) for col in row.nonzero()[1]]
for subj_id, score in model.predict(feature_values, top_k=1000):
score = round(score, 2)
results.append((id2label[str(subj_id)], score))
return results, language
def get_row_color(type: str):
if "main" in type.lower():
return "background-color: darkgrey;"
if "sub" in type.lower():
return "background-color: lightgrey;"
def generate_html_table(topics: List[Tuple[str, str, float]]):
html = '<table style="width:100%; border: 1px solid black; border-collapse: collapse;">'
html += "<tr><th>Type</th><th>Topic</th><th>Score</th></tr>"
for type, topic, score in topics:
color = get_row_color(type)
topic = f"<strong>{topic}</strong>" if "main" in type.lower() else topic
type = f"<strong>{type}</strong>" if "main" in type.lower() else type
score = f"<strong>{score}</strong>" if "main" in type.lower() else score
html += (
f'<tr style="{color}"><td>{type}</td><td>{topic}</td><td>{score}</td></tr>'
)
html += "</table>"
return html
def restructure_topics(topics: List[Tuple[str, float]]) -> List[Tuple[str, str, float]]:
topics = [(str(x[0]).lower(), x[1]) for x in topics]
topics_as_dict = {}
for predicted_topic, score in topics:
if str(predicted_topic).lower() in topics_hierarchy.keys():
topics_as_dict[str(predicted_topic).lower()] = []
for predicted_topic, score in topics:
for main_topic, sub_topics in topics_hierarchy.items():
if (
main_topic in topics_as_dict.keys()
and predicted_topic != main_topic
and predicted_topic in sub_topics
):
topics_as_dict[main_topic].append(predicted_topic)
topics_restructured = []
for predicted_main_topic, predicted_sub_topics in topics_as_dict.items():
if len(predicted_sub_topics) > 0:
score = [t for t in topics if t[0] == predicted_main_topic][0][1]
predicted_main_topic = predicted_main_topic.replace("hauptthema: ", "")
topics_restructured.append(("Main Topic", predicted_main_topic, score))
predicted_sub_topics_with_scores = []
for pst in predicted_sub_topics:
score = [t for t in topics if t[0] == pst][0][1]
pst = pst.replace("unterthema: ", "")
entry = ("Sub Topic", pst, score)
if entry not in predicted_sub_topics_with_scores:
predicted_sub_topics_with_scores.append(entry)
for x in predicted_sub_topics_with_scores:
topics_restructured.append(x)
return topics_restructured
def topic_modeling(text: str, threshold: float) -> [List[str], str]:
# Prepare labels and scores for the plot
sorted_topics, language = predict_topic(text)
if len(sorted_topics) > 0 and language in ["German", "French", "Italian"]:
sorted_topics = [t for t in sorted_topics if t[1] >= threshold]
else:
sorted_topics = []
sorted_topics = restructure_topics(sorted_topics)
sorted_topics = generate_html_table(sorted_topics)
return sorted_topics, language
with gr.Blocks() as iface:
gr.Markdown("# Topic Modeling")
gr.Markdown("Enter a document and get each topic along with its score.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=10, placeholder="Enter a document")
submit_button = gr.Button("Submit")
threshold_slider = gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, label="Score Threshold", value=0.0
)
language_text = gr.Textbox(
lines=1,
placeholder="Detected language will be shown here...",
interactive=False,
label="Detected Language",
)
with gr.Column():
output_data = gr.HTML()
submit_button.click(
topic_modeling,
inputs=[input_text, threshold_slider],
outputs=[output_data, language_text],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-ipa",
"--ip_address",
default=None,
type=str,
help="Specify the IP address of your computer.",
)
args = parser.parse_args()
# Launch the app
if args.ip_address is None:
_, public_url = iface.launch(share=True)
print(f"The app runs here: {public_url}")
else:
iface.launch(server_name=args.ip_address, server_port=8080, show_error=True)