blanchon commited on
Commit
0e8d10f
·
1 Parent(s): 230858b

change to app.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +0 -178
gradio_demo.py DELETED
@@ -1,178 +0,0 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
- from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
5
- from hi_diffusers.schedulers.flash_flow_match import (
6
- FlashFlowMatchEulerDiscreteScheduler,
7
- )
8
- from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
9
- from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
10
-
11
- # Constants
12
- MODEL_PREFIX: str = "HiDream-ai"
13
- LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
14
-
15
- # Model configurations
16
- MODEL_CONFIGS: dict[str, dict] = {
17
- "dev": {
18
- "path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
19
- "guidance_scale": 0.0,
20
- "num_inference_steps": 28,
21
- "shift": 6.0,
22
- "scheduler": FlashFlowMatchEulerDiscreteScheduler,
23
- },
24
- "full": {
25
- "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
26
- "guidance_scale": 5.0,
27
- "num_inference_steps": 50,
28
- "shift": 3.0,
29
- "scheduler": FlowUniPCMultistepScheduler,
30
- },
31
- "fast": {
32
- "path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
33
- "guidance_scale": 0.0,
34
- "num_inference_steps": 16,
35
- "shift": 3.0,
36
- "scheduler": FlashFlowMatchEulerDiscreteScheduler,
37
- },
38
- }
39
-
40
- # Supported image sizes
41
- RESOLUTION_OPTIONS: list[str] = [
42
- "1024 × 1024 (Square)",
43
- "768 × 1360 (Portrait)",
44
- "1360 × 768 (Landscape)",
45
- "880 × 1168 (Portrait)",
46
- "1168 × 880 (Landscape)",
47
- "1248 × 832 (Landscape)",
48
- "832 × 1248 (Portrait)",
49
- ]
50
-
51
- # Model cache
52
- loaded_models: dict[str, HiDreamImagePipeline] = {}
53
-
54
-
55
- def parse_resolution(res_str: str) -> tuple[int, int]:
56
- """Parse resolution string like '1024 × 1024' into (1024, 1024)"""
57
- return tuple(map(int, res_str.replace("×", "x").replace(" ", "").split("x")))
58
-
59
-
60
- def load_models(model_type: str) -> HiDreamImagePipeline:
61
- """Load and initialize the HiDream model pipeline for a given model type."""
62
- config = MODEL_CONFIGS[model_type]
63
- pretrained_model = config["path"]
64
-
65
- tokenizer = PreTrainedTokenizerFast.from_pretrained(
66
- LLAMA_MODEL_NAME, use_fast=False
67
- )
68
- text_encoder = LlamaForCausalLM.from_pretrained(
69
- LLAMA_MODEL_NAME,
70
- output_hidden_states=True,
71
- output_attentions=True,
72
- torch_dtype=torch.bfloat16,
73
- ).to("cuda")
74
-
75
- transformer = HiDreamImageTransformer2DModel.from_pretrained(
76
- pretrained_model,
77
- subfolder="transformer",
78
- torch_dtype=torch.bfloat16,
79
- ).to("cuda")
80
-
81
- scheduler = config["scheduler"](
82
- num_train_timesteps=1000,
83
- shift=config["shift"],
84
- use_dynamic_shifting=False,
85
- )
86
-
87
- pipe = HiDreamImagePipeline.from_pretrained(
88
- pretrained_model,
89
- scheduler=scheduler,
90
- tokenizer_4=tokenizer,
91
- text_encoder_4=text_encoder,
92
- torch_dtype=torch.bfloat16,
93
- ).to("cuda", torch.bfloat16)
94
-
95
- pipe.transformer = transformer
96
- return pipe
97
-
98
-
99
- # Preload default model
100
- print("🔧 Preloading default model (full)...")
101
- loaded_models["full"] = load_models("full")
102
- print("✅ Model loaded.")
103
-
104
-
105
- @spaces.GPU(duration=90)
106
- def generate_image(
107
- model_type: str,
108
- prompt: str,
109
- resolution: str,
110
- seed: int,
111
- ) -> tuple[object, int]:
112
- """Generate image using HiDream pipeline."""
113
- if model_type not in loaded_models:
114
- print(f"📦 Lazy-loading model {model_type}...")
115
- loaded_models[model_type] = load_models(model_type)
116
-
117
- pipe: HiDreamImagePipeline = loaded_models[model_type]
118
- config = MODEL_CONFIGS[model_type]
119
-
120
- if seed == -1:
121
- seed = torch.randint(0, 1_000_000, (1,)).item()
122
-
123
- height, width = parse_resolution(resolution)
124
- generator = torch.Generator("cuda").manual_seed(seed)
125
-
126
- image = pipe(
127
- prompt=prompt,
128
- height=height,
129
- width=width,
130
- guidance_scale=config["guidance_scale"],
131
- num_inference_steps=config["num_inference_steps"],
132
- generator=generator,
133
- ).images[0]
134
-
135
- torch.cuda.empty_cache()
136
- return image, seed
137
-
138
-
139
- # Gradio UI
140
- with gr.Blocks(title="HiDream Image Generator") as demo:
141
- gr.Markdown("## 🌈 HiDream Image Generator")
142
-
143
- with gr.Row():
144
- with gr.Column():
145
- model_type = gr.Radio(
146
- choices=list(MODEL_CONFIGS.keys()),
147
- value="full",
148
- label="Model Type",
149
- info="Choose between full, fast or dev variants",
150
- )
151
-
152
- prompt = gr.Textbox(
153
- label="Prompt",
154
- placeholder="e.g. A futuristic city with floating cars at sunset",
155
- lines=3,
156
- )
157
-
158
- resolution = gr.Radio(
159
- choices=RESOLUTION_OPTIONS,
160
- value=RESOLUTION_OPTIONS[0],
161
- label="Resolution",
162
- )
163
-
164
- seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
165
- generate_btn = gr.Button("Generate Image", variant="primary")
166
- seed_used = gr.Number(label="Seed Used", interactive=False)
167
-
168
- with gr.Column():
169
- output_image = gr.Image(label="Generated Image", type="pil")
170
-
171
- generate_btn.click(
172
- fn=generate_image,
173
- inputs=[model_type, prompt, resolution, seed],
174
- outputs=[output_image, seed_used],
175
- )
176
-
177
- if __name__ == "__main__":
178
- demo.launch()