danseith commited on
Commit
802cf91
1 Parent(s): d5cc744

Added multinomial prediction sampling.

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -20,7 +20,7 @@ def add_mask(text, size=1):
20
 
21
 
22
  class TempScalePipe(FillMaskPipeline):
23
- def postprocess(self, model_outputs, top_k=5, target_ids=None):
24
  # Cap top_k if there are targets
25
  if target_ids is not None and target_ids.shape[0] < top_k:
26
  top_k = target_ids.shape[0]
@@ -30,8 +30,10 @@ class TempScalePipe(FillMaskPipeline):
30
  masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
31
  # Fill mask pipeline supports only one ${mask_token} per sample
32
 
33
- logits = outputs[0, masked_index, :] / 1e3
34
  probs = logits.softmax(dim=-1)
 
 
35
  if target_ids is not None:
36
  probs = probs[..., target_ids]
37
 
@@ -98,7 +100,8 @@ from transformers import pipeline, Pipeline
98
  # textbox = gr.Textbox(label="Type language here", lines=5)
99
  #
100
  demo = gr.Interface(
101
- fn=unmask,
 
102
  inputs=textbox,
103
  outputs="label",
104
  examples=[example],
 
20
 
21
 
22
  class TempScalePipe(FillMaskPipeline):
23
+ def postprocess(self, model_outputs, top_k=3, target_ids=None):
24
  # Cap top_k if there are targets
25
  if target_ids is not None and target_ids.shape[0] < top_k:
26
  top_k = target_ids.shape[0]
 
30
  masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
31
  # Fill mask pipeline supports only one ${mask_token} per sample
32
 
33
+ logits = outputs[0, masked_index, :] / 1e1
34
  probs = logits.softmax(dim=-1)
35
+ indices = torch.multinomial(probs, num_samples=3)
36
+ probs = probs[indices]
37
  if target_ids is not None:
38
  probs = probs[..., target_ids]
39
 
 
100
  # textbox = gr.Textbox(label="Type language here", lines=5)
101
  #
102
  demo = gr.Interface(
103
+ unmask,
104
+ [gr.Slider(minimum=0, maximum=15, value=8, step=1, label="Guidance scale")],
105
  inputs=textbox,
106
  outputs="label",
107
  examples=[example],