File size: 10,133 Bytes
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05280f7
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fa1e8e
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275034c
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import gradio as gr
import numpy as np
from PIL import Image
import os.path as osp
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from semanticist.engine.trainer_utils import instantiate_from_config
from semanticist.stage1.diffuse_slot import DiffuseSlot
from semanticist.stage2.gpt import GPT_models
from semanticist.stage2.generate import generate
from safetensors import safe_open
from semanticist.utils.datasets import vae_transforms
from PIL import Image
from imagenet_classes import imagenet_classes

transform = vae_transforms('test')


def norm_ip(img, low, high):
    img.clamp_(min=low, max=high)
    img.sub_(low).div_(max(high - low, 1e-5))

def norm_range(t, value_range):
    if value_range is not None:
        norm_ip(t, value_range[0], value_range[1])
    else:
        norm_ip(t, float(t.min()), float(t.max()))

from PIL import Image
def convert_np(img):
    ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
            .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    return ndarr
def convert_PIL(img):
    ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
            .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    img = Image.fromarray(ndarr)
    return img

def norm_slots(slots):
    mean = torch.mean(slots, dim=-1, keepdim=True)
    std = torch.std(slots, dim=-1, keepdim=True)
    return (slots - mean) / std

def load_state_dict(state_dict, model):
    """Helper to load a state dict with proper prefix handling."""
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    # Remove '_orig_mod' prefix if present
    state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
    missing, unexpected = model.load_state_dict(
        state_dict, strict=False
    )
    # print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")

def load_safetensors(path, model):
    """Helper to load a safetensors checkpoint."""
    from safetensors.torch import safe_open
    with safe_open(path, framework="pt", device="cpu") as f:
        state_dict = {k: f.get_tensor(k) for k in f.keys()}
    load_state_dict(state_dict, model)

def load_checkpoint(ckpt_path, model):
    if ckpt_path is None or not osp.exists(ckpt_path):
        return

    if osp.isdir(ckpt_path):
        # ckpt_path is something like 'path/to/models/step10/'
        model_path = osp.join(ckpt_path, "model.safetensors")
        if osp.exists(model_path):
            load_safetensors(model_path, model)
    else:
        # ckpt_path is something like 'path/to/models/step10.pt'
        if ckpt_path.endswith(".safetensors"):
            load_safetensors(ckpt_path, model)
        else:
            state_dict = torch.load(ckpt_path, map_location="cpu")
            load_state_dict(state_dict, model)

    print(f"Loaded checkpoint from {ckpt_path}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Is CUDA available: {torch.cuda.is_available()}")
if device == 'cuda':
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_ar_gen_L.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
config_path = 'configs/autoregressive_xl.yaml'

cfg = OmegaConf.load(config_path)
params = cfg.trainer.params

ae_model = instantiate_from_config(params.ae_model).to(device)
ae_model_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_tok_XL.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
load_checkpoint(ae_model_path, ae_model)
ae_model.eval()

gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(device)
load_checkpoint(ckpt_path, gpt_model)
gpt_model.eval();

def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
    n_slots_inf = []
    for num_slots_to_inference in nums:
        drop_mask = model.nested_sampler(slots.shape[0], device, num_slots_to_inference)
        recon_n = model.sample(slots, drop_mask=drop_mask, cfg=cfg)
        n_slots_inf.append(recon_n)
    return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]

num_slots = params.ae_model.params.num_slots
slot_dim = params.ae_model.params.slot_dim
dtype = torch.bfloat16
# the model is trained with only 32 tokens.
num_slots_to_gen = 32

# Function to generate image from class
def generate_from_class(class_id, cfg_scale):
    with torch.no_grad():
        dtype = torch.float
        num_slots_to_gen = 32
        with torch.autocast(device, dtype=dtype):
            slots_gen = generate(
                gpt_model, 
                torch.tensor([class_id]).to(device), 
                num_slots_to_gen, 
                cfg_scale=cfg_scale, 
                cfg_schedule="linear"
            )
        if num_slots_to_gen < num_slots:
            null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
            null_slots = null_slots[:, num_slots_to_gen:, :]
            slots_gen = torch.cat([slots_gen, null_slots], dim=1)
    return slots_gen

