atharvprajod commited on
Commit
55d0647
·
verified ·
1 Parent(s): 3ce1c74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -24,6 +24,21 @@ if tokenizer.pad_token is None:
24
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
  gpt2.to(device)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def simple_position_config(model_type, component, layer):
28
  config = IntervenableConfig(
29
  model_type=model_type,
@@ -118,13 +133,13 @@ async def process_inputs(base_sentence: str, rival_sentence: str):
118
  base_sentence = pn.widgets.TextInput(
119
  name="Base Sentence",
120
  placeholder="Enter the base sentence",
121
- value="The cat is on the mat.",
122
  )
123
 
124
  rival_sentence = pn.widgets.TextInput(
125
  name="Rival Sentence",
126
  placeholder="Enter the rival sentence",
127
- value="The dog is on the mat.",
128
  )
129
 
130
  input_widgets = pn.Column(
 
24
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
  gpt2.to(device)
26
 
27
+ # Monkey patch the embed_to_distrib function to use the correct attribute
28
+ def patched_embed_to_distrib(model, embed, log=True, logits=True):
29
+ if "gpt2" in model.config.architectures[0].lower():
30
+ with torch.inference_mode():
31
+ vocab = torch.matmul(embed, model.transformer.wte.weight.t())
32
+ if logits:
33
+ return vocab
34
+ if log:
35
+ return torch.log_softmax(vocab, dim=-1)
36
+ return torch.softmax(vocab, dim=-1)
37
+ else:
38
+ return pv.embed_to_distrib(model, embed, log, logits)
39
+
40
+ pv.embed_to_distrib = patched_embed_to_distrib
41
+
42
  def simple_position_config(model_type, component, layer):
43
  config = IntervenableConfig(
44
  model_type=model_type,
 
133
  base_sentence = pn.widgets.TextInput(
134
  name="Base Sentence",
135
  placeholder="Enter the base sentence",
136
+ value="Jane got some weird looks because she wore sunglasses outside at 4 PM.",
137
  )
138
 
139
  rival_sentence = pn.widgets.TextInput(
140
  name="Rival Sentence",
141
  placeholder="Enter the rival sentence",
142
+ value="Jane got some weird looks because she wore sunglasses outside at 4 AM.",
143
  )
144
 
145
  input_widgets = pn.Column(