Spaces:
mikitona
/
Running on Zero

mikitona commited on
Commit
de3f501
·
verified ·
1 Parent(s): c4b24c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -39,19 +39,35 @@ parser.add_argument("--decoder_tile_size", type=int, default=64)
39
  parser.add_argument("--load_8bit_llava", action='store_true', default=True)
40
  args = parser.parse_args()
41
 
42
- if torch.cuda.device_count() > 0:
 
 
 
 
 
43
  SUPIR_device = 'cuda:0'
 
 
 
 
44
 
45
- # Load SUPIR
46
- model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
47
- if args.loading_half_params:
48
- model = model.half()
49
- if args.use_tile_vae:
50
- model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
51
- model = model.to(SUPIR_device)
52
- model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
53
- model.current_model = 'v0-Q'
54
- ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
 
 
 
 
 
 
 
55
 
56
  def check_upload(input_image):
57
  if input_image is None:
 
39
  parser.add_argument("--load_8bit_llava", action='store_true', default=True)
40
  args = parser.parse_args()
41
 
42
+ use_llava = not args.no_llava
43
+
44
+ if torch.cuda.device_count() >= 2:
45
+ SUPIR_device = 'cuda:0'
46
+ LLaVA_device = 'cuda:1'
47
+ elif torch.cuda.device_count() == 1:
48
  SUPIR_device = 'cuda:0'
49
+ LLaVA_device = 'cuda:0'
50
+ else:
51
+ raise ValueError('Currently support CUDA only.')
52
+
53
 
54
+ # Load SUPIR
55
+ model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
56
+ if args.loading_half_params:
57
+ model = model.half()
58
+ if args.use_tile_vae:
59
+ model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
60
+ model = model.to(SUPIR_device)
61
+ model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
62
+ model.current_model = 'v0-Q'
63
+ ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
64
+
65
+ # load LLaVA
66
+ if use_llava:
67
+ llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
68
+ else:
69
+ llava_agent = None
70
+
71
 
72
  def check_upload(input_image):
73
  if input_image is None: