alfredplpl commited on
Commit
518d5a1
1 Parent(s): 452bb51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py CHANGED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://huggingface.co/spaces/black-forest-labs/FLUX.1-schnell/blob/main/app.py
2
+ import spaces
3
+ import gradio as gr
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+ import torch
8
+ from diffusers import Transformer2DModel, PixArtSigmaPipeline, AutoencoderKL, DPMSolverMultistepScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverSDEScheduler
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, QuantoConfig, EetqConfig
10
+
11
+ device = "cuda"
12
+ weight_dtype = torch.float32
13
+ weight_dtype_te = torch.bfloat16
14
+ MAX_SEED = np.iinfo(np.int32).max
15
+
16
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=weight_dtype)
17
+ scheduler=DPMSolverMultistepScheduler()
18
+ pipe = PixArtSigmaPipeline(
19
+ vae=vae,
20
+ tokenizer=None,
21
+ text_encoder=None,
22
+ transformer=transformer,
23
+ scheduler=scheduler
24
+ )
25
+
26
+ pipe.to(device)
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b")
29
+ text_encoder = AutoModelForCausalLM.from_pretrained(
30
+ "cyberagent/calm2-7b",
31
+ torch_dtype=weight_dtype_te,
32
+ device_map=device
33
+ )
34
+
35
+ @spaces.GPU()
36
+ def infer(prompt, seed=42, randomize_seed=False, width=512, height=512, num_inference_steps=20, progress=gr.Progress(track_tqdm=True)):
37
+ if randomize_seed:
38
+ seed = random.randint(0, MAX_SEED)
39
+ generator = torch.Generator().manual_seed(seed)
40
+ with torch.no_grad():
41
+ pos_ids = tokenizer(
42
+ prompt, max_length=512, padding="max_length", truncation=True, return_tensors="pt",
43
+ ).to(device)
44
+ pos_emb = text_encoder(pos_ids.input_ids, output_hidden_states=True, attention_mask=pos_ids.attention_mask)
45
+ pos_emb = pos_emb.hidden_states[-1]
46
+ neg_ids = tokenizer(
47
+ "", max_length=512, padding="max_length", truncation=True, return_tensors="pt",
48
+ ).to(device)
49
+ neg_emb = text_encoder(neg_ids.input_ids, output_hidden_states=True, attention_mask=neg_ids.attention_mask)
50
+ neg_emb = neg_emb.hidden_states[-1]
51
+
52
+ image = pipe(
53
+ negative_prompt=None,
54
+ prompt_embeds=pos_emb,
55
+ negative_prompt_embeds=neg_emb,
56
+ prompt_attention_mask=pos_ids.attention_mask,
57
+ negative_prompt_attention_mask=neg_ids.attention_mask,
58
+ max_sequence_length=512,
59
+ width=width,
60
+ height=height,
61
+ num_inference_steps=num_inference_steps,
62
+ generator=generator,
63
+ guidance_scale=4.5).images[0]
64
+ return image, seed
65
+
66
+ examples = [
67
+ "芝生の上にあるピザ",
68
+ "東京の桜と建物。満開の桜の木が並び、ピンク色の花びらが風に舞っている。桜の背景には東京の高層ビルや伝統的な建物が調和して立っている。春の陽光が全体を明るく照らし、桜と建物が美しく映えている。都市の活気と自然の美しさが融合した風景。",
69
+ "パリは燃えているか",
70
+ ]
71
+
72
+ css="""
73
+ #col-container {
74
+ margin: 0 auto;
75
+ max-width: 520px;
76
+ }
77
+ """
78
+
79
+ with gr.Blocks(css=css) as demo:
80
+
81
+ with gr.Column(elem_id="col-container"):
82
+ gr.Markdown(f"""# CommonArt β
83
+ 商用利用できる透明性の高い日本語画像生成AI
84
+ """)
85
+
86
+ with gr.Row():
87
+
88
+ prompt = gr.Text(
89
+ label="テキスト",
90
+ show_label=False,
91
+ max_lines=1,
92
+ placeholder="生成したいものを日本語や英語で説明してください",
93
+ container=False,
94
+ )
95
+
96
+ run_button = gr.Button("生成", scale=0)
97
+
98
+ result = gr.Image(label="生成結果", show_label=False)
99
+
100
+ with gr.Accordion("詳細設定", open=False):
101
+
102
+ seed = gr.Slider(
103
+ label="シード値",
104
+ minimum=0,
105
+ maximum=MAX_SEED,
106
+ step=1,
107
+ value=0,
108
+ )
109
+
110
+ randomize_seed = gr.Checkbox(label="ランダム", value=True)
111
+
112
+ with gr.Row():
113
+
114
+ width = gr.Slider(
115
+ label="幅",
116
+ minimum=256,
117
+ maximum=768,
118
+ step=64,
119
+ value=512,
120
+ )
121
+
122
+ height = gr.Slider(
123
+ label="高さ",
124
+ minimum=256,
125
+ maximum=768,
126
+ step=64,
127
+ value=512,
128
+ )
129
+
130
+ with gr.Row():
131
+
132
+
133
+ num_inference_steps = gr.Slider(
134
+ label="推論回数",
135
+ minimum=1,
136
+ maximum=50,
137
+ step=1,
138
+ value=20,
139
+ )
140
+
141
+ gr.Examples(
142
+ examples = examples,
143
+ fn = infer,
144
+ inputs = [prompt],
145
+ outputs = [result, seed],
146
+ cache_examples="lazy"
147
+ )
148
+
149
+ gr.on(
150
+ triggers=[run_button.click, prompt.submit],
151
+ fn = infer,
152
+ inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
153
+ outputs = [result, seed]
154
+ )
155
+
156
+ demo.launch()