doevent commited on
Commit
129447d
·
verified ·
1 Parent(s): 95de845

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -5,11 +5,22 @@ import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, FluxPipeline
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
 
8
 
9
  dtype = torch.bfloat16
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- pipe = FluxPipeline.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
 
 
 
 
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
@@ -20,14 +31,24 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
20
  if randomize_seed:
21
  seed = random.randint(0, MAX_SEED)
22
  generator = torch.Generator().manual_seed(seed)
23
- image = pipe(
24
- prompt = prompt,
25
- width = width,
26
- height = height,
27
- num_inference_steps = num_inference_steps,
28
- generator = generator,
29
- guidance_scale=guidance_scale
30
- ).images[0]
 
 
 
 
 
 
 
 
 
 
31
  return image, seed
32
 
33
  examples = [
 
5
  import torch
6
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, FluxPipeline
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
+ from huggingface_hub import hf_hub_download
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ repo_name = "ByteDance/Hyper-SD"
14
+ ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
15
+ hyper_lora = hf_hub_download(repo_name, ckpt_name)
16
+
17
+ pipe = FluxPipeline.from_pretrained(base_model_id, token="xxx")
18
+ pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
19
+ pipe.fuse_lora(lora_scale=0.125)
20
+ pipe.to("cuda", dtype=torch.float16)
21
+
22
+
23
+ # pipe = FluxPipeline.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16).to(device)
24
 
25
  MAX_SEED = np.iinfo(np.int32).max
26
  MAX_IMAGE_SIZE = 2048
 
31
  if randomize_seed:
32
  seed = random.randint(0, MAX_SEED)
33
  generator = torch.Generator().manual_seed(seed)
34
+ # image = pipe(
35
+ # prompt = prompt,
36
+ # width = width,
37
+ # height = height,
38
+ # num_inference_steps = num_inference_steps,
39
+ # generator = generator,
40
+ # guidance_scale=guidance_scale
41
+ # ).images[0]
42
+
43
+ image = pipe(prompt=prompt,
44
+ num_inference_steps=num_inference_steps,
45
+ guidance_scale=guidance_scale,
46
+ height=height,
47
+ width=width,
48
+ max_sequence_length=256,
49
+ num_inference_steps=8,
50
+ generator = generator,
51
+ guidance_scale=guidance_scale).images[0]
52
  return image, seed
53
 
54
  examples = [