Skip to content

Commit 66bd6b1

Browse files
authored
Merge pull request #97 from SymbolicML/fix-map-part-2
More improvements to `map` and `similar` on `QuantityArray`
2 parents 8745a58 + 2c64dc4 commit 66bd6b1

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/arrays.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ end
196196

197197
Base.similar(A::QuantityArray) = QuantityArray(similar(ustrip(A)), dimension(A), quantity_type(A))
198198
Base.similar(A::QuantityArray, ::Type{S}) where {S} = QuantityArray(similar(ustrip(A), S), dimension(A), quantity_type(A))
199+
for (type, _, _) in ABSTRACT_QUANTITY_TYPES
200+
@eval Base.similar(A::QuantityArray, ::Type{S}) where {S<:$type} = QuantityArray(similar(ustrip(A), value_type(S)), dimension(A), S)
201+
end
199202

200203
# Unfortunately this mess of `similar` is required to avoid ambiguous methods.
201204
# c.f. base/abstractarray.jl
@@ -211,8 +214,10 @@ end
211214

212215
# `_similar_for` in Base does not account for changed dimensions, so
213216
# we need to overload it for QuantityArray.
214-
Base._similar_for(c::QuantityArray, ::Type{T}, itr, ::Base.HasShape, axs) where {T} =
217+
Base._similar_for(c::QuantityArray, ::Type{T}, itr, ::Base.HasShape, axs) where {T<:UnionAbstractQuantity} =
215218
QuantityArray(similar(ustrip(c), value_type(T), axs), dimension(materialize_first(itr))::dim_type(T), T)
219+
Base._similar_for(c::QuantityArray, ::Type{T}, itr, ::Base.HasShape, axs) where {T} =
220+
similar(ustrip(c), T, axs)
216221

217222
# These methods are not yet implemented, but the default implementation is dangerous,
218223
# as it may cause a stack overflow, so we raise a more helpful error instead.
@@ -223,8 +228,10 @@ Base._similar_for(::QuantityArray, ::Type{T}, _, ::Base.HasLength, ::Integer) wh
223228

224229
# In earlier Julia, `Base._similar_for` has different signatures.
225230
@static if hasmethod(Base._similar_for, Tuple{Array,Type,Any,Base.HasShape})
226-
@eval Base._similar_for(c::QuantityArray, ::Type{T}, itr, ::Base.HasShape) where {T} =
231+
@eval Base._similar_for(c::QuantityArray, ::Type{T}, itr, ::Base.HasShape) where {T<:UnionAbstractQuantity} =
227232
QuantityArray(similar(ustrip(c), value_type(T), axes(itr)), dimension(materialize_first(itr))::dim_type(T), T)
233+
@eval Base._similar_for(c::QuantityArray, ::Type{T}, itr, ::Base.HasShape) where {T} =
234+
similar(ustrip(c), T, axes(itr))
228235
end
229236
@static if hasmethod(Base._similar_for, Tuple{Array,Type,Any,Base.HasLength})
230237
@eval Base._similar_for(::QuantityArray, ::Type{T}, _, ::Base.HasLength) where {T} =

test/unittests.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,14 @@ end
983983
@test dimension(new_qa) == dimension(qa)
984984
@test isa(ustrip(new_qa), Array{Float32,2})
985985

986+
if Q !== GenericQuantity
987+
new_qa = similar(qa, typeof(GenericQuantity{Float16}(u"km/s")))
988+
@test eltype(new_qa) <: GenericQuantity{Float16}
989+
@test dim_type(new_qa) == dim_type(qa)
990+
@test dimension(new_qa) == dimension(qa)
991+
@test isa(ustrip(new_qa), Array{Float16,2})
992+
end
993+
986994
new_qa = similar(qa, axes(ones(6, 8)))
987995
@test size(new_qa) == (6, 8)
988996
@test eltype(new_qa) <: Q{Float64}
@@ -1104,6 +1112,10 @@ end
11041112
@test prod(qa) == 8.0u"m^3"
11051113
@inferred prod(qa)
11061114

1115+
# Map to non-quantity output:
1116+
@test map(x -> ustrip(x), qa) == fill(2.0, 3)
1117+
@test map(x -> cos(x/dimension(x)), qa) == fill(cos(2.0), 3)
1118+
11071119
# Test that we can use a function that returns a different type
11081120
if Q === RealQuantity
11091121
qa = fill(RealQuantity(2.0u"m"), 3)
@@ -1237,7 +1249,7 @@ end
12371249
# There is no easy way to test whether it actually ran,
12381250
# so we create a fake array type that has a custom `sizehint!`
12391251
# which tells us it actually ran.
1240-
@eval begin
1252+
isdefined(Main, :MyCustomArray) || @eval begin
12411253
mutable struct MyCustomArray{T,N} <: AbstractArray{T,N}
12421254
data::Array{T,N}
12431255
sizehint_called::Bool

0 commit comments

Comments
 (0)