Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,21 @@ if tokenizer.pad_token is None:
|
|
24 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
gpt2.to(device)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def simple_position_config(model_type, component, layer):
|
28 |
config = IntervenableConfig(
|
29 |
model_type=model_type,
|
@@ -118,13 +133,13 @@ async def process_inputs(base_sentence: str, rival_sentence: str):
|
|
118 |
base_sentence = pn.widgets.TextInput(
|
119 |
name="Base Sentence",
|
120 |
placeholder="Enter the base sentence",
|
121 |
-
value="
|
122 |
)
|
123 |
|
124 |
rival_sentence = pn.widgets.TextInput(
|
125 |
name="Rival Sentence",
|
126 |
placeholder="Enter the rival sentence",
|
127 |
-
value="
|
128 |
)
|
129 |
|
130 |
input_widgets = pn.Column(
|
|
|
24 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
gpt2.to(device)
|
26 |
|
27 |
+
# Monkey patch the embed_to_distrib function to use the correct attribute
|
28 |
+
def patched_embed_to_distrib(model, embed, log=True, logits=True):
|
29 |
+
if "gpt2" in model.config.architectures[0].lower():
|
30 |
+
with torch.inference_mode():
|
31 |
+
vocab = torch.matmul(embed, model.transformer.wte.weight.t())
|
32 |
+
if logits:
|
33 |
+
return vocab
|
34 |
+
if log:
|
35 |
+
return torch.log_softmax(vocab, dim=-1)
|
36 |
+
return torch.softmax(vocab, dim=-1)
|
37 |
+
else:
|
38 |
+
return pv.embed_to_distrib(model, embed, log, logits)
|
39 |
+
|
40 |
+
pv.embed_to_distrib = patched_embed_to_distrib
|
41 |
+
|
42 |
def simple_position_config(model_type, component, layer):
|
43 |
config = IntervenableConfig(
|
44 |
model_type=model_type,
|
|
|
133 |
base_sentence = pn.widgets.TextInput(
|
134 |
name="Base Sentence",
|
135 |
placeholder="Enter the base sentence",
|
136 |
+
value="Jane got some weird looks because she wore sunglasses outside at 4 PM.",
|
137 |
)
|
138 |
|
139 |
rival_sentence = pn.widgets.TextInput(
|
140 |
name="Rival Sentence",
|
141 |
placeholder="Enter the rival sentence",
|
142 |
+
value="Jane got some weird looks because she wore sunglasses outside at 4 AM.",
|
143 |
)
|
144 |
|
145 |
input_widgets = pn.Column(
|