Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import re | |
# Название модели | |
model_name = "t-bank-ai/ruDialoGPT-medium" | |
# Загрузка токенизатора и модели | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Инициализация истории диалога | |
chat_history = [] | |
# Функция генерации с историей | |
# Декоратор ZeroGPU для выделения GPU на 60 секунд | |
def generate_response(prompt): | |
global chat_history | |
# Формирование контекста с предыдущими сообщениями | |
dialogue_context = "" | |
for i, (sender, message) in enumerate(chat_history[-6:]): # Берем последние 6 реплик для контекста | |
prefix = "@@ПЕРВЫЙ@@" if sender == "Ты" else "@@ВТОРОЙ@@" | |
dialogue_context += f"{prefix} {message} " | |
# Добавляем текущий запрос с меткой | |
dialogue_context += f"@@ПЕРВЫЙ@@ {prompt} @@ВТОРОЙ@@" | |
# Преобразуем в тензор и отправляем в модель | |
inputs = tokenizer(dialogue_context, return_tensors="pt") | |
generated_token_ids = model.generate( | |
**inputs, | |
top_k=10, | |
top_p=0.95, | |
num_beams=3, | |
num_return_sequences=1, | |
do_sample=True, | |
no_repeat_ngram_size=2, | |
temperature=1.2, | |
repetition_penalty=1.2, | |
length_penalty=1.0, | |
eos_token_id=50257, | |
max_new_tokens=40 | |
) | |
# Декодируем и очищаем ответ от тегов | |
response = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True) | |
cleaned_response = re.sub(r'@@ПЕРВЫЙ@@|@@ВТОРОЙ@@', '', response).strip() | |
# Добавляем текущий запрос и ответ в историю для отображения в чате | |
chat_history.append(("Ты", prompt)) # Реплика пользователя | |
chat_history.append(("Бот", cleaned_response)) # Ответ бота | |
return chat_history | |
# Интерфейс Gradio с отключённым live-режимом (будет кнопка отправки) | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs="text", | |
outputs="chatbot", # Вывод в виде чата | |
title="ruDialoGPT Chatbot с Историей", | |
live=False # Отключаем live-режим для кнопки отправки | |
) | |
iface.launch() |