mimbres commited on
Commit
24273ee
1 Parent(s): 8f5b8ef

Update model_helper.py

Browse files
Files changed (1) hide show
  1. model_helper.py +3 -3
model_helper.py CHANGED
@@ -22,7 +22,7 @@ from model.ymt3 import YourMT3
22
 
23
 
24
 
25
- def load_model_checkpoint(args=None):
26
  parser = argparse.ArgumentParser(description="YourMT3")
27
  # General
28
  parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
@@ -104,7 +104,7 @@ def load_model_checkpoint(args=None):
104
  print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
105
 
106
  # Use GPU if available
107
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
 
109
  # Model
110
  model = YourMT3(
@@ -120,7 +120,7 @@ def load_model_checkpoint(args=None):
120
  state_dict = checkpoint['state_dict']
121
  new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
122
  model.load_state_dict(new_state_dict, strict=False)
123
- return model.eval()
124
 
125
 
126
  def transcribe(model, audio_info):
 
22
 
23
 
24
 
25
+ def load_model_checkpoint(args=None, device='cpu'):
26
  parser = argparse.ArgumentParser(description="YourMT3")
27
  # General
28
  parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
 
104
  print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
105
 
106
  # Use GPU if available
107
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
 
109
  # Model
110
  model = YourMT3(
 
120
  state_dict = checkpoint['state_dict']
121
  new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
122
  model.load_state_dict(new_state_dict, strict=False)
123
+ return model.eval() # load checkpoint on cpu first
124
 
125
 
126
  def transcribe(model, audio_info):