Emil25 commited on
Commit
8fc09a5
·
verified ·
1 Parent(s): 566821d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -97
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import nltk
5
  from nltk.tokenize import sent_tokenize, word_tokenize
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
- import functools
8
 
9
 
10
  # Настройка конфигурации страницы Streamlit
@@ -14,11 +13,14 @@ st.set_page_config(
14
  )
15
 
16
 
17
- @functools.lru_cache(maxsize=None)
 
 
 
 
 
 
18
  def get_model():
19
- """
20
- Кэшируемая функция для загрузки модели и токенизатора.
21
- """
22
  # Загрузка модели
23
  model = AutoModelForCausalLM.from_pretrained('model')
24
  # Загрузка токенизатора
@@ -26,91 +28,13 @@ def get_model():
26
  return model, tokenizer
27
 
28
 
29
- def correct_sentence(sentence):
30
- """
31
- Функция для исправления предложений.
32
- Делает первую букву заглавной.
33
- """
34
- words = word_tokenize(sentence)
35
- if len(words) > 0:
36
- words[0] = words[0].capitalize()
37
- corrected_sentence = ' '.join(words)
38
- return corrected_sentence
39
-
40
-
41
- def process_reviews(reviews):
42
- """
43
- Функция для обработки списка отзывов.
44
- Исправляет каждое предложение в отзывах.
45
- """
46
- corrected_reviews = []
47
- for review in reviews:
48
- sentences = sent_tokenize(review)
49
- corrected_sentences = [correct_sentence(sentence) for sentence in sentences]
50
- corrected_reviews.append(' '.join(corrected_sentences))
51
- return corrected_reviews
52
-
53
-
54
- def load_nltk_data():
55
- """
56
- Функция для загрузки данных NLTK.
57
- """
58
- nltk.download('punkt')
59
- nltk.download('punkt_tab')
60
-
61
-
62
- def preprocess_input(input_text):
63
- """
64
- Функция для предварительной обработки входного текста.
65
- Убирает лишние символы, исправляет предложения.
66
- """
67
- # Удаляем лишние символы и извлекаем текст после двоеточия
68
- print(f" ДО {input_text}")
69
- input_text = input_text.split(":")[-1].strip()
70
- print(f" ПОСЛЕ {input_text}")
71
-
72
- # Токенизация предложений
73
- sentences = sent_tokenize(input_text)
74
-
75
- # Проверка на наличие предложений
76
- if len(sentences) == 0:
77
- return ""
78
-
79
- # Удаляем последнее предложение, если оно не заканчивается на точку
80
- last_sentence = sentences[-1]
81
- if not last_sentence.endswith('.'):
82
- sentences.pop()
83
-
84
- corrected_sentences = []
85
-
86
- # Исправление предложений
87
- for sentence in sentences:
88
- words = word_tokenize(sentence)
89
- if len(words) > 0:
90
- # Капитализация первого слова
91
- words[0] = words[0].capitalize()
92
- corrected_sentence = ' '.join(words)
93
- corrected_sentences.append(corrected_sentence)
94
-
95
- # Объединение исправленных предложений в финальный текст
96
- final_text = ' '.join(corrected_sentences)
97
-
98
- # Добавление точки в конце, если финальный текст не пустой
99
- if final_text and not final_text.endswith('.'):
100
- final_text += '.'
101
-
102
- return final_text
103
-
104
-
105
- def generate_review(input_text):
106
- """
107
- Функция для генерации отзыва на основе входного текста.
108
- """
109
  model, tokenizer = get_model()
