wondervictor commited on
Commit
72cbce3
·
verified ·
1 Parent(s): dab917a
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -3,6 +3,8 @@ import time
3
  import os
4
  os.environ['PYTORCH_JIT'] = "0"
5
  os.system('mim install mmcv==2.0.1')
 
 
6
  # import spaces
7
  import sys
8
  import argparse
@@ -12,7 +14,7 @@ from functools import partial
12
 
13
  # import spaces
14
 
15
- from mmengine.runner import Runner
16
  from mmengine.dataset import Compose
17
  from mmengine.runner.amp import autocast
18
  from mmengine.config import Config, DictAction, ConfigDict
@@ -60,9 +62,10 @@ def generate_image_embeddings(prompt_image,
60
  projector,
61
  device='cuda:0'):
62
  prompt_image = prompt_image.convert('RGB')
63
- inputs = vision_processor(images=prompt_image,
64
  return_tensors="pt",
65
- padding=True)
 
66
  inputs = inputs.to(device)
67
  image_outputs = vision_encoder(**inputs)
68
  img_feats = image_outputs.image_embeds.view(1, -1)
@@ -283,10 +286,11 @@ if __name__ == '__main__':
283
 
284
  cfg.load_from = checkpoint
285
 
286
- if 'runner_type' not in cfg:
287
- runner = Runner.from_cfg(cfg)
288
- else:
289
- runner = RUNNERS.build(cfg)
 
290
  # runner.test()
291
  runner.call_hook('before_run')
292
  runner.load_or_resume()
 
3
  import os
4
  os.environ['PYTORCH_JIT'] = "0"
5
  os.system('mim install mmcv==2.0.1')
6
+ os.system("pip install 'numpy<2.0'")
7
+
8
  # import spaces
9
  import sys
10
  import argparse
 
14
 
15
  # import spaces
16
 
17
+ # from mmengine.runner import Runner
18
  from mmengine.dataset import Compose
19
  from mmengine.runner.amp import autocast
20
  from mmengine.config import Config, DictAction, ConfigDict
 
62
  projector,
63
  device='cuda:0'):
64
  prompt_image = prompt_image.convert('RGB')
65
+ inputs = vision_processor(images=[prompt_image],
66
  return_tensors="pt",
67
+ padding=False)
68
+
69
  inputs = inputs.to(device)
70
  image_outputs = vision_encoder(**inputs)
71
  img_feats = image_outputs.image_embeds.view(1, -1)
 
286
 
287
  cfg.load_from = checkpoint
288
 
289
+ # if 'runner_type' not in cfg:
290
+ # runner = Runner.from_cfg(cfg)
291
+ # else:
292
+ #
293
+ runner = RUNNERS.build(cfg)
294
  # runner.test()
295
  runner.call_hook('before_run')
296
  runner.load_or_resume()