Spanicin commited on
Commit
15bb65a
·
verified ·
1 Parent(s): 4d380a5

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +18 -18
src/facerender/animate.py CHANGED
@@ -112,25 +112,25 @@ class AnimateFromCoeff():
112
  he_estimator.load_state_dict(he_state_dict)
113
  if discriminator is not None:
114
  try:
115
- discriminator =adjust_state_dict(checkpoint['discriminator'],discriminator)
116
- discriminator.load_state_dict(discriminator)
117
  except:
118
  print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
119
  if optimizer_generator is not None:
120
- optimizer_generator =adjust_state_dict(checkpoint['optimizer_generator'],optimizer_generator)
121
- optimizer_generator.load_state_dict(optimizer_generator)
122
  if optimizer_discriminator is not None:
123
  try:
124
- optimizer_discriminator = adjust_state_dict(checkpoint['optimizer_discriminator'],optimizer_discriminator)
125
- optimizer_discriminator.load_state_dict(optimizer_discriminator)
126
  except RuntimeError as e:
127
  print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
128
  if optimizer_kp_detector is not None:
129
- optimizer_kp_detector = adjust_state_dict(checkpoint['optimizer_kp_detector'],optimizer_kp_detector)
130
- optimizer_kp_detector.load_state_dict(optimizer_kp_detector)
131
  if optimizer_he_estimator is not None:
132
- optimizer_he_estimator = adjust_state_dict(checkpoint['optimizer_he_estimator'],optimizer_he_estimator)
133
- optimizer_he_estimator.load_state_dict(optimizer_he_estimator)
134
 
135
  return checkpoint['epoch']
136
 
@@ -149,17 +149,17 @@ class AnimateFromCoeff():
149
  return new_state_dict
150
 
151
  if mapping is not None:
152
- mapping = adjust_state_dict(checkpoint['mapping'],mapping)
153
- mapping.load_state_dict(mapping)
154
  if discriminator is not None:
155
- discriminator = adjust_state_dict(checkpoint['discriminator'],discriminator)
156
- discriminator.load_state_dict(discriminator)
157
  if optimizer_mapping is not None:
158
- optimizer_mapping = adjust_state_dict(checkpoint['optimizer_mapping'],optimizer_mapping)
159
- optimizer_mapping.load_state_dict(optimizer_mapping)
160
  if optimizer_discriminator is not None:
161
- optimizer_discriminator= adjust_state_dict(checkpoint['optimizer_discriminator'],optimizer_discriminator)
162
- optimizer_discriminator.load_state_dict(optimizer_discriminator)
163
 
164
  return checkpoint['epoch']
165
 
 
112
  he_estimator.load_state_dict(he_state_dict)
113
  if discriminator is not None:
114
  try:
115
+ discriminator_dict =adjust_state_dict(checkpoint['discriminator'],discriminator)
116
+ discriminator.load_state_dict(discriminator_dict)
117
  except:
118
  print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
119
  if optimizer_generator is not None:
120
+ optimizer_generator_dict =adjust_state_dict(checkpoint['optimizer_generator'],optimizer_generator)
121
+ optimizer_generator.load_state_dict(optimizer_generator_dict)
122
  if optimizer_discriminator is not None:
123
  try:
124
+ optimizer_discriminator_dict = adjust_state_dict(checkpoint['optimizer_discriminator'],optimizer_discriminator)
125
+ optimizer_discriminator.load_state_dict(optimizer_discriminator_dict)
126
  except RuntimeError as e:
127
  print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
128
  if optimizer_kp_detector is not None:
129
+ optimizer_kp_detector_dict = adjust_state_dict(checkpoint['optimizer_kp_detector'],optimizer_kp_detector)
130
+ optimizer_kp_detector.load_state_dict(optimizer_kp_detector_dict)
131
  if optimizer_he_estimator is not None:
132
+ optimizer_he_estimator_dict = adjust_state_dict(checkpoint['optimizer_he_estimator'],optimizer_he_estimator)
133
+ optimizer_he_estimator.load_state_dict(optimizer_he_estimator_dict)
134
 
135
  return checkpoint['epoch']
136
 
 
149
  return new_state_dict
150
 
151
  if mapping is not None:
152
+ mapping_dict = adjust_state_dict(checkpoint['mapping'],mapping)
153
+ mapping.load_state_dict(mapping_dict)
154
  if discriminator is not None:
155
+ discriminator_dict = adjust_state_dict(checkpoint['discriminator'],discriminator)
156
+ discriminator.load_state_dict(discriminator_dict)
157
  if optimizer_mapping is not None:
158
+ optimizer_mapping_dict = adjust_state_dict(checkpoint['optimizer_mapping'],optimizer_mapping)
159
+ optimizer_mapping.load_state_dict(optimizer_mapping_dict)
160
  if optimizer_discriminator is not None:
161
+ optimizer_discriminator_dict = adjust_state_dict(checkpoint['optimizer_discriminator'],optimizer_discriminator)
162
+ optimizer_discriminator.load_state_dict(optimizer_discriminator_dict)
163
 
164
  return checkpoint['epoch']
165