File size: 3,422 Bytes
9a8789a
cb92d2b
 
9a8789a
cb92d2b
 
9a8789a
cb92d2b
 
 
 
 
 
 
592470d
cb92d2b
9a8789a
 
4b58964
bdf4b6f
8a96a46
 
e5edfc8
cb92d2b
9a8789a
a39f171
9a8789a
a39f171
 
 
9a8789a
 
 
 
 
 
 
 
 
 
 
cb92d2b
 
 
 
 
decd923
cb92d2b
 
 
 
 
 
 
 
31dbff3
592470d
31dbff3
 
 
cb92d2b
 
 
31dbff3
592470d
23d11db
31dbff3
 
cb92d2b
 
31dbff3
592470d
23d11db
31dbff3
 
cb92d2b
 
592470d
 
23d11db
592470d
 
cb92d2b
 
 
 
 
 
31dbff3
 
592470d
31dbff3
 
 
 
 
 
592470d
31dbff3
 
 
 
8a96a46
 
23d11db
8a96a46
 
 
 
 
23d11db
8a96a46
 
 
4b58964
 
 
 
 
 
a39f171
bdf4b6f
a39f171
 
bdf4b6f
a39f171
e5edfc8
 
 
 
 
 
592470d
cb92d2b
9a8789a
488b360
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from pydantic import BaseModel, field_validator
import argparse
import os
from typing import Annotated


class Args(BaseModel):
    host: str
    port: int
    reload: bool
    max_queue_size: int
    timeout: float
    safety_checker: bool
    torch_compile: bool
    taesd: bool
    pipeline: str
    ssl_certfile: str | None
    ssl_keyfile: str | None
    sfast: bool
    onediff: bool = False
    compel: bool = False
    debug: bool = False
    pruna: bool = False

    def pretty_print(self) -> None:
        print("\n")
        for field, value in self.model_dump().items():
            print(f"{field}: {value}")
        print("\n")

    @field_validator("ssl_keyfile")
    @classmethod
    def validate_ssl_keyfile(cls, v: str | None, info) -> str | None:
        """Validate that if ssl_certfile is provided, ssl_keyfile is also provided."""
        ssl_certfile = info.data.get("ssl_certfile")
        if ssl_certfile and not v:
            raise ValueError(
                "If ssl_certfile is provided, ssl_keyfile must also be provided"
            )
        return v


MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
TIMEOUT = float(os.environ.get("TIMEOUT", 0))
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True"
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) == "True"
USE_TAESD = os.environ.get("USE_TAESD", "False") == "True"
default_host = os.getenv("HOST", "0.0.0.0")
default_port = int(os.getenv("PORT", "7860"))

parser = argparse.ArgumentParser(description="Run the app")
parser.add_argument("--host", type=str, default=default_host, help="Host address")
parser.add_argument("--port", type=int, default=default_port, help="Port number")
parser.add_argument("--reload", action="store_true", help="Reload code on change")
parser.add_argument(
    "--max-queue-size",
    dest="max_queue_size",
    type=int,
    default=MAX_QUEUE_SIZE,
    help="Max Queue Size",
)
parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout")
parser.add_argument(
    "--safety-checker",
    dest="safety_checker",
    action="store_true",
    default=SAFETY_CHECKER,
    help="Safety Checker",
)
parser.add_argument(
    "--torch-compile",
    dest="torch_compile",
    action="store_true",
    default=TORCH_COMPILE,
    help="Torch Compile",
)
parser.add_argument(
    "--taesd",
    dest="taesd",
    action="store_true",
    help="Use Tiny Autoencoder",
)
parser.add_argument(
    "--pipeline",
    type=str,
    default="txt2img",
    help="Pipeline to use",
)
parser.add_argument(
    "--ssl-certfile",
    dest="ssl_certfile",
    type=str,
    default=None,
    help="SSL certfile",
)
parser.add_argument(
    "--ssl-keyfile",
    dest="ssl_keyfile",
    type=str,
    default=None,
    help="SSL keyfile",
)
parser.add_argument(
    "--debug",
    action="store_true",
    default=False,
    help="Debug",
)
parser.add_argument(
    "--compel",
    action="store_true",
    default=False,
    help="Compel",
)
parser.add_argument(
    "--sfast",
    action="store_true",
    default=False,
    help="Enable Stable Fast",
)
parser.add_argument(
    "--onediff",
    action="store_true",
    default=False,
    help="Enable OneDiff",
)
parser.add_argument(
    "--pruna",
    action="store_true",
    default=False,
    help="Enable Pruna",
)
parser.set_defaults(taesd=USE_TAESD)

config = Args.model_validate(vars(parser.parse_args()))
config.pretty_print()