Skip to content

Commit 2a017f0

Browse files
committed
updated testing
1 parent 6105f96 commit 2a017f0

File tree

4 files changed

+27
-15
lines changed

4 files changed

+27
-15
lines changed

GCN/GCNModel.fs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ open TorchSharp.Fun
66
let inline (!>) (x:^a) : ^b = ((^a or ^b) : (static member op_Implicit : ^a -> ^b) x)
77

88
let gcnLayer in_features out_features hasBias (adj:TorchTensor) =
9-
let weight = Parameter(randName(),Float32Tensor.empty([|in_features; out_features|]))
10-
let bias = if hasBias then Parameter(randName(),Float32Tensor.empty([|out_features|])) |> Some else None
9+
let weight = Parameter(randName(),Float32Tensor.empty([|in_features; out_features|],requiresGrad=true))
10+
let bias = if hasBias then Parameter(randName(),Float32Tensor.empty([|out_features|],requiresGrad=true)) |> Some else None
1111
let parms = [| yield weight; if hasBias then yield bias.Value|]
1212
Init.kaiming_uniform(weight.Tensor) |> ignore
1313

1414
Model.create(parms,fun wts t ->
15-
let support = t.mm(wts.[0])
15+
use support = t.mm(wts.[0])
1616
let output = adj.mm(support)
1717
if hasBias then
1818
output.add(wts.[1])
@@ -22,14 +22,14 @@ let gcnLayer in_features out_features hasBias (adj:TorchTensor) =
2222
let create nfeat nhid nclass dropout adj =
2323
let gc1 = gcnLayer nfeat nhid true adj
2424
let gc2 = gcnLayer nhid nclass true adj
25-
let relu = ReLU()
26-
let logm = LogSoftmax(1L)
25+
// let relu = ReLU()
26+
// let logm = LogSoftmax(1L)
2727
let drp = if dropout then Dropout() |> M else Model.nop
2828

2929
fwd3 gc1 gc2 drp (fun t g1 g2 drp ->
3030
use t = gc1.forward(t)
31-
use t = relu.forward(t)
31+
use t = Functions.ReLU(t)
3232
use t = drp.forward(t)
3333
use t = gc2.forward(t)
34-
let t = logm.forward(t)
34+
let t = Functions.LogSoftmax(t, dimension=1L)
3535
t)

GCN/Program.fs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ module Defs =
4343
let parse args =
4444
let parser = ArgumentParser.Create<Args>(programName = "gcn.exe")
4545
let args = parser.Parse(args)
46-
let datafolder = args.GetResult (Args.Datafolder, defaultValue = @"C:\Users\fwaris\Downloads\pygcn-master\data\cora")
46+
let datafolder = args.GetResult (Args.Datafolder, defaultValue = @"C:\s\Repos\gcn\data\cora")
4747
let no_cuda = args.GetResult (Args.No_CUDA, defaultValue=no_cuda)
4848
let fastmode = args.GetResult (Args.Fastmode, defaultValue=fastmode)
4949
let epochs = args.GetResult (Args.Epochs, defaultValue=epochs)
@@ -55,7 +55,11 @@ module Defs =
5555
datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay
5656

5757
[<EntryPoint>]
58-
let main args =
59-
let runParms = Defs.parse args
58+
let main args =
59+
let runParms = Defs.parse args
60+
try
6061
Train.run runParms
61-
0
62+
with ex ->
63+
printfn "%s" ex.Message
64+
System.Console.ReadLine() |> ignore
65+
0

GCN/TorchSharp.Fun.fs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,18 @@ let inline (=>>) m1 (n,m2) = compose (M m1) (Some n, M m2)
164164
module Tensor =
165165
let private _getData<'t> (t:TorchTensor) =
166166
let s = t.Data<'t>()
167-
s.ToArray()
167+
let xs = Array.zeroCreate s.Length
168+
for i in 0 .. s.Length-1 do
169+
xs.[i] <- s.[i]
170+
171+
//s.ToArray()
172+
xs
168173

169174
let getData<'t> (t:TorchTensor) =
170175
if t.device_type <> DeviceType.CPU then
171-
use t = t.clone()
172-
use t = t.cpu()
173-
_getData<'t> t
176+
//use t1 = t.clone()
177+
use t2 = t.cpu()
178+
_getData<'t> t2
174179
else
175180
_getData<'t> t
176181

GCN/Train.fs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay)
5959

6060
let t_total = DateTime.Now
6161
for i in 1 .. epochs-1 do
62+
printfn $"epoch {i}"
6263
train i
6364
printfn "Optimization done"
6465
printfn $"Time elapsed: {(DateTime.Now - t_total).TotalMinutes} minutes"
6566

67+
test()
68+

0 commit comments

Comments
 (0)