Skip to content

Commit a61aa9d

Browse files
committed
Initial SMS-AF and CT-AF push
1 parent 56c4665 commit a61aa9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+5521
-116
lines changed

metaaf/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def on_val_batch_end(self, out, aux, data_batch, cur_batch, cur_epoch):
194194
base_name = os.path.join(epoch_dir, f"{self.num_logged}")
195195
sf.write(f"{base_name}_out.wav", np.array(out[batch_idx, :, 0]), self.fs)
196196

197-
for (k, v) in data_batch["signals"].items():
197+
for k, v in data_batch["signals"].items():
198198
sf.write(f"{base_name}_{k}.wav", np.array(v[batch_idx, :, 0]), self.fs)
199199

200200
batch_idx += 1

metaaf/complex_groupnorm.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

metaaf/complex_gru.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
complex_sigmoid,
2929
complex_tanh,
3030
)
31+
from metaaf.complex_norm import CLNorm
3132
import types
3233
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
3334

@@ -56,6 +57,7 @@ def __init__(
5657
w_i_init: Optional[hk.initializers.Initializer] = None,
5758
w_h_init: Optional[hk.initializers.Initializer] = None,
5859
b_init: Optional[hk.initializers.Initializer] = None,
60+
use_norm=False,
5961
name: Optional[str] = None,
6062
):
6163
super().__init__(name=name)
@@ -65,6 +67,12 @@ def __init__(
6567
self.b_init = b_init or complex_zeros
6668
self.sig = complex_sigmoid
6769

70+
self.use_norm = use_norm
71+
if self.use_norm:
72+
self.i_norm = CLNorm(axis=-1, create_scale=True, create_offset=True)
73+
self.zrh_norm = CLNorm(axis=-1, create_scale=True, create_offset=True)
74+
self.ah_norm = CLNorm(axis=-1, create_scale=True, create_offset=True)
75+
6876
def __call__(self, inputs, state):
6977
if inputs.ndim not in (1, 2):
7078
raise ValueError("GRU input must be rank-1 or rank-2.")
@@ -82,15 +90,22 @@ def __call__(self, inputs, state):
8290
b_z, b_a = jnp.split(b, indices_or_sections=[2 * hidden_size], axis=0)
8391

8492
gates_x = jnp.matmul(inputs, w_i)
93+
if self.use_norm:
94+
gates_x = self.i_norm(gates_x)
8595

8696
zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * hidden_size], axis=-1)
8797
zr_h = jnp.matmul(state, w_h_z)
98+
if self.use_norm:
99+
zr_h = self.zrh_norm(zr_h)
88100

89101
zr = zr_x + zr_h + jnp.broadcast_to(b_z, zr_h.shape)
90102
z, r = jnp.split(self.sig(zr), indices_or_sections=2, axis=-1)
91103

92104
a_h = jnp.matmul(r * state, w_h_a)
93105

106+
if self.use_norm:
107+
a_h = self.ah_norm(a_h)
108+
94109
a = complex_tanh(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape))
95110

96111
next_state = (1 - z) * state + z * a
@@ -120,7 +135,7 @@ def make_deep_initial_state(params, **kwargs):
120135
n_layers = kwargs["n_layers"]
121136

122137
def single_layer_initial_state():
123-
state = jnp.zeros([h_size], dtype=np.dtype("complex64"))
138+
state = jnp.zeros([h_size], params.dtype) # dtype=np.dtype("complex64"))
124139
state = add_batch(state, b_size)
125140
return state
126141

metaaf/complex_norm.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import jax.numpy as jnp
2+
import haiku as hk
3+
4+
5+
class CGN(hk.Module):
6+
def __init__(
7+
self, groups=6, create_scale=True, create_offset=True, eps=1e-5, name=None
8+
):
9+
super().__init__(name=name)
10+
self.gn_real = hk.GroupNorm(
11+
groups=groups,
12+
create_scale=create_scale,
13+
create_offset=create_offset,
14+
eps=eps,
15+
)
16+
self.gn_imag = hk.GroupNorm(
17+
groups=groups,
18+
create_scale=create_scale,
19+
create_offset=create_offset,
20+
eps=eps,
21+
)
22+
23+
def __call__(self, x):
24+
x_real = jnp.real(x)
25+
x_imag = jnp.imag(x)
26+
27+
x_real_n = self.gn_real(x_real)
28+
x_imag_n = self.gn_imag(x_imag)
29+
30+
return (x_real_n + 1j * x_imag_n) / jnp.sqrt(2)
31+
32+
33+
class CLNorm(hk.Module):
34+
def __init__(
35+
self,
36+
axis,
37+
create_scale,
38+
create_offset,
39+
eps=1e-05,
40+
scale_init=None,
41+
offset_init=None,
42+
use_fast_variance=False,
43+
name=None,
44+
param_axis=None,
45+
):
46+
super().__init__(name=name)
47+
self.n_real = hk.LayerNorm(
48+
axis=axis,
49+
create_scale=create_scale,
50+
create_offset=create_offset,
51+
eps=eps,
52+
scale_init=scale_init,
53+
offset_init=offset_init,
54+
use_fast_variance=use_fast_variance,
55+
name=name,
56+
param_axis=param_axis,
57+
)
58+
59+
self.n_imag = hk.LayerNorm(
60+
axis=axis,
61+
create_scale=create_scale,
62+
create_offset=create_offset,
63+
eps=eps,
64+
scale_init=scale_init,
65+
offset_init=offset_init,
66+
use_fast_variance=use_fast_variance,
67+
name=name,
68+
param_axis=param_axis,
69+
)
70+
71+
def __call__(self, x):
72+
x_real = jnp.real(x)
73+
x_imag = jnp.imag(x)
74+
75+
x_real_n = self.n_real(x_real)
76+
x_imag_n = self.n_imag(x_imag)
77+
78+
return (x_real_n + 1j * x_imag_n) / jnp.sqrt(2)

metaaf/complex_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,33 @@
33
import haiku as hk
44

55

6-
def complex_zeros(shape, _):
7-
return jnp.zeros(shape, dtype=jnp.complex64)
6+
def complex_zeros(shape, dtype):
7+
return jnp.zeros(shape, dtype=dtype) + 1j * jnp.zeros(shape, dtype=dtype)
88

99

1010
# see https://openreview.net/attachment?id=H1T2hmZAb&name=pdf
1111
def complex_variance_scaling(shape, dtype):
1212
real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
1313
imag = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
1414

15-
mag = jnp.sqrt(real ** 2 + imag ** 2)
15+
mag = jnp.sqrt(real**2 + imag**2)
16+
angle = hk.initializers.RandomUniform(minval=-jnp.pi, maxval=jnp.pi)(
17+
shape, dtype=jnp.float32
18+
)
19+
20+
return mag * jnp.exp(1j * angle)
21+
22+
23+
def complex_xavier(shape, dtype):
24+
real = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")(
25+
shape, dtype=jnp.float32
26+
)
27+
imag = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")(
28+
shape, dtype=jnp.float32
29+
)
30+
31+
eps = 1e-11
32+
mag = jnp.sqrt(real**2 + imag**2 + eps)
1633
angle = hk.initializers.RandomUniform(minval=-jnp.pi, maxval=jnp.pi)(
1734
shape, dtype=jnp.float32
1835
)

0 commit comments

Comments
 (0)