with gr.Blocks() as demo:
    with gr.Row():
        # First column - Input and configs
        with gr.Column(scale=1):
            gr.Markdown("## Input")
            
            # Replace image input with ImageNet class selection
            imagenet_classes = {k: v for k, v in enumerate(imagenet_classes)}
            class_choices = [f"{id}: {name}" for id, name in imagenet_classes.items()]
            
            # Dropdown for class selection
            class_dropdown = gr.Dropdown(
                choices=class_choices,  # Limit for demonstration
                label="Select ImageNet Class",
                value=class_choices[0] if class_choices else None
            )
            
            # Option to enter class ID directly
            class_id_input = gr.Number(
                label="Or enter class ID directly (0-999)",
                value=0,
                minimum=0,
                maximum=999,
                step=1
            )
            
            with gr.Group():
                gr.Markdown("### Configuration")
                show_gallery = gr.Checkbox(label="Show Gallery", value=True)
                slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
                labels_input = gr.Textbox(
                    label="Number of tokens to reconstruct (comma-separated)", 
                    value="1, 2, 4, 8, 16", 
                    placeholder="Enter comma-separated numbers for the number of slots to use"
                )
        
        # Second column - Output (conditionally rendered)
        with gr.Column(scale=1):
            gr.Markdown("## Output")
            
            # Container for conditional rendering
            with gr.Group(visible=True) as gallery_container:
                gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
            
            # Always visible output image
            output_image = gr.Image(label="Generated Image", type="numpy")
    
    # Handle form submission
    submit_btn = gr.Button("Generate")
    
    # Define the processing logic
    def update_outputs(class_selection, class_id, show_gallery_value, slider_value, labels_text):
        # Determine which class to use - either from dropdown or direct input
        if class_selection:
            # Extract class ID from the dropdown selection
            selected_class_id = int(class_selection.split(":")[0])
        else:
            selected_class_id = int(class_id)
        
        # Update the visibility of the gallery container
        gallery_container.visible = show_gallery_value
        
        try:
            # Parse the labels from the text input
            if labels_text and "," in labels_text:
                labels = [int(label.strip()) for label in labels_text.split(",")]
            else:
                # Default labels if none provided or in wrong format
                labels = [1, 4, 16, 64, 256]
        except:
            labels = [1, 4, 16, 64, 256]
        
        while len(labels) < 3:
            labels.append(256)
        
        # Generate the image based on the selected class
        slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
        
        recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
        
        # Always generate the model decomposition for potential gallery display
        model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
        
        if not show_gallery_value:
            # If only the image should be shown, return just the processed image
            return gallery_container, [], recon
        else:
            # Create image variations and pair them with labels
            gallery_images = [
                (recon, f'Generated from class {selected_class_id}'),
            ] + [(img, 'Gen. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
            return gallery_container, gallery_images, recon
    
    # Connect the inputs and outputs
    submit_btn.click(
        fn=update_outputs,
        inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
        outputs=[gallery_container, gallery, output_image]
    )
    
    # Also update when checkbox changes
    show_gallery.change(
        fn=lambda value: gr.update(visible=value),
        inputs=[show_gallery],
        outputs=[gallery_container]
    )

    # Add examples
    examples = [
        # ["0: tench, Tinca tinca", 0, True, 4.0, "1,2,4,8,16"],
        ["1: goldfish", 1, True, 4.0, "1,2,4,8,16"],
        # ["2: great white shark, white shark", 2, True, 4.0, "1,2,4,8,16"],
    ]
    
    gr.Examples(
        examples=examples,
        inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
        outputs=[gallery_container, gallery, output_image],
        fn=update_outputs,
        cache_examples=False
    )

# Launch the demo
if __name__ == "__main__":
    demo.launch()