ironjr commited on
Commit
3a36c2e
·
verified ·
1 Parent(s): 99c6089

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -94,13 +94,23 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
94
  print(device)
95
 
96
 
97
- model_dict = {
98
- 'Blazing Drive V13md': 'ironjr/BlazingDriveV13md',
99
- # 'Real Cartoon Pixar V5': 'ironjr/RealCartoon-PixarV5',
100
- }
 
 
 
 
 
 
 
 
 
 
101
 
102
  models = {
103
- k: SemanticDrawPipeline(device, sd_version='1.5', hf_key=v).cuda()
104
  for k, v in model_dict.items()
105
  }
106
 
 
94
  print(device)
95
 
96
 
97
+ if opt.model is None:
98
+ model_dict = {
99
+ # 'Blazing Drive V11m': 'ironjr/BlazingDriveV11m',
100
+ # 'Real Cartoon Pixar V5': 'ironjr/RealCartoon-PixarV5',
101
+ 'Kohaku V2.1': 'KBlueLeaf/kohaku-v2.1',
102
+ # 'Realistic Vision V5.1': 'ironjr/RealisticVisionV5-1',
103
+ # 'Stable Diffusion V1.5': 'runwayml/stable-diffusion-v1-5',
104
+ }
105
+ else:
106
+ if opt.model.endswith('.safetensors'):
107
+ opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
108
+ model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
109
+
110
+ dtype = torch.float32 if device == 'cpu' else torch.float16
111
 
112
  models = {
113
+ k: SemanticDrawPipeline(device, dtype=dtype, sd_version='1.5', hf_key=v, has_i2t=False)
114
  for k, v in model_dict.items()
115
  }
116