slush0 commited on
Commit
8d5fa1d
·
1 Parent(s): e0349b7

Improved switching between models.

Browse files
Files changed (1) hide show
  1. chat.py +9 -0
chat.py CHANGED
@@ -5,6 +5,7 @@ import traceback
5
  import gradio as gr
6
  import chat_client
7
  import json
 
8
 
9
  CHAT_URL='ws://chat.petals.ml/api/v2/generate'
10
  #CHAT_URL='ws://localhost:8000/api/v2/generate'
@@ -38,6 +39,12 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
38
  print('prompt', prompt)
39
  eos = "</s>\n" if "bloomz" in model else "\n\n"
40
 
 
 
 
 
 
 
41
  if state['model'] != model or \
42
  state['client'] == None or state['client'].is_session() == False:
43
 
@@ -48,6 +55,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
48
  except Exception:
49
  print(traceback.format_exc())
50
  raise gr.Error(traceback.format_exc())
 
51
  else:
52
  context = ''
53
 
@@ -63,6 +71,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
63
  prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
64
  else:
65
  context = context.replace("</s>", eos)
 
66
  prompt2 = prompt.replace("</s>", eos) + "\n\n"
67
 
68
  prompt2 = f"{context}Human: {prompt2}AI:"
 
5
  import gradio as gr
6
  import chat_client
7
  import json
8
+ import re
9
 
10
  CHAT_URL='ws://chat.petals.ml/api/v2/generate'
11
  #CHAT_URL='ws://localhost:8000/api/v2/generate'
 
39
  print('prompt', prompt)
40
  eos = "</s>\n" if "bloomz" in model else "\n\n"
41
 
42
+ if state['model'] != model and output:
43
+ # If the connection is resumed, output is truncated in generate().
44
+ # So this happen when user change model.
45
+ context = output
46
+ output = ''
47
+
48
  if state['model'] != model or \
49
  state['client'] == None or state['client'].is_session() == False:
50
 
 
55
  except Exception:
56
  print(traceback.format_exc())
57
  raise gr.Error(traceback.format_exc())
58
+
59
  else:
60
  context = ''
61
 
 
71
  prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
72
  else:
73
  context = context.replace("</s>", eos)
74
+ context = re.sub(r"\n\n+", "\n\n", context)
75
  prompt2 = prompt.replace("</s>", eos) + "\n\n"
76
 
77
  prompt2 = f"{context}Human: {prompt2}AI:"