TenzinGayche commited on
Commit
a079f79
·
verified ·
1 Parent(s): eae63f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -37,6 +37,7 @@ def generate(
37
  # Clear the stop event before starting a new generation
38
  stop_event.clear()
39
 
 
40
  conversation = chat_history.copy()
41
  conversation.append({"role": "user", "content": message})
42
 
@@ -46,6 +47,7 @@ def generate(
46
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
47
  input_ids = input_ids.to(model.device)
48
 
 
49
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
50
  generate_kwargs = dict(
51
  {"input_ids": input_ids},
@@ -53,6 +55,7 @@ def generate(
53
  max_new_tokens=max_new_tokens,
54
  )
55
 
 
56
  t = Thread(target=model.generate, kwargs=generate_kwargs)
57
  t.start()
58
 
@@ -63,6 +66,11 @@ def generate(
63
  outputs.append(text)
64
  yield "".join(outputs)
65
 
 
 
 
 
 
66
  # Define a function to stop the generation
67
  def stop_generation():
68
  stop_event.set()
 
37
  # Clear the stop event before starting a new generation
38
  stop_event.clear()
39
 
40
+ # Append the user's message to the conversation history
41
  conversation = chat_history.copy()
42
  conversation.append({"role": "user", "content": message})
43
 
 
47
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
  input_ids = input_ids.to(model.device)
49
 
50
+ # Create a streamer to get the generated response
51
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
52
  generate_kwargs = dict(
53
  {"input_ids": input_ids},
 
55
  max_new_tokens=max_new_tokens,
56
  )
57
 
58
+ # Run generation in a background thread
59
  t = Thread(target=model.generate, kwargs=generate_kwargs)
60
  t.start()
61
 
 
66
  outputs.append(text)
67
  yield "".join(outputs)
68
 
69
+ # After generation, append the assistant's response to the chat history
70
+ assistant_response = "".join(outputs)
71
+ chat_history.append({"role": "assistant", "content": assistant_response})
72
+
73
+
74
  # Define a function to stop the generation
75
  def stop_generation():
76
  stop_event.set()