adamcasson commited on
Commit
8731543
·
1 Parent(s): 6da8e24

sample from normal dist for R and X denoisers

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -32,7 +32,7 @@ def mask_spans(
32
  # n = 1
33
  mu = max(1, int(len(tokens) * r))
34
  start = max(
35
- 0, len(tokens) - random.randint(1, int(2 * mu))
36
  ) # max to handle start < 0
37
  encoder_inputs += tokens[:start] + [sentinel_id]
38
  encoder_mask += ([1] * len(tokens[:start])) + [0]
@@ -47,8 +47,10 @@ def mask_spans(
47
  start = 0
48
  end = 0
49
  while start < len(tokens):
50
- # uniform random span length
51
- length = random.randint(1, int(2 * mu))
 
 
52
  end = min(start + length, len(tokens))
53
 
54
  # randomly decide if span should be masked
 
32
  # n = 1
33
  mu = max(1, int(len(tokens) * r))
34
  start = max(
35
+ 0, len(tokens) - random.randint(1, int(2 * mu)) # sample from uniform distribution for S denoisers
36
  ) # max to handle start < 0
37
  encoder_inputs += tokens[:start] + [sentinel_id]
38
  encoder_mask += ([1] * len(tokens[:start])) + [0]
 
47
  start = 0
48
  end = 0
49
  while start < len(tokens):
50
+ # for R and X denoisers, sample random span length from normal distribution bounded from 1 to 2 * mu.
51
+ # std of 0.25 * mu is arbitrary, not specified in paper but makes a sane looking distribution
52
+ # at extreme ends of span length means (from 3 to 64).
53
+ length = max(1, min(int(2 * mu), int(np.round(np.random.normal(mu, 0.25 * mu)))))
54
  end = min(start + length, len(tokens))
55
 
56
  # randomly decide if span should be masked