openfree commited on
Commit
8a835ac
·
verified ·
1 Parent(s): 62ec1ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -11
app.py CHANGED
@@ -16,6 +16,13 @@ from einops import rearrange
16
  from scipy.io import wavfile
17
  from transformers import pipeline
18
 
 
 
 
 
 
 
 
19
  # 환경 변수 설정으로 torch.load 체크 우회 (임시 해결책)
20
  os.environ["TRANSFORMERS_ALLOW_UNSAFE_DESERIALIZATION"] = "1"
21
 
@@ -45,7 +52,29 @@ from mmaudio.model.networks import MMAudio, get_my_mmaudio
45
  from mmaudio.model.sequence_config import SequenceConfig
46
  from mmaudio.model.utils.features_utils import FeaturesUtils
47
 
48
- # ControlNet 모델 로드
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
  from controlnet_union import ControlNetModel_Union
51
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
@@ -94,18 +123,14 @@ except Exception as e:
94
  logging.error(f"Failed to load outpainting models: {str(e)}")
95
  OUTPAINT_MODEL_LOADED = False
96
 
97
- # MMAudio 모델 설정
98
  if torch.cuda.is_available():
99
- device = torch.device("cuda")
100
- torch.backends.cuda.matmul.allow_tf32 = True
101
- torch.backends.cudnn.allow_tf32 = True
102
- torch.backends.cudnn.benchmark = True
103
  else:
104
  device = torch.device("cpu")
 
105
 
106
- dtype = torch.bfloat16
107
-
108
- # MMAudio 모델 초기화
109
  try:
110
  model_mmaudio: ModelConfig = all_model_cfg['large_44k_v2']
111
  model_mmaudio.download_if_needed()
@@ -155,7 +180,7 @@ VIDEO_API_URL = "http://211.233.58.201:7875"
155
  # 로깅 설정
156
  logging.basicConfig(level=logging.INFO)
157
 
158
- # Image size presets
159
  IMAGE_PRESETS = {
160
  "커스텀": {"width": 1024, "height": 1024},
161
  "1:1 정사각형": {"width": 1024, "height": 1024},
@@ -172,6 +197,7 @@ IMAGE_PRESETS = {
172
  "LinkedIn 배너": {"width": 1584, "height": 396},
173
  }
174
 
 
175
  def update_dimensions(preset):
176
  if preset in IMAGE_PRESETS:
177
  return IMAGE_PRESETS[preset]["width"], IMAGE_PRESETS[preset]["height"]
@@ -431,6 +457,113 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
431
  duration_sec=seq_cfg.duration)
432
  return video_save_path
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  # CSS
435
  css = """
436
  :root {
@@ -456,7 +589,7 @@ css = """
456
  padding: 20px !important;
457
  margin-bottom: 20px !important;
458
  }
