SauravMaheshkar commited on
Commit
87d2fe8
·
unverified ·
1 Parent(s): 5e9da0f

feat: decode model output

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -17,6 +17,9 @@ pipe = pipeline(
17
  torch_dtype=torch.bfloat16,
18
  device_map="auto",
19
  )
 
 
 
20
 
21
 
22
  class ChatState:
@@ -68,6 +71,7 @@ def invoke(history: HistoryType):
68
  response = pipe(input_text, do_sample=True, top_p=0.95, max_new_tokens=1024)[0][
69
  "generated_text"
70
  ]
 
71
  return response
72
 
73
 
 
17
  torch_dtype=torch.bfloat16,
18
  device_map="auto",
19
  )
20
+ pipe.model.generate = torch.compile(
21
+ pipe.model.generate, mode="reduce-overhead", fullgraph=True
22
+ )
23
 
24
 
25
  class ChatState:
 
71
  response = pipe(input_text, do_sample=True, top_p=0.95, max_new_tokens=1024)[0][
72
  "generated_text"
73
  ]
74
+ response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
75
  return response
76
 
77