radames commited on
Commit
100e61a
·
1 Parent(s): b5e95b8

type hints and Pydantic

Browse files
server/connection_manager.py CHANGED
@@ -1,12 +1,12 @@
1
- from typing import Dict, Union
2
  from uuid import UUID
3
  import asyncio
4
  from fastapi import WebSocket
5
  from starlette.websockets import WebSocketState
6
  import logging
7
- from types import SimpleNamespace
 
8
 
9
- Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
10
 
11
 
12
  class ServerFullException(Exception):
@@ -44,13 +44,13 @@ class ConnectionManager:
44
  def check_user(self, user_id: UUID) -> bool:
45
  return user_id in self.active_connections
46
 
47
- async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
48
  user_session = self.active_connections.get(user_id)
49
  if user_session:
50
  queue = user_session["queue"]
51
  await queue.put(new_data)
52
 
53
- async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
54
  user_session = self.active_connections.get(user_id)
55
  if user_session:
56
  queue = user_session["queue"]
@@ -58,6 +58,7 @@ class ConnectionManager:
58
  return await queue.get()
59
  except asyncio.QueueEmpty:
60
  return None
 
61
 
62
  def delete_user(self, user_id: UUID):
63
  user_session = self.active_connections.pop(user_id, None)
@@ -86,7 +87,7 @@ class ConnectionManager:
86
  await websocket.close()
87
  self.delete_user(user_id)
88
 
89
- async def send_json(self, user_id: UUID, data: Dict):
90
  try:
91
  websocket = self.get_websocket(user_id)
92
  if websocket:
@@ -94,18 +95,22 @@ class ConnectionManager:
94
  except Exception as e:
95
  logging.error(f"Error: Send json: {e}")
96
 
97
- async def receive_json(self, user_id: UUID) -> Dict:
98
  try:
99
  websocket = self.get_websocket(user_id)
100
  if websocket:
101
  return await websocket.receive_json()
 
102
  except Exception as e:
103
  logging.error(f"Error: Receive json: {e}")
 
104
 
105
- async def receive_bytes(self, user_id: UUID) -> bytes:
106
  try:
107
  websocket = self.get_websocket(user_id)
108
  if websocket:
109
  return await websocket.receive_bytes()
 
110
  except Exception as e:
111
  logging.error(f"Error: Receive bytes: {e}")
 
 
 
1
  from uuid import UUID
2
  import asyncio
3
  from fastapi import WebSocket
4
  from starlette.websockets import WebSocketState
5
  import logging
6
+ from typing import Any, TypeVar
7
+ from util import ParamsModel
8
 
9
+ Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]]
10
 
11
 
12
  class ServerFullException(Exception):
 
44
  def check_user(self, user_id: UUID) -> bool:
45
  return user_id in self.active_connections
46
 
47
+ async def update_data(self, user_id: UUID, new_data: ParamsModel):
48
  user_session = self.active_connections.get(user_id)
49
  if user_session:
50
  queue = user_session["queue"]
51
  await queue.put(new_data)
52
 
53
+ async def get_latest_data(self, user_id: UUID) -> ParamsModel | None:
54
  user_session = self.active_connections.get(user_id)
55
  if user_session:
56
  queue = user_session["queue"]
 
58
  return await queue.get()
59
  except asyncio.QueueEmpty:
60
  return None
61
+ return None
62
 
63
  def delete_user(self, user_id: UUID):
64
  user_session = self.active_connections.pop(user_id, None)
 
87
  await websocket.close()
88
  self.delete_user(user_id)
89
 
90
+ async def send_json(self, user_id: UUID, data: dict):
91
  try:
92
  websocket = self.get_websocket(user_id)
93
  if websocket:
 
95
  except Exception as e:
96
  logging.error(f"Error: Send json: {e}")
97
 
98
+ async def receive_json(self, user_id: UUID) -> dict | None:
99
  try:
100
  websocket = self.get_websocket(user_id)
101
  if websocket:
102
  return await websocket.receive_json()
103
+ return None
104
  except Exception as e:
105
  logging.error(f"Error: Receive json: {e}")
106
+ return None
107
 
