-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMultiLayerNet.py
More file actions
41 lines (29 loc) · 1.08 KB
/
MultiLayerNet.py
File metadata and controls
41 lines (29 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch.nn as nn
from typing import List
import torch
Vector = List[int]
class Network(nn.Module):
def __init__(self, net_dims:Vector) -> None:
"""Constructor for multi-layer perceptron pytorch class
params:
net_dims: dimensions of each layer in neural network
"""
super(Network, self).__init__()
layers = []
for i in range(len(net_dims) - 1):
if isinstance(net_dims[i + 1], str):
layers.append(nn.Linear(net_dims[i], net_dims[i + 2]))
else:
layers.append(nn.Linear(net_dims[i], net_dims[i + 1]))
# use activation function if not at end of layer
if i != len(net_dims) - 2:
layers.append(nn.ReLU())
self.net = nn.Sequential(*layers)
def forward(self, x:torch.tensor) -> torch.tensor:
"""Pass data through the network model
params:
x: data to pass though neural network
returns:
output from forward pass through NN
"""
return torch.squeeze(self.net(x))