blanchon commited on
Commit
8feabfd
Β·
1 Parent(s): 0e8d10f

change to app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
5
+ from hi_diffusers.schedulers.flash_flow_match import (
6
+ FlashFlowMatchEulerDiscreteScheduler,
7
+ )
8
+ from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
9
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
10
+
11
+ # Constants
12
+ MODEL_PREFIX: str = "HiDream-ai"
13
+ LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
14
+
15
+ # Model configurations
16
+ MODEL_CONFIGS: dict[str, dict] = {
17
+ "dev": {
18
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
19
+ "guidance_scale": 0.0,
20
+ "num_inference_steps": 28,
21
+ "shift": 6.0,
22
+ "scheduler": FlashFlowMatchEulerDiscreteScheduler,
23
+ },
24
+ "full": {
25
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
26
+ "guidance_scale": 5.0,
27
+ "num_inference_steps": 50,
28
+ "shift": 3.0,
29
+ "scheduler": FlowUniPCMultistepScheduler,
30
+ },
31
+ "fast": {
32
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
33
+ "guidance_scale": 0.0,
34
+ "num_inference_steps": 16,
35
+ "shift": 3.0,
36
+ "scheduler": FlashFlowMatchEulerDiscreteScheduler,
37
+ },
38
+ }
39
+
40
+ # Supported image sizes
41
+ RESOLUTION_OPTIONS: list[str] = [
42
+ "1024 Γ— 1024 (Square)",
43
+ "768 Γ— 1360 (Portrait)",
44
+ "1360 Γ— 768 (Landscape)",
45
+ "880 Γ— 1168 (Portrait)",
46
+ "1168 Γ— 880 (Landscape)",
47
+ "1248 Γ— 832 (Landscape)",
48
+ "832 Γ— 1248 (Portrait)",
49
+ ]
50
+
51
+ # Model cache
52
+ loaded_models: dict[str, HiDreamImagePipeline] = {}
53
+
54
+
55
+ def parse_resolution(res_str: str) -> tuple[int, int]:
56
+ """Parse resolution string like '1024 Γ— 1024' into (1024, 1024)"""
57
+ return tuple(map(int, res_str.replace("Γ—", "x").replace(" ", "").split("x")))
58
+
59
+
60
+ def load_models(model_type: str) -> HiDreamImagePipeline:
61
+ """Load and initialize the HiDream model pipeline for a given model type."""
62
+ config = MODEL_CONFIGS[model_type]
63
+ pretrained_model = config["path"]
64
+
65
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
66
+ LLAMA_MODEL_NAME, use_fast=False
67
+ )
68
+ text_encoder = LlamaForCausalLM.from_pretrained(
69
+ LLAMA_MODEL_NAME,
70
+ output_hidden_states=True,
71
+ output_attentions=True,
72
+ torch_dtype=torch.bfloat16,
73
+ ).to("cuda")
74
+
75
+ transformer = HiDreamImageTransformer2DModel.from_pretrained(
76
+ pretrained_model,
77
+ subfolder="transformer",
78
+ torch_dtype=torch.bfloat16,
79
+ ).to("cuda")
80
+
81
+ scheduler = config["scheduler"](
82
+ num_train_timesteps=1000,
83
+ shift=config["shift"],
84
+ use_dynamic_shifting=False,
85
+ )
86
+
87
+ pipe = HiDreamImagePipeline.from_pretrained(
88
+ pretrained_model,
89
+ scheduler=scheduler,
90
+ tokenizer_4=tokenizer,
91
+ text_encoder_4=text_encoder,
92
+ torch_dtype=torch.bfloat16,
93
+ ).to("cuda", torch.bfloat16)
94
+
95
+ pipe.transformer = transformer
96
+ return pipe
97
+
98
+
99
+ # Preload default model
100
+ print("πŸ”§ Preloading default model (full)...")
101
+ loaded_models["full"] = load_models("full")
102
+ print("βœ… Model loaded.")
103
+
104
+
105
+ @spaces.GPU(duration=90)
106
+ def generate_image(
107
+ model_type: str,
108
+ prompt: str,
109
+ resolution: str,
110
+ seed: int,
111
+ ) -> tuple[object, int]:
112
+ """Generate image using HiDream pipeline."""
113
+ if model_type not in loaded_models:
114
+ print(f"πŸ“¦ Lazy-loading model {model_type}...")
115
+ loaded_models[model_type] = load_models(model_type)
116
+
117
+ pipe: HiDreamImagePipeline = loaded_models[model_type]
118
+ config = MODEL_CONFIGS[model_type]
119
+
120
+ if seed == -1:
121
+ seed = torch.randint(0, 1_000_000, (1,)).item()
122
+
123
+ height, width = parse_resolution(resolution)
124
+ generator = torch.Generator("cuda").manual_seed(seed)
125
+
126
+ image = pipe(
127
+ prompt=prompt,
128
+ height=height,
129
+ width=width,
130
+ guidance_scale=config["guidance_scale"],
131
+ num_inference_steps=config["num_inference_steps"],
132
+ generator=generator,
133
+ ).images[0]
134
+
135
+ torch.cuda.empty_cache()
136
+ return image, seed
137
+
138
+
139
+ # Gradio UI
140
+ with gr.Blocks(title="HiDream Image Generator") as demo:
141
+ gr.Markdown("## 🌈 HiDream Image Generator")
142
+
143
+ with gr.Row():
144
+ with gr.Column():
145
+ model_type = gr.Radio(
146
+ choices=list(MODEL_CONFIGS.keys()),
147
+ value="full",
148
+ label="Model Type",
149
+ info="Choose between full, fast or dev variants",
150
+ )
151
+
152
+ prompt = gr.Textbox(
153
+ label="Prompt",
154
+ placeholder="e.g. A futuristic city with floating cars at sunset",
155
+ lines=3,
156
+ )
157
+
158
+ resolution = gr.Radio(
159
+ choices=RESOLUTION_OPTIONS,
160
+ value=RESOLUTION_OPTIONS[0],
161
+ label="Resolution",
162
+ )
163
+
164
+ seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
165
+ generate_btn = gr.Button("Generate Image", variant="primary")
166
+ seed_used = gr.Number(label="Seed Used", interactive=False)
167
+
168
+ with gr.Column():
169
+ output_image = gr.Image(label="Generated Image", type="pil")
170
+
171
+ generate_btn.click(
172
+ fn=generate_image,
173
+ inputs=[model_type, prompt, resolution, seed],
174
+ outputs=[output_image, seed_used],
175
+ )
176
+
177
+ if __name__ == "__main__":
178
+ demo.launch()