Spaces:
Runtime error
Runtime error
Commit
·
84fa2e9
1
Parent(s):
8bb7965
change threading options for onnx inference
Browse files
app.py
CHANGED
@@ -87,6 +87,10 @@ hide_streamlit_style = """
|
|
87 |
"""
|
88 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
89 |
|
|
|
|
|
|
|
|
|
90 |
|
91 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
92 |
def create_model_dir(chkpt, model_dir):
|
@@ -180,6 +184,9 @@ if select_task=='README':
|
|
180 |
if select_task == 'Detect Sentiment':
|
181 |
t1=time.time()
|
182 |
tokenizer_sentiment,sentiment_session = sentiment_task_selected(task=select_task)
|
|
|
|
|
|
|
183 |
t2 = time.time()
|
184 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
185 |
|
@@ -210,7 +217,9 @@ if select_task == 'Detect Sentiment':
|
|
210 |
|
211 |
if select_task=='Zero Shot Classification':
|
212 |
t1=time.time()
|
213 |
-
tokenizer_zs,
|
|
|
|
|
214 |
t2 = time.time()
|
215 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
216 |
|
@@ -225,7 +234,7 @@ if select_task=='Zero Shot Classification':
|
|
225 |
|
226 |
if response1:
|
227 |
start = time.time()
|
228 |
-
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=
|
229 |
_tokenizer=tokenizer_zs)
|
230 |
end = time.time()
|
231 |
st.write("")
|
|
|
87 |
"""
|
88 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
89 |
|
90 |
+
options = ort.SessionOptions()
|
91 |
+
options.intra_op_num_threads=1
|
92 |
+
options.inter_op_num_threads=1
|
93 |
+
|
94 |
|
95 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
96 |
def create_model_dir(chkpt, model_dir):
|
|
|
184 |
if select_task == 'Detect Sentiment':
|
185 |
t1=time.time()
|
186 |
tokenizer_sentiment,sentiment_session = sentiment_task_selected(task=select_task)
|
187 |
+
##below 2 steps are slower as caching is not enabled
|
188 |
+
# tokenizer_sentiment = AutoTokenizer.from_pretrained(sent_mdl_dir)
|
189 |
+
# sentiment_session = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}")
|
190 |
t2 = time.time()
|
191 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
192 |
|
|
|
217 |
|
218 |
if select_task=='Zero Shot Classification':
|
219 |
t1=time.time()
|
220 |
+
tokenizer_zs,session_zs = zs_task_selected(task=select_task)
|
221 |
+
# tokenizer_zs= AutoTokenizer.from_pretrained(zs_mdl_dir)
|
222 |
+
# session_zs = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
|
223 |
t2 = time.time()
|
224 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
225 |
|
|
|
234 |
|
235 |
if response1:
|
236 |
start = time.time()
|
237 |
+
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=session_zs,
|
238 |
_tokenizer=tokenizer_zs)
|
239 |
end = time.time()
|
240 |
st.write("")
|