dar-tau commited on
Commit
14c86d4
·
verified ·
1 Parent(s): aeef19d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -47,7 +47,7 @@ Assistant: girlfriend;mother;father;friend
47
  # setup
48
  torch.set_grad_enabled(False)
49
  model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
50
- pipe = pipeline("text-generation", model=model_name, device='cpu')
51
  generate_kwargs = {'max_new_tokens': 20}
52
 
53
  def past_kv_to_device(past_kv, device, dtype):
@@ -57,7 +57,7 @@ def detach_past_kv(past_kv):
57
  return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
58
 
59
 
60
- # @spaces.GPU
61
  def set_past_key_values():
62
  model, tokenizer = pipe.model, pipe.tokenizer
63
  tokenized = tokenizer.encode(
@@ -90,7 +90,6 @@ def generate(text, past_key_values):
90
  if __name__ == "__main__":
91
  with torch.no_grad():
92
  past_key_values = set_past_key_values()
93
- # pipe.model = pipe.model.cpu()
94
  demo = gr.Interface(
95
  partial(generate, past_key_values=past_key_values),
96
  inputs="textbox", outputs="textbox"
 
47
  # setup
48
  torch.set_grad_enabled(False)
49
  model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
50
+ pipe = pipeline("text-generation", model=model_name, device='cuda')
51
  generate_kwargs = {'max_new_tokens': 20}
52
 
53
  def past_kv_to_device(past_kv, device, dtype):
 
57
  return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
58
 
59
 
60
+ @spaces.GPU
61
  def set_past_key_values():
62
  model, tokenizer = pipe.model, pipe.tokenizer
63
  tokenized = tokenizer.encode(
 
90
  if __name__ == "__main__":
91
  with torch.no_grad():
92
  past_key_values = set_past_key_values()
 
93
  demo = gr.Interface(
94
  partial(generate, past_key_values=past_key_values),
95
  inputs="textbox", outputs="textbox"