dar-tau commited on
Commit
4362d26
·
verified ·
1 Parent(s): 1fc5a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -2,17 +2,17 @@ import os
2
  import gradio as gr
3
  import spaces
4
  import torch
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
6
 
7
-
8
  model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
9
  token = os.environ['hf_token']
10
-
11
  pipe = pipeline("text-generation", model=model_name, device="cuda")
12
-
13
-
14
  generate_kwargs = {'max_new_tokens': 20}
15
 
 
16
  system_prompt = '''You are given a partial input text for a chat interface. Propose auto-completion to the text. You have several roles:
17
  - Fight under-specification.
18
  - Complete text to save the user time.
@@ -37,13 +37,18 @@ Assistant: "girlfriend;mother;father;friend"
37
  You will now get a blank message from the user and then after your answer, the user will give you the text to complete.
38
  '''
39
 
 
40
  start_messages = [
41
  {'role': 'system', 'content': system_prompt},
42
  {'role': 'user', 'content': ' '},
43
  {'role': 'assistant', 'content': '<Waiting for text>'}
44
  ]
45
 
46
- torch.set_grad_enabled(False)
 
 
 
 
47
 
48
 
49
  def past_kv_to_device(past_kv, device):
@@ -60,7 +65,7 @@ def get_past_key_values(system_prompt):
60
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
61
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
62
  past_key_values = model(tokenized.to(model.device)).past_key_values
63
- return past_kv_to_device(past_key_values, 'cpu')
64
 
65
  @spaces.GPU
66
  def generate(text, past_key_values):
@@ -76,6 +81,6 @@ def generate(text, past_key_values):
76
 
77
  if __name__ == "__main__":
78
  past_key_values = get_past_key_values(system_prompt)
79
- demo = gr.Interface(partial(generate, past_key_values=past_key_values),
80
  inputs="textbox", outputs="textbox")
81
  demo.launch()
 
2
  import gradio as gr
3
  import spaces
4
  import torch
5
+ from typing import Optional
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ from dataclasses import dataclass
8
 
9
+ torch.set_grad_enabled(False)
10
  model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
11
  token = os.environ['hf_token']
 
12
  pipe = pipeline("text-generation", model=model_name, device="cuda")
 
 
13
  generate_kwargs = {'max_new_tokens': 20}
14
 
15
+
16
  system_prompt = '''You are given a partial input text for a chat interface. Propose auto-completion to the text. You have several roles:
17
  - Fight under-specification.
18
  - Complete text to save the user time.
 
37
  You will now get a blank message from the user and then after your answer, the user will give you the text to complete.
38
  '''
39
 
40
+
41
  start_messages = [
42
  {'role': 'system', 'content': system_prompt},
43
  {'role': 'user', 'content': ' '},
44
  {'role': 'assistant', 'content': '<Waiting for text>'}
45
  ]
46
 
47
+
48
+ # functions
49
+ @dataclass
50
+ class PastKV:
51
+ past_key_values: Optional[torch.Tensor] = None
52
 
53
 
54
  def past_kv_to_device(past_kv, device):
 
65
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
66
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
67
  past_key_values = model(tokenized.to(model.device)).past_key_values
68
+ return PastKV(past_kv_to_device(past_key_values, 'cpu'))
69
 
70
  @spaces.GPU
71
  def generate(text, past_key_values):
 
81
 
82
  if __name__ == "__main__":
83
  past_key_values = get_past_key_values(system_prompt)
84
+ demo = gr.Interface(partial(generate, past_key_values=past_key_values.past_key_values),
85
  inputs="textbox", outputs="textbox")
86
  demo.launch()