JamalAG commited on
Commit
9d8e92e
·
1 Parent(s): 0221f0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -1,23 +1,33 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
 
4
- # Load the conversational pipeline
5
- conversational_pipeline = pipeline("conversational")
 
6
 
7
  # Streamlit app header
8
  st.set_page_config(page_title="Conversational Model Demo", page_icon="🤖")
9
  st.header("Conversational Model Demo")
10
 
 
 
 
11
  # Input for user message
12
  user_message = st.text_input("You:", "")
13
 
14
  if st.button("Send"):
15
- # Format the conversation for the conversational pipeline
16
- conversation_history = [{"role": "system", "content": "You are an AI assistant."},
17
- {"role": "user", "content": user_message}]
 
 
 
 
 
18
 
19
- # Get the model's response
20
- model_response = conversational_pipeline(conversation_history)[0]['generated_text']
21
 
22
  # Display the model's response
23
  st.text_area("Model:", model_response, height=100)
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Load DialoGPT model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
8
 
9
  # Streamlit app header
10
  st.set_page_config(page_title="Conversational Model Demo", page_icon="🤖")
11
  st.header("Conversational Model Demo")
12
 
13
+ # Initialize chat history
14
+ chat_history_ids = None
15
+
16
  # Input for user message
17
  user_message = st.text_input("You:", "")
18
 
19
  if st.button("Send"):
20
+ # Encode the new user input, add the eos_token and return a tensor in PyTorch
21
+ new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt')
22
+
23
+ # Append the new user input tokens to the chat history
24
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
25
+
26
+ # Generate a response while limiting the total chat history to 1000 tokens
27
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
28
 
29
+ # Pretty print last output tokens from the bot
30
+ model_response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
31
 
32
  # Display the model's response
33
  st.text_area("Model:", model_response, height=100)