cocktailpeanut commited on
Commit
64bcce2
·
1 Parent(s): 8c936a5
Files changed (2) hide show
  1. app.py +3 -1
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,18 +4,20 @@ from PIL import Image, ImageDraw, ImageFont
4
  from src.condition import Condition
5
  from diffusers.pipelines import FluxPipeline
6
  import numpy as np
 
7
 
8
  from src.generate import seed_everything, generate
9
 
10
  pipe = None
11
 
 
12
 
13
  def init_pipeline():
14
  global pipe
15
  pipe = FluxPipeline.from_pretrained(
16
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
17
  )
18
- pipe = pipe.to("cuda")
19
  pipe.load_lora_weights(
20
  "Yuanshi/OminiControl",
21
  weight_name=f"omini/subject_512.safetensors",
 
4
  from src.condition import Condition
5
  from diffusers.pipelines import FluxPipeline
6
  import numpy as np
7
+ import devicetorch
8
 
9
  from src.generate import seed_everything, generate
10
 
11
  pipe = None
12
 
13
+ device = devicetorch.get(torch)
14
 
15
  def init_pipeline():
16
  global pipe
17
  pipe = FluxPipeline.from_pretrained(
18
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
19
  )
20
+ pipe = pipe.to(device)
21
  pipe.load_lora_weights(
22
  "Yuanshi/OminiControl",
23
  weight_name=f"omini/subject_512.safetensors",
requirements.txt CHANGED
@@ -3,4 +3,5 @@ diffusers
3
  peft
4
  opencv-python
5
  protobuf
6
- sentencepiece
 
 
3
  peft
4
  opencv-python
5
  protobuf
6
+ sentencepiece
7
+ gradio