elapt1c commited on
Commit
ecd8dab
·
verified ·
1 Parent(s): 0e9df1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -43
app.py CHANGED
@@ -24,70 +24,68 @@ def create_formatted_history(history_messages: List[dict]) -> List[Tuple[str, st
24
  formatted_history = []
25
  user_messages = []
26
  assistant_messages = []
27
-
28
  for message in history_messages:
29
  if message["role"] == "user":
30
  user_messages.append(message["content"])
31
  elif message["role"] == "assistant":
32
  assistant_messages.append(message["content"])
33
-
34
  if user_messages and assistant_messages:
35
  formatted_history.append(
36
  ("".join(user_messages), "".join(assistant_messages))
37
  )
38
  user_messages = []
39
  assistant_messages = []
40
-
41
  # Append any remaining messages
42
  if user_messages:
43
  formatted_history.append(("".join(user_messages), None))
44
  elif assistant_messages:
45
  formatted_history.append((None, "".join(assistant_messages)))
46
-
47
  return formatted_history
48
 
49
- def chat(message: str, state: List[Dict[str, str]]) -> Generator[Tuple[List[Tuple[str, str]], List[Dict[str, str]]], None, None]:
50
- history_messages = state
51
- if history_messages == None:
52
- history_messages = []
53
- history_messages.append({"role": "system", "content": "A helpful assistant."})
54
-
55
- history_messages.append({"role": "user", "content": message})
56
- history_messages.append({"role": "assistant", "content": ""})
57
-
58
- # Tokenize user input and prepare input tensor
59
- input_ids = tokenizer.encode(message, return_tensors='pt').to(device)
60
-
61
- if input_ids.size(-1) == 0:
62
- response_message = "Input was empty after tokenization. Please try again."
63
- else:
64
- # Generate tokens one by one
65
- with torch.no_grad():
66
- for _ in range(100): # Limit generation to 50 tokens
67
- outputs = model(input_ids)
68
- next_token_logits = outputs.logits[:, -1, :]
69
- next_token_id = torch.argmax(next_token_logits, dim=-1)
70
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
71
-
72
- # Decode and append the latest token
73
- decoded_token = tokenizer.decode(next_token_id)
74
- history_messages[-1]["content"] += decoded_token
75
-
76
- # Stop if the model generates the end-of-sequence token
77
- if next_token_id.item() == tokenizer.eos_token_id:
78
- break
79
-
80
- response_message = history_messages[-1]["content"]
81
-
82
- formatted_history = create_formatted_history(history_messages)
83
- yield formatted_history, history_messages
 
84
 
85
  chatbot = gr.Chatbot(label="Chat")
 
 
86
  iface = gr.Interface(
87
  fn=chat,
88
- inputs=[gr.Textbox(placeholder="Hello! How are you?", label="Message"), "state"],
89
- outputs=[chatbot, "state"],
90
  allow_flagging="never",
91
  )
92
-
93
- iface.queue().launch()
 
24
  formatted_history = []
25
  user_messages = []
26
  assistant_messages = []
 
27
  for message in history_messages:
28
  if message["role"] == "user":
29
  user_messages.append(message["content"])
30
  elif message["role"] == "assistant":
31
  assistant_messages.append(message["content"])
 
32
  if user_messages and assistant_messages:
33
  formatted_history.append(
34
  ("".join(user_messages), "".join(assistant_messages))
35
  )
36
  user_messages = []
37
  assistant_messages = []
 
38
  # Append any remaining messages
39
  if user_messages:
40
  formatted_history.append(("".join(user_messages), None))
41
  elif assistant_messages:
42
  formatted_history.append((None, "".join(assistant_messages)))
 
43
  return formatted_history
44
 
45
+ class ConversationHistory:
46
+ def __init__(self):
47
+ self.messages: List[Tuple[str, str]] = [] # Stores conversation history
48
+
49
+ def append(self, user_message: str, assistant_message: str):
50
+ self.messages.append((user_message, assistant_message))
51
+
52
+ def get_formatted_history(self):
53
+ return create_formatted_history(create_history_messages(self.messages))
54
+
55
+ def chat(message: str, conversation_history: ConversationHistory) -> Generator[Tuple[List[Tuple[str, str]], ConversationHistory], None, None]:
56
+ # Update history
57
+ conversation_history.append(message, "")
58
+
59
+ # Tokenize user input and prepare input tensor
60
+ input_ids = tokenizer.encode(message, return_tensors='pt').to(device)
61
+ if input_ids.size(-1) == 0:
62
+ response_message = "Input was empty after tokenization. Please try again."
63
+ else:
64
+ # Generate tokens one by one
65
+ with torch.no_grad():
66
+ for _ in range(100): # Limit generation to 50 tokens
67
+ outputs = model(input_ids)
68
+ next_token_logits = outputs.logits[:, -1, :]
69
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
70
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
71
+ # Decode and append the latest token
72
+ decoded_token = tokenizer.decode(next_token_id)
73
+ conversation_history.messages[-1] = (conversation_history.messages[-1][0], decoded_token)
74
+ # Stop if the model generates the end-of-sequence token
75
+ if next_token_id.item() == tokenizer.eos_token_id:
76
+ break
77
+ response_message = conversation_history.messages[-1][1]
78
+
79
+ # Yield formatted history and updated conversation history
80
+ yield conversation_history.get_formatted_history(), conversation_history
81
 
82
  chatbot = gr.Chatbot(label="Chat")
83
+ conversation_history = ConversationHistory() # Initialize conversation history
84
+
85
  iface = gr.Interface(
86
  fn=chat,
87
+ inputs=[gr.Textbox(placeholder="Hello! How are you?", label="Message")],
88
+ outputs=[chatbot, conversation_history],
89
  allow_flagging="never",
90
  )
91
+ iface.queue().launch()