frankaging commited on
Commit
1644e6b
·
1 Parent(s): bddba98
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -32,6 +32,8 @@ class Steer(pv.SourcelessIntervention):
32
  self.embed_dim, kwargs["latent_dim"], bias=False
33
  )
34
  def forward(self, base, source=None, subspaces=None):
 
 
35
  steering_vec = torch.tensor(subspaces["mag"]) * \
36
  self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
37
  return base + steering_vec
@@ -97,6 +99,7 @@ def generate(
97
  yield "[Truncated prior text]\n"
98
 
99
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
100
  generate_kwargs = {
101
  "base": {"input_ids": input_ids},
102
  "unit_locations": None,
@@ -107,7 +110,7 @@ def generate(
107
  "idx": int(subspaces_list[0]["idx"]),
108
  "mag": int(subspaces_list[0]["internal_mag"])
109
  }
110
- ] if subspaces_list else [],
111
  "streamer": streamer,
112
  "do_sample": True
113
  }
 
32
  self.embed_dim, kwargs["latent_dim"], bias=False
33
  )
34
  def forward(self, base, source=None, subspaces=None):
35
+ if subspaces is None:
36
+ return base
37
  steering_vec = torch.tensor(subspaces["mag"]) * \
38
  self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
39
  return base + steering_vec
 
99
  yield "[Truncated prior text]\n"
100
 
101
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
102
+ print(subspaces_list)
103
  generate_kwargs = {
104
  "base": {"input_ids": input_ids},
105
  "unit_locations": None,
 
110
  "idx": int(subspaces_list[0]["idx"]),
111
  "mag": int(subspaces_list[0]["internal_mag"])
112
  }
113
+ ] if subspaces_list else None,
114
  "streamer": streamer,
115
  "do_sample": True
116
  }