seawolf2357 commited on
Commit
46cdca3
·
verified ·
1 Parent(s): f3cb7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -113
app.py CHANGED
@@ -33,7 +33,6 @@ if hf_token:
33
  else:
34
  print("Warning: HF_TOKEN not found in environment variables. You may encounter authentication issues.")
35
 
36
-
37
  def download_model():
38
  REPO_ID = 'Doubiiu/DynamiCrafter_1024'
39
  filename_list = ['model.ckpt']
@@ -45,11 +44,11 @@ def download_model():
45
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)
46
 
47
  download_model()
48
- ckpt_path='checkpoints/dynamicrafter_1024_v1/model.ckpt'
49
- config_file='configs/inference_1024_v1.0.yaml'
50
  config = OmegaConf.load(config_file)
51
  model_config = config.pop("model", OmegaConf.create())
52
- model_config['params']['unet_config']['params']['use_checkpoint']=False
53
  model = instantiate_from_config(model_config)
54
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
55
  model = load_model_checkpoint(model, ckpt_path)
@@ -67,11 +66,18 @@ flux_pipe = FluxPipeline.from_pretrained(
67
  )
68
  flux_pipe.enable_model_cpu_offload()
69
 
 
 
 
 
 
 
70
 
71
  def generate_image_from_text(prompt, seed=0):
 
72
  generator = torch.Generator("cpu").manual_seed(seed)
73
  image = flux_pipe(
74
- prompt,
75
  height=576,
76
  width=1024,
77
  guidance_scale=3.5,
@@ -83,158 +89,96 @@ def generate_image_from_text(prompt, seed=0):
83
 
84
  @spaces.GPU(duration=600)
85
  def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, video_length=2):
86
- # 한글 입력 감지 및 번역
87
- if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
88
- translated = translator(prompt, max_length=512)[0]['translation_text']
89
- prompt = translated
90
- print(f"Translated prompt: {prompt}")
91
-
92
  resolution = (576, 1024)
93
  save_fps = 8
94
  seed_everything(seed)
95
  transform = transforms.Compose([
96
- transforms.Resize(min(resolution)),
97
  transforms.CenterCrop(resolution),
98
- ])
99
  torch.cuda.empty_cache()
100
- print('Start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
101
  start = time.time()
102
  if steps > 60:
103
  steps = 60
104
-
105
  batch_size = 1
106
  channels = model.model.diffusion_model.out_channels
107
- frames = int(video_length * save_fps) # 비디오 길이에 따른 프레임 수 계산
108
  h, w = resolution[0] // 8, resolution[1] // 8
109
  noise_shape = [batch_size, channels, frames, h, w]
110
-
111
- # 텍스트 조건 설정
112
  with torch.no_grad(), torch.cuda.amp.autocast():
113
- text_emb = model.get_learned_conditioning([prompt])
114
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
115
  img_tensor = (img_tensor / 255. - 0.5) * 2
116
  image_tensor_resized = transform(img_tensor).unsqueeze(0) # bchw
117
-
118
  z = get_latent_z(model, image_tensor_resized.unsqueeze(2)) #bc,1,hw
119
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
120
  cond_images = model.embedder(img_tensor.unsqueeze(0)) # blc
121
  img_emb = model.image_proj_model(cond_images)
122
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
123
-
124
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
125
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
126
-
127
- # 추론 실행
128
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
129
-
130
  video_path = './output.mp4'
131
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
132
  return video_path
133
 
134
-
135
- @spaces.GPU(duration=300)
136
- def infer_t2v(prompt, video_prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, video_length=2):
137
- # 이미지 생성
138
- image = generate_image_from_text(prompt, seed)
139
-
140
- # 이미지를 numpy 배열로 변환
141
- image_np = np.array(image)
142
-
143
- # 비디오 생성을 위해 기존 infer 함수 호출
144
- return infer(image_np, video_prompt, steps, cfg_scale, eta, fs, seed, video_length)
145
-
146
- i2v_examples = [
147
- ['prompts/1024/astronaut04.png', 'a man in an astronaut suit playing a guitar', 30, 7.5, 1.0, 6, 123, 2],
148
- ]
149
-
150
  css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
151
 
152
- def generate_only_image(prompt, seed=123):
153
- # 이미지 생성
154
- image = generate_image_from_text(prompt, seed)
155
-
156
- # PIL 이미지로 변환 후 반환
157
- return Image.fromarray(np.array(image))
158
-
159
  with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
160
  gr.Markdown("kAI 무비 스튜디오")
161
-
162
-
163
- with gr.Tab(label='Image(+Text) Generation'):
164
  with gr.Column():
165
  with gr.Row():
166
- with gr.Column():
167
- img_input_text = gr.Text(label='Image Generation Prompt')
168
- img_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
169
- img_generate_btn = gr.Button("Generate Image")
170
- with gr.Row():
171
- img_output_image = gr.Image(label="Generated Image")
172
-
173
  img_generate_btn.click(
174
  inputs=[img_input_text, img_seed],
175
  outputs=[img_output_image],
176
- fn=generate_only_image
177
- )
178
-
179
-
180
  with gr.Tab(label='Image to Video Generation'):
181
  with gr.Column():
182
  with gr.Row():
183
- with gr.Column():
184
- with gr.Row():
185
- i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
186
- with gr.Row():
187
- i2v_input_text = gr.Text(label='Prompts')
188
- with gr.Row():
189
- i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
190
- i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
191
- i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
192
- with gr.Row():
193
- i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=30)
194
- i2v_motion = gr.Slider(minimum=5, maximum=20, step=1, elem_id="i2v_motion", label="FPS", value=8)
195
- with gr.Row():
196
- i2v_video_length = gr.Slider(minimum=2, maximum=8, step=1, elem_id="i2v_video_length", label="Video Length (seconds)", value=2)
197
- i2v_end_btn = gr.Button("Generate")
198
- with gr.Row():
199
- i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
200
-
201
- gr.Examples(examples=i2v_examples,
202
- inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_video_length],
203
- outputs=[i2v_output_video],
204
- fn = infer,
205
- cache_examples=True,
206
  )
