File size: 7,243 Bytes
ff71fa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import streamlit as st
import os
import py7zr
import requests
from huggingface_hub import HfApi
import torch
from torch.utils.data import DataLoader
import shutil
from pathlib import Path
from typing import Optional
import sys
import io
# Import the denoising code (assuming it's in a file called denoising_model.py)
from denoising_model import DenoisingModel, DenoiseDataset, get_optimal_threads
class StreamCapture:
def __init__(self):
self.logs = []
def write(self, text):
self.logs.append(text)
st.warning(text)
def flush(self):
pass
def download_and_extract_7z(url: str, extract_to: str = '.') -> Optional[str]:
"""Downloads a 7z file and extracts it"""
try:
st.warning(f"Downloading file from {url}...")
response = requests.get(url, stream=True)
response.raise_for_status()
archive_path = os.path.join(extract_to, 'dataset.7z')
with open(archive_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
st.warning("Extracting 7z archive...")
with py7zr.SevenZipFile(archive_path, mode='r') as z:
z.extractall(extract_to)
# Handle directory renaming
output_images_path = os.path.join(extract_to, 'output_images')
if os.path.exists(output_images_path):
# Move and rename directories
source_noisy = os.path.join(output_images_path, 'images_noisy')
source_target = os.path.join(output_images_path, 'images_target')
if os.path.exists('noisy_images'):
shutil.rmtree('noisy_images')
if os.path.exists('target_images'):
shutil.rmtree('target_images')
shutil.move(source_noisy, 'noisy_images')
shutil.move(source_target, 'target_images')
# Clean up
if os.path.exists(output_images_path):
shutil.rmtree(output_images_path)
os.remove(archive_path)
st.warning("Download and extraction completed successfully.")
return None
except Exception as e:
return f"Error: {str(e)}"
def upload_to_huggingface(file_path: str, repo_id: str, path_in_repo: str):
"""Uploads a file to Hugging Face"""
try:
api = HfApi()
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="space"
)
st.warning(f"Successfully uploaded {file_path} to {repo_id}")
except Exception as e:
st.error(f"Error uploading to Hugging Face: {str(e)}")
def train_model_with_upload(epochs, batch_size, learning_rate, save_interval, num_workers, repo_id):
"""Modified training function that uploads checkpoints to Hugging Face"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.warning(f"Using device: {device}")
# Create temporary directory for checkpoints
checkpoint_dir = "temp_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
try:
dataset = DenoiseDataset('noisy_images', 'target_images')
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True if torch.cuda.is_available() else False
)
model = DenoisingModel().to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
st.warning(f"Starting epoch {epoch+1}/{epochs}")
for batch_idx, (noisy_patches, target_patches) in enumerate(dataloader):
noisy_patches = noisy_patches.to(device)
target_patches = target_patches.to(device)
outputs = model(noisy_patches)
loss = criterion(outputs, target_patches)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx + 1) % 10 == 0:
st.warning(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}], Loss: {loss.item():.6f}")
if (batch_idx + 1) % save_interval == 0:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch{epoch+1}_batch{batch_idx+1}.pth")
torch.save(model.state_dict(), checkpoint_path)
# Upload checkpoint to Hugging Face
upload_to_huggingface(
checkpoint_path,
repo_id,
f"checkpoints/checkpoint_epoch{epoch+1}_batch{batch_idx+1}.pth"
)
# Save and upload final model
final_model_path = os.path.join(checkpoint_dir, "final_model.pth")
torch.save(model.state_dict(), final_model_path)
upload_to_huggingface(final_model_path, repo_id, "model/final_model.pth")
finally:
# Clean up temporary directory
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
def main():
st.title("Image Denoising Model Training")
# Redirect stdout to capture print statements
sys.stdout = StreamCapture()
# Input for Hugging Face token
hf_token = st.text_input("Enter your Hugging Face token:", type="password")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
# Input for repository ID
repo_id = st.text_input("Enter your Hugging Face repository ID (username/repo):")
# Download and extract dataset button
if st.button("Download and Extract Dataset"):
url = "https://huggingface.co/spaces/vericudebuget/ok4231/resolve/main/output_images.7z"
error = download_and_extract_7z(url)
if error:
st.error(error)
# Training parameters
col1, col2 = st.columns(2)
with col1:
epochs = st.number_input("Number of epochs", min_value=1, value=10)
batch_size = st.number_input("Batch size", min_value=1, value=4)
learning_rate = st.number_input("Learning rate", min_value=0.0001, value=0.001, format="%.4f")
with col2:
save_interval = st.number_input("Save interval (batches)", min_value=1, value=1000)
num_workers = st.number_input("Number of workers", min_value=1, value=get_optimal_threads())
# Start training button
if st.button("Start Training"):
if not hf_token:
st.error("Please enter your Hugging Face token")
return
if not repo_id:
st.error("Please enter your repository ID")
return
if not os.path.exists("noisy_images") or not os.path.exists("target_images"):
st.error("Dataset not found. Please download and extract it first.")
return
train_model_with_upload(epochs, batch_size, learning_rate, save_interval, num_workers, repo_id)
if __name__ == "__main__":
main() |