Spaces:
Runtime error
Runtime error
File size: 4,802 Bytes
8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 6d1d75b 8578371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
from diffusers import DiffusionPipeline, DDPMScheduler
from accelerate import Accelerator
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import TrainingArguments
import gradio as gr
# Konfigurasi
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
dataset_name = "DucHaiten/anime-SDXL" # Gunakan dataset sesuai keinginan Anda
learning_rate = 1e-5
num_train_epochs = 2 # Sesuaikan dengan kebutuhan
train_batch_size = 1 # Gunakan batch size kecil untuk Spaces gratis
gradient_accumulation_steps = 4 # Sesuaikan dengan kebutuhan
output_dir = "flux-anime"
image_resize = 128 # Sesuaikan dengan kebutuhan
# Muat model dan scheduler
pipeline = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path, torch_dtype=torch.float16
)
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_xformers_memory_efficient_attention()
# Muat dataset
dataset = load_dataset(dataset_name)["train"]
# Fungsi untuk memproses data
def preprocess_function(examples):
images = [
image.convert("RGB").resize((image_resize, image_resize))
for image in examples["image"]
]
texts = [text for text in examples["text"]]
examples["pixel_values"] = pipeline.feature_extractor(
images=images, return_tensors="pt"
).pixel_values
examples["prompt"] = texts
return examples
# Proses dataset
processed_dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=4,
remove_columns=dataset.column_names,
)
# Inisialisasi accelerator
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision="fp16",
)
pipeline.unet, pipeline.vae, processed_dataset = accelerator.prepare(
pipeline.unet, pipeline.vae, processed_dataset
)
# Optimizer
optimizer = torch.optim.AdamW(
pipeline.unet.parameters(),
lr=learning_rate,
)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
fp16=True,
logging_dir="./logs",
report_to="tensorboard",
push_to_hub=True, # Push model ke Hugging Face Hub
)
# Training loop
progress_bar = tqdm(
range(num_train_epochs * len(processed_dataset) // train_batch_size)
)
# --- Komponen Gradio ---
with gr.Blocks() as interface:
gr.Markdown(
"## Fine-tuning FLUX untuk Anime"
) # Ganti judul sesuai dataset Anda
loss_textbox = gr.Textbox(label="Loss")
epoch_textbox = gr.Textbox(label="Epoch")
progress_bar_gradio = gr.ProgressBar(label="Progress")
output_image = gr.Image(label="Generated Image")
def train_step(step, epoch, loss):
loss_textbox.update(value=loss)
epoch_textbox.update(value=epoch)
progress_bar_gradio.update(value=step / len(progress_bar))
if step % 100 == 0:
with torch.no_grad():
image = pipeline(
"anime style image of a girl with blue hair"
).images[
0
] # Ganti prompt sesuai dataset Anda
output_image.update(value=image)
return loss, epoch, step / len(progress_bar)
interface.launch(server_name="0.0.0.0")
# ------------------------
for epoch in range(num_train_epochs):
pipeline.unet.train()
for step, batch in enumerate(
processed_dataset.iter(batch_size=train_batch_size)
):
with accelerator.accumulate(pipeline.unet):
latents = pipeline.vae.encode(
batch["pixel_values"].to(dtype=torch.float16)
).latent_dist.sample()
latents = latents * pipeline.vae.config.scaling_factor
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(
0,
pipeline.scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()
noisy_latents = pipeline.scheduler.add_noise(
latents, noise, timesteps
)
model_pred = pipeline.unet(
noisy_latents, timesteps, batch["prompt"]
).sample
loss = torch.nn.functional.mse_loss(
model_pred.float(), noise.float(), reduction="mean"
)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
# Update komponen Gradio
train_step(step, epoch, loss.item())
# Simpan model
pipeline.save_pretrained(output_dir)
|