LanguageBind commited on
Commit
358828b
·
verified ·
1 Parent(s): b9b33a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -98,12 +98,10 @@ def img2b64(image_path):
98
  data_uri = f"data:image/jpeg;base64,{b64}"
99
  return data_uri
100
 
101
- @spaces.GPU(duration=900)
102
  @spaces.GPU
103
  def initialize_models(args):
104
  os.makedirs("tmp", exist_ok=True)
105
  # Paths
106
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
 
108
  quantization_config = BitsAndBytesConfig(
109
  load_in_4bit=True,
@@ -114,16 +112,16 @@ def initialize_models(args):
114
  # Load main model and task head
115
  model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained(
116
  args.model_path,
117
- torch_dtype=torch.bfloat16,
118
  attn_implementation="sdpa",
119
  quantization_config=quantization_config if args.nf4 else None,
120
- ).to(device)
121
  task_head = nn.Sequential(
122
  nn.Linear(3584, 10240),
123
  nn.SiLU(),
124
  nn.Dropout(0.3),
125
  nn.Linear(10240, 2)
126
- ).to(device)
127
  task_head.load_state_dict(torch.load(os.path.join(args.model_path, 'task_head_final.pt')))
128
  task_head.eval()
129
 
@@ -137,20 +135,20 @@ def initialize_models(args):
137
  args.flux_path,
138
  subfolder="text_encoder_2",
139
  quantization_config=quantization_config,
140
- torch_dtype=torch.bfloat16,
141
  )
142
  pipe = FluxPipeline.from_pretrained(
143
  args.flux_path,
144
  transformer=model.denoise_tower.denoiser,
145
  text_encoder_2=text_encoder_2,
146
- torch_dtype=torch.bfloat16,
147
- ).to(device)
148
  else:
149
  pipe = FluxPipeline.from_pretrained(
150
  args.flux_path,
151
  transformer=model.denoise_tower.denoiser,
152
- torch_dtype=torch.bfloat16,
153
- ).to(device)
154
  if args.offload:
155
  pipe.enable_model_cpu_offload()
156
  pipe.enable_vae_slicing()
@@ -162,8 +160,8 @@ def initialize_models(args):
162
  siglip_processor = SiglipImageProcessor.from_pretrained(args.siglip_path)
163
  siglip_model = SiglipVisionModel.from_pretrained(
164
  args.siglip_path,
165
- torch_dtype=torch.bfloat16,
166
- ).to(device)
167
 
168
  return {
169
  'model': model,
@@ -174,12 +172,23 @@ def initialize_models(args):
174
  'text_encoders': text_encoders,
175
  'siglip_processor': siglip_processor,
176
  'siglip_model': siglip_model,
177
- 'device': device,
178
  }
179
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  args = parse_args()
182
  state = initialize_models(args)
 
183
 
184
  @spaces.GPU
185
  def process_large_image(raw_img):
 
98
  data_uri = f"data:image/jpeg;base64,{b64}"
99
  return data_uri
100
 
 
101
  @spaces.GPU
102
  def initialize_models(args):
103
  os.makedirs("tmp", exist_ok=True)
104
  # Paths
 
105
 
106
  quantization_config = BitsAndBytesConfig(
107
  load_in_4bit=True,
 
112
  # Load main model and task head
113
  model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained(
114
  args.model_path,
115
+ torch_dtype=torch.float32,
116
  attn_implementation="sdpa",
117
  quantization_config=quantization_config if args.nf4 else None,
118
+ )
119
  task_head = nn.Sequential(
120
  nn.Linear(3584, 10240),
121
  nn.SiLU(),
122
  nn.Dropout(0.3),
123
  nn.Linear(10240, 2)
124
+ )
125
  task_head.load_state_dict(torch.load(os.path.join(args.model_path, 'task_head_final.pt')))
126
  task_head.eval()
127
 
 
135
  args.flux_path,
136
  subfolder="text_encoder_2",
137
  quantization_config=quantization_config,
138
+ torch_dtype=torch.float32,
139
  )
140
  pipe = FluxPipeline.from_pretrained(
141
  args.flux_path,
142
  transformer=model.denoise_tower.denoiser,
143
  text_encoder_2=text_encoder_2,
144
+ torch_dtype=torch.float32,
145
+ )
146
  else:
147
  pipe = FluxPipeline.from_pretrained(
148
  args.flux_path,
149
  transformer=model.denoise_tower.denoiser,
150
+ torch_dtype=torch.float32,
151
+ )
152
  if args.offload:
153
  pipe.enable_model_cpu_offload()
154
  pipe.enable_vae_slicing()
 
160
  siglip_processor = SiglipImageProcessor.from_pretrained(args.siglip_path)
161
  siglip_model = SiglipVisionModel.from_pretrained(
162
  args.siglip_path,
163
+ torch_dtype=torch.float32,
164
+ )
165
 
166
  return {
167
  'model': model,
 
172
  'text_encoders': text_encoders,
173
  'siglip_processor': siglip_processor,
174
  'siglip_model': siglip_model,
175
+
176
  }
177
 
178
+ @spaces.GPU(duration=600)
179
+ def to_device(state):
180
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181
+ state['model'] = state['model'].to(device, dtype=torch.bfloat16)
182
+ state['task_head'] = state['task_head'].to(device, dtype=torch.bfloat16)
183
+ state['pipe'] = state['pipe'].to(device, dtype=torch.bfloat16)
184
+ state['text_encoders'] = state['text_encoders'].to(device, dtype=torch.bfloat16)
185
+ state['siglip_model'] = state['siglip_model'].to(device, dtype=torch.bfloat16)
186
+ state['device'] = device
187
+ return state
188
 
189
  args = parse_args()
190
  state = initialize_models(args)
191
+ state = to_device(state)
192
 
193
  @spaces.GPU
194
  def process_large_image(raw_img):