BartPoint commited on
Commit
9e23608
·
1 Parent(s): b1b0ba3

Update app_multi.py

Browse files
Files changed (1) hide show
  1. app_multi.py +3 -1
app_multi.py CHANGED
@@ -108,7 +108,9 @@ for model_name in multi_cfg.get('models'):
108
  new_shape = net_g.enc_p.emb_phone.weight.shape
109
  if old_shape != new_shape:
110
  print(f"Resizing enc_p.emb_phone.weight: {old_shape} -> {new_shape}")
111
- cpt['weight']['enc_p.emb_phone.weight'] = cpt['weight']['enc_p.emb_phone.weight'][:new_shape[0], :new_shape[1]]
 
 
112
 
113
  del net_g.enc_q
114
 
 
108
  new_shape = net_g.enc_p.emb_phone.weight.shape
109
  if old_shape != new_shape:
110
  print(f"Resizing enc_p.emb_phone.weight: {old_shape} -> {new_shape}")
111
+ weight = cpt['weight']['enc_p.emb_phone.weight']
112
+ resized_weight = weight[:, :new_shape[1]].resize_(new_shape)
113
+ cpt['weight']['enc_p.emb_phone.weight'] = resized_weight
114
 
115
  del net_g.enc_q
116