Spaces:
Running
on
A100
Running
on
A100
File size: 3,422 Bytes
9a8789a cb92d2b 9a8789a cb92d2b 9a8789a cb92d2b 592470d cb92d2b 9a8789a 4b58964 bdf4b6f 8a96a46 e5edfc8 cb92d2b 9a8789a a39f171 9a8789a a39f171 9a8789a cb92d2b decd923 cb92d2b 31dbff3 592470d 31dbff3 cb92d2b 31dbff3 592470d 23d11db 31dbff3 cb92d2b 31dbff3 592470d 23d11db 31dbff3 cb92d2b 592470d 23d11db 592470d cb92d2b 31dbff3 592470d 31dbff3 592470d 31dbff3 8a96a46 23d11db 8a96a46 23d11db 8a96a46 4b58964 a39f171 bdf4b6f a39f171 bdf4b6f a39f171 e5edfc8 592470d cb92d2b 9a8789a 488b360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from pydantic import BaseModel, field_validator
import argparse
import os
from typing import Annotated
class Args(BaseModel):
host: str
port: int
reload: bool
max_queue_size: int
timeout: float
safety_checker: bool
torch_compile: bool
taesd: bool
pipeline: str
ssl_certfile: str | None
ssl_keyfile: str | None
sfast: bool
onediff: bool = False
compel: bool = False
debug: bool = False
pruna: bool = False
def pretty_print(self) -> None:
print("\n")
for field, value in self.model_dump().items():
print(f"{field}: {value}")
print("\n")
@field_validator("ssl_keyfile")
@classmethod
def validate_ssl_keyfile(cls, v: str | None, info) -> str | None:
"""Validate that if ssl_certfile is provided, ssl_keyfile is also provided."""
ssl_certfile = info.data.get("ssl_certfile")
if ssl_certfile and not v:
raise ValueError(
"If ssl_certfile is provided, ssl_keyfile must also be provided"
)
return v
MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
TIMEOUT = float(os.environ.get("TIMEOUT", 0))
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True"
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) == "True"
USE_TAESD = os.environ.get("USE_TAESD", "False") == "True"
default_host = os.getenv("HOST", "0.0.0.0")
default_port = int(os.getenv("PORT", "7860"))
parser = argparse.ArgumentParser(description="Run the app")
parser.add_argument("--host", type=str, default=default_host, help="Host address")
parser.add_argument("--port", type=int, default=default_port, help="Port number")
parser.add_argument("--reload", action="store_true", help="Reload code on change")
parser.add_argument(
"--max-queue-size",
dest="max_queue_size",
type=int,
default=MAX_QUEUE_SIZE,
help="Max Queue Size",
)
parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout")
parser.add_argument(
"--safety-checker",
dest="safety_checker",
action="store_true",
default=SAFETY_CHECKER,
help="Safety Checker",
)
parser.add_argument(
"--torch-compile",
dest="torch_compile",
action="store_true",
default=TORCH_COMPILE,
help="Torch Compile",
)
parser.add_argument(
"--taesd",
dest="taesd",
action="store_true",
help="Use Tiny Autoencoder",
)
parser.add_argument(
"--pipeline",
type=str,
default="txt2img",
help="Pipeline to use",
)
parser.add_argument(
"--ssl-certfile",
dest="ssl_certfile",
type=str,
default=None,
help="SSL certfile",
)
parser.add_argument(
"--ssl-keyfile",
dest="ssl_keyfile",
type=str,
default=None,
help="SSL keyfile",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Debug",
)
parser.add_argument(
"--compel",
action="store_true",
default=False,
help="Compel",
)
parser.add_argument(
"--sfast",
action="store_true",
default=False,
help="Enable Stable Fast",
)
parser.add_argument(
"--onediff",
action="store_true",
default=False,
help="Enable OneDiff",
)
parser.add_argument(
"--pruna",
action="store_true",
default=False,
help="Enable Pruna",
)
parser.set_defaults(taesd=USE_TAESD)
config = Args.model_validate(vars(parser.parse_args()))
config.pretty_print()
|