Files changed (1) hide show
  1. models.py +242 -0
models.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from pydantic import BaseModel, Field, create_model
3
+ from typing import Any, Optional
4
+ from typing_extensions import Literal
5
+ from inflection import underscore
6
+ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
7
+ from modules.shared import sd_upscalers, opts, parser
8
+ from typing import Dict, List
9
+
10
+ API_NOT_ALLOWED = [
11
+ "self",
12
+ "kwargs",
13
+ "sd_model",
14
+ "outpath_samples",
15
+ "outpath_grids",
16
+ "sampler_index",
17
+ "do_not_save_samples",
18
+ "do_not_save_grid",
19
+ "extra_generation_params",
20
+ "overlay_images",
21
+ "do_not_reload_embeddings",
22
+ "seed_enable_extras",
23
+ "prompt_for_display",
24
+ "sampler_noise_scheduler_override",
25
+ "ddim_discretize"
26
+ ]
27
+
28
+ class ModelDef(BaseModel):
29
+ """Assistance Class for Pydantic Dynamic Model Generation"""
30
+
31
+ field: str
32
+ field_alias: str
33
+ field_type: Any
34
+ field_value: Any
35
+ field_exclude: bool = False
36
+
37
+
38
+ class PydanticModelGenerator:
39
+ """
40
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
41
+ source_data is a snapshot of the default values produced by the class
42
+ params are the names of the actual keys required by __init__
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ model_name: str = None,
48
+ class_instance = None,
49
+ additional_fields = None,
50
+ ):
51
+ def field_type_generator(k, v):
52
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
53
+ # print(k, v.annotation, v.default)
54
+ field_type = v.annotation
55
+
56
+ return Optional[field_type]
57
+
58
+ def merge_class_params(class_):
59
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
60
+ parameters = {}
61
+ for classes in all_classes:
62
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
63
+ return parameters
64
+
65
+
66
+ self._model_name = model_name
67
+ self._class_data = merge_class_params(class_instance)
68
+
69
+ self._model_def = [
70
+ ModelDef(
71
+ field=underscore(k),
72
+ field_alias=k,
73
+ field_type=field_type_generator(k, v),
74
+ field_value=v.default
75
+ )
76
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
77
+ ]
78
+
79
+ for fields in additional_fields:
80
+ self._model_def.append(ModelDef(
81
+ field=underscore(fields["key"]),
82
+ field_alias=fields["key"],
83
+ field_type=fields["type"],
84
+ field_value=fields["default"],
85
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
86
+
87
+ def generate_model(self):
88
+ """
89
+ Creates a pydantic BaseModel
90
+ from the json and overrides provided at initialization
91
+ """
92
+ fields = {
93
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
94
+ }
95
+ DynamicModel = create_model(self._model_name, **fields)
96
+ DynamicModel.__config__.allow_population_by_field_name = True
97
+ DynamicModel.__config__.allow_mutation = True
98
+ return DynamicModel
99
+
100
+ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
101
+ "StableDiffusionProcessingTxt2Img",
102
+ StableDiffusionProcessingTxt2Img,
103
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
104
+ ).generate_model()
105
+
106
+ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
107
+ "StableDiffusionProcessingImg2Img",
108
+ StableDiffusionProcessingImg2Img,
109
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
110
+ ).generate_model()
111
+
112
+ class TextToImageResponse(BaseModel):
113
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
114
+ parameters: dict
115
+ info: str
116
+
117
+ class ImageToImageResponse(BaseModel):
118
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
119
+ parameters: dict
120
+ info: str
121
+
122
+ class ExtrasBaseRequest(BaseModel):
123
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
124
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
125
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
126
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
127
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
128
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
129
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
130
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
131
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
132
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
133
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
134
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
135
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
136
+
137
+ class ExtraBaseResponse(BaseModel):
138
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
139
+
140
+ class ExtrasSingleImageRequest(ExtrasBaseRequest):
141
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
142
+
143
+ class ExtrasSingleImageResponse(ExtraBaseResponse):
144
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
145
+
146
+ class FileData(BaseModel):
147
+ data: str = Field(title="File data", description="Base64 representation of the file")
148
+ name: str = Field(title="File name")
149
+
150
+ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
151
+ imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
152
+
153
+ class ExtrasBatchImagesResponse(ExtraBaseResponse):
154
+ images: List[str] = Field(title="Images", description="The generated images in base64 format.")
155
+
156
+ class PNGInfoRequest(BaseModel):
157
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
158
+
159
+ class PNGInfoResponse(BaseModel):
160
+ info: str = Field(title="Image info", description="A string with all the info the image had")
161
+
162
+ class ProgressRequest(BaseModel):
163
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
164
+
165
+ class ProgressResponse(BaseModel):
166
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
167
+ eta_relative: float = Field(title="ETA in secs")
168
+ state: dict = Field(title="State", description="The current state snapshot")
169
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
170
+
171
+ class InterrogateRequest(BaseModel):
172
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
173
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
174
+
175
+ class InterrogateResponse(BaseModel):
176
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
177
+
178
+ fields = {}
179
+ for key, metadata in opts.data_labels.items():
180
+ value = opts.data.get(key)
181
+ optType = opts.typemap.get(type(metadata.default), type(value))
182
+
183
+ if (metadata is not None):
184
+ fields.update({key: (Optional[optType], Field(
185
+ default=metadata.default ,description=metadata.label))})
186
+ else:
187
+ fields.update({key: (Optional[optType], Field())})
188
+
189
+ OptionsModel = create_model("Options", **fields)
190
+
191
+ flags = {}
192
+ _options = vars(parser)['_option_string_actions']
193
+ for key in _options:
194
+ if(_options[key].dest != 'help'):
195
+ flag = _options[key]
196
+ _type = str
197
+ if _options[key].default is not None: _type = type(_options[key].default)
198
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
199
+
200
+ FlagsModel = create_model("Flags", **flags)
201
+
202
+ class SamplerItem(BaseModel):
203
+ name: str = Field(title="Name")
204
+ aliases: List[str] = Field(title="Aliases")
205
+ options: Dict[str, str] = Field(title="Options")
206
+
207
+ class UpscalerItem(BaseModel):
208
+ name: str = Field(title="Name")
209
+ model_name: Optional[str] = Field(title="Model Name")
210
+ model_path: Optional[str] = Field(title="Path")
211
+ model_url: Optional[str] = Field(title="URL")
212
+
213
+ class SDModelItem(BaseModel):
214
+ title: str = Field(title="Title")
215
+ model_name: str = Field(title="Model Name")
216
+ hash: str = Field(title="Hash")
217
+ filename: str = Field(title="Filename")
218
+ config: str = Field(title="Config file")
219
+
220
+ class HypernetworkItem(BaseModel):
221
+ name: str = Field(title="Name")
222
+ path: Optional[str] = Field(title="Path")
223
+
224
+ class FaceRestorerItem(BaseModel):
225
+ name: str = Field(title="Name")
226
+ cmd_dir: Optional[str] = Field(title="Path")
227
+
228
+ class RealesrganItem(BaseModel):
229
+ name: str = Field(title="Name")
230
+ path: Optional[str] = Field(title="Path")
231
+ scale: Optional[int] = Field(title="Scale")
232
+
233
+ class PromptStyleItem(BaseModel):
234
+ name: str = Field(title="Name")
235
+ prompt: Optional[str] = Field(title="Prompt")
236
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
237
+
238
+ class ArtistItem(BaseModel):
239
+ name: str = Field(title="Name")
240
+ score: float = Field(title="Score")
241
+ category: str = Field(title="Category")
242
+