|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
import torch |
|
|
|
def initialize_chatbot(): |
|
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
return model, tokenizer |
|
|
|
def get_chatbot_response(model, tokenizer, prompt, src_lang): |
|
tokenizer.src_lang = src_lang |
|
encoded_input = tokenizer(prompt, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_input, max_length=100) |
|
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
def display_chatbot_interface(lang_code): |
|
translations = { |
|
'es': { |
|
'title': "AIdeaText - Chatbot Multilingüe", |
|
'input_placeholder': "Escribe tu mensaje aquí...", |
|
'send_button': "Enviar", |
|
}, |
|
'en': { |
|
'title': "AIdeaText - Multilingual Chatbot", |
|
'input_placeholder': "Type your message here...", |
|
'send_button': "Send", |
|
}, |
|
'fr': { |
|
'title': "AIdeaText - Chatbot Multilingue", |
|
'input_placeholder': "Écrivez votre message ici...", |
|
'send_button': "Envoyer", |
|
} |
|
} |
|
|
|
t = translations[lang_code] |
|
|
|
st.header(t['title']) |
|
|
|
if 'chatbot' not in st.session_state: |
|
st.session_state.chatbot, st.session_state.tokenizer = initialize_chatbot() |
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
if prompt := st.chat_input(t['input_placeholder']): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
with st.chat_message("assistant"): |
|
response = get_chatbot_response(st.session_state.chatbot, st.session_state.tokenizer, prompt, lang_code) |
|
st.markdown(response) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
store_chat_history(st.session_state.username, st.session_state.messages) |