Spaces:
Runtime error
Runtime error
File size: 3,635 Bytes
ad74093 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
"""⭐ Text Classification with Optimum and ONNXRuntime
Author:
- @ChainYo - https://github.com/ChainYo
"""
import streamlit as st
from transformers import AutoTokenizer, AutoModel, pipeline
from optimum.onnxruntime import ORTModelForTextClassification
from optimum.pipelines import pipeline
MODEL_PATH = "ProsusAI/finbert"
st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐")
st.title("🤗 Optimum Text Classification")
st.subheader("Classify financial text with 🤗 Optimum and ONNXRuntime")
st.markdown("""
[](https://github.com/ChainYo)
[](https://huggingface.co/ChainYo)
[](https://www.linkedin.com/in/thomas-chaigneau-dev/)
[](https://discord.gg/)
""")
if "tokenizer" not in st.session_state:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
st.session_state["tokenizer"] = tokenizer
if "ort_model" not in st.session_state:
ort_model = ORTModelForTextClassification.from_pretrained(MODEL_PATH, from_transformers=True)
st.session_state["ort_model"] = ort_model
if "pt_model" not in st.session_state:
pt_model = AutoModel.from_pretrained(MODEL_PATH)
st.session_state["pt_model"] = pt_model
if "ort_pipeline" not in st.session_state:
ort_pipeline = pipeline(
"text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"]
)
st.session_state["ort_pipeline"] = ort_pipeline
if "pt_pipeline" not in st.session_state:
pt_pipeline = pipeline(
"text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"]
)
st.session_state["pt_pipeline"] = pt_pipeline
model_format = st.radio("Choose the model format", ("PyTorch", "ONNXRuntime"))
optimized = st.checkbox("Optimize the model for inference", value=False)
quantized = st.checkbox("Quantize the model", value=False)
if model_format == "PyTorch":
optimized.disabled = True
quantized.disabled = True
|