|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
import torch |
|
from torch import nn |
|
from torch.distributions.relaxed_bernoulli import RelaxedBernoulli |
|
|
|
from .config import PolyConfig |
|
|
|
|
|
EPS = 1e-12 |
|
|
|
|
|
def get_router(poly_config: PolyConfig) -> nn.Module: |
|
if poly_config.poly_type == "poly": |
|
return PolyRouter(poly_config) |
|
else: |
|
raise ValueError( |
|
f"Unsupported poly_type: {poly_config.poly_type}. " |
|
"Currently, only the following types are supported: " |
|
"`poly`." |
|
) |
|
|
|
|
|
class Router(nn.Module, ABC): |
|
@abstractmethod |
|
def reset(self): ... |
|
|
|
@abstractmethod |
|
def forward(self, task_ids: torch.Tensor, input_ids: torch.Tensor): ... |
|
|
|
|
|
class PolyRouter(Router): |
|
|
|
|
|
def __init__(self, poly_config: PolyConfig): |
|
super().__init__() |
|
|
|
self.poly_type = poly_config.poly_type |
|
self.n_tasks = poly_config.n_tasks |
|
self.n_skills = poly_config.n_skills |
|
self.n_splits = poly_config.n_splits |
|
|
|
self.module_logits = nn.Parameter(torch.empty((self.n_tasks, self.n_splits * self.n_skills))) |
|
|
|
def reset(self): |
|
torch.nn.init.uniform_(self.module_logits, -1e-3, 1e-3) |
|
|
|
def forward(self, task_ids: torch.Tensor, input_ids: torch.Tensor): |
|
if task_ids is None: |
|
raise ValueError("task_ids should not be None.") |
|
if task_ids.max().item() >= self.n_tasks: |
|
raise ValueError(f"Only {self.n_tasks} tasks available. Found task id = {task_ids.max().item()}") |
|
|
|
|
|
task_ids = task_ids.to(self.module_logits.device) |
|
|
|
module_logits = self.module_logits[task_ids] |
|
module_logits = module_logits.view(-1, self.n_splits, self.n_skills) |
|
|
|
if self.training: |
|
module_logits = RelaxedBernoulli(temperature=1.0, logits=module_logits).rsample() |
|
else: |
|
module_logits = torch.sigmoid(module_logits) |
|
|
|
module_weights = module_logits / (module_logits.sum(dim=-1, keepdim=True) + EPS) |
|
|
|
return module_weights |
|
|