Skip to content

Commit e5988bb

Browse files
dzelletensorflower-gardener
authored andcommitted
Add structured readout versions of binary and multiclass classification tasks.
PiperOrigin-RevId: 588115018
1 parent 68f1e28 commit e5988bb

File tree

4 files changed

+467
-26
lines changed

4 files changed

+467
-26
lines changed

tensorflow_gnn/api_def/runner-symbols.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ runner.ParameterServerStrategy
2626
runner.PassthruDatasetProvider
2727
runner.PassthruSampleDatasetsProvider
2828
runner.Predictions
29+
runner.NodeBinaryClassification
2930
runner.RootNodeBinaryClassification
3031
runner.RootNodeLabelFn
3132
runner.RootNodeMeanAbsoluteError
@@ -34,6 +35,7 @@ runner.RootNodeMeanAbsolutePercentageError
3435
runner.RootNodeMeanSquaredError
3536
runner.RootNodeMeanSquaredLogScaledError
3637
runner.RootNodeMeanSquaredLogarithmicError
38+
runner.NodeMulticlassClassification
3739
runner.RootNodeMulticlassClassification
3840
runner.RunResult
3941
runner.SampleTFRecordDatasetsProvider

tensorflow_gnn/runner/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,12 @@
9494
# in `distribute_test.py`.)
9595
#
9696
# Tasks (Classification)
97-
RootNodeBinaryClassification = classification.RootNodeBinaryClassification
98-
RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification
9997
GraphBinaryClassification = classification.GraphBinaryClassification
10098
GraphMulticlassClassification = classification.GraphMulticlassClassification
99+
NodeBinaryClassification = classification.NodeBinaryClassification
100+
NodeMulticlassClassification = classification.NodeMulticlassClassification
101+
RootNodeBinaryClassification = classification.RootNodeBinaryClassification
102+
RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification
101103
# Tasks (Link Prediction)
102104
DotProductLinkPrediction = link_prediction.DotProductLinkPrediction
103105
HadamardProductLinkPrediction = link_prediction.HadamardProductLinkPrediction

0 commit comments

Comments
 (0)