207
- i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_video_length],
208
- outputs=[i2v_output_video],
209
- fn = infer
210
- )
211
-
212
  with gr.Tab(label='Text to Video Generation'):
213
  with gr.Column():
214
  with gr.Row():
215
- with gr.Column():
216
- with gr.Row():
217
- t2v_input_text = gr.Text(label='Image Generation Prompt') # 이미지 생성을 위한 프롬프트 입력
218
- t2v_video_prompt = gr.Text(label='Video Generation Prompt') # 비디오 생성을 위한 프롬프트 입력
219
- with gr.Row():
220
- t2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
221
- t2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0)
222
- t2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5)
223
- with gr.Row():
224
- t2v_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Sampling steps", value=30)
225
- t2v_motion = gr.Slider(minimum=5, maximum=20, step=1, label="FPS", value=8)
226
- with gr.Row():
227
- t2v_video_length = gr.Slider(minimum=2, maximum=8, step=1, label="Video Length (seconds)", value=2)
228
- t2v_end_btn = gr.Button("Generate")
229
- with gr.Row():
230
- t2v_output_video = gr.Video(label="Generated Video", autoplay=True, show_share_button=True)
231
-
232
- t2v_end_btn.click(
233
- inputs=[t2v_input_text, t2v_video_prompt, t2v_steps, t2v_cfg_scale, t2v_eta, t2v_motion, t2v_seed, t2v_video_length],
234
- outputs=[t2v_output_video],
235
  fn=infer_t2v
236
- )
237
-
238
-
239
 
240
- dynamicrafter_iface.queue(max_size=12).launch(show_api=True)
 
33
  else:
34
  print("Warning: HF_TOKEN not found in environment variables. You may encounter authentication issues.")
35
 
 
36
  def download_model():
37
  REPO_ID = 'Doubiiu/DynamiCrafter_1024'
38
  filename_list = ['model.ckpt']
 
44
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_1024_v1/', force_download=True)
45
 
46
  download_model()
47
+ ckpt_path = 'checkpoints/dynamicrafter_1024_v1/model.ckpt'
48
+ config_file = 'configs/inference_1024_v1.0.yaml'
49
  config = OmegaConf.load(config_file)
50
  model_config = config.pop("model", OmegaConf.create())
51
+ model_config['params']['unet_config']['params']['use_checkpoint'] = False
52
  model = instantiate_from_config(model_config)
53
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
54
  model = load_model_checkpoint(model, ckpt_path)
 
66
  )
67
  flux_pipe.enable_model_cpu_offload()
68
 
69
+ def translate_prompt(prompt):
70
+ # 한글 입력 감지 및 번역
71
+ if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
72
+ translated = translator(prompt, max_length=512)[0]['translation_text']
73
+ return translated
74
+ return prompt
75
 
76
  def generate_image_from_text(prompt, seed=0):
77
+ translated_prompt = translate_prompt(prompt)
78
  generator = torch.Generator("cpu").manual_seed(seed)
79
  image = flux_pipe(
80
+ translated_prompt,
81
  height=576,
82
  width=1024,
83
  guidance_scale=3.5,
 
89
 
90
  @spaces.GPU(duration=600)
91
  def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, video_length=2):
92
+ translated_prompt = translate_prompt(prompt)
93
+ print(f"Translated prompt: {translated_prompt}")
 
 
 
 
94
  resolution = (576, 1024)
95
  save_fps = 8
96
  seed_everything(seed)
