Davide Fiocco commited on
Commit
57fe04c
·
1 Parent(s): 76b4c44

Try out better caching for the pipeline

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import pandas as pd
2
  import streamlit as st
3
- from transformers import pipeline
 
 
4
 
5
  st.set_page_config(
6
  page_title="Zero-shot classification from tabular data",
@@ -10,8 +12,21 @@ st.set_page_config(
10
  menu_items=None,
11
  )
12
 
13
- with st.spinner("Setting stuff up related to the inference engine..."):
 
 
 
 
 
 
 
 
 
14
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
 
 
15
 
16
  st.title("Zero-shot classification from tabular data")
17
  st.text(
 
1
  import pandas as pd
2
  import streamlit as st
3
+ import tokenizers
4
+ import torch
5
+ from transformers import Pipeline, pipeline
6
 
7
  st.set_page_config(
8
  page_title="Zero-shot classification from tabular data",
 
12
  menu_items=None,
13
  )
14
 
15
+
16
+ @st.cache(
17
+ hash_funcs={
18
+ torch.nn.parameter.Parameter: lambda _: None,
19
+ tokenizers.Tokenizer: lambda _: None,
20
+ tokenizers.AddedToken: lambda _: None,
21
+ },
22
+ allow_output_mutation=True,
23
+ )
24
+ def load_classifier() -> Pipeline:
25
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
26
+ return classifier
27
+
28
+
29
+ classifier = load_classifier()
30
 
31
  st.title("Zero-shot classification from tabular data")
32
  st.text(