File size: 4,802 Bytes
8578371
 
 
 
 
 
6d1d75b
8578371
 
 
6d1d75b
8578371
6d1d75b
 
 
 
 
8578371
 
6d1d75b
 
 
8578371
 
 
 
 
 
 
 
6d1d75b
 
 
 
8578371
6d1d75b
 
 
8578371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d1d75b
 
 
8578371
 
 
6d1d75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8578371
 
6d1d75b
 
 
8578371
6d1d75b
 
 
8578371
 
 
6d1d75b
 
 
 
 
 
8578371
6d1d75b
 
 
 
 
 
 
 
 
 
 
8578371
 
 
 
 
6d1d75b
 
 
8578371
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from diffusers import DiffusionPipeline, DDPMScheduler
from accelerate import Accelerator
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import TrainingArguments
import gradio as gr

# Konfigurasi
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
dataset_name = "DucHaiten/anime-SDXL"  # Gunakan dataset sesuai keinginan Anda
learning_rate = 1e-5
num_train_epochs = 2  # Sesuaikan dengan kebutuhan
train_batch_size = 1  # Gunakan batch size kecil untuk Spaces gratis
gradient_accumulation_steps = 4  # Sesuaikan dengan kebutuhan
output_dir = "flux-anime"
image_resize = 128  # Sesuaikan dengan kebutuhan

# Muat model dan scheduler
pipeline = DiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path, torch_dtype=torch.float16
)
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_xformers_memory_efficient_attention()

# Muat dataset
dataset = load_dataset(dataset_name)["train"]

# Fungsi untuk memproses data
def preprocess_function(examples):
    images = [
        image.convert("RGB").resize((image_resize, image_resize))
        for image in examples["image"]
    ]
    texts = [text for text in examples["text"]]
    examples["pixel_values"] = pipeline.feature_extractor(
        images=images, return_tensors="pt"
    ).pixel_values
    examples["prompt"] = texts
    return examples

# Proses dataset
processed_dataset = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=dataset.column_names,
)

# Inisialisasi accelerator
accelerator = Accelerator(
    gradient_accumulation_steps=gradient_accumulation_steps,
    mixed_precision="fp16",
)
pipeline.unet, pipeline.vae, processed_dataset = accelerator.prepare(
    pipeline.unet, pipeline.vae, processed_dataset
)

# Optimizer
optimizer = torch.optim.AdamW(
    pipeline.unet.parameters(),
    lr=learning_rate,
)

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    fp16=True,
    logging_dir="./logs",
    report_to="tensorboard",
    push_to_hub=True,  # Push model ke Hugging Face Hub
)

# Training loop
progress_bar = tqdm(
    range(num_train_epochs * len(processed_dataset) // train_batch_size)
)

# --- Komponen Gradio ---
with gr.Blocks() as interface:
    gr.Markdown(
        "## Fine-tuning FLUX untuk Anime"
    )  # Ganti judul sesuai dataset Anda
    loss_textbox = gr.Textbox(label="Loss")
    epoch_textbox = gr.Textbox(label="Epoch")
    progress_bar_gradio = gr.ProgressBar(label="Progress")
    output_image = gr.Image(label="Generated Image")

    def train_step(step, epoch, loss):
        loss_textbox.update(value=loss)
        epoch_textbox.update(value=epoch)
        progress_bar_gradio.update(value=step / len(progress_bar))
        if step % 100 == 0:
            with torch.no_grad():
                image = pipeline(
                    "anime style image of a girl with blue hair"
                ).images[
                    0
                ]  # Ganti prompt sesuai dataset Anda
            output_image.update(value=image)
        return loss, epoch, step / len(progress_bar)

    interface.launch(server_name="0.0.0.0")

# ------------------------

for epoch in range(num_train_epochs):
    pipeline.unet.train()
    for step, batch in enumerate(
        processed_dataset.iter(batch_size=train_batch_size)
    ):
        with accelerator.accumulate(pipeline.unet):
            latents = pipeline.vae.encode(
                batch["pixel_values"].to(dtype=torch.float16)
            ).latent_dist.sample()
            latents = latents * pipeline.vae.config.scaling_factor
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randint(
                0,
                pipeline.scheduler.config.num_train_timesteps,
                (bsz,),
                device=latents.device,
            )
            timesteps = timesteps.long()
            noisy_latents = pipeline.scheduler.add_noise(
                latents, noise, timesteps
            )

            model_pred = pipeline.unet(
                noisy_latents, timesteps, batch["prompt"]
            ).sample

            loss = torch.nn.functional.mse_loss(
                model_pred.float(), noise.float(), reduction="mean"
            )
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
        progress_bar.update(1)

        # Update komponen Gradio
        train_step(step, epoch, loss.item())

    # Simpan model
    pipeline.save_pretrained(output_dir)