Emil25 commited on
Commit
1425e9d
·
verified ·
1 Parent(s): 6d9443c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -36
app.py CHANGED
@@ -5,6 +5,7 @@ import nltk
5
  from nltk.tokenize import sent_tokenize, word_tokenize
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from nltk.data import find
 
8
 
9
 
10
  # Настройка конфигурации страницы Streamlit
@@ -14,21 +15,7 @@ st.set_page_config(
14
  )
15
 
16
 
17
- def download_nltk_data():
18
- try:
19
- # Проверяем, установлены ли данные
20
- find('tokenizers/punkt')
21
- find('tokenizers/punkt_tab')
22
- print("Данные уже загружены.")
23
- except LookupError:
24
- # Если данные не найдены, загружаем их
25
- print("Загрузка данных NLTK...")
26
- nltk.download('punkt')
27
- nltk.download('punkt_tab')
28
-
29
-
30
- # Загрузка модели и токенизатора
31
- @st.cache_data()
32
  def get_model():
33
  # Загрузка модели
34
  model = AutoModelForCausalLM.from_pretrained('model')
@@ -37,7 +24,7 @@ def get_model():
37
  return model, tokenizer
38
 
39
 
40
- # Генерация отзыва
41
  def gen_review(input_text):
42
  model, tokenizer = get_model()
43
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
@@ -55,36 +42,73 @@ def gen_review(input_text):
55
  return tokenizer.decode(output[0], skip_special_tokens=True)
56
 
57
 
58
- def capitalize_and_punctuate(text):
59
- download_nltk_data()
60
- # Разделяем текст на предложения
61
- sentences = sent_tokenize(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Проверка последнего предложения
 
 
 
 
 
 
 
 
 
 
 
 
64
  last_sentence = sentences[-1]
65
  if not last_sentence.endswith('.'):
66
  sentences.pop()
67
-
68
- # Обрабатываем оставшиеся предложения
69
  corrected_sentences = []
70
  for sentence in sentences:
71
  words = word_tokenize(sentence)
72
-
73
- # Делаем первую букву первого слова заглавной
74
  if len(words) > 0:
75
  words[0] = words[0].capitalize()
76
-
77
- # Собираем обратно предложение
78
  corrected_sentence = ' '.join(words)
79
  corrected_sentences.append(corrected_sentence)
80
-
81
- # Объединяем все предложения в единый текст
82
  final_text = ' '.join(corrected_sentences)
83
-
84
  return final_text
85
 
86
 
87
- # Главная функция
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def main():
89
  if 'btn_predict' not in st.session_state:
90
  st.session_state['btn_predict'] = False
@@ -98,11 +122,10 @@ def main():
98
 
99
  if st.button('Generate'):
100
  with st.spinner('Генерация отзыва...'):
101
- generated_text = gen_review(input_text)
102
- generated_text = capitalize_and_punctuate(generated_text)
103
- st.success("Готово!")
104
- st.text(generated_text)
105
-
106
 
107
  if __name__ == "__main__":
108
  main()
 
5
  from nltk.tokenize import sent_tokenize, word_tokenize
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from nltk.data import find
8
+ import functools
9
 
10
 
11
  # Настройка конфигурации страницы Streamlit
 
15
  )
16
 
17
 
18
+ @functools.lru_cache(maxsize=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_model():
20
  # Загрузка модели
21
  model = AutoModelForCausalLM.from_pretrained('model')
 
24
  return model, tokenizer
25
 
26
 
27
+ @functools.lru_cache(maxsize=None)
28
  def gen_review(input_text):
29
  model, tokenizer = get_model()
30
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
 
42
  return tokenizer.decode(output[0], skip_special_tokens=True)
43
 
44
 
45
+ def correct_sentence(sentence):
46
+ """Функция для исправления предложений."""
47
+ words = word_tokenize(sentence)
48
+
49
+ # Делаем первую букву первого слова заглавной
50
+ if len(words) > 0:
51
+ words[0] = words[0].capitalize()
52
+
53
+ # Собираем обратно предложение
54
+ corrected_sentence = ' '.join(words)
55
+ return corrected_sentence
56
+
57
+
58
+ def process_reviews(reviews):
59
+ """Функция для обработки списка отзывов."""
60
+ corrected_reviews = []
61
+ for review in reviews:
62
+ sentences = sent_tokenize(review)
63
+ corrected_sentences = [correct_sentence(sentence) for sentence in sentences]
64
+ corrected_reviews.append(' '.join(corrected_sentences))
65
+ return corrected_reviews
66
+
67
 
68
+ def load_nltk_data():
69
+ try:
70
+ find('tokenizers/punkt')
71
+ find('tokenizers/punkt_tab')
72
+ print("Данные уже загружены.")
73
+ except LookupError:
74
+ print("Загрузка данных NLTK...")
75
+ nltk.download(['punkt', 'punkt_tab'])
76
+
77
+
78
+ def preprocess_input(input_text):
79
+ input_text = input_text.split(":")[-1].strip()
80
+ sentences = sent_tokenize(input_text)
81
  last_sentence = sentences[-1]
82
  if not last_sentence.endswith('.'):
83
  sentences.pop()
 
 
84
  corrected_sentences = []
85
  for sentence in sentences:
86
  words = word_tokenize(sentence)
 
 
87
  if len(words) > 0:
88
  words[0] = words[0].capitalize()
 
 
89
  corrected_sentence = ' '.join(words)
90
  corrected_sentences.append(corrected_sentence)
 
 
91
  final_text = ' '.join(corrected_sentences)
 
92
  return final_text
93
 
94
 
95
+ def generate_review(input_text):
96
+ model, tokenizer = get_model()
97
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
98
+ output = model.generate(
99
+ input_ids,
100
+ max_length=300,
101
+ num_return_sequences=1,
102
+ no_repeat_ngram_size=2,
103
+ do_sample=True,
104
+ top_p=0.95,
105
+ top_k=60,
106
+ temperature=0.9,
107
+ eos_token_id=tokenizer.eos_token_id,
108
+ )
109
+ return tokenizer.decode(output[0], skip_special_tokens=True)
110
+
111
+
112
  def main():
113
  if 'btn_predict' not in st.session_state:
114
  st.session_state['btn_predict'] = False
 
122
 
123
  if st.button('Generate'):
124
  with st.spinner('Генерация отзыва...'):
125
+ processed_input = preprocess_input(input_text)
126
+ generated_text = generate_review(processed_input)
127
+ st.success("Готово!")
128
+ st.text(generated_text)
 
129
 
130
  if __name__ == "__main__":
131
  main()