Spaces:
Running
on
A100
Running
on
A100
File size: 2,320 Bytes
cb92d2b 9a8789a d6fedfa 2ab3299 cb92d2b 2ab3299 100e61a 9a8789a 100e61a 9a8789a 100e61a 9a8789a 100e61a 2ab3299 9a8789a 2ab3299 9a8789a 2ab3299 9a8789a 2ab3299 cb92d2b 1123781 cb92d2b 100e61a cb92d2b d6fedfa 3207814 d6fedfa d446912 d6fedfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from importlib import import_module
from typing import Any, TypeVar
from PIL import Image
import io
from pydantic import BaseModel
# Used only for type checking the pipeline class
TPipeline = TypeVar("TPipeline", bound=type[Any])
class ParamsModel(BaseModel):
"""Base model for pipeline parameters."""
model_config = {
"arbitrary_types_allowed": True,
"extra": "allow", # Allow extra attributes for dynamic fields like 'image'
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ParamsModel":
"""Create a model instance from dictionary data."""
return cls.model_validate(data)
def to_dict(self) -> dict[str, Any]:
"""Convert model to dictionary."""
return self.model_dump()
def get_pipeline_class(pipeline_name: str) -> type:
"""
Dynamically imports and returns the Pipeline class from a specified module.
Args:
pipeline_name: The name of the pipeline module to import
Returns:
The Pipeline class from the specified module
Raises:
ValueError: If the module or Pipeline class isn't found
TypeError: If Pipeline is not a class
"""
try:
module = import_module(f"pipelines.{pipeline_name}")
except ModuleNotFoundError:
raise ValueError(f"Pipeline {pipeline_name} module not found")
pipeline_class = getattr(module, "Pipeline", None)
if pipeline_class is None:
raise ValueError(f"'Pipeline' class not found in module '{pipeline_name}'.")
# Type check to ensure we're returning a class
if not isinstance(pipeline_class, type):
raise TypeError(f"'Pipeline' in module '{pipeline_name}' is not a class")
return pipeline_class
def bytes_to_pil(image_bytes: bytes) -> Image.Image:
image = Image.open(io.BytesIO(image_bytes))
return image
def pil_to_frame(image: Image.Image) -> bytes:
frame_data = io.BytesIO()
image.save(frame_data, format="JPEG", quality=80, optimize=True, progressive=True)
frame_data = frame_data.getvalue()
return (
b"--frame\r\n"
+ b"Content-Type: image/jpeg\r\n"
+ f"Content-Length: {len(frame_data)}\r\n\r\n".encode()
+ frame_data
+ b"\r\n"
)
def is_firefox(user_agent: str) -> bool:
return "Firefox" in user_agent
|