Spanicin commited on
Commit
bf35a83
·
verified ·
1 Parent(s): 2efc33b

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -99,98 +99,98 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
99
  return {'value': kp_transformed}
100
 
101
 
102
- # def make_animation(source_image, source_semantics, target_semantics,
103
- # generator, kp_detector, he_estimator, mapping,
104
- # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
- # use_exp=True):
106
- # with torch.no_grad():
107
- # predictions = []
108
- # device = 'cuda'
109
- # source_image = source_image.to(device)
110
- # source_semantics = source_semantics.to(device)
111
- # target_semantics = target_semantics.to(device)
112
-
113
- # kp_canonical = kp_detector(source_image)
114
- # he_source = mapping(source_semantics)
115
- # kp_source = keypoint_transformation(kp_canonical, he_source)
116
-
117
-
118
- # for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
119
- # target_semantics_frame = target_semantics[:, frame_idx]
120
- # he_driving = mapping(target_semantics_frame)
121
- # if yaw_c_seq is not None:
122
- # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
123
- # if pitch_c_seq is not None:
124
- # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
125
- # if roll_c_seq is not None:
126
- # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
127
-
128
- # kp_driving = keypoint_transformation(kp_canonical, he_driving)
129
-
130
- # #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
131
- # #kp_driving_initial=kp_driving_initial)
132
- # kp_norm = kp_driving
133
- # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
134
- # '''
135
- # source_image_new = out['prediction'].squeeze(1)
136
- # kp_canonical_new = kp_detector(source_image_new)
137
- # he_source_new = he_estimator(source_image_new)
138
- # kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
139
- # kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
140
- # out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
141
- # '''
142
- # predictions.append(out['prediction'])
143
- # torch.cuda.empty_cache()
144
- # predictions_ts = torch.stack(predictions, dim=1)
145
- # return predictions_ts
146
-
147
- import torch
148
- from torch.cuda.amp import autocast
149
-
150
  def make_animation(source_image, source_semantics, target_semantics,
151
- generator, kp_detector, he_estimator, mapping,
152
- yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
153
- use_exp=True):
154
-
155
- device='cuda'
156
- # Move inputs to GPU
157
- source_image = source_image.to(device)
158
- source_semantics = source_semantics.to(device)
159
- target_semantics = target_semantics.to(device)
160
-
161
- with torch.no_grad(): # No gradients needed
162
  predictions = []
 
 
 
 
 
163
  kp_canonical = kp_detector(source_image)
164
  he_source = mapping(source_semantics)
165
  kp_source = keypoint_transformation(kp_canonical, he_source)
166
 
167
- for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
 
168
  target_semantics_frame = target_semantics[:, frame_idx]
169
  he_driving = mapping(target_semantics_frame)
170
-
171
  if yaw_c_seq is not None:
172
  he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
173
  if pitch_c_seq is not None:
174
- he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
175
  if roll_c_seq is not None:
176
- he_driving['roll_in'] = roll_c_seq[:, frame_idx]
177
-
178
  kp_driving = keypoint_transformation(kp_canonical, he_driving)
 
 
 
179
  kp_norm = kp_driving
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Use mixed precision for faster computation
182
- with autocast():
183
- out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
184
 
185
- predictions.append(out['prediction'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # Optional: Explicitly synchronize (use only if necessary)
188
- torch.cuda.synchronize()
189
 
190
- # Stack predictions into a single tensor
191
- predictions_ts = torch.stack(predictions, dim=1)
192
 
193
- return predictions_ts
194
 
195
 
196
 
 
99
  return {'value': kp_transformed}
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def make_animation(source_image, source_semantics, target_semantics,
103
+ generator, kp_detector, he_estimator, mapping,
104
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
+ use_exp=True):
106
+ with torch.no_grad():
 
 
 
 
 
 
 
107
  predictions = []
108
+ device = 'cuda'
109
+ source_image = source_image.to(device)
110
+ source_semantics = source_semantics.to(device)
111
+ target_semantics = target_semantics.to(device)
112
+
113
  kp_canonical = kp_detector(source_image)
114
  he_source = mapping(source_semantics)
115
  kp_source = keypoint_transformation(kp_canonical, he_source)
116
 
117
+
118
+ for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
119
  target_semantics_frame = target_semantics[:, frame_idx]
120
  he_driving = mapping(target_semantics_frame)
 
121
  if yaw_c_seq is not None:
122
  he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
123
  if pitch_c_seq is not None:
124
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
125
  if roll_c_seq is not None:
126
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
127
+
128
  kp_driving = keypoint_transformation(kp_canonical, he_driving)
129
+
130
+ #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
131
+ #kp_driving_initial=kp_driving_initial)
132
  kp_norm = kp_driving
133
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
134
+ '''
135
+ source_image_new = out['prediction'].squeeze(1)
136
+ kp_canonical_new = kp_detector(source_image_new)
137
+ he_source_new = he_estimator(source_image_new)
138
+ kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
139
+ kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
140
+ out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
141
+ '''
142
+ predictions.append(out['prediction'])
143
+ torch.cuda.empty_cache()
144
+ predictions_ts = torch.stack(predictions, dim=1)
145
+ return predictions_ts
146
 
147
+ # import torch
148
+ # from torch.cuda.amp import autocast
 
149
 
150
+ # def make_animation(source_image, source_semantics, target_semantics,
151
+ # generator, kp_detector, he_estimator, mapping,
152
+ # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
153
+ # use_exp=True):
154
+
155
+ # device='cuda'
156
+ # # Move inputs to GPU
157
+ # source_image = source_image.to(device)
158
+ # source_semantics = source_semantics.to(device)
159
+ # target_semantics = target_semantics.to(device)
160
+
161
+ # with torch.no_grad(): # No gradients needed
162
+ # predictions = []
163
+ # kp_canonical = kp_detector(source_image)
164
+ # he_source = mapping(source_semantics)
165
+ # kp_source = keypoint_transformation(kp_canonical, he_source)
166
+
167
+ # for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
168
+ # target_semantics_frame = target_semantics[:, frame_idx]
169
+ # he_driving = mapping(target_semantics_frame)
170
+
171
+ # if yaw_c_seq is not None:
172
+ # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
173
+ # if pitch_c_seq is not None:
174
+ # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
175
+ # if roll_c_seq is not None:
176
+ # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
177
+
178
+ # kp_driving = keypoint_transformation(kp_canonical, he_driving)
179
+ # kp_norm = kp_driving
180
+
181
+ # # Use mixed precision for faster computation
182
+ # with autocast():
183
+ # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
184
+
185
+ # predictions.append(out['prediction'])
186
 
187
+ # # Optional: Explicitly synchronize (use only if necessary)
188
+ # torch.cuda.synchronize()
189
 
190
+ # # Stack predictions into a single tensor
191
+ # predictions_ts = torch.stack(predictions, dim=1)
192
 
193
+ # return predictions_ts
194
 
195
 
196