jbilcke-hf HF Staff commited on
Commit
9000726
Β·
1 Parent(s): 2bdf2d8

workaround for Finetrainers

Browse files
finetrainers/data/dataset.py CHANGED
@@ -970,9 +970,59 @@ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
970
  image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
971
  return image
972
 
973
-
974
- def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
975
- video = video.get_batch(list(range(len(video))))
976
- video = video.permute(0, 3, 1, 2).contiguous()
977
- video = video.float() / 127.5 - 1.0
978
- return video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
  image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
971
  return image
972
 
973
+ def _preprocess_video(video) -> torch.Tensor:
974
+ import torch
975
+ import numpy as np
976
+
977
+ # For decord VideoReader
978
+ if hasattr(video, 'get_batch') and 'decord' in str(type(video)):
979
+ video = video.get_batch(list(range(len(video))))
980
+ video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
981
+ return video
982
+
983
+ # For torchvision VideoReader
984
+ elif 'torchvision.io.video_reader' in str(type(video)):
985
+ # Use the correct iteration pattern for torchvision.io.VideoReader
986
+ frames = []
987
+ try:
988
+ # First seek to the beginning
989
+ video.seek(0)
990
+
991
+ # Then collect frames by iterating
992
+ for _ in range(30): # Try to get a reasonable number of frames
993
+ try:
994
+ frame_dict = next(video)
995
+ frame = frame_dict["data"] # Extract the tensor data from the dict
996
+ frames.append(frame)
997
+ except StopIteration:
998
+ break
999
+ except Exception as e:
1000
+ print(f"Error iterating VideoReader: {e}")
1001
+
1002
+ if frames:
1003
+ # In torchvision.io.VideoReader, frames are already in [C, H, W] format
1004
+ # We need to stack and convert to [B, C, H, W]
1005
+ stacked_frames = torch.stack(frames)
1006
+ # Normalize to [-1, 1]
1007
+ stacked_frames = stacked_frames.float() / 127.5 - 1.0
1008
+ return stacked_frames
1009
+
1010
+ # If we couldn't get frames, create a dummy tensor
1011
+ print("Failed to get frames, creating dummy tensor")
1012
+ return torch.zeros(16, 3, 512, 768).float()
1013
+
1014
+ # For list of PIL images
1015
+ elif isinstance(video, list) and len(video) > 0 and hasattr(video[0], 'convert'):
1016
+ frames = []
1017
+ for img in video:
1018
+ img_tensor = torch.from_numpy(np.array(img.convert("RGB"))).float()
1019
+ frames.append(img_tensor)
1020
+
1021
+ video = torch.stack(frames)
1022
+ video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
1023
+ return video
1024
+
1025
+ # Unknown type
1026
+ else:
1027
+ print(f"Unknown video type: {type(video)}")
1028
+ return torch.zeros(16, 3, 512, 768).float()
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -325,8 +325,21 @@ class SFTTrainer:
325
  resume_from_checkpoint = self.args.resume_from_checkpoint
326
  if resume_from_checkpoint == "latest":
327
  resume_from_checkpoint = -1
 
 
 
328
  if resume_from_checkpoint is not None:
329
- self.checkpointer.load(resume_from_checkpoint)
 
 
 
 
 
 
 
 
 
 
330
 
331
  def _train(self) -> None:
332
  logger.info("Starting training")
 
325
  resume_from_checkpoint = self.args.resume_from_checkpoint
326
  if resume_from_checkpoint == "latest":
327
  resume_from_checkpoint = -1
328
+
329
+ # Store the load result
330
+ load_successful = False
331
  if resume_from_checkpoint is not None:
332
+ load_successful = self.checkpointer.load(resume_from_checkpoint)
333
+
334
+ # If loading succeeded and we have a specific checkpoint path
335
+ if load_successful and isinstance(resume_from_checkpoint, str) and resume_from_checkpoint != "latest":
336
+ try:
337
+ step = int(resume_from_checkpoint.split("_")[-1])
338
+ self.state.train_state.step = step
339
+ logger.info(f"Explicitly setting training step to {step} based on checkpoint path")
340
+ except (ValueError, IndexError):
341
+ logger.warning(f"Could not parse step number from checkpoint path: {resume_from_checkpoint}")
342
+
343
 
344
  def _train(self) -> None:
345
  logger.info("Starting training")
vms/ui/app_ui.py CHANGED
@@ -146,7 +146,7 @@ class AppUI:
146
  # Sidebar for navigation
147
  with gr.Sidebar(position="left", open=True):
148
  gr.Markdown("# 🎞️ Video Model Studio")
149
- self.components["current_project_btn"] = gr.Button("πŸ“‚ Current Project", variant="primary")
150
  self.components["system_monitoring_btn"] = gr.Button("🌑️ System Monitoring")
151
 
