ginipick commited on
Commit
7393451
·
verified ·
1 Parent(s): 997f240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -1,7 +1,15 @@
1
- # 가장 상단에 패치를 적용 (모델 로딩 전에 실행)
2
- from transformers import PreTrainedModel
 
 
3
  PreTrainedModel.tie_weights = lambda self: None
4
 
 
 
 
 
 
 
5
  import os
6
  import cv2
7
  import numpy as np
@@ -14,7 +22,6 @@ from typing import Tuple, Optional
14
 
15
  from PIL import Image
16
  from gradio_imageslider import ImageSlider
17
- from transformers import AutoModelForImageSegmentation
18
  from torchvision import transforms
19
 
20
  import requests
@@ -81,7 +88,7 @@ usage_to_weights_file = {
81
 
82
  # 초기 모델 로딩 (기본: General)
83
  birefnet = AutoModelForImageSegmentation.from_pretrained(
84
- '/'.join(('zhengpeng7', usage_to_weights_file['General'])),
85
  trust_remote_code=True
86
  )
87
  birefnet.to(device)
 
1
+ # --- 패치 적용: 모델 로딩 전에 실행 ---
2
+ from transformers import AutoModelForImageSegmentation, PreTrainedModel
3
+
4
+ # tie_weights를 빈 함수로 오버라이드 (이미 존재하는 경우)
5
  PreTrainedModel.tie_weights = lambda self: None
6
 
7
+ # 모델의 config 클래스에 dummy get_text_config 메서드를 추가
8
+ def dummy_get_text_config(self, decoder=True):
9
+ return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
10
+ AutoModelForImageSegmentation.config_class.get_text_config = dummy_get_text_config
11
+ # --- 패치 종료 ---
12
+
13
  import os
14
  import cv2
15
  import numpy as np
 
22
 
23
  from PIL import Image
24
  from gradio_imageslider import ImageSlider
 
25
  from torchvision import transforms
26
 
27
  import requests
 
88
 
89
  # 초기 모델 로딩 (기본: General)
90
  birefnet = AutoModelForImageSegmentation.from_pretrained(
91
+ '/'.join(('zhengpeng7', usage_to_weights_file['General'])),
92
  trust_remote_code=True
93
  )
94
  birefnet.to(device)