File size: 3,635 Bytes
ad74093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""⭐ Text Classification with Optimum and ONNXRuntime

Author:
    - @ChainYo - https://github.com/ChainYo
"""

import streamlit as st

from transformers import AutoTokenizer, AutoModel, pipeline
from optimum.onnxruntime import ORTModelForTextClassification
from optimum.pipelines import pipeline


MODEL_PATH = "ProsusAI/finbert"

st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐")
st.title("🤗 Optimum Text Classification")
st.subheader("Classify financial text with 🤗 Optimum and ONNXRuntime")
st.markdown("""
[![GitHub](https://img.shields.io/badge/-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/ChainYo)
[![HuggingFace](https://img.shields.io/badge/-yellow.svg?style=for-the-badge&logo=)](https://huggingface.co/ChainYo)
[![LinkedIn](https://img.shields.io/badge/-%230077B5.svg?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/thomas-chaigneau-dev/)
[![Discord](https://img.shields.io/badge/Chainyo%233610-%237289DA.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/)
""")

if "tokenizer" not in st.session_state:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    st.session_state["tokenizer"] = tokenizer

if "ort_model" not in st.session_state:
    ort_model = ORTModelForTextClassification.from_pretrained(MODEL_PATH, from_transformers=True)
    st.session_state["ort_model"] = ort_model

if "pt_model" not in st.session_state:
    pt_model = AutoModel.from_pretrained(MODEL_PATH)
    st.session_state["pt_model"] = pt_model

if "ort_pipeline" not in st.session_state:
    ort_pipeline = pipeline(
        "text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"]
    )
    st.session_state["ort_pipeline"] = ort_pipeline

if "pt_pipeline" not in st.session_state:
    pt_pipeline = pipeline(
        "text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"]
    )
    st.session_state["pt_pipeline"] = pt_pipeline


model_format = st.radio("Choose the model format", ("PyTorch", "ONNXRuntime"))
optimized = st.checkbox("Optimize the model for inference", value=False)
quantized = st.checkbox("Quantize the model", value=False)

if model_format == "PyTorch":
    optimized.disabled = True
    quantized.disabled = True