SaviAnna commited on
Commit
1c8cc34
·
verified ·
1 Parent(s): 17a60a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -27,7 +27,7 @@ def load_custom_model():
27
  # Функция для генерации текста
28
  def generate_text(model, tokenizer, prompt, max_len, temperature):
29
  input_ids = tokenizer.encode(prompt, return_tensors='pt')
30
-
31
  # Генерация текста
32
  output = model.generate(
33
  input_ids,
@@ -38,6 +38,7 @@ def generate_text(model, tokenizer, prompt, max_len, temperature):
38
  repetition_penalty=1.2, # Штраф за повторение слов или фраз
39
  no_repeat_ngram_size=4, # Запрет на повторение n-грамм (например, биграмм)
40
  do_sample=True, # Включение сэмплинга для большей разнообразности
 
41
  pad_token_id=tokenizer.eos_token_id
42
  )
43
 
@@ -64,12 +65,12 @@ def main():
64
 
65
  # Блок для генерации текста с кастомной моделью
66
  st.subheader("Custom Model Text Generation")
67
- prompt_custom = st.text_area("Введите фразу для генерации с кастомной моделью:", value="Когда-то давно")
68
- generate_button_custom = st.button("Сгенерировать текст с кастомной моделью")
69
 
70
  if generate_button_custom:
71
  generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
72
- st.subheader("Результат генерации с кастомной моделью:")
73
  st.write(generated_text_custom)
74
 
75
  if __name__ == "__main__":
 
27
  # Функция для генерации текста
28
  def generate_text(model, tokenizer, prompt, max_len, temperature):
29
  input_ids = tokenizer.encode(prompt, return_tensors='pt')
30
+ attention_mask = (input_ids != tokenizer.pad_token_id).long()
31
  # Генерация текста
32
  output = model.generate(
33
  input_ids,
 
38
  repetition_penalty=1.2, # Штраф за повторение слов или фраз
39
  no_repeat_ngram_size=4, # Запрет на повторение n-грамм (например, биграмм)
40
  do_sample=True, # Включение сэмплинга для большей разнообразности
41
+ attention_mask=attention_mask
42
  pad_token_id=tokenizer.eos_token_id
43
  )
44
 
 
65
 
66
  # Блок для генерации текста с кастомной моделью
67
  st.subheader("Custom Model Text Generation")
68
+ prompt_custom = st.text_area("Enter a phrase to generate with the updated model:", value="Когда-то давно")
69
+ generate_button_custom = st.button("Generate!")
70
 
71
  if generate_button_custom:
72
  generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
73
+ st.subheader("Result:")
74
  st.write(generated_text_custom)
75
 
76
  if __name__ == "__main__":