erwold commited on
Commit
ac06db6
·
1 Parent(s): 660497c

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +37 -17
  2. requirements.txt +2 -1
app.py CHANGED
@@ -12,6 +12,7 @@ import logging
12
  import sys
13
 
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
 
15
  import spaces
16
 
17
  # 设置日志
@@ -25,6 +26,27 @@ logging.basicConfig(
25
  logger = logging.getLogger(__name__)
26
 
27
  MODEL_ID = "Djrango/Qwen2vl-Flux"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Add aspect ratio options
30
  ASPECT_RATIOS = {
@@ -61,33 +83,31 @@ class FluxInterface:
61
  torch.cuda.max_memory_allocated = lambda *args, **kwargs: 0 # 忽略已分配内存的限制
62
 
63
  # Load FLUX components
64
- tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
65
- text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder").to(self.dtype).to(self.device)
66
- text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2").to(self.dtype).to(self.device)
67
- tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
68
 
69
  # Load VAE and transformer
70
- vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux/vae").to(self.dtype).to(self.device)
71
- transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux/transformer").to(self.dtype).to(self.device)
72
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
73
 
74
  # Load Qwen2VL components
75
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl").to(self.dtype).to(self.device)
76
 
77
- # Load connector
78
  connector = Qwen2Connector().to(self.dtype).to(self.device)
79
- connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
80
- connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
81
- # Move state dict to dtype before loading
82
  connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
83
  connector.load_state_dict(connector_state)
84
  connector = connector.to(self.device)
85
-
86
- # Load T5 embedder
87
  self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
88
- t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
89
- t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
90
- # Move state dict to dtype before loading
91
  t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
92
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
93
  self.t5_context_embedder = self.t5_context_embedder.to(self.device)
 
12
  import sys
13
 
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
+ from huggingface_hub import snapshot_download
16
  import spaces
17
 
18
  # 设置日志
 
26
  logger = logging.getLogger(__name__)
27
 
28
  MODEL_ID = "Djrango/Qwen2vl-Flux"
29
+ MODEL_CACHE_DIR = "model_cache"
30
+
31
+ # 预下载所有模型
32
+ def download_models():
33
+ logger.info("Starting model download...")
34
+ try:
35
+ # 下载完整模型仓库
36
+ snapshot_download(
37
+ repo_id=MODEL_ID,
38
+ local_dir=MODEL_CACHE_DIR,
39
+ local_dir_use_symlinks=False
40
+ )
41
+
42
+ logger.info("Model download completed successfully")
43
+ except Exception as e:
44
+ logger.error(f"Error downloading models: {str(e)}")
45
+ raise
46
+
47
+ # 在脚本开始时下载模型
48
+ if not os.path.exists(MODEL_CACHE_DIR):
49
+ download_models()
50
 
51
  # Add aspect ratio options
52
  ASPECT_RATIOS = {
 
83
  torch.cuda.max_memory_allocated = lambda *args, **kwargs: 0 # 忽略已分配内存的限制
84
 
85
  # Load FLUX components
86
+ tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
87
+ text_encoder = CLIPTextModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")).to(self.dtype).to(self.device)
88
+ text_encoder_two = T5EncoderModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")).to(self.dtype).to(self.device)
89
+ tokenizer_two = T5TokenizerFast.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
90
 
91
  # Load VAE and transformer
92
+ vae = AutoencoderKL.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/vae")).to(self.dtype).to(self.device)
93
+ transformer = FluxTransformer2DModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/transformer")).to(self.dtype).to(self.device)
94
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/scheduler"), shift=1)
95
 
96
  # Load Qwen2VL components
97
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "qwen2-vl")).to(self.dtype).to(self.device)
98
 
99
+ # 加载 connector
100
  connector = Qwen2Connector().to(self.dtype).to(self.device)
101
+ connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
102
+ connector_state = torch.load(connector_path, map_location='cpu')
 
103
  connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
104
  connector.load_state_dict(connector_state)
105
  connector = connector.to(self.device)
106
+
107
+ # 加载 T5 embedder
108
  self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
109
+ t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
110
+ t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
 
111
  t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
112
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
113
  self.t5_context_embedder = self.t5_context_embedder.to(self.device)
requirements.txt CHANGED
@@ -11,4 +11,5 @@ numpy>=1.24.0
11
  # Utilities
12
  protobuf==4.23.4
13
  sentencepiece==0.2.0
14
- gradio==5.6.0
 
 
11
  # Utilities
12
  protobuf==4.23.4
13
  sentencepiece==0.2.0
14
+ gradio==5.6.0
15
+ huggingface-hub