|
import streamlit as st |
|
import os |
|
import py7zr |
|
import requests |
|
from huggingface_hub import HfApi |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import shutil |
|
from pathlib import Path |
|
from typing import Optional |
|
import sys |
|
import io |
|
|
|
|
|
from denoising_model import DenoisingModel, DenoiseDataset, get_optimal_threads |
|
|
|
class StreamCapture: |
|
def __init__(self): |
|
self.logs = [] |
|
|
|
def write(self, text): |
|
self.logs.append(text) |
|
st.warning(text) |
|
|
|
def flush(self): |
|
pass |
|
|
|
def download_and_extract_7z(url: str, extract_to: str = '.') -> Optional[str]: |
|
"""Downloads a 7z file and extracts it""" |
|
try: |
|
st.warning(f"Downloading file from {url}...") |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
|
|
archive_path = os.path.join(extract_to, 'dataset.7z') |
|
with open(archive_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
st.warning("Extracting 7z archive...") |
|
with py7zr.SevenZipFile(archive_path, mode='r') as z: |
|
z.extractall(extract_to) |
|
|
|
|
|
output_images_path = os.path.join(extract_to, 'output_images') |
|
if os.path.exists(output_images_path): |
|
|
|
source_noisy = os.path.join(output_images_path, 'images_noisy') |
|
source_target = os.path.join(output_images_path, 'images_target') |
|
|
|
if os.path.exists('noisy_images'): |
|
shutil.rmtree('noisy_images') |
|
if os.path.exists('target_images'): |
|
shutil.rmtree('target_images') |
|
|
|
shutil.move(source_noisy, 'noisy_images') |
|
shutil.move(source_target, 'target_images') |
|
|
|
|
|
if os.path.exists(output_images_path): |
|
shutil.rmtree(output_images_path) |
|
|
|
os.remove(archive_path) |
|
st.warning("Download and extraction completed successfully.") |
|
return None |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def upload_to_huggingface(file_path: str, repo_id: str, path_in_repo: str): |
|
"""Uploads a file to Hugging Face""" |
|
try: |
|
api = HfApi() |
|
api.upload_file( |
|
path_or_fileobj=file_path, |
|
path_in_repo=path_in_repo, |
|
repo_id=repo_id, |
|
repo_type="space" |
|
) |
|
st.warning(f"Successfully uploaded {file_path} to {repo_id}") |
|
except Exception as e: |
|
st.error(f"Error uploading to Hugging Face: {str(e)}") |
|
|
|
def train_model_with_upload(epochs, batch_size, learning_rate, save_interval, num_workers, repo_id): |
|
"""Modified training function that uploads checkpoints to Hugging Face""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
st.warning(f"Using device: {device}") |
|
|
|
|
|
checkpoint_dir = "temp_checkpoints" |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
try: |
|
dataset = DenoiseDataset('noisy_images', 'target_images') |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
pin_memory=True if torch.cuda.is_available() else False |
|
) |
|
|
|
model = DenoisingModel().to(device) |
|
criterion = torch.nn.MSELoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
for epoch in range(epochs): |
|
st.warning(f"Starting epoch {epoch+1}/{epochs}") |
|
for batch_idx, (noisy_patches, target_patches) in enumerate(dataloader): |
|
noisy_patches = noisy_patches.to(device) |
|
target_patches = target_patches.to(device) |
|
|
|
outputs = model(noisy_patches) |
|
loss = criterion(outputs, target_patches) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if (batch_idx + 1) % 10 == 0: |
|
st.warning(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}], Loss: {loss.item():.6f}") |
|
|
|
if (batch_idx + 1) % save_interval == 0: |
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch{epoch+1}_batch{batch_idx+1}.pth") |
|
torch.save(model.state_dict(), checkpoint_path) |
|
|
|
|
|
upload_to_huggingface( |
|
checkpoint_path, |
|
repo_id, |
|
f"checkpoints/checkpoint_epoch{epoch+1}_batch{batch_idx+1}.pth" |
|
) |
|
|
|
|
|
final_model_path = os.path.join(checkpoint_dir, "final_model.pth") |
|
torch.save(model.state_dict(), final_model_path) |
|
upload_to_huggingface(final_model_path, repo_id, "model/final_model.pth") |
|
|
|
finally: |
|
|
|
if os.path.exists(checkpoint_dir): |
|
shutil.rmtree(checkpoint_dir) |
|
|
|
def main(): |
|
st.title("Image Denoising Model Training") |
|
|
|
|
|
sys.stdout = StreamCapture() |
|
|
|
|
|
hf_token = st.text_input("Enter your Hugging Face token:", type="password") |
|
if hf_token: |
|
os.environ["HF_TOKEN"] = hf_token |
|
|
|
|
|
repo_id = st.text_input("Enter your Hugging Face repository ID (username/repo):") |
|
|
|
|
|
if st.button("Download and Extract Dataset"): |
|
url = "https://huggingface.co/spaces/vericudebuget/ok4231/resolve/main/output_images.7z" |
|
error = download_and_extract_7z(url) |
|
if error: |
|
st.error(error) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
epochs = st.number_input("Number of epochs", min_value=1, value=10) |
|
batch_size = st.number_input("Batch size", min_value=1, value=4) |
|
learning_rate = st.number_input("Learning rate", min_value=0.0001, value=0.001, format="%.4f") |
|
|
|
with col2: |
|
save_interval = st.number_input("Save interval (batches)", min_value=1, value=1000) |
|
num_workers = st.number_input("Number of workers", min_value=1, value=get_optimal_threads()) |
|
|
|
|
|
if st.button("Start Training"): |
|
if not hf_token: |
|
st.error("Please enter your Hugging Face token") |
|
return |
|
if not repo_id: |
|
st.error("Please enter your repository ID") |
|
return |
|
if not os.path.exists("noisy_images") or not os.path.exists("target_images"): |
|
st.error("Dataset not found. Please download and extract it first.") |
|
return |
|
|
|
train_model_with_upload(epochs, batch_size, learning_rate, save_interval, num_workers, repo_id) |
|
|
|
if __name__ == "__main__": |
|
main() |