Skip to content

Commit b9548a7

Browse files
SevKodhbredin
andcommitted
feat(model) : add segmentation model based on self-supervised representation (#1362)
Co-authored-by: Hervé BREDIN <[email protected]>
1 parent 6740db2 commit b9548a7

File tree

4 files changed

+255
-6
lines changed

4 files changed

+255
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@
3232

3333
### Features and improvements
3434

35+
- feat(task): add [powerset](https://www.isca-speech.org/archive/interspeech_2023/plaquet23_interspeech.html) support to `SpeakerDiarization` task
3536
- feat(task): add support for multi-task models
37+
- feat(task): add support for label scope in speaker diarization task
38+
- feat(task): add support for missing classes in multi-label segmentation task
39+
- feat(model): add segmentation model based on torchaudio self-supervised representation
3640
- feat(pipeline): send pipeline to device with `pipeline.to(device)`
37-
- feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`)
38-
- feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task
3941
- feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline
42+
- feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`)
4043
- feat(pipeline): add progress hook to pipelines
4144
- feat(pipeline): check version compatibility at load time
42-
- feat(task): add support for label scope in speaker diarization task
43-
- feat(task): add support for missing classes in multi-label segmentation task
4445
- improve(task): load metadata as tensors rather than pyannote.core instances
4546
- improve(task): improve error message on missing specifications
4647

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# @package _group_
2+
_target_: pyannote.audio.models.segmentation.SSeRiouSS
3+
wav2vec: WAVLM_BASE
4+
wav2vec_layer: -1
5+
lstm:
6+
hidden_size: 128
7+
num_layers: 4
8+
bidirectional: true
9+
monolithic: true
10+
dropout: 0.5
11+
linear:
12+
hidden_size: 128
13+
num_layers: 2
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2023- CNRS
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
24+
from typing import Optional, Union
25+
26+
import torch
27+
import torch.nn as nn
28+
import torch.nn.functional as F
29+
import torchaudio
30+
from pyannote.core.utils.generators import pairwise
31+
32+
from pyannote.audio.core.model import Model
33+
from pyannote.audio.core.task import Task
34+
from pyannote.audio.utils.params import merge_dict
35+
36+
37+
class SSeRiouSS(Model):
38+
"""Self-Supervised Representation for Speaker Segmentation
39+
40+
wav2vec > LSTM > Feed forward > Classifier
41+
42+
Parameters
43+
----------
44+
sample_rate : int, optional
45+
Audio sample rate. Defaults to 16kHz (16000).
46+
num_channels : int, optional
47+
Number of channels. Defaults to mono (1).
48+
wav2vec: dict or str, optional
49+
Defaults to "WAVLM_BASE".
50+
wav2vec_layer: int, optional
51+
Index of layer to use as input to the LSTM.
52+
Defaults (-1) to use average of all layers (with learnable weights).
53+
lstm : dict, optional
54+
Keyword arguments passed to the LSTM layer.
55+
Defaults to {"hidden_size": 128, "num_layers": 4, "bidirectional": True},
56+
i.e. two bidirectional layers with 128 units each.
57+
Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs.
58+
This may proove useful for probing LSTM internals.
59+
linear : dict, optional
60+
Keyword arugments used to initialize linear layers
61+
Defaults to {"hidden_size": 128, "num_layers": 2},
62+
i.e. two linear layers with 128 units each.
63+
"""
64+
65+
WAV2VEC_DEFAULTS = "WAVLM_BASE"
66+
67+
LSTM_DEFAULTS = {
68+
"hidden_size": 128,
69+
"num_layers": 4,
70+
"bidirectional": True,
71+
"monolithic": True,
72+
"dropout": 0.0,
73+
}
74+
LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2}
75+
76+
def __init__(
77+
self,
78+
wav2vec: Union[dict, str] = None,
79+
wav2vec_layer: int = -1,
80+
lstm: dict = None,
81+
linear: dict = None,
82+
sample_rate: int = 16000,
83+
num_channels: int = 1,
84+
task: Optional[Task] = None,
85+
):
86+
super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)
87+
88+
if isinstance(wav2vec, str):
89+
# `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE")
90+
if hasattr(torchaudio.pipelines, wav2vec):
91+
bundle = getattr(torchaudio.pipelines, wav2vec)
92+
if sample_rate != bundle._sample_rate:
93+
raise ValueError(
94+
f"Expected {bundle._sample_rate}Hz, found {sample_rate}Hz."
95+
)
96+
wav2vec_dim = bundle._params["encoder_embed_dim"]
97+
wav2vec_num_layers = bundle._params["encoder_num_layers"]
98+
self.wav2vec = bundle.get_model()
99+
100+
# `wav2vec` is a path to a self-supervised representation checkpoint
101+
else:
102+
_checkpoint = torch.load(wav2vec)
103+
wav2vec = _checkpoint.pop("config")
104+
self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec)
105+
state_dict = _checkpoint.pop("state_dict")
106+
self.wav2vec.load_state_dict(state_dict)
107+
wav2vec_dim = wav2vec["encoder_embed_dim"]
108+
wav2vec_num_layers = wav2vec["encoder_num_layers"]
109+
110+
# `wav2vec` is a config dictionary understood by `wav2vec2_model`
111+
# this branch is typically used by Model.from_pretrained(...)
112+
elif isinstance(wav2vec, dict):
113+
self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec)
114+
wav2vec_dim = wav2vec["encoder_embed_dim"]
115+
wav2vec_num_layers = wav2vec["encoder_num_layers"]
116+
117+
if wav2vec_layer < 0:
118+
self.wav2vec_weights = nn.Parameter(
119+
data=torch.ones(wav2vec_num_layers), requires_grad=True
120+
)
121+
122+
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
123+
lstm["batch_first"] = True
124+
linear = merge_dict(self.LINEAR_DEFAULTS, linear)
125+
126+
self.save_hyperparameters("wav2vec", "wav2vec_layer", "lstm", "linear")
127+
128+
monolithic = lstm["monolithic"]
129+
if monolithic:
130+
multi_layer_lstm = dict(lstm)
131+
del multi_layer_lstm["monolithic"]
132+
self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm)
133+
134+
else:
135+
num_layers = lstm["num_layers"]
136+
if num_layers > 1:
137+
self.dropout = nn.Dropout(p=lstm["dropout"])
138+
139+
one_layer_lstm = dict(lstm)
140+
one_layer_lstm["num_layers"] = 1
141+
one_layer_lstm["dropout"] = 0.0
142+
del one_layer_lstm["monolithic"]
143+
144+
self.lstm = nn.ModuleList(
145+
[
146+
nn.LSTM(
147+
wav2vec_dim
148+
if i == 0
149+
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
150+
**one_layer_lstm,
151+
)
152+
for i in range(num_layers)
153+
]
154+
)
155+
156+
if linear["num_layers"] < 1:
157+
return
158+
159+
lstm_out_features: int = self.hparams.lstm["hidden_size"] * (
160+
2 if self.hparams.lstm["bidirectional"] else 1
161+
)
162+
self.linear = nn.ModuleList(
163+
[
164+
nn.Linear(in_features, out_features)
165+
for in_features, out_features in pairwise(
166+
[
167+
lstm_out_features,
168+
]
169+
+ [self.hparams.linear["hidden_size"]]
170+
* self.hparams.linear["num_layers"]
171+
)
172+
]
173+
)
174+
175+
def build(self):
176+
if self.hparams.linear["num_layers"] > 0:
177+
in_features = self.hparams.linear["hidden_size"]
178+
else:
179+
in_features = self.hparams.lstm["hidden_size"] * (
180+
2 if self.hparams.lstm["bidirectional"] else 1
181+
)
182+
183+
if isinstance(self.specifications, tuple):
184+
raise ValueError("SSeRiouSS model does not support multi-tasking.")
185+
186+
if self.specifications.powerset:
187+
out_features = self.specifications.num_powerset_classes
188+
else:
189+
out_features = len(self.specifications.classes)
190+
191+
self.classifier = nn.Linear(in_features, out_features)
192+
self.activation = self.default_activation()
193+
194+
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
195+
"""Pass forward
196+
197+
Parameters
198+
----------
199+
waveforms : (batch, channel, sample)
200+
201+
Returns
202+
-------
203+
scores : (batch, frame, classes)
204+
"""
205+
206+
num_layers = (
207+
None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer
208+
)
209+
210+
with torch.no_grad():
211+
outputs, _ = self.wav2vec.extract_features(
212+
waveforms.squeeze(1), num_layers=num_layers
213+
)
214+
215+
if num_layers is None:
216+
outputs = torch.stack(outputs, dim=-1) @ F.softmax(
217+
self.wav2vec_weights, dim=0
218+
)
219+
else:
220+
outputs = outputs[-1]
221+
222+
if self.hparams.lstm["monolithic"]:
223+
outputs, _ = self.lstm(outputs)
224+
else:
225+
for i, lstm in enumerate(self.lstm):
226+
outputs, _ = lstm(outputs)
227+
if i + 1 < self.hparams.lstm["num_layers"]:
228+
outputs = self.dropout(outputs)
229+
230+
if self.hparams.linear["num_layers"] > 0:
231+
for linear in self.linear:
232+
outputs = F.leaky_relu(linear(outputs))
233+
234+
return self.activation(self.classifier(outputs))

pyannote/audio/models/segmentation/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (c) 2020 CNRS
3+
# Copyright (c) 2020- CNRS
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the "Software"), to deal
@@ -21,5 +21,6 @@
2121
# SOFTWARE.
2222

2323
from .PyanNet import PyanNet
24+
from .SSeRiouSS import SSeRiouSS
2425

25-
__all__ = ["PyanNet"]
26+
__all__ = ["PyanNet", "SSeRiouSS"]

0 commit comments

Comments
 (0)