Boese0601 commited on
Commit
f3bd7d5
·
verified ·
1 Parent(s): 1c193eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -66
app.py CHANGED
@@ -9,44 +9,15 @@ from safetensors.torch import load_file
9
  from omegaconf import OmegaConf
10
 
11
  from image_datasets.dataset import image_resize
12
- def tensor_to_pil_image(in_image):
13
- tensor = in_image.squeeze(0)
14
- tensor = (tensor + 1) / 2
15
- tensor = tensor * 255
16
- numpy_array = tensor.permute(1, 2, 0).byte().numpy()
17
- pil_image = Image.fromarray(numpy_array)
18
- return pil_image
19
- # from src.flux.xflux_pipeline import XFluxSampler
20
  args = OmegaConf.load("inference_configs/inference.yaml")
21
- # is_schnell = args.model_name == "flux-schnell"
22
- # sampler = None
23
  device = torch.device("cuda")
24
  dtype = torch.bfloat16
25
- # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
26
- # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
27
- # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
28
- # clip = load_clip("cpu").to(device, dtype=dtype)
29
- #test push
30
  @spaces.GPU
31
  def generate(image: Image.Image, edit_prompt: str):
32
  from src.flux.xflux_pipeline import XFluxSampler
33
 
34
 
35
-
36
-
37
- # vae.requires_grad_(False)
38
- # t5.requires_grad_(False)
39
- # clip.requires_grad_(False)
40
-
41
- # model_path = hf_hub_download(
42
- # repo_id="Boese0601/ByteMorpher",
43
- # filename="dit.safetensors",
44
- # use_auth_token=os.getenv("HF_TOKEN")
45
- # )
46
- # state_dict = load_file(model_path)
47
- # dit.load_state_dict(state_dict)
48
- # dit.eval()
49
- # dit.to(device, dtype=dtype)
50
 
51
  sampler = XFluxSampler(
52
  device = device,
@@ -56,42 +27,7 @@ def generate(image: Image.Image, edit_prompt: str):
56
  image_encoder=None,
57
  improj=None
58
  )
59
- # global sampler
60
- # device = torch.device("cuda")
61
- # dtype = torch.bfloat16
62
-
63
- # if sampler is None:
64
- # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
65
- # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
66
- # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
67
- # clip = load_clip("cpu").to(device, dtype=dtype)
68
-
69
- # vae.requires_grad_(False)
70
- # t5.requires_grad_(False)
71
- # clip.requires_grad_(False)
72
-
73
- # model_path = hf_hub_download(
74
- # repo_id="Boese0601/ByteMorpher",
75
- # filename="dit.safetensors",
76
- # use_auth_token=os.getenv("HF_TOKEN")
77
- # )
78
- # state_dict = load_file(model_path)
79
- # dit.load_state_dict(state_dict)
80
- # dit.eval()
81
-
82
- # sampler = XFluxSampler(
83
- # clip=clip,
84
- # t5=t5,
85
- # ae=vae,
86
- # model=dit,
87
- # device=device,
88
- # ip_loaded=False,
89
- # spatial_condition=False,
90
- # clip_image_processor=None,
91
- # image_encoder=None,
92
- # improj=None
93
- # )
94
-
95
  img = image_resize(image, 512)
96
  w, h = img.size
97
  img = img.resize(((w // 32) * 32, (h // 32) * 32))
 
9
  from omegaconf import OmegaConf
10
 
11
  from image_datasets.dataset import image_resize
12
+
 
 
 
 
 
 
 
13
  args = OmegaConf.load("inference_configs/inference.yaml")
 
 
14
  device = torch.device("cuda")
15
  dtype = torch.bfloat16
 
 
 
 
 
16
  @spaces.GPU
17
  def generate(image: Image.Image, edit_prompt: str):
18
  from src.flux.xflux_pipeline import XFluxSampler
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  sampler = XFluxSampler(
23
  device = device,
 
27
  image_encoder=None,
28
  improj=None
29
  )
30
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  img = image_resize(image, 512)
32
  w, h = img.size
33
  img = img.resize(((w // 32) * 32, (h // 32) * 32))