97
  transform = transforms.Compose([
98
+ transforms.Resize(min(resolution), antialias=True),
99
  transforms.CenterCrop(resolution),
100
+ ])
101
  torch.cuda.empty_cache()
102
+ print('Start:', translated_prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
103
  start = time.time()
104
  if steps > 60:
105
  steps = 60
 
106
  batch_size = 1
107
  channels = model.model.diffusion_model.out_channels
108
+ frames = int(video_length * save_fps)
109
  h, w = resolution[0] // 8, resolution[1] // 8
110
  noise_shape = [batch_size, channels, frames, h, w]
 
 
111
  with torch.no_grad(), torch.cuda.amp.autocast():
112
+ text_emb = model.get_learned_conditioning([translated_prompt])
113
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
114
  img_tensor = (img_tensor / 255. - 0.5) * 2
115
  image_tensor_resized = transform(img_tensor).unsqueeze(0) # bchw
 
116
  z = get_latent_z(model, image_tensor_resized.unsqueeze(2)) #bc,1,hw
117
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
118
  cond_images = model.embedder(img_tensor.unsqueeze(0)) # blc
119
  img_emb = model.image_proj_model(cond_images)
120
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
 
121
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
122
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
 
 
123
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
 
124
  video_path = './output.mp4'
125
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
126
  return video_path
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
129
 
 
 
 
 
 
 
 
130
  with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
131
  gr.Markdown("kAI 무비 스튜디오")
132
+ with gr.Tab(label='Image Generation'):
 
 
133
  with gr.Column():
134
  with gr.Row():
135
+ img_input_text = gr.Text(label='Image Generation Prompt')
136
+ img_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
137
+ img_generate_btn = gr.Button("Generate Image")
138
+ with gr.Row():
139
+ img_output_image = gr.Image(label="Generated Image")
 
 
140
  img_generate_btn.click(
141
  inputs=[img_input_text, img_seed],
142
  outputs=[img_output_image],
143
+ fn=generate_image_from_text
144
+ )
 
 
145
  with gr.Tab(label='Image to Video Generation'):
146
  with gr.Column():
147
  with gr.Row():
148
+ video_input_image = gr.Image(label="Input Image for Video", tool="input")
149
+ video_prompt = gr.Text(label='Video Generation Prompt')
150
+ video_seed = gr.Slider(label='Random Seed', minimum=0, maximum 10000, step=1, value=123)
151
+ video_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=30)
152
+ video_cfg_scale = gr.Slider(label='CFG Scale', minimum=1.0, maximum=15.0, step=0.5, value=7.5)
153
+ video_eta = gr.Slider(label='ETA', minimum=0.0, maximum=1.0, step=0.1, value=1.0)
154
+ video_fs = gr.Slider(label='FS', minimum=1, maximum=10, step=1, value=3)
155
+ video_length = gr.Slider(label="Video Length (seconds)", minimum=2, maximum=8, step=1, value=2)
156
+ video_generate_btn = gr.Button("Generate Video")
157
+ with gr.Row():
158
+ video_output = gr.Video(label="Generated Video", autoplay=True, show_share_button=True)
159
+ video_generate_btn.click(
160
+ inputs=[video_input_image, video_prompt, video_seed, video_steps, video_cfg_scale, video_eta, video_fs, video_length],
161
+ outputs=[video_output],
162
+ fn=infer
 
 
 
 
 
 
 
 
163
  )
164
+
 
 
 
 
165
  with gr.Tab(label='Text to Video Generation'):
166
  with gr.Column():
167
  with gr.Row():
168
+ video_prompt = gr.Text(label='Video Generation Prompt')
169
+ video_seed = gr.Slider(label='Random Seed', minimum=0, maximum 10000, step=1, value=123)
170
+ video_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=30)
171
+ video_cfg_scale = gr.Slider(label='CFG Scale', minimum=1.0, maximum=15.0, step=0.5, value=7.5)
172
+ video_eta = gr.Slider(label='ETA', minimum=0.0, maximum=1.0, step=0.1, value=1.0)
173
+ video_fs = gr.Slider(label='FS', minimum=1, maximum 10, step=1, value=3)
174
+ video_length = gr.Slider(label="Video Length (seconds)", minimum=2, maximum 8, step=1, value=2)
175
+ video_generate_btn = gr.Button("Generate Video")
176
+ with gr.Row():
177
+ video_output = gr.Video(label="Generated Video", autoplay=True, show_share_button=True)
178
+ video_generate_btn.click(
179
+ inputs=[video_prompt, video_seed, video_steps, video_cfg_scale, video_eta, video_fs, video_length],
180
+ outputs=[video_output],
 
 
 
 
 
 
 
181
  fn=infer_t2v
182
+ )
 
 
183
 
184
+ dynamicrafter_iface.launch(show_api=True)