s-a-malik commited on
Commit
3dc5f5e
·
1 Parent(s): f838d5b
Files changed (1) hide show
  1. app.py +7 -52
app.py CHANGED
@@ -10,6 +10,12 @@ import gradio as gr
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
 
 
 
 
 
 
13
 
14
  MAX_MAX_NEW_TOKENS = 2048
15
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -40,7 +46,6 @@ if torch.cuda.is_available():
40
  tokenizer.use_default_system_prompt = False
41
 
42
  # load the probe data
43
- # TODO compare accuracy and SE probe in different tabs/sections
44
  with open("./model/20240625-131035_demo.pkl", "rb") as f:
45
  probe_data = pkl.load(f)
46
  # take the NQ open one
@@ -52,7 +57,6 @@ if torch.cuda.is_available():
52
  else:
53
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
54
 
55
-
56
  @spaces.GPU
57
  def generate(
58
  message: str,
@@ -62,7 +66,7 @@ def generate(
62
  top_p: float = 0.9,
63
  top_k: int = 50,
64
  repetition_penalty: float = 1.2,
65
- ) -> Iterator[str]:
66
  conversation = []
67
  if system_prompt:
68
  conversation.append({"role": "system", "content": system_prompt})
@@ -74,55 +78,6 @@ def generate(
74
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
75
  input_ids = input_ids.to(model.device)
76
 
77
-
78
- # streamer = CustomStreamer(skip_prompt=True, timeout=10.0)
79
-
80
- # def generate_with_states():
81
- # with torch.no_grad():
82
- # model.generate(
83
- # input_ids=input_ids,
84
- # max_new_tokens=max_new_tokens,
85
- # do_sample=True,
86
- # top_p=top_p,
87
- # top_k=top_k,
88
- # temperature=temperature,
89
- # repetition_penalty=repetition_penalty,
90
- # output_hidden_states=True,
91
- # return_dict_in_generate=True,
92
- # streamer=streamer
93
- # )
94
-
95
- # thread = Thread(target=generate_with_states)
96
- # thread.start()
97
-
98
- # se_highlighted_text = ""
99
- # acc_highlighted_text = ""
100
- # for token_id in streamer:
101
- # print
102
- # hidden_states = streamer.hidden_states_queue.get()
103
- # if hidden_states is streamer.stop_signal:
104
- # break
105
-
106
- # # Semantic Uncertainty Probe
107
- # token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
108
- # se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
109
- # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
110
-
111
- # # Accuracy Probe
112
- # acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
113
- # acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
114
-
115
- # # decode latest token
116
- # new_text = tokenizer.decode(token_id)
117
- # print(new_text, se_probe_pred, acc_probe_pred)
118
-
119
- # se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
120
- # acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
121
- # se_highlighted_text += f" {se_new_highlighted_text}"
122
- # acc_highlighted_text += f" {acc_new_highlighted_text}"
123
-
124
- # yield se_highlighted_text, acc_highlighted_text
125
-
126
  #### Generate without threading
127
  generation_kwargs = dict(
128
  input_ids=input_ids,
 
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
+ # TODO Sentence level highlighting instead (prediction after every word is not what it was trained on). Also solves token-level highlighting issues.
14
+ # TODO log prob output scaling highlighting instead?
15
+ # TODO make it look nicer
16
+ # TODO streaming output (need custom generation function because of probes)
17
+ # TODO add options to switch between models, SLT/TBG, layers?
18
+ # TODO full semantic entropy calculation
19
 
20
  MAX_MAX_NEW_TOKENS = 2048
21
  DEFAULT_MAX_NEW_TOKENS = 1024
 
46
  tokenizer.use_default_system_prompt = False
47
 
48
  # load the probe data
 
49
  with open("./model/20240625-131035_demo.pkl", "rb") as f:
50
  probe_data = pkl.load(f)
51
  # take the NQ open one
 
57
  else:
58
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
59
 
 
60
  @spaces.GPU
61
  def generate(
62
  message: str,
 
66
  top_p: float = 0.9,
67
  top_k: int = 50,
68
  repetition_penalty: float = 1.2,
69
+ ) -> Tuple[str, str]:
70
  conversation = []
71
  if system_prompt:
72
  conversation.append({"role": "system", "content": system_prompt})
 
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
  input_ids = input_ids.to(model.device)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  #### Generate without threading
82
  generation_kwargs = dict(
83
  input_ids=input_ids,