459
- #generate-btn, #video-btn, #outpaint-btn, #preview-btn, #audio-btn {
460
  background: linear-gradient(135deg, #ff9a9e, #fad0c4) !important;
461
  font-size: 1.1rem !important;
462
  padding: 12px 24px !important;
@@ -652,6 +785,110 @@ with demo:
652
 
653
  if not MMAUDIO_MODEL_LOADED:
654
  gr.Markdown("⚠️ MMAudio 모델을 로드하지 못했습니다. 이 기능은 사용할 수 없습니다.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
  # 이벤트 연결 - 첫 번째 탭
657
  size_preset.change(update_dimensions, [size_preset], [width, height])
@@ -689,5 +926,29 @@ with demo:
689
  [audio_video_input, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg, audio_duration],
690
  [output_video_with_audio]
691
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
  demo.launch()
 
16
  from scipy.io import wavfile
17
  from transformers import pipeline
18
 
19
+ # 비디오 배경제거를 위한 추가 import
20
+ from transformers import AutoModelForImageSegmentation
21
+ from torchvision import transforms
22
+ from moviepy import VideoFileClip, vfx, concatenate_videoclips, ImageSequenceClip
23
+ import time
24
+ from concurrent.futures import ThreadPoolExecutor
25
+
26
  # 환경 변수 설정으로 torch.load 체크 우회 (임시 해결책)
27
  os.environ["TRANSFORMERS_ALLOW_UNSAFE_DESERIALIZATION"] = "1"
28
 
 
52
  from mmaudio.model.sequence_config import SequenceConfig
53
  from mmaudio.model.utils.features_utils import FeaturesUtils
54
 
55
+ # 기존 코드의 모든 설정과 초기화 부분 유지
56
+ torch.set_float32_matmul_precision("medium")
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+
59
+ # BiRefNet 모델 로드
60
+ try:
61
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
62
+ birefnet.to(device)
63
+ birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
64
+ birefnet_lite.to(device)
65
+
66
+ transform_image = transforms.Compose([
67
+ transforms.Resize((768, 768)),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
70
+ ])
71
+
72
+ BIREFNET_MODEL_LOADED = True
73
+ except Exception as e:
74
+ logging.error(f"Failed to load BiRefNet models: {str(e)}")
75
+ BIREFNET_MODEL_LOADED = False
76
+
77
+ # ControlNet 모델 로드 (기존 코드)
78
  try:
79
  from controlnet_union import ControlNetModel_Union
80
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
123
  logging.error(f"Failed to load outpainting models: {str(e)}")
124
  OUTPAINT_MODEL_LOADED = False
125
 
126
+ # MMAudio 모델 설정 (기존 코드)
127
  if torch.cuda.is_available():
128
+ dtype = torch.bfloat16
 
 
 
129
  else:
130
  device = torch.device("cpu")
131
+ dtype = torch.float32
132
 
133
+ # MMAudio 모델 초기화 (기존 코드)
 
 
134
  try:
135
  model_mmaudio: ModelConfig = all_model_cfg['large_44k_v2']
136
  model_mmaudio.download_if_needed()
 
180
  # 로깅 설정
181
  logging.basicConfig(level=logging.INFO)
182
 
183
+ # Image size presets (기존 코드)
184
  IMAGE_PRESETS = {
185
  "커스텀": {"width": 1024, "height": 1024},
186
  "1:1 정사각형": {"width": 1024, "height": 1024},
 
197
  "LinkedIn 배너": {"width": 1584, "height": 396},
198
  }
199
 
200
+ # 기존 함수들 모두 유지
201
  def update_dimensions(preset):
202
  if preset in IMAGE_PRESETS:
203
  return IMAGE_PRESETS[preset]["width"], IMAGE_PRESETS[preset]["height"]
 
457
  duration_sec=seq_cfg.duration)
458
  return video_save_path
459
 
460
+ # 비디오 배경제거 관련 함수들
461
+ def process_bg_image(image, bg, fast_mode=False):
462
+ """단일 이미지 배경 처리"""
463
+ if not BIREFNET_MODEL_LOADED:
464
+ return image
465
+
466
+ image_size = image.size
467
+ input_images = transform_image(image).unsqueeze(0).to(device)
468
+ model = birefnet_lite if fast_mode else birefnet
469
+
470
+ with torch.no_grad():
471
+ preds = model(input_images)[-1].sigmoid().cpu()
472
+ pred = preds[0].squeeze()
473
+ pred_pil = transforms.ToPILImage()(pred)
474
+ mask = pred_pil.resize(image_size)
475
+
476
+ if isinstance(bg, str) and bg.startswith("#"):
477
+ color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5))
478
+ background = Image.new("RGBA", image_size, color_rgb + (255,))
479
+ elif isinstance(bg, Image.Image):
480
+ background = bg.convert("RGBA").resize(image_size)
481
+ else:
482
+ background = Image.open(bg).convert("RGBA").resize(image_size)
483
+
484
+ image = Image.composite(image, background, mask)
485
+ return image
486
+
487
+ def process_video_frame(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color):
488
+ """비디오 프레임 처리"""
489
+ try:
490
+ pil_image = Image.fromarray(frame)
491
+ if bg_type == "색상":
492
+ processed_image = process_bg_image(pil_image, color, fast_mode)
493
+ elif bg_type == "이미지":
494
+ processed_image = process_bg_image(pil_image, bg, fast_mode)
495
+ elif bg_type == "비디오":
496
+ background_frame = background_frames[bg_frame_index]
497
+ bg_frame_index += 1
498
+ background_image = Image.fromarray(background_frame)
499
+ processed_image = process_bg_image(pil_image, background_image, fast_mode)
500
+ else:
501
+ processed_image = pil_image
502
+ return np.array(processed_image), bg_frame_index
503
+ except Exception as e:
504
+ print(f"Error processing frame: {e}")
505
+ return frame, bg_frame_index
506
+
507
+ @spaces.GPU
508
+ def process_video_bg(vid, bg_type="색상", bg_image=None, bg_video=None, color="#00FF00",
509
+ fps=0, video_handling="slow_down", fast_mode=True, max_workers=10):
510
+ """비디오 배경 처리 메인 함수"""
511
+ if not BIREFNET_MODEL_LOADED:
512
+ yield gr.update(visible=False), gr.update(visible=True), "BiRefNet 모델을 로드하지 못했습니다."
513
+ yield None, None, "BiRefNet 모델을 로드하지 못했습니다."
514
+ return
515
+
516
+ try:
517
+ start_time = time.time()
518
+ video = VideoFileClip(vid)
519
+ if fps == 0:
520
+ fps = video.fps
521
+
522
+ audio = video.audio
523
+ frames = list(video.iter_frames(fps=fps))
524
+
525
+ processed_frames = []
526
+ yield gr.update(visible=True), gr.update(visible=False), f"처리 시작... 경과 시간: 0초"
527
+
528
+ if bg_type == "비디오":
529
+ background_video = VideoFileClip(bg_video)
530
+ if background_video.duration < video.duration:
531
+ if video_handling == "slow_down":
532
+ background_video = background_video.fx(vfx.speedx, factor=video.duration / background_video.duration)
533
+ else: # video_handling == "loop"
534
+ background_video = concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
535
+ background_frames = list(background_video.iter_frames(fps=fps))
536
+ else:
537
+ background_frames = None
538
+
539
+ bg_frame_index = 0
540
+
541
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
542
+ futures = [executor.submit(process_video_frame, frames[i], bg_type, bg_image, fast_mode,
543
+ bg_frame_index + i, background_frames, color) for i in range(len(frames))]
544
+ for i, future in enumerate(futures):
545
+ result, _ = future.result()
546
+ processed_frames.append(result)
547
+ elapsed_time = time.time() - start_time
548
+ yield result, None, f"프레임 {i+1}/{len(frames)} 처리 중... 경과 시간: {elapsed_time:.2f}초"
549
+
550
+ processed_video = ImageSequenceClip(processed_frames, fps=fps)
551
+ processed_video = processed_video.with_audio(audio)
552
+
553
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
554
+ temp_filepath = temp_file.name
555
+ processed_video.write_videofile(temp_filepath, codec="libx264")
556
+
557
+ elapsed_time = time.time() - start_time
558
+ yield gr.update(visible=False), gr.update(visible=True), f"처리 완료! 경과 시간: {elapsed_time:.2f}초"
559
+ yield processed_frames[-1], temp_filepath, f"처리 완료! 경과 시간: {elapsed_time:.2f}초"
560
+
561
+ except Exception as e:
562
+ print(f"Error: {e}")
563
+ elapsed_time = time.time() - start_time
564
+ yield gr.update(visible=False), gr.update(visible=True), f"비디오 처리 오류: {e}. 경과 시간: {elapsed_time:.2f}초"
565
+ yield None, f"비디오 처리 오류: {e}", f"비디오 처리 오류: {e}. 경과 시간: {elapsed_time:.2f}초"
566
+
567
  # CSS
