santiviquez's picture
text classification app
c56be05
raw
history blame
2.66 kB
import streamlit as st
from transformers import pipeline
import pandas as pd
import nannyml as nml
if 'count' not in st.session_state:
st.session_state.count = 0
if 'dissable' not in st.session_state:
st.session_state.dissable = False
def increment_counter():
st.session_state.count += 1
@st.cache_resource
def get_model(url):
tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
return pipeline(model=url, **tokenizer_kwargs)
rating_classification_model = get_model("NannyML/amazon-reviews-sentiment-bert-base-uncased-6000-samples")
label_mapping = {
'LABEL_0': 'Negative',
'LABEL_1': 'Neutral',
'LABEL_2': 'Positive'
}
review = st.text_input(label='write a review', value='I love this book!')
single_review_button = st.button(label='Classify Single Review')
if review and single_review_button:
rating = rating_classification_model(review)[0]
label = label_mapping[rating['label']]
score = rating['score']
st.write(f"{label} β€” confidence: {round(score, 2)}")
# # # # # # # #
reference_df = pd.read_csv('../reference.csv')
analysis_df = pd.read_csv('../analysis.csv')
reference_df['label'] = reference_df['label'].astype(str)
reference_df['pred_label'] = reference_df['pred_label'].astype(str)
analysis_df['label'] = analysis_df['label'].astype(str)
analysis_df['pred_label'] = analysis_df['pred_label'].astype(str)
estimator = nml.CBPE(
y_pred_proba={
'0': 'pred_proba_label_negative',
'1': 'pred_proba_label_neutral',
'2': 'pred_proba_label_positive'},
y_pred='pred_label',
y_true='label',
problem_type='classification_multiclass',
metrics='f1',
chunk_size=400,
)
estimator.fit(reference_df)
calculator = nml.PerformanceCalculator(
y_pred_proba={
'0': 'pred_proba_label_negative',
'1': 'pred_proba_label_neutral',
'2': 'pred_proba_label_positive'},
y_true='label',
y_pred='pred_label',
problem_type='classification_multiclass',
metrics=['f1'],
chunk_size=400,
)
calculator.fit(reference_df)
multiple_reviews_button = st.button('Estimate Model Performance on 400 Reviews', on_click=increment_counter, disabled=st.session_state.dissable)
if multiple_reviews_button:
prod_data = analysis_df[0: st.session_state.count * 400]
results = estimator.estimate(prod_data.drop(columns=['label']))
realize_results = calculator.calculate(prod_data)
fig = results.compare(realize_results).plot()
st.plotly_chart(fig, use_container_width=True, theme=None)
st.write(f'Batch {st.session_state.count} / 5')
if st.session_state.count >= 5:
st.session_state.count = 0