Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,9 @@ import torch
|
|
3 |
import tokenizers
|
4 |
import streamlit as st
|
5 |
import re
|
|
|
6 |
from PIL import Image
|
|
|
7 |
|
8 |
|
9 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
|
@@ -37,22 +39,30 @@ def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, lengt
|
|
37 |
return list(map(tokenizer.decode, out))[0]
|
38 |
|
39 |
|
40 |
-
|
|
|
41 |
|
42 |
# st.title("NeuroKorzh")
|
43 |
|
44 |
image = Image.open('korzh.jpg')
|
45 |
st.image(image, caption='NeuroKorzh')
|
46 |
|
|
|
|
|
47 |
st.markdown("\n")
|
48 |
|
49 |
-
text = st.
|
50 |
button = st.button('Go')
|
51 |
|
52 |
if button:
|
53 |
#try:
|
54 |
with st.spinner("Generation in progress"):
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
#st.subheader('Max Korzh:')
|
58 |
#lines = result.split('\n')
|
|
|
3 |
import tokenizers
|
4 |
import streamlit as st
|
5 |
import re
|
6 |
+
|
7 |
from PIL import Image
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
|
11 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
|
|
|
39 |
return list(map(tokenizer.decode, out))[0]
|
40 |
|
41 |
|
42 |
+
medium_model, medium_tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
|
43 |
+
large_model, large_tokenizer = get_model('sberbank-ai/rugpt3large_based_on_gpt2', 'korzh-large_best_eval_loss.bin')
|
44 |
|
45 |
# st.title("NeuroKorzh")
|
46 |
|
47 |
image = Image.open('korzh.jpg')
|
48 |
st.image(image, caption='NeuroKorzh')
|
49 |
|
50 |
+
option = st.selectbox('Model to be used', ('medium', 'large'))
|
51 |
+
|
52 |
st.markdown("\n")
|
53 |
|
54 |
+
text = st.text_area(label='Starting point for text generation', value='Что делать, Макс?', height=200)
|
55 |
button = st.button('Go')
|
56 |
|
57 |
if button:
|
58 |
#try:
|
59 |
with st.spinner("Generation in progress"):
|
60 |
+
if option == 'medium':
|
61 |
+
result = predict(text, medium_model, medium_tokenizer)
|
62 |
+
elif option == 'large':
|
63 |
+
result = predict(text, large_model, large_tokenizer)
|
64 |
+
else:
|
65 |
+
raise st.error('Error in selectbox')
|
66 |
|
67 |
#st.subheader('Max Korzh:')
|
68 |
#lines = result.split('\n')
|