Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 736666d

Browse files
Adam Procterdiyessi
authored andcommitted
Fix sum reference to handle corner cases with +-inf (#3412)
* Fix sum reference to handle corner cases with +-inf * Review comments, and try to make Windows happy * Update GPU unit_test.manifest * More template grindery, to make macOS happy
1 parent 5fb9fd1 commit 736666d

File tree

3 files changed

+79
-8
lines changed

3 files changed

+79
-8
lines changed

src/ngraph/runtime/gpu/unit_test.manifest

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ all_2x2x3_eliminate_dims_1_2
100100
all_2x2x3_eliminate_dims_0_1_2
101101
all_dynamic
102102

103+
# Corner-case tests for sum with infs and -infs.
104+
sum_inf
105+
103106
# GPU backend uses floats to implement these ops for int32
104107
floor_int32
105108
divide_int32

src/ngraph/runtime/reference/sum.hpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@ namespace ngraph
2727
{
2828
namespace reference
2929
{
30+
// Windows doesn't seem to like it if we directly use std::isfinite on integer types,
31+
// so we will roll our own thing here.
32+
template <typename T>
33+
typename std::enable_if<std::is_floating_point<T>::value, bool>::type is_finite(T x)
34+
{
35+
return std::isfinite(x);
36+
}
37+
38+
template <typename T>
39+
typename std::enable_if<std::is_same<T, bfloat16>::value ||
40+
std::is_same<T, float16>::value,
41+
bool>::type
42+
is_finite(T x)
43+
{
44+
return std::isfinite(static_cast<float>(x));
45+
}
46+
47+
template <typename T>
48+
typename std::enable_if<std::is_integral<T>::value, bool>::type is_finite(T x)
49+
{
50+
return true;
51+
}
52+
3053
template <typename T>
3154
void sum(const T* arg,
3255
T* out,
@@ -35,25 +58,34 @@ namespace ngraph
3558
const AxisSet& reduction_axes)
3659
{
3760
CoordinateTransform output_transform(out_shape);
38-
std::vector<T> c(shape_size(out_shape));
61+
std::vector<T> cs(shape_size(out_shape));
3962

4063
for (const Coordinate& output_coord : output_transform)
4164
{
4265
out[output_transform.index(output_coord)] = 0;
43-
c[output_transform.index(output_coord)] = 0;
66+
cs[output_transform.index(output_coord)] = 0;
4467
}
4568

4669
CoordinateTransform input_transform(in_shape);
4770

4871
for (const Coordinate& input_coord : input_transform)
4972
{
5073
Coordinate output_coord = reduce(input_coord, reduction_axes);
51-
T y = arg[input_transform.index(input_coord)] -
52-
c[output_transform.index(output_coord)];
53-
T t = out[output_transform.index(output_coord)] + y;
54-
c[output_transform.index(output_coord)] =
55-
(t - out[output_transform.index(output_coord)]) - y;
56-
out[output_transform.index(output_coord)] = t;
74+
75+
T x = arg[input_transform.index(input_coord)];
76+
T& z = out[output_transform.index(output_coord)];
77+
78+
if (is_finite(x) && is_finite(z))
79+
{
80+
T& c = cs[output_transform.index(output_coord)];
81+
T t = z + (x - c);
82+
c = (t - z) - (x - c);
83+
z = t;
84+
}
85+
else
86+
{
87+
z = z + x;
88+
}
5789
}
5890
}
5991
}

test/backend/sum.in.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,39 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_dynamic)
740740
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
741741
}
742742
}
743+
744+
NGRAPH_TEST(${BACKEND_NAME}, sum_inf)
745+
{
746+
Shape shape{7, 4};
747+
auto A = make_shared<op::Parameter>(element::f32, shape);
748+
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{1}), ParameterVector{A});
749+
750+
auto infi = std::numeric_limits<float>::infinity();
751+
752+
auto backend = runtime::Backend::create("${BACKEND_NAME}");
753+
754+
// Create some tensors for input/output
755+
auto a = backend->create_tensor(element::f32, shape);
756+
copy_data(a,
757+
test::NDArray<float, 2>({{-infi, 0, 0, infi},
758+
{infi, 100, -100, -infi},
759+
{infi, 0, 100, infi},
760+
{-infi, -100, 0, -infi},
761+
{infi, infi, infi, infi},
762+
{infi, infi, infi, -infi},
763+
{infi, std::nanf(""), 42, infi}})
764+
.get_vector());
765+
auto result = backend->create_tensor(element::f32, Shape{7});
766+
767+
auto handle = backend->compile(f);
768+
handle->call_with_validate({result}, {a});
769+
auto r = read_vector<float>(result);
770+
ASSERT_EQ(r.size(), 7);
771+
EXPECT_TRUE(isnan(r[0]));
772+
EXPECT_TRUE(isnan(r[1]));
773+
EXPECT_TRUE(r[2] > 0 && isinf(r[2]));
774+
EXPECT_TRUE(r[3] < 0 && isinf(r[3]));
775+
EXPECT_TRUE(r[4] > 0 && isinf(r[4]));
776+
EXPECT_TRUE(isnan(r[5]));
777+
EXPECT_TRUE(isnan(r[6]));
778+
}

0 commit comments

Comments
 (0)