eienmojiki commited on
Commit
f549a1f
·
1 Parent(s): 8ece4a0
Files changed (1) hide show
  1. app.py +63 -4
app.py CHANGED
@@ -1,7 +1,66 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ import random
3
+ from typing import Callable, Dict, Optional, Tuple
4
+
5
  import gradio as gr
6
+ import numpy as np
7
+ import PIL.Image
8
+ import spaces
9
+ import torch
10
+
11
+ from transformers import CLIPTextModel
12
+ from diffusers import AutoencoderKL, StableDiffusionXLPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
13
+
14
+ MODEL = "eienmojiki/Starry-XL-v5.2"
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
17
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
26
+ if randomize_seed:
27
+ seed = random.randint(0, MAX_SEED)
28
+ return seed
29
+
30
+ def seed_everything(seed: int) -> torch.Generator:
31
+ torch.manual_seed(seed)
32
+ torch.cuda.manual_seed_all(seed)
33
+ np.random.seed(seed)
34
+ generator = torch.Generator()
35
+ generator.manual_seed(seed)
36
+ return generator
37
+
38
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
39
+ scheduler_factory_map = {
40
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
41
+ scheduler_config, use_karras_sigmas=True
42
+ ),
43
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
44
+ scheduler_config, use_karras_sigmas=True
45
+ ),
46
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
47
+ scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
48
+ ),
49
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
50
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config),
51
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
52
+ }
53
+ return scheduler_factory_map.get(name, lambda: None)()
54
 
55
+ def load_pipeline(model_name):
56
+ pipe = StableDiffusionXLPipeline.from_pretrained(
57
+ model_name,
58
+ torch_dtype=torch.float16,
59
+ custom_pipeline="lpw_stable_diffusion_xl",
60
+ use_safetensors=True,
61
+ add_watermarker=False,
62
+ use_auth_token=HF_TOKEN,
63
+ )
64
+ pipe.to(device)
65
+ return pipe
66