Spanicin commited on
Commit
0dcabf1
·
verified ·
1 Parent(s): 923c799

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +55 -30
src/facerender/animate.py CHANGED
@@ -119,45 +119,70 @@ class AnimateFromCoeff():
119
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
120
 
121
  return checkpoint['epoch']
122
-
 
123
  def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
124
 
125
- source_image=x['source_image'].type(torch.FloatTensor)
126
- source_semantics=x['source_semantics'].type(torch.FloatTensor)
127
- target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
128
- source_image=source_image.to(self.device)
129
- source_semantics=source_semantics.to(self.device)
130
- target_semantics=target_semantics.to(self.device)
131
- if 'yaw_c_seq' in x:
132
- yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
133
- yaw_c_seq = x['yaw_c_seq'].to(self.device)
134
- else:
135
- yaw_c_seq = None
136
- if 'pitch_c_seq' in x:
137
- pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
138
- pitch_c_seq = x['pitch_c_seq'].to(self.device)
139
- else:
140
- pitch_c_seq = None
141
- if 'roll_c_seq' in x:
142
- roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
143
- roll_c_seq = x['roll_c_seq'].to(self.device)
144
- else:
145
- roll_c_seq = None
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  frame_num = x['frame_num']
148
 
149
- predictions_video = make_animation(source_image, source_semantics, target_semantics,
 
150
  self.generator, self.kp_extractor, self.he_estimator, self.mapping,
151
  yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
152
 
153
- predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
154
  predictions_video = predictions_video[:frame_num]
155
 
156
- video = []
157
- for idx in range(predictions_video.shape[0]):
158
- image = predictions_video[idx]
159
- image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
160
- video.append(image)
161
  result = img_as_ubyte(video)
162
 
163
  ### the generated video is 256x256, so we keep the aspect ratio,
 
119
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
120
 
121
  return checkpoint['epoch']
122
+
123
+ from torch.cuda.amp import autocast
124
  def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
125
 
126
+ # source_image=x['source_image'].type(torch.FloatTensor)
127
+ # source_semantics=x['source_semantics'].type(torch.FloatTensor)
128
+ # target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
129
+ # source_image=source_image.to(self.device)
130
+ # source_semantics=source_semantics.to(self.device)
131
+ # target_semantics=target_semantics.to(self.device)
132
+ # if 'yaw_c_seq' in x:
133
+ # yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
134
+ # yaw_c_seq = x['yaw_c_seq'].to(self.device)
135
+ # else:
136
+ # yaw_c_seq = None
137
+ # if 'pitch_c_seq' in x:
138
+ # pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
139
+ # pitch_c_seq = x['pitch_c_seq'].to(self.device)
140
+ # else:
141
+ # pitch_c_seq = None
142
+ # if 'roll_c_seq' in x:
143
+ # roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
144
+ # roll_c_seq = x['roll_c_seq'].to(self.device)
145
+ # else:
146
+ # roll_c_seq = None
147
+
148
+ # frame_num = x['frame_num']
149
+
150
+ # predictions_video = make_animation(source_image, source_semantics, target_semantics,
151
+ # self.generator, self.kp_extractor, self.he_estimator, self.mapping,
152
+ # yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
153
+
154
+ # predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
155
+ # predictions_video = predictions_video[:frame_num]
156
+
157
+ # video = []
158
+ # for idx in range(predictions_video.shape[0]):
159
+ # image = predictions_video[idx]
160
+ # image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
161
+ # video.append(image)
162
+ # result = img_as_ubyte(video)
163
+
164
+
165
+
166
+
167
+
168
+ source_image = x['source_image'].to(self.device).type(torch.FloatTensor)
169
+ source_semantics = x['source_semantics'].to(self.device).type(torch.FloatTensor)
170
+ target_semantics = x['target_semantics_list'].to(self.device).type(torch.FloatTensor)
171
+ yaw_c_seq = x.get('yaw_c_seq', None).to(self.device).type(torch.FloatTensor) if 'yaw_c_seq' in x else None
172
+ pitch_c_seq = x.get('pitch_c_seq', None).to(self.device).type(torch.FloatTensor) if 'pitch_c_seq' in x else None
173
+ roll_c_seq = x.get('roll_c_seq', None).to(self.device).type(torch.FloatTensor) if 'roll_c_seq' in x else None
174
  frame_num = x['frame_num']
175
 
176
+ with autocast():
177
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
178
  self.generator, self.kp_extractor, self.he_estimator, self.mapping,
179
  yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
180
 
181
+ predictions_video = predictions_video.reshape((-1,) + predictions_video.shape[2:])
182
  predictions_video = predictions_video[:frame_num]
183
 
184
+ # Create video
185
+ video = [np.transpose(img.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) for img in predictions_video]
 
 
 
186
  result = img_as_ubyte(video)
187
 
188
  ### the generated video is 256x256, so we keep the aspect ratio,