108
+ async def receive_bytes(self, user_id: UUID) -> bytes | None:
109
  try:
110
  websocket = self.get_websocket(user_id)
111
  if websocket:
112
  return await websocket.receive_bytes()
113
+ return None
114
  except Exception as e:
115
  logging.error(f"Error: Receive bytes: {e}")
116
+ return None
server/main.py CHANGED
@@ -10,30 +10,53 @@ import logging
10
  from config import config, Args
11
  from connection_manager import ConnectionManager, ServerFullException
12
  import uuid
 
13
  import time
14
- from types import SimpleNamespace
15
- from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class
16
  from device import device, torch_dtype
17
  import asyncio
18
  import os
19
  import time
20
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  THROTTLE = 1.0 / 120
24
 
25
 
26
  class App:
27
- def __init__(self, config: Args, pipeline):
28
  self.args = config
29
- self.pipeline = pipeline
30
  self.app = FastAPI()
31
  self.conn_manager = ConnectionManager()
 
32
  if self.args.safety_checker:
33
  self.safety_checker = SafetyChecker(device=device.type)
34
  self.init_app()
35
 
36
- def init_app(self):
37
  self.app.add_middleware(
38
  CORSMiddleware,
39
  allow_origins=["*"],
@@ -43,7 +66,7 @@ class App:
43
  )
44
 
45
  @self.app.websocket("/api/ws/{user_id}")
46
- async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
47
  try:
48
  await self.conn_manager.connect(
49
  user_id, websocket, self.args.max_queue_size
@@ -55,9 +78,9 @@ class App:
55
  await self.conn_manager.disconnect(user_id)
56
  logging.info(f"User disconnected: {user_id}")
57
 
58
- async def handle_websocket_data(user_id: uuid.UUID):
59
  if not self.conn_manager.check_user(user_id):
60
- return HTTPException(status_code=404, detail="User not found")
61
  last_time = time.time()
62
  try:
63
  while True:
@@ -75,19 +98,29 @@ class App:
75
  await self.conn_manager.disconnect(user_id)
76
  return
77
  data = await self.conn_manager.receive_json(user_id)
 
 
 
78
  if data["status"] == "next_frame":
79
- info = pipeline.Info()
80
- params = await self.conn_manager.receive_json(user_id)
81
- params = pipeline.InputParams(**params)
82
- params = SimpleNamespace(**params.dict())
 
 
 
83
  if info.input_mode == "image":
84
  image_data = await self.conn_manager.receive_bytes(user_id)
85
- if len(image_data) == 0:
86
  await self.conn_manager.send_json(
87
  user_id, {"status": "send_frame"}
88
  )
89
  continue
90
- params.image = bytes_to_pil(image_data)
 
 
 
 
91
 
92
  await self.conn_manager.update_data(user_id, params)
93
  await self.conn_manager.send_json(user_id, {"status": "wait"})
@@ -97,29 +130,32 @@ class App:
97
  await self.conn_manager.disconnect(user_id)
98
 
99
  @self.app.get("/api/queue")
100
- async def get_queue_size():
101
  queue_size = self.conn_manager.get_user_count()
102
  return JSONResponse({"queue_size": queue_size})
103
 
104
  @self.app.get("/api/stream/{user_id}")
105
- async def stream(user_id: uuid.UUID, request: Request):
106
  try:
107
-
108
- async def generate():
109
- last_params = SimpleNamespace()
110
  while True:
111
  last_time = time.time()
112
  await self.conn_manager.send_json(
113
  user_id, {"status": "send_frame"}
114
  )
115
  params = await self.conn_manager.get_latest_data(user_id)
116
- if params.__dict__ == last_params.__dict__ or params is None:
 
 
 
117
  await asyncio.sleep(THROTTLE)
118
  continue
119
- last_params: SimpleNamespace = params
120
- image = pipeline.predict(params)
 
121
 
122
- if self.args.safety_checker:
123
  image, has_nsfw_concept = self.safety_checker(image)
124
  if has_nsfw_concept:
125
  image = None
@@ -141,23 +177,24 @@ class App:
141
  )
142
  except Exception as e:
143
  logging.error(f"Streaming Error: {e}, {user_id} ")
144
- return HTTPException(status_code=404, detail="User not found")
145
 
146
  # route to setup frontend
147
  @self.app.get("/api/settings")
148
- async def settings():
149
- info_schema = pipeline.Info.schema()
150
- info = pipeline.Info()
151
- if info.page_content:
 
152
  page_content = markdown2.markdown(info.page_content)
153
 
154
- input_params = pipeline.InputParams.schema()
155
  return JSONResponse(
156
  {
157
  "info": info_schema,
158
  "input_params": input_params,
159
  "max_queue_size": self.args.max_queue_size,
160
- "page_content": page_content if info.page_content else "",
161
  }
162
  )
163
 
@@ -169,17 +206,35 @@ class App:
169
  )
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  print(f"Device: {device}")
173
  print(f"torch_dtype: {torch_dtype}")
 
174
  pipeline_class = get_pipeline_class(config.pipeline)
175
- pipeline = pipeline_class(config, device, torch_dtype)
176
- app = App(config, pipeline).app
 
177
 
178
  if __name__ == "__main__":
179
  import uvicorn
180
 
 
 
181
  uvicorn.run(
182
- "main:app",
183
  host=config.host,
184
  port=config.port,
185
  reload=config.reload,
 
10
  from config import config, Args
11
  from connection_manager import ConnectionManager, ServerFullException
12
  import uuid
13
+ from uuid import UUID
14
  import time
15
+ from typing import Any, Protocol, TypeVar, runtime_checkable, cast
16
+ from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
17
  from device import device, torch_dtype
18
  import asyncio
19
  import os
20
  import time
21
  import torch
22
+ from pydantic import BaseModel, create_model
23
+
24
+ @runtime_checkable
25
+ class BasePipeline(Protocol):
26
+ class Info:
27
+ @classmethod
28
+ def schema(cls) -> dict[str, Any]:
29
+ ...
30
+ page_content: str | None
31
+ input_mode: str
32
+
33
+ class InputParams(ParamsModel):
34
+ @classmethod
35
+ def schema(cls) -> dict[str, Any]:
36
+ ...
37
+
38
+ def dict(self) -> dict[str, Any]:
39
+ ...
40
+
41
+ def predict(self, params: ParamsModel) -> Image.Image | None:
42
+ ...
43
 
44
 
45
  THROTTLE = 1.0 / 120
46
 
47
 
48
  class App:
49
+ def __init__(self, config: Args, pipeline_instance: BasePipeline):
50
  self.args = config
51
+ self.pipeline = pipeline_instance
52
  self.app = FastAPI()
53
  self.conn_manager = ConnectionManager()
54
+ self.safety_checker: SafetyChecker | None = None
55
  if self.args.safety_checker:
56
  self.safety_checker = SafetyChecker(device=device.type)
57
  self.init_app()
58
 
59
+ def init_app(self) -> None:
60
  self.app.add_middleware(
61
  CORSMiddleware,
62
  allow_origins=["*"],
 
66
  )
67
 
68
  @self.app.websocket("/api/ws/{user_id}")
69
+ async def websocket_endpoint(user_id: UUID, websocket: WebSocket) -> None:
70
  try:
71
  await self.conn_manager.connect(
72
  user_id, websocket, self.args.max_queue_size
 
78
  await self.conn_manager.disconnect(user_id)
79
  logging.info(f"User disconnected: {user_id}")
80
 
81
+ async def handle_websocket_data(user_id: UUID) -> None:
82
  if not self.conn_manager.check_user(user_id):
83
+ raise HTTPException(status_code=404, detail="User not found")
84
  last_time = time.time()
85
  try:
86
  while True:
 
98
  await self.conn_manager.disconnect(user_id)
99
  return
100
  data = await self.conn_manager.receive_json(user_id)
101
+ if data is None:
102
+ continue
103
+
104
  if data["status"] == "next_frame":
105
+ info = self.pipeline.Info()
106
+ params_data = await self.conn_manager.receive_json(user_id)
107
+ if params_data is None:
108
+ continue
109
+
110
+ params = self.pipeline.InputParams.model_validate(params_data)
111
+
112
  if info.input_mode == "image":
113
  image_data = await self.conn_manager.receive_bytes(user_id)
114
+ if image_data is None or len(image_data) == 0:
115
  await self.conn_manager.send_json(
116
  user_id, {"status": "send_frame"}
117
  )
118
  continue
119
+
120
+ # Create a new Pydantic model with the image field
121
+ params_dict = params.model_dump()
122
+ params_dict["image"] = bytes_to_pil(image_data)
123
+ params = self.pipeline.InputParams.model_validate(params_dict)
124
 
125
  await self.conn_manager.update_data(user_id, params)
126
  await self.conn_manager.send_json(user_id, {"status": "wait"})
 
130
  await self.conn_manager.disconnect(user_id)
131
 
132
  @self.app.get("/api/queue")
133
+ async def get_queue_size() -> JSONResponse:
134
  queue_size = self.conn_manager.get_user_count()
135
  return JSONResponse({"queue_size": queue_size})
136
 
137
  @self.app.get("/api/stream/{user_id}")
138
+ async def stream(user_id: UUID, request: Request) -> StreamingResponse:
139
  try:
140
+ async def generate() -> bytes:
141
+ last_params: ParamsModel | None = None
 
142
  while True:
143
  last_time = time.time()
144
  await self.conn_manager.send_json(
145
  user_id, {"status": "send_frame"}
146
  )
147
  params = await self.conn_manager.get_latest_data(user_id)
148
+
149
+ if (params is None or
150
+ (last_params is not None and
151
+ params.model_dump() == last_params.model_dump())):
152
  await asyncio.sleep(THROTTLE)
153
  continue
154
+
155
+ last_params = params
156
+ image = self.pipeline.predict(params)
157
 
158
+ if self.args.safety_checker and self.safety_checker is not None and image is not None:
159
  image, has_nsfw_concept = self.safety_checker(image)
160
  if has_nsfw_concept:
161
  image = None
 
177
  )
178
  except Exception as e:
179
  logging.error(f"Streaming Error: {e}, {user_id} ")
180
+ raise HTTPException(status_code=404, detail="User not found")
181
 
182
  # route to setup frontend
183
  @self.app.get("/api/settings")
184
+ async def settings() -> JSONResponse:
185
+ info_schema = self.pipeline.Info.schema()
186
+ info = self.pipeline.Info()
187
+ page_content = ""
188
+ if hasattr(info, 'page_content') and info.page_content:
189
  page_content = markdown2.markdown(info.page_content)
190
 
191
+ input_params = self.pipeline.InputParams.schema()
192
  return JSONResponse(
193
  {
194
  "info": info_schema,
195
  "input_params": input_params,
196
  "max_queue_size": self.args.max_queue_size,
197
+ "page_content": page_content,
198
  }
199
  )
200
 
 
206
  )
207
 
208
 
209
+ # def create_app(config):
210
+ # print(f"Device: {device}")
211
+ # print(f"torch_dtype: {torch_dtype}")
212
+
213
+ # # Create pipeline once
214
+ # pipeline_class = get_pipeline_class(config.pipeline)
215
+ # pipeline_instance = pipeline_class(config, device, torch_dtype)
216
+
217
+ # # Pass the existing pipeline instance to App
218
+ # app = App(config, pipeline_instance).app
219
+ # return app
220
+
221
+
222
+ # Create app instance at module level
223
  print(f"Device: {device}")
224
  print(f"torch_dtype: {torch_dtype}")
225
+
226
  pipeline_class = get_pipeline_class(config.pipeline)
227
+ pipeline_instance = pipeline_class(config, device, torch_dtype)
228
+ app = App(config, pipeline_instance).app # This creates the FastAPI app instance
229
+
230
 
231
  if __name__ == "__main__":
232
  import uvicorn
233
 
234
+ # app = create_app(config) # Create the app once
235
+
236
  uvicorn.run(
237
+ app,
238
  host=config.host,
239
  port=config.port,
240
  reload=config.reload,
server/pipelines/img2imgFlux.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from optimum.quanto import freeze, qfloat8, quantize
4
+ from transformers.modeling_utils import PreTrainedModel
5
+
6
+ from diffusers import (
7
+ FlowMatchEulerDiscreteScheduler,
8
+ AutoencoderKL,
9
+ AutoencoderTiny,
10
+ FluxImg2ImgPipeline,
11
+ FluxPipeline,
12
+ )
13
+
14
+ from diffusers import (
15
+ FluxImg2ImgPipeline,
16
+ FluxPipeline,
17
+ FluxTransformer2DModel,
18
+ GGUFQuantizationConfig,
19
+ )
20
+
21
+ try:
22
+ import intel_extension_for_pytorch as ipex # type: ignore
23
+ except:
24
+ pass
25
+
26
+ import psutil
27
+ from config import Args
28
+ from pydantic import BaseModel, Field
29
+ from PIL import Image
30
+ from pathlib import Path
31
+ import math
32
+ import gc
33
+
34
+
35
+ # model_path = "black-forest-labs/FLUX.1-dev"
36
+ model_path = "black-forest-labs/FLUX.1-schnell"
37
+ base_model_path = "black-forest-labs/FLUX.1-schnell"
38
+ taesd_path = "madebyollin/taef1"
39
+ subfolder = "transformer"
40
+ transformer_path = model_path
41
+ models_path = Path("models")
42
+
43
+ default_prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
44
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
45
+ page_content = """
46
+ <h1 class="text-3xl font-bold">Real-Time FLUX</h1>
47
+
48
+ """
49
+
50
+
51
+ def flush():
52
+ torch.cuda.empty_cache()
53
+ gc.collect()
54
+
55
+
56
+ class Pipeline:
57
+ class Info(BaseModel):
58
+ name: str = "img2img"
59
+ title: str = "Image-to-Image SDXL"
60
+ description: str = "Generates an image from a text prompt"
61
+ input_mode: str = "image"
62
+ page_content: str = page_content
63
+
64
+ class InputParams(BaseModel):
65
+ prompt: str = Field(
66
+ default_prompt,
67
+ title="Prompt",
68
+ field="textarea",
69
+ id="prompt",
70
+ )
71
+ seed: int = Field(
72
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
73
+ )
74
+ steps: int = Field(
75
+ 1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
76
+ )
77
+ width: int = Field(
78
+ 256, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
79
+ )
80
+ height: int = Field(
81
+ 256, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
82
+ )
83
+ strength: float = Field(
84
+ 0.5,
85
+ min=0.25,
86
+ max=1.0,
87
+ step=0.001,
88
+ title="Strength",
89
+ field="range",
90
+ hide=True,
91
+ id="strength",
92
+ )
93
+ guidance: float = Field(
94
+ 3.5,
95
+ min=0,
96
+ max=20,
97
+ step=0.001,
98
+ title="Guidance",
99
+ hide=True,
100
+ field="range",
101
+ id="guidance",
102
+ )
103
+
104
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
105
+ # ckpt_path = (
106
+ # "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
107
+ # )
108
+ print("Loading model")
109
+ # ckpt_path: str = "https://huggingface.co/city96/FLUX.1-schnell-gguf/blob/main/flux1-schnell-Q6_K.gguf"
110
+ ckpt_path: str = "https://huggingface.co/city96/FLUX.1-schnell-gguf/blob/main/flux1-schnell-Q4_K_S.gguf"
111
+ transformer = FluxTransformer2DModel.from_single_file(
112
+ ckpt_path,
113
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
114
+ torch_dtype=torch.bfloat16,
115
+ )
116
+
117
+ # else:
118
+ pipe = FluxImg2ImgPipeline.from_pretrained(
119
+ # "black-forest-labs/FLUX.1-dev",
120
+ "black-forest-labs/FLUX.1-Schnell",
121
+ transformer=transformer,
122
+ torch_dtype=torch.bfloat16,
123
+ )
124
+ if args.taesd:
125
+ pipe.vae = AutoencoderTiny.from_pretrained(
126
+ taesd_path, torch_dtype=torch.bfloat16, use_safetensors=True
127
+ )
128
+ # pipe.enable_model_cpu_offload()
129
+ pipe = pipe.to(device)
130
+
131
+ # pipe.enable_model_cpu_offload()
132
+
133
+ self.pipe = pipe
134
+ self.pipe.set_progress_bar_config(disable=True)
135
+
136
+ # vae = AutoencoderKL.from_pretrained(
137
+ # base_model_path, subfolder="vae", torch_dtype=torch_dtype
138
+ # )
139
+
140
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
141
+ generator = torch.manual_seed(params.seed)
142
+ steps = params.steps
143
+ strength = params.strength
144
+ prompt = params.prompt
145
+ guidance = params.guidance
146
+
147
+ results = self.pipe(
148
+ image=params.image,
149
+ prompt=prompt,
150
+ generator=generator,
151
+ strength=strength,
152
+ num_inference_steps=steps,
153
+ guidance_scale=guidance,
154
+ width=params.width,
155
+ height=params.height,
156
+ )
157
+ return results.images[0]
server/requirements.txt CHANGED
@@ -15,9 +15,16 @@ xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
15
  markdown2
16
  safetensors
17
  stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/nightly/stable_fast-1.0.5.dev20241127+torch230cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
- oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20241114%2Bcu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
19
- onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
20
  setuptools
21
  mpmath==1.3.0
22
  numpy==1.*
23
- controlnet-aux
 
 
 
 
 
 
 
 
15
  markdown2
16
  safetensors
17
  stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/nightly/stable_fast-1.0.5.dev20241127+torch230cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
+ #oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20241114%2Bcu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
19
+ #onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
20
  setuptools
21
  mpmath==1.3.0
22
  numpy==1.*
23
+ controlnet-aux
24
+ sentencepiece==0.2.0
25
+ optimum-quanto
26
+ gguf==0.13.0
27
+ pydantic>=2.7.0
28
+ types-Pillow
29
+ mypy
30
+ python-dotenv
server/util.py CHANGED
@@ -1,10 +1,28 @@
1
  from importlib import import_module
2
- from types import ModuleType
3
  from PIL import Image
4
  import io
 
5
 
6
 
7
- def get_pipeline_class(pipeline_name: str) -> ModuleType:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  try:
9
  module = import_module(f"pipelines.{pipeline_name}")
10
  except ModuleNotFoundError:
@@ -15,6 +33,10 @@ def get_pipeline_class(pipeline_name: str) -> ModuleType:
15
  if pipeline_class is None:
16
  raise ValueError(f"'Pipeline' class not found in module '{pipeline_name}'.")
17
 
 
 
 
 
18
  return pipeline_class
19
 
20
 
 
1
  from importlib import import_module
2
+ from typing import Any, TypeVar, Generic, TypeVar
3
  from PIL import Image
4
  import io
5
+ from pydantic import BaseModel, create_model, Field
6
 
7
 
8
+ TPipeline = TypeVar("TPipeline", bound=type[Any])
9
+ T = TypeVar('T')
10
+
11
+
12
+ class ParamsModel(BaseModel):
13
+ """Base model for pipeline parameters."""
14
+
15
+ @classmethod
16
+ def from_dict(cls, data: dict[str, Any]) -> 'ParamsModel':
17
+ """Create a model instance from dictionary data."""
18
+ return cls.model_validate(data)
19
+
20
+ def to_dict(self) -> dict[str, Any]:
21
+ """Convert model to dictionary."""
22
+ return self.model_dump()
23
+
24
+
25
+ def get_pipeline_class(pipeline_name: str) -> TPipeline:
26
  try:
27
  module = import_module(f"pipelines.{pipeline_name}")
28
  except ModuleNotFoundError:
 
33
  if pipeline_class is None:
34
  raise ValueError(f"'Pipeline' class not found in module '{pipeline_name}'.")
35
 
36
+ # Type check to ensure we're returning a class
37
+ if not isinstance(pipeline_class, type):
38
+ raise TypeError(f"'Pipeline' in module '{pipeline_name}' is not a class")
39
+
40
  return pipeline_class
41
 
42