Skip to content

Commit c80301d

Browse files
committed
bug fixes
1 parent 2a017f0 commit c80301d

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

GCN/GCNModel.fs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ let gcnLayer in_features out_features hasBias (adj:TorchTensor) =
1010
let bias = if hasBias then Parameter(randName(),Float32Tensor.empty([|out_features|],requiresGrad=true)) |> Some else None
1111
let parms = [| yield weight; if hasBias then yield bias.Value|]
1212
Init.kaiming_uniform(weight.Tensor) |> ignore
13+
if hasBias then Init.uniform(bias.Value.Tensor,0.,1.0) |> ignore
1314

1415
Model.create(parms,fun wts t ->
1516
use support = t.mm(wts.[0])
@@ -22,9 +23,7 @@ let gcnLayer in_features out_features hasBias (adj:TorchTensor) =
2223
let create nfeat nhid nclass dropout adj =
2324
let gc1 = gcnLayer nfeat nhid true adj
2425
let gc2 = gcnLayer nhid nclass true adj
25-
// let relu = ReLU()
26-
// let logm = LogSoftmax(1L)
27-
let drp = if dropout then Dropout() |> M else Model.nop
26+
let drp = if dropout > 0.0 then Dropout(dropout) |> M else Model.nop
2827

2928
fwd3 gc1 gc2 drp (fun t g1 g2 drp ->
3029
use t = gc1.forward(t)

GCN/TorchSharp.Fun.fs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,10 @@ let inline (->>) m1 m2 = compose (M m1) (None,M m2)
162162
let inline (=>>) m1 (n,m2) = compose (M m1) (Some n, M m2)
163163

164164
module Tensor =
165+
//Note: ensure 't matches tensor datatype otherwise ToArray might crash the app (i.e. exception cannot be caught)
165166
let private _getData<'t> (t:TorchTensor) =
166167
let s = t.Data<'t>()
167-
let xs = Array.zeroCreate s.Length
168-
for i in 0 .. s.Length-1 do
169-
xs.[i] <- s.[i]
170-
171-
//s.ToArray()
172-
xs
168+
s.ToArray()
173169

174170
let getData<'t> (t:TorchTensor) =
175171
if t.device_type <> DeviceType.CPU then

GCN/Train.fs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay)
1818

1919
let nclass = labels.max().ToInt64() + 1L
2020

21-
let model = GCNModel.create features.shape.[1] (int64 hidden) nclass true adj
21+
let model = GCNModel.create features.shape.[1] (int64 hidden) nclass dropout adj
2222
let loss = NN.Functions.nll_loss()
2323

2424
if cuda then
@@ -29,10 +29,13 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay)
2929
let train epoch =
3030
let t = DateTime.Now
3131
model.Module.Train()
32+
let parms = model.Module.parameters()
3233
optimizer.zero_grad()
3334
let output = model.forward(features)
3435
let loss_train = loss.Invoke(output.[ idx_train], labels.[idx_train])
36+
let ls = float loss_train
3537
let acc_train = Utils.accuracy(output.[idx_train], labels.[idx_train])
38+
printfn $"training - loss: {ls}, acc: {acc_train}"
3639
loss_train.backward()
3740
optimizer.step()
3841

GCN/Utils.fs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ let sparse_mx_to_torch_sparse_tensor (m:Matrix<float32>) =
3535
let idxs = Seq.append rows cols |> Seq.toArray
3636
let idx1 = idxs |> Int64Tensor.from |> fun x -> x.view(2L,-1L)
3737
let vals = coo |> Seq.map(fun (r,c,v) -> v) |> Seq.toArray |> Float32Tensor.from
38-
Float32Tensor.sparse(idx1,vals,[|int64 m.RowCount; int64 m.ColumnCount|])
38+
let t = Float32Tensor.sparse(idx1,vals,[|int64 m.RowCount; int64 m.ColumnCount|])
39+
let dt = TorchSharp.Fun.Tensor.getData<float32>(t.to_dense())
40+
t
3941

4042
let accuracy(output:TorchTensor, labels:TorchTensor) =
41-
let predsData = TorchSharp.Fun.Tensor.getData<int64>(output)
43+
let predsData = TorchSharp.Fun.Tensor.getData<float32>(output)
4244
let preds = predsData |> Array.chunkBySize (int output.shape.[1]) |> Array.map maxIdx
4345
let lbls = TorchSharp.Fun.Tensor.getData<int64>(labels)
4446
let correct = Array.zip preds lbls |> Array.filter (fun (a,b) -> a = b) |> Array.length |> float

0 commit comments

Comments
 (0)