silencer107 commited on
Commit
0a9e6a7
·
verified ·
1 Parent(s): 708f19a

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +8 -10
src/pipeline.py CHANGED
@@ -8,21 +8,19 @@ from diffusers import FluxPipeline, AutoencoderTiny
8
 
9
  Pipeline = None
10
 
11
- pipeline = FluxPipeline.from_pretrained(
12
- "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
13
- )
14
- pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16)
15
- pipeline.enable_sequential_cpu_offload()
 
16
 
17
  for _ in range(2):
18
- empty_cache()
19
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
20
- return pipeline
21
-
22
 
 
23
 
24
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
25
- empty_cache()
26
  generator = Generator("cuda").manual_seed(request.seed)
27
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
28
- return(image)
 
8
 
9
  Pipeline = None
10
 
11
+
12
+ def load_pipeline() -> Pipeline:
13
+ pipeline = FluxPipeline.from_pretrained(
14
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
15
+ pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16)
16
+ pipeline.enable_sequential_cpu_offload()
17
 
18
  for _ in range(2):
 
19
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
 
20
 
21
+ return pipeline
22
 
23
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
 
24
  generator = Generator("cuda").manual_seed(request.seed)
25
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
26
+ return(image)