|
from inference.core.env import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, LAMBDA |
|
from inference.core.models.classification_base import ( |
|
ClassificationBaseOnnxRoboflowInferenceModel, |
|
) |
|
|
|
|
|
class VitClassification(ClassificationBaseOnnxRoboflowInferenceModel): |
|
"""VitClassification handles classification inference |
|
for Vision Transformer (ViT) models using ONNX. |
|
|
|
Inherits: |
|
ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference. |
|
ClassificationMixin: Mixin class providing classification-specific methods. |
|
|
|
Attributes: |
|
multiclass (bool): A flag that specifies if the model should handle multiclass classification. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
"""Initializes the VitClassification instance. |
|
|
|
Args: |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
super().__init__(*args, **kwargs) |
|
self.multiclass = self.environment.get("MULTICLASS", False) |
|
|
|
@property |
|
def weights_file(self) -> str: |
|
"""Determines the weights file to be used based on the availability of AWS keys. |
|
|
|
If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. |
|
Otherwise, it returns the path to 'best.onnx'. |
|
|
|
Returns: |
|
str: Path to the weights file. |
|
""" |
|
if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA: |
|
return "weights.onnx" |
|
else: |
|
return "best.onnx" |
|
|