Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/DAT/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ function Base.materialize(bc::Broadcast.Broadcasted{XStyle})
args2 = map(arg -> arg isa Broadcast.Broadcasted ? Base.materialize(arg) : arg, bc.args)
args2 = map(to_yax, args2)
# determine output type by calling `eltype` on a dummy function call
dummy_args = map(a -> first(a.data), args2)
outtype = typeof(bc.f(dummy_args...))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applying f here is for a reason, isn't? e.g. this could determine if the output is just a number or an Array, an f that reduces along a given dimension or not, like cumsum.

not sure...

intypes = (eltype.(args2)...,)
@debug intypes
outtypes = Base.return_types(bc.f, intypes)
outtype = Union{outtypes...}
@debug outtype
return xmap(XFunction(bc.f; inplace=false), args2..., output=XOutput(; outtype))
end
function Base.materialize!(bc::Broadcast.Broadcasted{XStyle})
args2 = map(arg -> arg isa Broadcast.Broadcasted ? Base.materialize(arg) : arg, bc.args)
args2 = map(to_yax, args2)
dummy_args = map(a -> first(a.data), args2)
outtype = typeof(bc.f(dummy_args...))
intypes = (eltype.(args2)...,)
@debug intypes
outtypes = Base.return_types(bc.f, intypes)
outtype = promote_type{outtypes...}
return xmap(XFunction(bc.f; inplace=true), args2..., output=XOutput(; outtype))
end
13 changes: 13 additions & 0 deletions test/DAT/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,17 @@ a, b, c = sample_arrays()
xscalar = a .* 3 .+ 1
@test all(xscalar[:] .== 4.0)
@test isa(a .+ b, YAXArray)
end

@testset "missing handling" begin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we also test the promote_type? as well as forcing the outtype? just to make sure things work in all cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we could also test the promote_type.
How would you force the outtype in a broadcast operation?
What would be the syntax?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree outtype does not exist for broadcast but only for xmap. However, a quick test for something a function like f(x) = rand() > 0.5 ? Float64(1) : Float32(1) and then f.(cube) would be nice to add.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this coming from an actual use case or is this good to have in case?
For Union{Float32, Float64} we run into problems with zero because this is not defined for these general unions. zero is used in DiskArrayEngine.create_userfunction to provide an init value.
The Union{Missing, T} works, because this is special cased for zero and one in Julia Base.

am = YAXArray([missing 1 ; 1 2])
aeq = am .== am
@test eltype(aeq) == Union{Missing, Bool}
@test ismissing(aeq[1,1])
@test aeq[1,2]
aeq2 = similar(aeq)
aeq2 .= am .== am
@test eltype(aeq2) == Union{Missing, Bool}
@test ismissing(aeq2[1,1])
@test aeq2[2,2]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include("Datasets/datasets.jl")

include("DAT/PickAxisArray.jl")
include("DAT/MovingWindow.jl")
include("DAT/broadcast.jl")
include("DAT/tablestats.jl")
include("DAT/mapcube.jl")
include("DAT/xmap.jl")
Expand Down
Loading