blanchon commited on
Commit
118bda1
·
1 Parent(s): be01f5b
Files changed (1) hide show
  1. app.py +140 -140
app.py CHANGED
@@ -1,142 +1,142 @@
1
- from typing import Any
2
-
3
- import gradio as gr
4
- import PIL
5
- import spaces
6
- import torch
7
- from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
8
- from hi_diffusers.schedulers.flash_flow_match import (
9
- FlashFlowMatchEulerDiscreteScheduler,
10
- )
11
- from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, AutoTokenizer
12
-
13
- # Constants
14
- MODEL_PREFIX: str = "HiDream-ai"
15
- LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
- MODEL_PATH = "HiDream-ai/HiDream-I1-Dev"
17
- MODEL_CONFIGS: dict[str, Any] = {
18
- "guidance_scale": 0.0,
19
- "num_inference_steps": 28,
20
- "shift": 6.0,
21
- "scheduler": FlashFlowMatchEulerDiscreteScheduler,
22
- }
23
-
24
- # Model configurations
25
- # MODEL_CONFIGS: dict[str, dict] = {
26
- # "full": {
27
- # "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
28
- # "guidance_scale": 5.0,
29
- # "num_inference_steps": 50,
30
- # "shift": 3.0,
31
- # "scheduler": FlowUniPCMultistepScheduler,
32
- # },
33
- # "fast": {
34
- # "path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
35
- # "guidance_scale": 0.0,
36
- # "num_inference_steps": 16,
37
- # "shift": 3.0,
38
- # "scheduler": FlashFlowMatchEulerDiscreteScheduler,
39
- # },
40
- # }
41
-
42
- # Supported image sizes
43
- RESOLUTION_OPTIONS: list[str] = [
44
- "1024 x 1024 (Square)",
45
- "768 x 1360 (Portrait)",
46
- "1360 x 768 (Landscape)",
47
- "880 x 1168 (Portrait)",
48
- "1168 x 880 (Landscape)",
49
- "1248 x 832 (Landscape)",
50
- "832 x 1248 (Portrait)",
51
- ]
52
-
53
-
54
- tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
55
- text_encoder = LlamaForCausalLM.from_pretrained(
56
- LLAMA_MODEL_NAME,
57
- output_hidden_states=True,
58
- output_attentions=True,
59
- torch_dtype=torch.bfloat16,
60
- ).to("cuda")
61
-
62
- transformer = HiDreamImageTransformer2DModel.from_pretrained(
63
- MODEL_PATH,
64
- subfolder="transformer",
65
- torch_dtype=torch.bfloat16,
66
- ).to("cuda")
67
-
68
- scheduler = MODEL_CONFIGS["scheduler"](
69
- num_train_timesteps=1000,
70
- shift=MODEL_CONFIGS["shift"],
71
- use_dynamic_shifting=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
- pipe = HiDreamImagePipeline.from_pretrained(
75
- MODEL_PATH,
76
- scheduler=scheduler,
77
- tokenizer_4=tokenizer,
78
- text_encoder_4=text_encoder,
79
- torch_dtype=torch.bfloat16,
80
- ).to("cuda", torch.bfloat16)
81
-
82
- pipe.transformer = transformer
83
-
84
-
85
- @spaces.GPU(duration=90)
86
- def generate_image(
87
- prompt: str,
88
- resolution: str,
89
- seed: int,
90
- ) -> tuple[PIL.Image.Image, int]:
91
- if seed == -1:
92
- seed = torch.randint(0, 1_000_000, (1,)).item()
93
-
94
- height, width = tuple(map(int, resolution.replace(" ", "").split("x")))
95
- generator = torch.Generator("cuda").manual_seed(seed)
96
-
97
- image = pipe(
98
- prompt=prompt,
99
- height=height,
100
- width=width,
101
- guidance_scale=MODEL_CONFIGS["guidance_scale"],
102
- num_inference_steps=MODEL_CONFIGS["num_inference_steps"],
103
- generator=generator,
104
- ).images[0]
105
-
106
- torch.cuda.empty_cache()
107
- return image, seed
108
-
109
-
110
- # Gradio UI
111
- with gr.Blocks(title="HiDream Image Generator") as demo:
112
- gr.Markdown("## 🌈 HiDream Image Generator")
113
-
114
- with gr.Row():
115
- with gr.Column():
116
- prompt = gr.Textbox(
117
- label="Prompt",
118
- placeholder="e.g. A futuristic city with floating cars at sunset",
119
- lines=3,
120
- )
121
-
122
- resolution = gr.Radio(
123
- choices=RESOLUTION_OPTIONS,
124
- value=RESOLUTION_OPTIONS[0],
125
- label="Resolution",
126
- )
127
-
128
- seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
129
- generate_btn = gr.Button("Generate Image", variant="primary")
130
- seed_used = gr.Number(label="Seed Used", interactive=False)
131
-
132
- with gr.Column():
133
- output_image = gr.Image(label="Generated Image", type="pil")
134
-
135
- generate_btn.click(
136
- fn=generate_image,
137
- inputs=[prompt, resolution, seed],
138
- outputs=[output_image, seed_used],
139
- )
140
-
141
- if __name__ == "__main__":
142
- demo.launch()
 
1
+ from typing import Any
2
+
3
+ import gradio as gr
4
+ import PIL
5
+ import spaces
6
+ import torch
7
+ from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
8
+ from hi_diffusers.schedulers.flash_flow_match import (
9
+ FlashFlowMatchEulerDiscreteScheduler,
10
+ )
11
+ from transformers import AutoTokenizer, LlamaForCausalLM
12
+
13
+ # Constants
14
+ MODEL_PREFIX: str = "HiDream-ai"
15
+ LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
+ MODEL_PATH = "HiDream-ai/HiDream-I1-Dev"
17
+ MODEL_CONFIGS: dict[str, Any] = {
18
+ "guidance_scale": 0.0,
19
+ "num_inference_steps": 28,
20
+ "shift": 6.0,
21
+ "scheduler": FlashFlowMatchEulerDiscreteScheduler,
22
+ }
23
+
24
+ # Model configurations
25
+ # MODEL_CONFIGS: dict[str, dict] = {
26
+ # "full": {
27
+ # "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
28
+ # "guidance_scale": 5.0,
29
+ # "num_inference_steps": 50,
30
+ # "shift": 3.0,
31
+ # "scheduler": FlowUniPCMultistepScheduler,
32
+ # },
33
+ # "fast": {
34
+ # "path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
35
+ # "guidance_scale": 0.0,
36
+ # "num_inference_steps": 16,
37
+ # "shift": 3.0,
38
+ # "scheduler": FlashFlowMatchEulerDiscreteScheduler,
39
+ # },
40
+ # }
41
+
42
+ # Supported image sizes
43
+ RESOLUTION_OPTIONS: list[str] = [
44
+ "1024 x 1024 (Square)",
45
+ "768 x 1360 (Portrait)",
46
+ "1360 x 768 (Landscape)",
47
+ "880 x 1168 (Portrait)",
48
+ "1168 x 880 (Landscape)",
49
+ "1248 x 832 (Landscape)",
50
+ "832 x 1248 (Portrait)",
51
+ ]
52
+
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
55
+ text_encoder = LlamaForCausalLM.from_pretrained(
56
+ LLAMA_MODEL_NAME,
57
+ output_hidden_states=True,
58
+ output_attentions=True,
59
+ torch_dtype=torch.bfloat16,
60
+ ).to("cuda")
61
+
62
+ transformer = HiDreamImageTransformer2DModel.from_pretrained(
63
+ MODEL_PATH,
64
+ subfolder="transformer",
65
+ torch_dtype=torch.bfloat16,
66
+ ).to("cuda")
67
+
68
+ scheduler = MODEL_CONFIGS["scheduler"](
69
+ num_train_timesteps=1000,
70
+ shift=MODEL_CONFIGS["shift"],
71
+ use_dynamic_shifting=False,
72
+ )
73
+
74
+ pipe = HiDreamImagePipeline.from_pretrained(
75
+ MODEL_PATH,
76
+ scheduler=scheduler,
77
+ tokenizer_4=tokenizer,
78
+ text_encoder_4=text_encoder,
79
+ torch_dtype=torch.bfloat16,
80
+ ).to("cuda", torch.bfloat16)
81
+
82
+ pipe.transformer = transformer
83
+
84
+
85
+ @spaces.GPU(duration=90)
86
+ def generate_image(
87
+ prompt: str,
88
+ resolution: str,
89
+ seed: int,
90
+ ) -> tuple[PIL.Image.Image, int]:
91
+ if seed == -1:
92
+ seed = torch.randint(0, 1_000_000, (1,)).item()
93
+
94
+ height, width = tuple(map(int, resolution.replace(" ", "").split("x")))
95
+ generator = torch.Generator("cuda").manual_seed(seed)
96
+
97
+ image = pipe(
98
+ prompt=prompt,
99
+ height=height,
100
+ width=width,
101
+ guidance_scale=MODEL_CONFIGS["guidance_scale"],
102
+ num_inference_steps=MODEL_CONFIGS["num_inference_steps"],
103
+ generator=generator,
104
+ ).images[0]
105
+
106
+ torch.cuda.empty_cache()
107
+ return image, seed
108
+
109
+
110
+ # Gradio UI
111
+ with gr.Blocks(title="HiDream Image Generator") as demo:
112
+ gr.Markdown("## 🌈 HiDream Image Generator")
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ prompt = gr.Textbox(
117
+ label="Prompt",
118
+ placeholder="e.g. A futuristic city with floating cars at sunset",
119
+ lines=3,
120
+ )
121
+
122
+ resolution = gr.Radio(
123
+ choices=RESOLUTION_OPTIONS,
124
+ value=RESOLUTION_OPTIONS[0],
125
+ label="Resolution",
126
+ )
127
+
128
+ seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
129
+ generate_btn = gr.Button("Generate Image", variant="primary")
130
+ seed_used = gr.Number(label="Seed Used", interactive=False)
131
+
132
+ with gr.Column():
133
+ output_image = gr.Image(label="Generated Image", type="pil")
134
+
135
+ generate_btn.click(
136
+ fn=generate_image,
137
+ inputs=[prompt, resolution, seed],
138
+ outputs=[output_image, seed_used],
139
  )
140
 
141
+ if __name__ == "__main__":
142
+ demo.launch()