修复使用gpu时报错的bug
Browse files- csc_model.py +3 -0
csc_model.py
CHANGED
@@ -302,6 +302,9 @@ class ReaLiseForCSC(BertPreTrainedModel):
|
|
302 |
bert_hiddens = self.bert(input_ids, attention_mask=attention_mask)[0]
|
303 |
|
304 |
pho_embeddings = self.pho_embeddings(pho_idx)
|
|
|
|
|
|
|
305 |
pho_embeddings = torch.nn.utils.rnn.pack_padded_sequence(
|
306 |
input=pho_embeddings,
|
307 |
lengths=pho_lens,
|
|
|
302 |
bert_hiddens = self.bert(input_ids, attention_mask=attention_mask)[0]
|
303 |
|
304 |
pho_embeddings = self.pho_embeddings(pho_idx)
|
305 |
+
|
306 |
+
if torch.is_tensor(pho_lens):
|
307 |
+
pho_lens = pho_lens.tolist()
|
308 |
pho_embeddings = torch.nn.utils.rnn.pack_padded_sequence(
|
309 |
input=pho_embeddings,
|
310 |
lengths=pho_lens,
|