dar-tau commited on
Commit
7eb4c2f
·
verified ·
1 Parent(s): 2b202a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -57,12 +57,14 @@ start_messages = [
57
 
58
  # past_key_values = PastKV()
59
 
 
60
  def past_kv_to_device(past_kv, device):
61
- return tuple((k.to(device).detach(), v.to(device).detach()) for k, v in past_kv)
62
 
63
  def detach_past_kv(past_kv):
64
  return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
65
 
 
66
  @spaces.GPU
67
  def set_past_key_values():
68
  model, tokenizer = pipe.model, pipe.tokenizer
@@ -91,6 +93,6 @@ if __name__ == "__main__":
91
  with torch.no_grad():
92
  past_key_values = set_past_key_values()
93
  print(f'{past_key_values=}')
94
- demo = gr.Interface(generate,
95
- inputs=["textbox", gr.State(past_key_values)], outputs="textbox")
96
  demo.launch()
 
57
 
58
  # past_key_values = PastKV()
59
 
60
+
61
  def past_kv_to_device(past_kv, device):
62
+ return tuple((torch.tensor(k).to(device).detach(), torch.tensor(v).to(device).detach()) for k, v in past_kv)
63
 
64
  def detach_past_kv(past_kv):
65
  return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
66
 
67
+
68
  @spaces.GPU
69
  def set_past_key_values():
70
  model, tokenizer = pipe.model, pipe.tokenizer
 
93
  with torch.no_grad():
94
  past_key_values = set_past_key_values()
95
  print(f'{past_key_values=}')
96
+ demo = gr.Interface(partial(generate, past_key_values=past_kv_to_device(past_key_values, pipe.model.device)),
97
+ inputs="textbox", outputs="textbox")
98
  demo.launch()