qyoo commited on
Commit
08e5866
·
1 Parent(s): d2007a2
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -120,7 +120,7 @@ def change_model_fn(model_name: str) -> None:
120
  torch_dtype=torch.bfloat16,
121
  feature_extractor=None,
122
  safety_checker=None
123
- )
124
  pipeline = ConceptrolIPAdapterPlus(pipe, "", adapter_name, device, num_tokens=16)
125
  globals()["pipeline"] = pipeline
126
  else:
@@ -202,6 +202,7 @@ def generate(
202
  ) -> np.ndarray:
203
  global pipeline
204
  change_model_fn(model_name)
 
205
  if isinstance(pipeline, FluxConceptrolPipeline):
206
  images = pipeline(
207
  prompt=prompt,
 
120
  torch_dtype=torch.bfloat16,
121
  feature_extractor=None,
122
  safety_checker=None
123
+ ).to(device)
124
  pipeline = ConceptrolIPAdapterPlus(pipe, "", adapter_name, device, num_tokens=16)
125
  globals()["pipeline"] = pipeline
126
  else:
 
202
  ) -> np.ndarray:
203
  global pipeline
204
  change_model_fn(model_name)
205
+ print(image)
206
  if isinstance(pipeline, FluxConceptrolPipeline):
207
  images = pipeline(
208
  prompt=prompt,