568
  css = """
569
  :root {
 
589
  padding: 20px !important;
590
  margin-bottom: 20px !important;
591
  }
592
+ #generate-btn, #video-btn, #outpaint-btn, #preview-btn, #audio-btn, #bg-remove-btn {
593
  background: linear-gradient(135deg, #ff9a9e, #fad0c4) !important;
594
  font-size: 1.1rem !important;
595
  padding: 12px 24px !important;
 
785
 
786
  if not MMAUDIO_MODEL_LOADED:
787
  gr.Markdown("⚠️ MMAudio 모델을 로드하지 못했습니다. 이 기능은 사용할 수 없습니다.")
788
+
789
+ # 네 번째 탭: 비디오 배경제거/합성
790
+ with gr.Tab("비디오 배경제거/합성", elem_classes="tabitem"):
791
+ with gr.Row(equal_height=True):
792
+ # 입력 컬럼
793
+ with gr.Column(scale=1):
794
+ with gr.Group(elem_classes="panel-box"):
795
+ gr.Markdown("### 🎥 비디오 업로드")
796
+
797
+ bg_video_input = gr.Video(
798
+ label="입력 비디오",
799
+ interactive=True
800
+ )
801
+
802
+ with gr.Group(elem_classes="panel-box"):
803
+ gr.Markdown("### 🎨 배경 설정")
804
+
805
+ bg_type = gr.Radio(
806
+ ["색상", "이미지", "비디오"],
807
+ label="배경 유형",
808
+ value="색상",
809
+ interactive=True
810
+ )
811
+
812
+ color_picker = gr.ColorPicker(
813
+ label="배경 색상",
814
+ value="#00FF00",
815
+ visible=True,
816
+ interactive=True
817
+ )
818
+
819
+ bg_image_input = gr.Image(
820
+ label="배경 이미지",
821
+ type="filepath",
822
+ visible=False,
823
+ interactive=True
824
+ )
825
+
826
+ bg_video_bg = gr.Video(
827
+ label="배경 비디오",
828
+ visible=False,
829
+ interactive=True
830
+ )
831
+
832
+ with gr.Column(visible=False) as video_handling_options:
833
+ video_handling_radio = gr.Radio(
834
+ ["slow_down", "loop"],
835
+ label="비디오 처리 방식",
836
+ value="slow_down",
837
+ interactive=True,
838
+ info="slow_down: 배경 비디오를 느리게 재생, loop: 배경 비디오를 반복"
839
+ )
840
+
841
+ with gr.Group(elem_classes="panel-box"):
842
+ gr.Markdown("### ⚙️ 처리 설정")
843
+
844
+ fps_slider = gr.Slider(
845
+ minimum=0,
846
+ maximum=60,
847
+ step=1,
848
+ value=0,
849
+ label="출력 FPS (0 = 원본 FPS 유지)",
850
+ interactive=True
851
+ )
852
+
853
+ fast_mode_checkbox = gr.Checkbox(
854
+ label="빠른 모드 (BiRefNet_lite 사용)",
855
+ value=True,
856
+ interactive=True
857
+ )
858
+
859
+ max_workers_slider = gr.Slider(
860
+ minimum=1,
861
+ maximum=32,
862
+ step=1,
863
+ value=10,
864
+ label="최대 워커 수",
865
+ info="병렬로 처리할 프레임 수",
866
+ interactive=True
867
+ )
868
+
869
+ bg_remove_btn = gr.Button("🎬 배경 변경", variant="primary", elem_id="bg-remove-btn")
870
+
871
+ if not BIREFNET_MODEL_LOADED:
872
+ gr.Markdown("⚠️ BiRefNet 모델을 로드하지 못했습니다. 이 기능은 사용할 수 없습니다.")
873
+
874
+ # 출력 컬럼
875
+ with gr.Column(scale=1):
876
+ with gr.Group(elem_classes="panel-box"):
877
+ gr.Markdown("### 🎬 처리 결과")
878
+
879
+ stream_image = gr.Image(label="실시간 스트리밍", visible=False)
880
+ output_bg_video = gr.Video(label="최종 비디오")
881
+ time_textbox = gr.Textbox(label="경과 시간", interactive=False)
882
+
883
+ gr.Markdown("""
884
+ ### ℹ️ 사용 방법
885
+ 1. 비디오를 업로드하세요
886
+ 2. 원하는 배경 유형을 선택하세요
887
+ 3. 설정을 조정하고 '배경 변경' 버튼을 클릭하세요
888
+
889
+ **참고**: GPU 제한으로 한 번에 약 200프레임까지 처리 가능합니다.
890
+ 긴 비디오는 작은 조각으로 나누어 처리하세요.
891
+ """)
892
 
893
  # 이벤트 연결 - 첫 번째 탭
894
  size_preset.change(update_dimensions, [size_preset], [width, height])
 
926
  [audio_video_input, audio_prompt, audio_negative_prompt, audio_seed, audio_steps, audio_cfg, audio_duration],
927
  [output_video_with_audio]
928
  )
929
+
930
+ # 이벤트 연결 - 네 번째 탭
931
+ def update_bg_visibility(bg_type):
932
+ if bg_type == "색상":
933
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
934
+ elif bg_type == "이미지":
935
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
936
+ elif bg_type == "비디오":
937
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
938
+ else:
939
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
940
+
941
+ bg_type.change(
942
+ update_bg_visibility,
943
+ inputs=bg_type,
944
+ outputs=[color_picker, bg_image_input, bg_video_bg, video_handling_options]
945
+ )
946
+
947
+ bg_remove_btn.click(
948
+ process_video_bg,
949
+ inputs=[bg_video_input, bg_type, bg_image_input, bg_video_bg, color_picker,
950
+ fps_slider, video_handling_radio, fast_mode_checkbox, max_workers_slider],
951
+ outputs=[stream_image, output_bg_video, time_textbox]
952
+ )
953
 
954
  demo.launch()