Spaces:
Build error
Build error
danseith
commited on
Commit
•
802cf91
1
Parent(s):
d5cc744
Added multinomial prediction sampling.
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ def add_mask(text, size=1):
|
|
20 |
|
21 |
|
22 |
class TempScalePipe(FillMaskPipeline):
|
23 |
-
def postprocess(self, model_outputs, top_k=
|
24 |
# Cap top_k if there are targets
|
25 |
if target_ids is not None and target_ids.shape[0] < top_k:
|
26 |
top_k = target_ids.shape[0]
|
@@ -30,8 +30,10 @@ class TempScalePipe(FillMaskPipeline):
|
|
30 |
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
|
31 |
# Fill mask pipeline supports only one ${mask_token} per sample
|
32 |
|
33 |
-
logits = outputs[0, masked_index, :] /
|
34 |
probs = logits.softmax(dim=-1)
|
|
|
|
|
35 |
if target_ids is not None:
|
36 |
probs = probs[..., target_ids]
|
37 |
|
@@ -98,7 +100,8 @@ from transformers import pipeline, Pipeline
|
|
98 |
# textbox = gr.Textbox(label="Type language here", lines=5)
|
99 |
#
|
100 |
demo = gr.Interface(
|
101 |
-
|
|
|
102 |
inputs=textbox,
|
103 |
outputs="label",
|
104 |
examples=[example],
|
|
|
20 |
|
21 |
|
22 |
class TempScalePipe(FillMaskPipeline):
|
23 |
+
def postprocess(self, model_outputs, top_k=3, target_ids=None):
|
24 |
# Cap top_k if there are targets
|
25 |
if target_ids is not None and target_ids.shape[0] < top_k:
|
26 |
top_k = target_ids.shape[0]
|
|
|
30 |
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
|
31 |
# Fill mask pipeline supports only one ${mask_token} per sample
|
32 |
|
33 |
+
logits = outputs[0, masked_index, :] / 1e1
|
34 |
probs = logits.softmax(dim=-1)
|
35 |
+
indices = torch.multinomial(probs, num_samples=3)
|
36 |
+
probs = probs[indices]
|
37 |
if target_ids is not None:
|
38 |
probs = probs[..., target_ids]
|
39 |
|
|
|
100 |
# textbox = gr.Textbox(label="Type language here", lines=5)
|
101 |
#
|
102 |
demo = gr.Interface(
|
103 |
+
unmask,
|
104 |
+
[gr.Slider(minimum=0, maximum=15, value=8, step=1, label="Guidance scale")],
|
105 |
inputs=textbox,
|
106 |
outputs="label",
|
107 |
examples=[example],
|