|
import json |
|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
|
|
from category_classification.models import models as class_models |
|
from languages import * |
|
from results import process_results |
|
|
|
page_title = {en: "Papers classification", ru: "Классификация статей"} |
|
model_label = {en: "Select model", ru: "Выберете модель"} |
|
title_label = {en: "Title", ru: "Название статьи"} |
|
authors_label = {en: "Author(s)", ru: "Автор(ы)"} |
|
abstract_label = {en: "Abstract", ru: "Аннотация"} |
|
metrics_label = {en: "Test metrics", ru: "Метрики на тренировочном датасете"} |
|
|
|
with open( |
|
Path(__file__).parent / "category_classification" / "test_results.json", "r" |
|
) as metric_f: |
|
metrics = json.load(metric_f) |
|
|
|
|
|
def text_area_height(line_height: int): |
|
return 34 * line_height |
|
|
|
|
|
@st.cache_data |
|
def load_class_model(name): |
|
model = class_models.get_model(name) |
|
return model |
|
|
|
|
|
lang = st.pills(label=langs_str, options=langs) |
|
if lang is None: |
|
lang = en |
|
st.title(page_title[lang]) |
|
model_name = st.selectbox( |
|
model_label[lang], options=class_models.get_model_names_by_lang(lang) |
|
) |
|
title = st.text_area(title_label[lang], height=text_area_height(2)) |
|
authors = st.text_area(authors_label[lang], height=text_area_height(2)) |
|
abstract = st.text_area(abstract_label[lang], height=text_area_height(5)) |
|
|
|
if title: |
|
input = {"title": title, "abstract": abstract, "authors": authors} |
|
model = load_class_model(model_name) |
|
results = model(input) |
|
results = process_results(results, lang) |
|
st.dataframe(results, hide_index=True) |
|
|
|
lang_metrics = pd.DataFrame(metrics[lang]) |
|
if not lang_metrics.empty: |
|
with st.expander(metrics_label[lang]): |
|
st.dataframe(lang_metrics) |
|
|