Spaces:
Runtime error
Runtime error
"""⭐ Text Classification with Optimum and ONNXRuntime | |
Author: | |
- @ChainYo - https://github.com/ChainYo | |
""" | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModel, pipeline | |
from optimum.onnxruntime import ORTModelForSequenceClassification | |
from optimum.pipelines import pipeline | |
MODEL_PATH = "cardiffnlp/twitter-roberta-base-sentiment-latest" | |
st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐") | |
st.title("🤗 Optimum Text Classification") | |
st.subheader("Sentiment analysis with 🤗 Optimum and ONNXRuntime") | |
st.markdown(""" | |
[](https://github.com/ChainYo) | |
[](https://huggingface.co/ChainYo) | |
[](https://www.linkedin.com/in/thomas-chaigneau-dev/) | |
[](https://discord.gg/) | |
""") | |
if "tokenizer" not in st.session_state: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
st.session_state["tokenizer"] = tokenizer | |
if "ort_model" not in st.session_state: | |
ort_model = ORTModelForSequenceClassification.from_pretrained(MODEL_PATH, from_transformers=True) | |
st.session_state["ort_model"] = ort_model | |
if "pt_model" not in st.session_state: | |
pt_model = AutoModel.from_pretrained(MODEL_PATH) | |
st.session_state["pt_model"] = pt_model | |
if "ort_pipeline" not in st.session_state: | |
ort_pipeline = pipeline( | |
"text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"] | |
) | |
st.session_state["ort_pipeline"] = ort_pipeline | |
if "pt_pipeline" not in st.session_state: | |
pt_pipeline = pipeline( | |
"text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"] | |
) | |
st.session_state["pt_pipeline"] = pt_pipeline | |
model_format = st.radio("Choose the model format", ("PyTorch", "ONNXRuntime")) | |
optimized = st.checkbox("Optimize the model for inference", value=False) | |
quantized = st.checkbox("Quantize the model", value=False) | |
if model_format == "PyTorch": | |
optimized.disabled = True | |
quantized.disabled = True | |