Emil25 commited on
Commit
6d9443c
·
verified ·
1 Parent(s): 0b757d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -59
app.py CHANGED
@@ -5,7 +5,6 @@ import nltk
5
  from nltk.tokenize import sent_tokenize, word_tokenize
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from nltk.data import find
8
- import functools
9
 
10
 
11
  # Настройка конфигурации страницы Streamlit
@@ -15,7 +14,21 @@ st.set_page_config(
15
  )
16
 
17
 
18
- @functools.lru_cache(maxsize=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_model():
20
  # Загрузка модели
21
  model = AutoModelForCausalLM.from_pretrained('model')
@@ -24,7 +37,7 @@ def get_model():
24
  return model, tokenizer
25
 
26
 
27
- @functools.lru_cache(maxsize=None)
28
  def gen_review(input_text):
29
  model, tokenizer = get_model()
30
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
@@ -42,76 +55,39 @@ def gen_review(input_text):
42
  return tokenizer.decode(output[0], skip_special_tokens=True)
43
 
44
 
45
- def correct_sentence(sentence):
46
- """Функция для исправления предложений."""
47
- words = word_tokenize(sentence)
48
-
49
- # Делаем первую букву первого слова заглавной
50
- if len(words) > 0:
51
- words[0] = words[0].capitalize()
52
-
53
- # Собираем обратно предложение
54
- corrected_sentence = ' '.join(words)
55
- return corrected_sentence
56
-
57
-
58
- def process_reviews(reviews):
59
- """Функция для обработки списка отзывов."""
60
- corrected_reviews = []
61
- for review in reviews:
62
- sentences = sent_tokenize(review)
63
- corrected_sentences = [correct_sentence(sentence) for sentence in sentences]
64
- corrected_reviews.append(' '.join(corrected_sentences))
65
- return corrected_reviews
66
-
67
 
68
- def load_nltk_data():
69
- try:
70
- find('tokenizers/punkt')
71
- find('tokenizers/punkt_tab')
72
- print("Данные уже загружены.")
73
- except LookupError:
74
- print("Загрузка данных NLTK...")
75
- nltk.download(['punkt', 'punkt_tab'])
76
-
77
-
78
- def preprocess_input(input_text):
79
- sentences = sent_tokenize(input_text)
80
  last_sentence = sentences[-1]
81
  if not last_sentence.endswith('.'):
82
  sentences.pop()
 
 
83
  corrected_sentences = []
84
  for sentence in sentences:
85
  words = word_tokenize(sentence)
 
 
86
  if len(words) > 0:
87
  words[0] = words[0].capitalize()
 
 
88
  corrected_sentence = ' '.join(words)
89
  corrected_sentences.append(corrected_sentence)
90
- final_text = ' '.join(corrected_sentences)
91
- return final_text
92
 
 
 
93
 
94
- def generate_review(input_text):
95
- model, tokenizer = get_model()
96
- input_ids = tokenizer.encode(input_text, return_tensors='pt')
97
- output = model.generate(
98
- input_ids,
99
- max_length=300,
100
- num_return_sequences=1,
101
- no_repeat_ngram_size=2,
102
- do_sample=True,
103
- top_p=0.95,
104
- top_k=60,
105
- temperature=0.9,
106
- eos_token_id=tokenizer.eos_token_id,
107
- )
108
- return tokenizer.decode(output[0], skip_special_tokens=True)
109
 
110
 
 
111
  def main():
112
  if 'btn_predict' not in st.session_state:
113
  st.session_state['btn_predict'] = False
114
- load_nltk_data()
115
 
116
  category = st.text_input("Категория:", value="Кондитерская")
117
  rating = st.slider("Рейтинг", 1, 5, 1)
@@ -122,10 +98,11 @@ def main():
122
 
123
  if st.button('Generate'):
124
  with st.spinner('Генерация отзыва...'):
125
- processed_input = preprocess_input(input_text)
126
- generated_text = generate_review(processed_input)
127
- st.success("Готово!")
128
- st.text(generated_text)
 
129
 
130
  if __name__ == "__main__":
131
  main()
 
5
  from nltk.tokenize import sent_tokenize, word_tokenize
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from nltk.data import find
 
8
 
9
 
10
  # Настройка конфигурации страницы Streamlit
 
14
  )
15
 
16
 
17
+ def download_nltk_data():
18
+ try:
19
+ # Проверяем, установлены ли данные
20
+ find('tokenizers/punkt')
21
+ find('tokenizers/punkt_tab')
22
+ print("Данные уже загружены.")
23
+ except LookupError:
24
+ # Если данные не найдены, загружаем их
25
+ print("Загрузка данных NLTK...")
26
+ nltk.download('punkt')
27
+ nltk.download('punkt_tab')
28
+
29
+
30
+ # Загрузка модели и токенизатора
31
+ @st.cache_data()
32
  def get_model():
33
  # Загрузка модели
34
  model = AutoModelForCausalLM.from_pretrained('model')
 
37
  return model, tokenizer
38
 
39
 
40
+ # Генерация отзыва
41
  def gen_review(input_text):
42
  model, tokenizer = get_model()
43
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
 
55
  return tokenizer.decode(output[0], skip_special_tokens=True)
56
 
57
 
58
+ def capitalize_and_punctuate(text):
59
+ download_nltk_data()
60
+ # Разделяем текст на предложения
61
+ sentences = sent_tokenize(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Проверка последнего предложения
 
 
 
 
 
 
 
 
 
 
 
64
  last_sentence = sentences[-1]
65
  if not last_sentence.endswith('.'):
66
  sentences.pop()
67
+
68
+ # Обрабатываем оставшиеся предложения
69
  corrected_sentences = []
70
  for sentence in sentences:
71
  words = word_tokenize(sentence)
72
+
73
+ # Делаем первую букву первого слова заглавной
74
  if len(words) > 0:
75
  words[0] = words[0].capitalize()
76
+
77
+ # Собираем обратно предложение
78
  corrected_sentence = ' '.join(words)
79
  corrected_sentences.append(corrected_sentence)
 
 
80
 
81
+ # Объединяем все предложения в единый текст
82
+ final_text = ' '.join(corrected_sentences)
83
 
84
+ return final_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
+ # Главная функция
88
  def main():
89
  if 'btn_predict' not in st.session_state:
90
  st.session_state['btn_predict'] = False
 
91
 
92
  category = st.text_input("Категория:", value="Кондитерская")
93
  rating = st.slider("Рейтинг", 1, 5, 1)
 
98
 
99
  if st.button('Generate'):
100
  with st.spinner('Генерация отзыва...'):
101
+ generated_text = gen_review(input_text)
102
+ generated_text = capitalize_and_punctuate(generated_text)
103
+ st.success("Готово!")
104
+ st.text(generated_text)
105
+
106
 
107
  if __name__ == "__main__":
108
  main()