vladyur commited on
Commit
833aa0a
·
1 Parent(s): 2499488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -3,7 +3,9 @@ import torch
3
  import tokenizers
4
  import streamlit as st
5
  import re
 
6
  from PIL import Image
 
7
 
8
 
9
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
@@ -37,22 +39,30 @@ def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, lengt
37
  return list(map(tokenizer.decode, out))[0]
38
 
39
 
40
- model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
 
41
 
42
  # st.title("NeuroKorzh")
43
 
44
  image = Image.open('korzh.jpg')
45
  st.image(image, caption='NeuroKorzh')
46
 
 
 
47
  st.markdown("\n")
48
 
49
- text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
50
  button = st.button('Go')
51
 
52
  if button:
53
  #try:
54
  with st.spinner("Generation in progress"):
55
- result = predict(text, model, tokenizer)
 
 
 
 
 
56
 
57
  #st.subheader('Max Korzh:')
58
  #lines = result.split('\n')
 
3
  import tokenizers
4
  import streamlit as st
5
  import re
6
+
7
  from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
 
10
 
11
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
 
39
  return list(map(tokenizer.decode, out))[0]
40
 
41
 
42
+ medium_model, medium_tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
43
+ large_model, large_tokenizer = get_model('sberbank-ai/rugpt3large_based_on_gpt2', 'korzh-large_best_eval_loss.bin')
44
 
45
  # st.title("NeuroKorzh")
46
 
47
  image = Image.open('korzh.jpg')
48
  st.image(image, caption='NeuroKorzh')
49
 
50
+ option = st.selectbox('Model to be used', ('medium', 'large'))
51
+
52
  st.markdown("\n")
53
 
54
+ text = st.text_area(label='Starting point for text generation', value='Что делать, Макс?', height=200)
55
  button = st.button('Go')
56
 
57
  if button:
58
  #try:
59
  with st.spinner("Generation in progress"):
60
+ if option == 'medium':
61
+ result = predict(text, medium_model, medium_tokenizer)
62
+ elif option == 'large':
63
+ result = predict(text, large_model, large_tokenizer)
64
+ else:
65
+ raise st.error('Error in selectbox')
66
 
67
  #st.subheader('Max Korzh:')
68
  #lines = result.split('\n')