test2 / modules /chatbot.py
AIdeaText's picture
Update modules/chatbot.py
6a9fc93 verified
raw
history blame
2.3 kB
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
def initialize_chatbot():
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
return model, tokenizer
def get_chatbot_response(model, tokenizer, prompt, src_lang):
tokenizer.src_lang = src_lang
encoded_input = tokenizer(prompt, return_tensors="pt")
generated_tokens = model.generate(**encoded_input, max_length=100)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
def display_chatbot_interface(lang_code):
translations = {
'es': {
'title': "AIdeaText - Chatbot Multilingüe",
'input_placeholder': "Escribe tu mensaje aquí...",
'send_button': "Enviar",
},
'en': {
'title': "AIdeaText - Multilingual Chatbot",
'input_placeholder': "Type your message here...",
'send_button': "Send",
},
'fr': {
'title': "AIdeaText - Chatbot Multilingue",
'input_placeholder': "Écrivez votre message ici...",
'send_button': "Envoyer",
}
}
t = translations[lang_code]
st.header(t['title'])
if 'chatbot' not in st.session_state:
st.session_state.chatbot, st.session_state.tokenizer = initialize_chatbot()
if 'messages' not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input(t['input_placeholder']):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
response = get_chatbot_response(st.session_state.chatbot, st.session_state.tokenizer, prompt, lang_code)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
# Guardar la conversación en la base de datos
store_chat_history(st.session_state.username, st.session_state.messages)