Spaces:
Running
Running
Commit
Β·
9000726
1
Parent(s):
2bdf2d8
workaround for Finetrainers
Browse files
finetrainers/data/dataset.py
CHANGED
@@ -970,9 +970,59 @@ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
|
970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
971 |
return image
|
972 |
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
971 |
return image
|
972 |
|
973 |
+
def _preprocess_video(video) -> torch.Tensor:
|
974 |
+
import torch
|
975 |
+
import numpy as np
|
976 |
+
|
977 |
+
# For decord VideoReader
|
978 |
+
if hasattr(video, 'get_batch') and 'decord' in str(type(video)):
|
979 |
+
video = video.get_batch(list(range(len(video))))
|
980 |
+
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
981 |
+
return video
|
982 |
+
|
983 |
+
# For torchvision VideoReader
|
984 |
+
elif 'torchvision.io.video_reader' in str(type(video)):
|
985 |
+
# Use the correct iteration pattern for torchvision.io.VideoReader
|
986 |
+
frames = []
|
987 |
+
try:
|
988 |
+
# First seek to the beginning
|
989 |
+
video.seek(0)
|
990 |
+
|
991 |
+
# Then collect frames by iterating
|
992 |
+
for _ in range(30): # Try to get a reasonable number of frames
|
993 |
+
try:
|
994 |
+
frame_dict = next(video)
|
995 |
+
frame = frame_dict["data"] # Extract the tensor data from the dict
|
996 |
+
frames.append(frame)
|
997 |
+
except StopIteration:
|
998 |
+
break
|
999 |
+
except Exception as e:
|
1000 |
+
print(f"Error iterating VideoReader: {e}")
|
1001 |
+
|
1002 |
+
if frames:
|
1003 |
+
# In torchvision.io.VideoReader, frames are already in [C, H, W] format
|
1004 |
+
# We need to stack and convert to [B, C, H, W]
|
1005 |
+
stacked_frames = torch.stack(frames)
|
1006 |
+
# Normalize to [-1, 1]
|
1007 |
+
stacked_frames = stacked_frames.float() / 127.5 - 1.0
|
1008 |
+
return stacked_frames
|
1009 |
+
|
1010 |
+
# If we couldn't get frames, create a dummy tensor
|
1011 |
+
print("Failed to get frames, creating dummy tensor")
|
1012 |
+
return torch.zeros(16, 3, 512, 768).float()
|
1013 |
+
|
1014 |
+
# For list of PIL images
|
1015 |
+
elif isinstance(video, list) and len(video) > 0 and hasattr(video[0], 'convert'):
|
1016 |
+
frames = []
|
1017 |
+
for img in video:
|
1018 |
+
img_tensor = torch.from_numpy(np.array(img.convert("RGB"))).float()
|
1019 |
+
frames.append(img_tensor)
|
1020 |
+
|
1021 |
+
video = torch.stack(frames)
|
1022 |
+
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
1023 |
+
return video
|
1024 |
+
|
1025 |
+
# Unknown type
|
1026 |
+
else:
|
1027 |
+
print(f"Unknown video type: {type(video)}")
|
1028 |
+
return torch.zeros(16, 3, 512, 768).float()
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
@@ -325,8 +325,21 @@ class SFTTrainer:
|
|
325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
326 |
if resume_from_checkpoint == "latest":
|
327 |
resume_from_checkpoint = -1
|
|
|
|
|
|
|
328 |
if resume_from_checkpoint is not None:
|
329 |
-
self.checkpointer.load(resume_from_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
def _train(self) -> None:
|
332 |
logger.info("Starting training")
|
|
|
325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
326 |
if resume_from_checkpoint == "latest":
|
327 |
resume_from_checkpoint = -1
|
328 |
+
|
329 |
+
# Store the load result
|
330 |
+
load_successful = False
|
331 |
if resume_from_checkpoint is not None:
|
332 |
+
load_successful = self.checkpointer.load(resume_from_checkpoint)
|
333 |
+
|
334 |
+
# If loading succeeded and we have a specific checkpoint path
|
335 |
+
if load_successful and isinstance(resume_from_checkpoint, str) and resume_from_checkpoint != "latest":
|
336 |
+
try:
|
337 |
+
step = int(resume_from_checkpoint.split("_")[-1])
|
338 |
+
self.state.train_state.step = step
|
339 |
+
logger.info(f"Explicitly setting training step to {step} based on checkpoint path")
|
340 |
+
except (ValueError, IndexError):
|
341 |
+
logger.warning(f"Could not parse step number from checkpoint path: {resume_from_checkpoint}")
|
342 |
+
|
343 |
|
344 |
def _train(self) -> None:
|
345 |
logger.info("Starting training")
|
vms/ui/app_ui.py
CHANGED
@@ -146,7 +146,7 @@ class AppUI:
|
|
146 |
# Sidebar for navigation
|
147 |
with gr.Sidebar(position="left", open=True):
|
148 |
gr.Markdown("# ποΈ Video Model Studio")
|
149 |
-
self.components["current_project_btn"] = gr.Button("π
|
150 |
self.components["system_monitoring_btn"] = gr.Button("π‘οΈ System Monitoring")
|
151 |
|
152 |
# Main content area with tabs
|
@@ -156,7 +156,7 @@ class AppUI:
|
|
156 |
self.main_tabs = main_tabs
|
157 |
|
158 |
# Project View Tab
|
159 |
-
with gr.Tab("π
|
160 |
# Create project tabs
|
161 |
with gr.Tabs() as project_tabs:
|
162 |
# Store reference to project tabs component
|
@@ -551,20 +551,20 @@ class AppUI:
|
|
551 |
if is_training:
|
552 |
# Active training detected
|
553 |
start_btn_props = {"interactive": False, "variant": "secondary", "value": "π Start new training"}
|
554 |
-
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "
|
555 |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
556 |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
557 |
else:
|
558 |
# No active training
|
559 |
start_btn_props = {"interactive": True, "variant": "primary", "value": "π Start new training"}
|
560 |
-
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "
|
561 |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
562 |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
563 |
else:
|
564 |
# Use button states from recovery, adding the new resume button
|
565 |
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "π Start new training"})
|
566 |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
|
567 |
-
"variant": "primary", "value": "
|
568 |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
569 |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
570 |
|
|
|
146 |
# Sidebar for navigation
|
147 |
with gr.Sidebar(position="left", open=True):
|
148 |
gr.Markdown("# ποΈ Video Model Studio")
|
149 |
+
self.components["current_project_btn"] = gr.Button("π New Project", variant="primary")
|
150 |
self.components["system_monitoring_btn"] = gr.Button("π‘οΈ System Monitoring")
|
151 |
|
152 |
# Main content area with tabs
|
|
|
156 |
self.main_tabs = main_tabs
|
157 |
|
158 |
# Project View Tab
|
159 |
+
with gr.Tab("π New Project", id=0) as project_view:
|
160 |
# Create project tabs
|
161 |
with gr.Tabs() as project_tabs:
|
162 |
# Store reference to project tabs component
|
|
|
551 |
if is_training:
|
552 |
# Active training detected
|
553 |
start_btn_props = {"interactive": False, "variant": "secondary", "value": "π Start new training"}
|
554 |
+
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "πΈ Start from latest checkpoint"}
|
555 |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
556 |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
557 |
else:
|
558 |
# No active training
|
559 |
start_btn_props = {"interactive": True, "variant": "primary", "value": "π Start new training"}
|
560 |
+
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "πΈ Start from latest checkpoint"}
|
561 |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
562 |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
563 |
else:
|
564 |
# Use button states from recovery, adding the new resume button
|
565 |
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "π Start new training"})
|
566 |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
|
567 |
+
"variant": "primary", "value": "πΈ Start from latest checkpoint"}
|
568 |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
569 |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
570 |
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -187,8 +187,8 @@ class TrainTab(BaseTab):
|
|
187 |
# Add description of the training buttons
|
188 |
self.components["training_buttons_info"] = gr.Markdown("""
|
189 |
## βοΈ Train your model on your dataset
|
190 |
-
-
|
191 |
-
-
|
192 |
""")
|
193 |
|
194 |
with gr.Row():
|
@@ -204,7 +204,7 @@ class TrainTab(BaseTab):
|
|
204 |
|
205 |
# Add new button for continuing from checkpoint
|
206 |
self.components["resume_btn"] = gr.Button(
|
207 |
-
"
|
208 |
variant="primary",
|
209 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
210 |
)
|
@@ -972,7 +972,7 @@ class TrainTab(BaseTab):
|
|
972 |
)
|
973 |
|
974 |
resume_btn = gr.Button(
|
975 |
-
value="Start from latest checkpoint",
|
976 |
interactive=has_checkpoints and not is_training,
|
977 |
variant="primary" if not is_training else "secondary"
|
978 |
)
|
|
|
187 |
# Add description of the training buttons
|
188 |
self.components["training_buttons_info"] = gr.Markdown("""
|
189 |
## βοΈ Train your model on your dataset
|
190 |
+
- **π Start new training**: Begins training from scratch (clears previous checkpoints)
|
191 |
+
- **πΈ Start from latest checkpoint**: Continues training from the most recent checkpoint
|
192 |
""")
|
193 |
|
194 |
with gr.Row():
|
|
|
204 |
|
205 |
# Add new button for continuing from checkpoint
|
206 |
self.components["resume_btn"] = gr.Button(
|
207 |
+
"πΈ Start from latest checkpoint",
|
208 |
variant="primary",
|
209 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
210 |
)
|
|
|
972 |
)
|
973 |
|
974 |
resume_btn = gr.Button(
|
975 |
+
value="πΈ Start from latest checkpoint",
|
976 |
interactive=has_checkpoints and not is_training,
|
977 |
variant="primary" if not is_training else "secondary"
|
978 |
)
|