Spaces:
Running
Running
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
model_name = 'KhantKyaw/Chat_GPT-2' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'sep_token': '<SEP>', 'pad_token': '<PAD>'} | |
tokenizer.add_special_tokens(special_tokens_dict) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Function to generate a response | |
def generate_response(input_text): | |
# Adjusted input to include the [Bot] marker | |
#adjusted_input = f"{input_text} [Bot]" | |
# Encode the adjusted input | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Generate a sequence of text with a slightly increased max_length to account for the prompt length | |
output_sequences = model.generate( | |
input_ids=inputs['input_ids'], | |
attention_mask=inputs['attention_mask'], | |
max_length=100, # Adjusted max_length | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
no_repeat_ngram_size=2, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
#early_stopping=True, | |
do_sample=True | |
) | |
# Decode the generated sequence | |
full_generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
# Extract the generated response after the [Bot] marker | |
bot_response_start = full_generated_text.find('[Bot]') + len('[Bot]') | |
bot_response = full_generated_text[bot_response_start:] | |
return bot_response | |
# Load pre-trained model tokenizer (vocabulary) and model | |
model_name = 'KhantKyaw/Chat_GPT-2' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
st.title("Chat_GPT-2 Bot") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# React to user input | |
if prompt := st.chat_input("What is up?"): | |
# Display user message in chat message container | |
st.chat_message("user").markdown(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
response = generate_response(prompt) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |