radames commited on
Commit
2ab3299
·
1 Parent(s): 100e61a

simplify type annotations and remove unused TypeVars

Browse files
Files changed (3) hide show
  1. server/connection_manager.py +1 -1
  2. server/main.py +1 -1
  3. server/util.py +17 -4
server/connection_manager.py CHANGED
@@ -3,7 +3,7 @@ import asyncio
3
  from fastapi import WebSocket
4
  from starlette.websockets import WebSocketState
5
  import logging
6
- from typing import Any, TypeVar
7
  from util import ParamsModel
8
 
9
  Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]]
 
3
  from fastapi import WebSocket
4
  from starlette.websockets import WebSocketState
5
  import logging
6
+ from typing import Any
7
  from util import ParamsModel
8
 
9
  Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]]
server/main.py CHANGED
@@ -12,7 +12,7 @@ from connection_manager import ConnectionManager, ServerFullException
12
  import uuid
13
  from uuid import UUID
14
  import time
15
- from typing import Any, Protocol, TypeVar, runtime_checkable, cast
16
  from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
17
  from device import device, torch_dtype
18
  import asyncio
 
12
  import uuid
13
  from uuid import UUID
14
  import time
15
+ from typing import Any, Protocol, runtime_checkable
16
  from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
17
  from device import device, torch_dtype
18
  import asyncio
server/util.py CHANGED
@@ -1,12 +1,12 @@
1
  from importlib import import_module
2
- from typing import Any, TypeVar, Generic, TypeVar
3
  from PIL import Image
4
  import io
5
- from pydantic import BaseModel, create_model, Field
6
 
7
 
 
8
  TPipeline = TypeVar("TPipeline", bound=type[Any])
9
- T = TypeVar('T')
10
 
11
 
12
  class ParamsModel(BaseModel):
@@ -22,7 +22,20 @@ class ParamsModel(BaseModel):
22
  return self.model_dump()
23
 
24
 
25
- def get_pipeline_class(pipeline_name: str) -> TPipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
  module = import_module(f"pipelines.{pipeline_name}")
28
  except ModuleNotFoundError:
 
1
  from importlib import import_module
2
+ from typing import Any, TypeVar, type_check_only
3
  from PIL import Image
4
  import io
5
+ from pydantic import BaseModel
6
 
7
 
8
+ # Used only for type checking the pipeline class
9
  TPipeline = TypeVar("TPipeline", bound=type[Any])
 
10
 
11
 
12
  class ParamsModel(BaseModel):
 
22
  return self.model_dump()
23
 
24
 
25
+ def get_pipeline_class(pipeline_name: str) -> type:
26
+ """
27
+ Dynamically imports and returns the Pipeline class from a specified module.
28
+
29
+ Args:
30
+ pipeline_name: The name of the pipeline module to import
31
+
32
+ Returns:
33
+ The Pipeline class from the specified module
34
+
35
+ Raises:
36
+ ValueError: If the module or Pipeline class isn't found
37
+ TypeError: If Pipeline is not a class
38
+ """
39
  try:
40
  module = import_module(f"pipelines.{pipeline_name}")
41
  except ModuleNotFoundError: