Commit
·
5b70063
1
Parent(s):
2e36228
Upload loconet
Browse files- config.json +17 -0
- config_loconet.py +23 -0
- modeling_loconet.py +45 -0
- pytorch_model.bin +3 -0
config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"adjust_attention": false,
|
3 |
+
"architectures": [
|
4 |
+
"loconet"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "config_loconet.LoCoNetConfig",
|
8 |
+
"AutoModel": "modeling_loconet.loconet"
|
9 |
+
},
|
10 |
+
"av": "speaker_temporal",
|
11 |
+
"av_layers": 3,
|
12 |
+
"clip_length": 200,
|
13 |
+
"model_type": "loconet",
|
14 |
+
"num_speakers": 3,
|
15 |
+
"torch_dtype": "float32",
|
16 |
+
"transformers_version": "4.28.1"
|
17 |
+
}
|
config_loconet.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class LoCoNetConfig(PretrainedConfig):
|
6 |
+
model_type = "loconet"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
num_speakers: int = 3,
|
11 |
+
clip_length: int = 200,
|
12 |
+
av: str = "speaker_temporal",
|
13 |
+
av_layers: int = 3,
|
14 |
+
adjust_attention: bool = False,
|
15 |
+
**kwargs,
|
16 |
+
):
|
17 |
+
|
18 |
+
self.num_speakers = num_speakers
|
19 |
+
self.clip_length = clip_length
|
20 |
+
self.av = av
|
21 |
+
self.av_layers = av_layers
|
22 |
+
self.adjust_attention = adjust_attention
|
23 |
+
super().__init__(**kwargs)
|
modeling_loconet.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from config_loconet import LoCoNetConfig
|
2 |
+
from transformers import PreTrainedModel
|
3 |
+
from loconet_encoder import locoencoder
|
4 |
+
from loss_multi import lossAV, lossA, lossV
|
5 |
+
|
6 |
+
|
7 |
+
class loconet(PreTrainedModel):
|
8 |
+
config_class = LoCoNetConfig
|
9 |
+
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
|
13 |
+
self.model = locoencoder(config)
|
14 |
+
|
15 |
+
def forward(self, audioFeature, visualFeature, masks, labels=None):
|
16 |
+
b, s, t = visualFeature.shape[:3]
|
17 |
+
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
18 |
+
labels = labels.view(b * s, *labels.shape[2:])
|
19 |
+
masks = masks.view(b * s, *masks.shape[2:])
|
20 |
+
|
21 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4
|
22 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature)
|
23 |
+
audioEmbed = audioEmbed.repeat(s, 1, 1)
|
24 |
+
|
25 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
26 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
|
27 |
+
outsA = self.model.forward_audio_backend(audioEmbed)
|
28 |
+
outsV = self.model.forward_visual_backend(visualEmbed)
|
29 |
+
num_frames = masks.sum()
|
30 |
+
|
31 |
+
if labels is not None:
|
32 |
+
|
33 |
+
labels = labels.reshape((-1))
|
34 |
+
masks = masks.reshape((-1))
|
35 |
+
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
|
36 |
+
nlossA = self.lossA.forward(outsA, labels, masks)
|
37 |
+
nlossV = self.lossV.forward(outsV, labels, masks)
|
38 |
+
|
39 |
+
nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
|
40 |
+
|
41 |
+
return {"loss": nloss, "logits": outsAV}
|
42 |
+
|
43 |
+
else:
|
44 |
+
|
45 |
+
return {"logits": outsAV}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6918e8391d48c40cfd90b332687508bb4b2269879ba1303dacb5a26937ecda87
|
3 |
+
size 137464429
|