GlowCheese commited on
Commit
27888bb
·
1 Parent(s): 8eff58f
Files changed (2) hide show
  1. bert.py +1 -1
  2. constants.py +2 -2
bert.py CHANGED
@@ -59,7 +59,7 @@ class BertSelfAttention(nn.Module):
59
 
60
  def forward(self, hidden_states, attention_mask):
61
  """
62
- hidden_states: [bs, seq_len, hidden_state]
63
  attention_mask: [bs, 1, 1, seq_len]
64
  output: [bs, seq_len, hidden_state]
65
  """
 
59
 
60
  def forward(self, hidden_states, attention_mask):
61
  """
62
+ hidden_states: [bs, seq_len, hidden_size]
63
  attention_mask: [bs, 1, 1, seq_len]
64
  output: [bs, seq_len, hidden_state]
65
  """
constants.py CHANGED
@@ -28,8 +28,8 @@ STSB_DEV = os.path.join(DATA_DIR, 'stsb-dev.parquet')
28
  # Training-specific constants
29
  SEED=11711
30
  NUM_CPU_CORES=4
31
- EPOCHS=1
32
- USE_GPU=False
33
  BATCH_SIZE_CSE=8
34
  BATCH_SIZE_SST=64
35
  BATCH_SIZE_CFIMDB=8
 
28
  # Training-specific constants
29
  SEED=11711
30
  NUM_CPU_CORES=4
31
+ EPOCHS=10
32
+ USE_GPU=True
33
  BATCH_SIZE_CSE=8
34
  BATCH_SIZE_SST=64
35
  BATCH_SIZE_CFIMDB=8