Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2023 The GlotLID Authors. | |
# Lint as: python3 | |
""" | |
GlotLID Space | |
""" | |
""" This space is built based on AMR-KELEG/ALDi space """ | |
import constants | |
import pandas as pd | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
from GlotScript import get_script_predictor | |
import matplotlib.pyplot as plt | |
import fasttext | |
import altair as alt | |
from altair import X, Y, Scale | |
import base64 | |
def load_sp(): | |
sp = get_script_predictor() | |
return sp | |
sp = load_sp() | |
def get_script(text): | |
"""Get the writing system of given text. | |
Args: | |
text: The text to be preprocessed. | |
Returns: | |
The writing system of text. | |
""" | |
return sp(text)[0] | |
def render_svg(svg): | |
"""Renders the given svg string.""" | |
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>' | |
c = st.container() | |
c.write(html, unsafe_allow_html=True) | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(index=None).encode("utf-8") | |
def load_model(model_name): | |
model_path = hf_hub_download(repo_id=model_name, filename="model.bin") | |
model = fasttext.load_model(model_path) | |
return model | |
model = load_model(constants.MODEL_NAME) | |
def compute(sentences): | |
"""Computes the language labels for the given sentences. | |
Args: | |
sentences: A list of sentences. | |
Returns: | |
A list of language probablities and labels for the given sentences. | |
""" | |
progress_text = "Computing Language..." | |
my_bar = st.progress(0, text=progress_text) | |
BATCH_SIZE = 1 | |
probs = [] | |
labels = [] | |
preprocessed_sentences = sentences | |
for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE): | |
outputs = model.predict(preprocessed_sentences[first_index : first_index + BATCH_SIZE]) | |
# BATCH_SIZE = 1 | |
outputs_labels = outputs[0][0] | |
outputs_probs = outputs[1][0] | |
probs = probs + [max(min(o, 1), 0) for o in outputs_probs] | |
labels = labels + outputs_labels | |
my_bar.progress( | |
min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1), | |
text=progress_text, | |
) | |
my_bar.empty() | |
return probs, labels | |
render_svg(open("assets/GlotLID_logo.svg").read()) | |
tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) | |
with tab1: | |
sent = st.text_input( | |
"Sentence:", placeholder="Enter a sentence.", on_change=None | |
) | |
# TODO: Check if this is needed! | |
clicked = st.button("Submit") | |
if sent: | |
probs, labels = compute([sent]) | |
prob = probs[0] | |
label = labels[0] | |
ORANGE_COLOR = "#FF8000" | |
fig, ax = plt.subplots(figsize=(8, 1)) | |
fig.patch.set_facecolor("none") | |
ax.set_facecolor("none") | |
ax.spines["left"].set_color(ORANGE_COLOR) | |
ax.spines["bottom"].set_color(ORANGE_COLOR) | |
ax.tick_params(axis="x", colors=ORANGE_COLOR) | |
ax.spines[["right", "top"]].set_visible(False) | |
ax.barh(y=[0], width=[prob], color=ORANGE_COLOR) | |
ax.set_xlim(0, 1) | |
ax.set_ylim(-1, 1) | |
ax.set_title(f"Langauge is: {label}", color=ORANGE_COLOR) | |
ax.get_yaxis().set_visible(False) | |
ax.set_xlabel("Confidence", color=ORANGE_COLOR) | |
st.pyplot(fig) | |
print(sent) | |
with open("logs.txt", "a") as f: | |
f.write(sent + "\n") | |
with tab2: | |
file = st.file_uploader("Upload a file", type=["txt"]) | |
if file is not None: | |
df = pd.read_csv(file, sep="\t", header=None) | |
df.columns = ["Sentence"] | |
df.reset_index(drop=True, inplace=True) | |
# TODO: Run the model | |
df['Probs'], df["Language"] = compute(df["Sentence"].tolist()) | |
# A horizontal rule | |
st.markdown("""---""") | |
chart = ( | |
alt.Chart(df.reset_index()) | |
.mark_area(color="darkorange", opacity=0.5) | |
.encode( | |
x=X(field="index", title="Sentence Index"), | |
y=Y("Probs", scale=Scale(domain=[0, 1])), | |
) | |
) | |
st.altair_chart(chart.interactive(), use_container_width=True) | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
# Display the output | |
st.table( | |
df, | |
) | |
with col2: | |
# Add a download button | |
csv = convert_df(df) | |
st.download_button( | |
label=":file_folder: Download predictions as CSV", | |
data=csv, | |
file_name="GlotLID.csv", | |
mime="text/csv", | |
) | |