Filipstrozik
Add initial implementation of EllipseRCNN model and dataset utilities
afc2161
# README.md
## Symmetric Kullback-Leibler Divergence Loss
This directory provides an implementation of a **Symmetric Kullback-Leibler (KL) Divergence Loss** tailored for tensors
representing ellipses in matrix form. The loss function is designed to measure the difference between two elliptical
shapes and is particularly useful in optimization and generative modeling tasks.
## Loss Calculation
### **Kullback-Leibler Divergence**
For two ellipses represented by their matrix forms ( $A_1$ ) and ( $A_2$ ), the KL divergence is calculated as:
$$ D_{KL}(A_1 \parallel A_2) = \frac{1}{2} \left( \text{Tr}(C_2^{-1}C_1) + (\mu_1 - \mu_2)^T C_2^{-1} (\mu_1 - \mu_2) - 2 + \log\left(\frac{\det(C_2)}{\det(C_1)}\right) \right) $$
Where:
- ( $C_1$, $C_2$ ): Covariance matrices extracted from ( $A_1$, $A_2$ ).
- ( $\mu_1$, $\mu_2$ ): Centers (means) of the ellipses, computed from the conic representation.
- ( $\text{Tr}$ ): Trace operator.
- ( $C_2^{-1}$ ): Inverse of the covariance matrix of ( $A_2$ ).
- ( $\det(C_1)$, $\det(C_2)$ ): Determinants of covariance matrices.
A regularization term ( $\epsilon$ ) is added to ensure numerical stability when computing inverses and determinants.
### **Symmetric KL Divergence**
The symmetric version of the KL divergence combines the calculations in both directions:
$$ D_{KL}^{\text{sym}}(A_1, A_2) = \frac{1}{2} \left( D_{KL}(A_1 \parallel A_2) + D_{KL}(A_2 \parallel A_1) \right) $$
This ensures a bidirectional comparison, making the function suitable as a loss metric in optimization tasks.
## Features of the Loss
- **Shape-Only Comparison**: Option to ignore translation and compute divergence based purely on the shapes (covariance
matrices).
- **NaN Handling**: Replaces NaN values with a specified constant, ensuring robust loss evaluation.
- **Normalization**: An optional normalization step that rescales the divergence for certain applications.
### Usage
The loss is encapsulated in the `SymmetricKLDLoss` class, which integrates seamlessly into PyTorch-based workflows.