110
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
111
  output = model.generate(
112
  input_ids,
113
- max_length=200,
114
  num_return_sequences=1,
115
  no_repeat_ngram_size=2,
116
  do_sample=True,
@@ -122,29 +46,54 @@ def generate_review(input_text):
122
  return tokenizer.decode(output[0], skip_special_tokens=True)
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def main():
126
- """
127
- Основная функция приложения Streamlit.
128
- """
129
  if 'btn_predict' not in st.session_state:
130
  st.session_state['btn_predict'] = False
131
 
132
- # Ввод данных пользователем
133
  category = st.text_input("Категория:", value="Кондитерская")
134
  rating = st.slider("Рейтинг", 1, 5, 1)
135
  key_words = st.text_input("Ключевые слова", value="десерт, торт, цена")
136
 
137
- # Формируем входной текст
138
  input_text = f"Категория: {category}; Рейтинг: {rating}; Ключевые слова: {key_words} -> Отзыв:"
139
 
140
- if st.button('Generate'): # Кнопка для генерации отзыва
141
  with st.spinner('Генерация отзыва...'):
142
- processed_input = preprocess_input(input_text)
143
- generated_text = generate_review(processed_input)
144
- st.success("Готово!")
145
- st.text(generated_text)
146
 
147
 
148
  if __name__ == "__main__":
149
- load_nltk_data()
150
  main()
 
4
  import nltk
5
  from nltk.tokenize import sent_tokenize, word_tokenize
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
7
 
8
 
9
  # Настройка конфигурации страницы Streamlit
 
13
  )
14
 
15
 
16
+ def download_nltk_data():
17
+ nltk.download('punkt')
18
+ nltk.download('punkt_tab')
19
+
20
+
21
+ # Загрузка модели и токенизатора
22
+ @st.cache_data()
23
  def get_model():
 
 
 
24
  # Загрузка модели
25
  model = AutoModelForCausalLM.from_pretrained('model')
26
  # Загрузка токенизатора
 
28
  return model, tokenizer
29
 
30
 
31
+ # Генерация отзыва
32
+ def gen_review(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  model, tokenizer = get_model()
34
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
35
  output = model.generate(
36
  input_ids,
37
+ max_length=300,
38
  num_return_sequences=1,
39
  no_repeat_ngram_size=2,
40
  do_sample=True,
 
46
  return tokenizer.decode(output[0], skip_special_tokens=True)
47
 
48
 
49
+ def capitalize_and_punctuate(text):
50
+ # Разделяем текст на предложения
51
+ sentences = sent_tokenize(text)
52
+
53
+ # Проверка последнего предложения
54
+ last_sentence = sentences[-1]
55
+ if not last_sentence.endswith('.'):
56
+ sentences.pop()
57
+
58
+ # Обрабатываем оставшиеся предложения
59
+ corrected_sentences = []
60
+ for sentence in sentences:
61
+ words = word_tokenize(sentence)
62
+
63
+ # Делаем первую букву первого слова заглавной
64
+ if len(words) > 0:
65
+ words[0] = words[0].capitalize()
66
+
67
+ # Собираем обратно предложение
68
+ corrected_sentence = ' '.join(words)
69
+ corrected_sentences.append(corrected_sentence)
70
+
71
+ # Объединяем все предложения в единый текст
72
+ final_text = ' '.join(corrected_sentences)
73
+
74
+ return final_text
75
+
76
+
77
+ # Главная функция
78
  def main():
 
 
 
79
  if 'btn_predict' not in st.session_state:
80
  st.session_state['btn_predict'] = False
81
 
 
82
  category = st.text_input("Категория:", value="Кондитерская")
83
  rating = st.slider("Рейтинг", 1, 5, 1)
84
  key_words = st.text_input("Ключевые слова", value="десерт, торт, цена")
85
 
86
+ # Ввод новых параметров
87
  input_text = f"Категория: {category}; Рейтинг: {rating}; Ключевые слова: {key_words} -> Отзыв:"
88
 
89
+ if st.button('Generate'):
90
  with st.spinner('Генерация отзыва...'):
91
+ generated_text = gen_review(input_text)
92
+ generated_text = capitalize_and_punctuate(generated_text)
93
+ st.success("Готово!")
94
+ st.text(generated_text)
95
 
96
 
97
  if __name__ == "__main__":
98
+ download_nltk_data()
99
  main()