Skip to content

Commit 174a421

Browse files
committed
Tidy up
1 parent 473fcf3 commit 174a421

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ pure module subroutine forward(self, input)
6262
integer :: iws, iwe
6363

6464
input_channels = size(input, dim=1)
65-
input_width = size(input, dim=2)
65+
input_width = size(input, dim=2)
6666

6767
! Loop over output positions.
6868
do j = 1, self % width
@@ -73,11 +73,11 @@ pure module subroutine forward(self, input)
7373

7474
! For each filter, compute the convolution (inner product over channels and kernel width).
7575
do concurrent (n = 1:self % filters)
76-
self % z(n, j) = sum(self % kernel(n, :, :) * input(:, iws:iwe))
76+
self % z(n, j) = sum(self % kernel(n,:,:) * input(:,iws:iwe))
7777
end do
7878

7979
! Add the bias for each filter.
80-
self % z(:, j) = self % z(:, j) + self % biases
80+
self % z(:,j) = self % z(:,j) + self % biases
8181
end do
8282

8383
! Apply the activation function.
@@ -103,18 +103,14 @@ pure module subroutine backward(self, input, gradient)
103103

104104
! Determine dimensions.
105105
input_channels = size(input, dim=1)
106-
input_width = size(input, dim=2)
107-
output_width = self % width ! Note: output_width = input_width - kernel_size + 1
106+
input_width = size(input, dim=2)
107+
output_width = self % width ! Note: output_width = input_width - kernel_size + 1
108108

109109
!--- Compute the local gradient gdz = (dL/dy) * sigma'(z) for each output.
110-
do j = 1, output_width
111-
gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j))
112-
end do
110+
gdz = gradient * self % activation % eval_prime(self % z)
113111

114112
!--- Compute bias gradients: db(n) = sum_j gdz(n, j)
115-
do n = 1, self % filters
116-
db_local(n) = sum(gdz(n, :), dim=1)
117-
end do
113+
db_local = sum(gdz, dim=2)
118114

119115
!--- Initialize weight gradient and input gradient accumulators.
120116
dw_local = 0.0
@@ -124,16 +120,16 @@ pure module subroutine backward(self, input, gradient)
124120
! In the forward pass the window for output index j was:
125121
! iws = j, iwe = j + kernel_size - 1.
126122
do n = 1, self % filters
127-
do j = 1, output_width
128-
iws = j
129-
iwe = j + self % kernel_size - 1
130-
do k = 1, self % channels
131-
! Weight gradient: accumulate contribution from the input window.
132-
dw_local(n, k, :) = dw_local(n, k, :) + input(k, iws:iwe) * gdz(n, j)
133-
! Input gradient: propagate gradient back to the input window.
134-
self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, j)
135-
end do
136-
end do
123+
do j = 1, output_width
124+
iws = j
125+
iwe = j + self % kernel_size - 1
126+
do k = 1, self % channels
127+
! Weight gradient: accumulate contribution from the input window.
128+
dw_local(n,k,:) = dw_local(n,k,:) + input(k,iws:iwe) * gdz(n,j)
129+
! Input gradient: propagate gradient back to the input window.
130+
self % gradient(k,iws:iwe) = self % gradient(k,iws:iwe) + self % kernel(n,k,:) * gdz(n,j)
131+
end do
132+
end do
137133
end do
138134

139135
!--- Update stored gradients.

0 commit comments

Comments
 (0)