IvaElen commited on
Commit
6b51df2
·
1 Parent(s): 6eea4a4

Update pages/GPT.py

Browse files
Files changed (1) hide show
  1. pages/GPT.py +6 -0
pages/GPT.py CHANGED
@@ -5,6 +5,7 @@ import transformers
5
  import random
6
  import textwrap
7
 
 
8
  def load_model():
9
  model_finetuned = transformers.AutoModelWithLMHead.from_pretrained(
10
  'tinkoff-ai/ruDialoGPT-small',
@@ -17,6 +18,7 @@ def load_model():
17
 
18
  def preprocess_text(text_input, tokenizer):
19
  prompt = tokenizer.encode(text_input, return_tensors='pt')
 
20
 
21
  def predict_sentiment(model, prompt, temp, num_generate):
22
  with torch.inference_mode():
@@ -31,6 +33,7 @@ def predict_sentiment(model, prompt, temp, num_generate):
31
  no_repeat_ngram_size=3,
32
  num_return_sequences=num_generate,
33
  ).cpu().numpy()
 
34
  return result
35
 
36
  st.title('Text generation with dreambook')
@@ -38,8 +41,11 @@ st.title('Text generation with dreambook')
38
  model, tokenizer = load_model()
39
 
40
  text_input = st.text_input("Enter some text about movie")
 
41
  max_len = st.slider('Length of sequence', 0, 500, 250)
 
42
  temp = st.slider('Temperature', 1, 30, 1)
 
43
  if st.button('Generate a random number of sequences'):
44
  num_generate = random.randint(1,5)
45
  st.write(f'Number of sequences: {num_generate}')
 
5
  import random
6
  import textwrap
7
 
8
+ @st.cache
9
  def load_model():
10
  model_finetuned = transformers.AutoModelWithLMHead.from_pretrained(
11
  'tinkoff-ai/ruDialoGPT-small',
 
18
 
19
  def preprocess_text(text_input, tokenizer):
20
  prompt = tokenizer.encode(text_input, return_tensors='pt')
21
+ return prompt
22
 
23
  def predict_sentiment(model, prompt, temp, num_generate):
24
  with torch.inference_mode():
 
33
  no_repeat_ngram_size=3,
34
  num_return_sequences=num_generate,
35
  ).cpu().numpy()
36
+ print(result)
37
  return result
38
 
39
  st.title('Text generation with dreambook')
 
41
  model, tokenizer = load_model()
42
 
43
  text_input = st.text_input("Enter some text about movie")
44
+ print(text_input)
45
  max_len = st.slider('Length of sequence', 0, 500, 250)
46
+ print(max_len)
47
  temp = st.slider('Temperature', 1, 30, 1)
48
+ print(temp)
49
  if st.button('Generate a random number of sequences'):
50
  num_generate = random.randint(1,5)
51
  st.write(f'Number of sequences: {num_generate}')