File size: 3,589 Bytes
c38b9d4
 
 
 
 
 
 
6b7f048
 
c38b9d4
 
 
 
 
8b4a0aa
 
 
 
c38b9d4
8fc09a5
 
 
 
 
 
22eee74
 
c38b9d4
 
 
 
 
6b7f048
c38b9d4
 
8fc09a5
 
1425e9d
 
 
 
293fc6f
1425e9d
 
 
 
 
 
 
 
 
 
 
8fc09a5
 
84c93e4
8fc09a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ef3cc8
 
8fc09a5
 
 
 
c38b9d4
8b4a0aa
c38b9d4
 
 
 
 
 
6b7f048
8fc09a5
c38b9d4
 
8fc09a5
25a0fee
8fc09a5
 
 
 
25a0fee
c54be94
c38b9d4
8fc09a5
c38b9d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import pandas as pd
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from transformers import AutoModelForCausalLM, AutoTokenizer


# Настройка конфигурации страницы Streamlit
st.set_page_config(
    page_title="Generate reviews",
    initial_sidebar_state="expanded"
)

# Заголовок приложения
st.title("Генератор отзывов на основе ИИ")
st.write("Создайте уникальные текстовые отзывы о различных местах на основе категорий, рейтинга и ключевых слов.")


def download_nltk_data():
    nltk.download('punkt')
    nltk.download('punkt_tab')


# Загрузка модели и токенизатора
# @st.cache_data()
@st.cache_resource
def get_model():
    # Загрузка модели
    model = AutoModelForCausalLM.from_pretrained('model')
    # Загрузка токенизатора
    tokenizer = AutoTokenizer.from_pretrained('model')
    return model, tokenizer


# Генерация отзыва
def gen_review(input_text):
    model, tokenizer = get_model()
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    output = model.generate(
        input_ids,
        max_length=200,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        do_sample=True,
        top_p=0.95,
        top_k=60,
        temperature=0.9,
        eos_token_id=tokenizer.eos_token_id,
    )
    return tokenizer.decode(output[0], skip_special_tokens=True)


def capitalize_and_punctuate(text):    
    # Разделяем текст на предложения
    text = text.split(":")[-1].strip()
    sentences = sent_tokenize(text)

    # Проверка последнего предложения
    last_sentence = sentences[-1]
    if not last_sentence.endswith('.'):
        sentences.pop()

    # Обрабатываем оставшиеся предложения
    corrected_sentences = []
    for sentence in sentences:
        words = word_tokenize(sentence)

        # Делаем первую букву первого слова заглавной
        if len(words) > 0:
            words[0] = words[0].capitalize()

        # Собираем обратно предложение
        corrected_sentence = ' '.join(words)
        corrected_sentences.append(corrected_sentence)

    # Объединяем все предложения в единый текст
    final_text = ' '.join(corrected_sentences)
    final_text = final_text.replace(' .', '.')
    
    return final_text


# Главная функция
def main():
    
    if 'btn_predict' not in st.session_state:
        st.session_state['btn_predict'] = False

    category = st.text_input("Категория:", value="Кондитерская")
    rating = st.slider("Рейтинг", 1, 5, 1)
    key_words = st.text_input("Ключевые слова", value="десерт, торт, цена")

    # Ввод новых параметров
    input_text = f"Категория: {category}; Рейтинг: {rating}; Ключевые слова: {key_words} -> Отзыв:"

    if st.button('Generate'):
        with st.spinner('Генерация отзыва...'):
            generated_text = gen_review(input_text)
            generated_text = capitalize_and_punctuate(generated_text)        
        st.success("Готово!")
        st.text(generated_text)


if __name__ == "__main__":
    download_nltk_data()
    main()