Spaces:
Build error
Build error
File size: 1,256 Bytes
546a9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
def set_dropout_prob(p):
global dropout_p
dropout_p = p
def set_seq_dropout(option): # option = True or False
global do_seq_dropout
do_seq_dropout = option
def seq_dropout(x, p=0, training=False):
"""
x: batch * len * input_size
"""
if training == False or p == 0:
return x
dropout_mask = Variable(
1.0
/ (1 - p)
* torch.bernoulli((1 - p) * (x.data.new(x.size(0), x.size(2)).zero_() + 1)),
requires_grad=False,
)
return dropout_mask.unsqueeze(1).expand_as(x) * x
def dropout(x, p=0, training=False):
"""
x: (batch * len * input_size) or (any other shape)
"""
if do_seq_dropout and len(x.size()) == 3: # if x is (batch * len * input_size)
return seq_dropout(x, p=p, training=training)
else:
return F.dropout(x, p=p, training=training)
|