152
  # Main content area with tabs
@@ -156,7 +156,7 @@ class AppUI:
156
  self.main_tabs = main_tabs
157
 
158
  # Project View Tab
159
- with gr.Tab("πŸ“ Current Project", id=0) as project_view:
160
  # Create project tabs
161
  with gr.Tabs() as project_tabs:
162
  # Store reference to project tabs component
@@ -551,20 +551,20 @@ class AppUI:
551
  if is_training:
552
  # Active training detected
553
  start_btn_props = {"interactive": False, "variant": "secondary", "value": "πŸš€ Start new training"}
554
- resume_btn_props = {"interactive": False, "variant": "secondary", "value": "πŸ›°οΈ Start from latest checkpoint"}
555
  stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
556
  delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
557
  else:
558
  # No active training
559
  start_btn_props = {"interactive": True, "variant": "primary", "value": "πŸš€ Start new training"}
560
- resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "πŸ›°οΈ Start from latest checkpoint"}
561
  stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
562
  delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
563
  else:
564
  # Use button states from recovery, adding the new resume button
565
  start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "πŸš€ Start new training"})
566
  resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
567
- "variant": "primary", "value": "πŸ›°οΈ Start from latest checkpoint"}
568
  stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
569
  delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
570
 
 
146
  # Sidebar for navigation
147
  with gr.Sidebar(position="left", open=True):
148
  gr.Markdown("# 🎞️ Video Model Studio")
149
+ self.components["current_project_btn"] = gr.Button("πŸ“‚ New Project", variant="primary")
150
  self.components["system_monitoring_btn"] = gr.Button("🌑️ System Monitoring")
151
 
152
  # Main content area with tabs
 
156
  self.main_tabs = main_tabs
157
 
158
  # Project View Tab
159
+ with gr.Tab("πŸ“ New Project", id=0) as project_view:
160
  # Create project tabs
161
  with gr.Tabs() as project_tabs:
162
  # Store reference to project tabs component
 
551
  if is_training:
552
  # Active training detected
553
  start_btn_props = {"interactive": False, "variant": "secondary", "value": "πŸš€ Start new training"}
554
+ resume_btn_props = {"interactive": False, "variant": "secondary", "value": "πŸ›Έ Start from latest checkpoint"}
555
  stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
556
  delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
557
  else:
558
  # No active training
559
  start_btn_props = {"interactive": True, "variant": "primary", "value": "πŸš€ Start new training"}
560
+ resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "πŸ›Έ Start from latest checkpoint"}
561
  stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
562
  delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
563
  else:
564
  # Use button states from recovery, adding the new resume button
565
  start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "πŸš€ Start new training"})
566
  resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
567
+ "variant": "primary", "value": "πŸ›Έ Start from latest checkpoint"}
568
  stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
569
  delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
570
 
vms/ui/project/tabs/train_tab.py CHANGED
@@ -187,8 +187,8 @@ class TrainTab(BaseTab):
187
  # Add description of the training buttons
188
  self.components["training_buttons_info"] = gr.Markdown("""
189
  ## βš—οΈ Train your model on your dataset
190
- - **Start new training**: Begins training from scratch (clears previous checkpoints)
191
- - **Start from latest checkpoint**: Continues training from the most recent checkpoint
192
  """)
193
 
194
  with gr.Row():
@@ -204,7 +204,7 @@ class TrainTab(BaseTab):
204
 
205
  # Add new button for continuing from checkpoint
206
  self.components["resume_btn"] = gr.Button(
207
- "πŸ›°οΈ Start from latest checkpoint",
208
  variant="primary",
209
  interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
210
  )
@@ -972,7 +972,7 @@ class TrainTab(BaseTab):
972
  )
973
 
974
  resume_btn = gr.Button(
975
- value="Start from latest checkpoint",
976
  interactive=has_checkpoints and not is_training,
977
  variant="primary" if not is_training else "secondary"
978
  )
 
187
  # Add description of the training buttons
188
  self.components["training_buttons_info"] = gr.Markdown("""
189
  ## βš—οΈ Train your model on your dataset
190
+ - **πŸš€ Start new training**: Begins training from scratch (clears previous checkpoints)
191
+ - **πŸ›Έ Start from latest checkpoint**: Continues training from the most recent checkpoint
192
  """)
193
 
194
  with gr.Row():
 
204
 
205
  # Add new button for continuing from checkpoint
206
  self.components["resume_btn"] = gr.Button(
207
+ "πŸ›Έ Start from latest checkpoint",
208
  variant="primary",
209
  interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
210
  )
 
972
  )
973
 
974
  resume_btn = gr.Button(
975
+ value="πŸ›Έ Start from latest checkpoint",
976
  interactive=has_checkpoints and not is_training,
977
  variant="primary" if not is_training else "secondary"
978
  )