Skip to content

Commit 3ce473d

Browse files
committed
implemented static map, reduce, etc, fixed #6
1 parent 8f60f51 commit 3ce473d

16 files changed

+1369
-907
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractTensors"
22
uuid = "a8e43f4a-99b7-5565-8bf1-0165161caaea"
33
authors = ["Michael Reed"]
4-
version = "0.6.4"
4+
version = "0.6.5"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/FixedVector.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
struct FixedVector{N,T} <: TupleVector{N,T}
3+
v::Vector{T}
4+
function FixedVector{N,T}(a::Vector) where {N,T}
5+
if length(a) != N
6+
throw(DimensionMismatch("Dimensions $(size(a)) don't match static size $S"))
7+
end
8+
new{N,T}(a)
9+
end
10+
function FixedVector{N,T}(::UndefInitializer) where {N,T}
11+
new{N,T}(Vector{T}(undef,N))
12+
end
13+
end
14+
15+
@inline FixedVector{N}(a::Vector{T}) where {N,T} = FixedVector{N,T}(a)
16+
17+
@generated function FixedVector{N,T}(x::NTuple{N,Any}) where {N,T}
18+
exprs = [:(a[$i] = x[$i]) for i = 1:N]
19+
return quote
20+
$(Expr(:meta, :inline))
21+
a = FixedVector{N,T}(undef)
22+
@inbounds $(Expr(:block, exprs...))
23+
return a
24+
end
25+
end
26+
27+
@inline FixedVector{N,T}(x::Tuple) where {N,T} = FixedVector{N,T}(x)
28+
@inline FixedVector{N}(x::NTuple{N,T}) where {N,T} = FixedVector{N,T}(x)
29+
30+
# Overide some problematic default behaviour
31+
@inline Base.convert(::Type{SA}, sa::FixedVector) where {SA<:FixedVector} = SA(sa.v)
32+
@inline Base.convert(::Type{SA}, sa::SA) where {SA<:FixedVector} = sa
33+
34+
# Back to Array (unfortunately need both convert and construct to overide other methods)
35+
@inline Base.Array(sa::FixedVector) = Vector(sa.v)
36+
@inline Base.Array{T}(sa::FixedVector{N,T}) where {N,T} = Vector{T}(sa.v)
37+
@inline Base.Array{T,1}(sa::FixedVector{N,T}) where {N,T} = Vector{T}(sa.v)
38+
39+
@inline Base.convert(::Type{Array}, sa::FixedVector) = sa.v
40+
@inline Base.convert(::Type{Array{T}}, sa::FixedVector{N,T}) where {N,T} = sa.v
41+
@inline Base.convert(::Type{Array{T,1}}, sa::FixedVector{N,T}) where {N,T} = sa.v
42+
43+
@propagate_inbounds Base.getindex(a::FixedVector, i::Int) = getindex(a.v, i)
44+
@propagate_inbounds Base.setindex!(a::FixedVector, v, i::Int) = setindex!(a.v, v, i)
45+
46+
Base.dataids(sa::FixedVector) = Base.dataids(sa.v)
47+
48+
function Base.promote_rule(::Type{<:FixedVector{N,T}}, ::Type{<:FixedVector{N,U}}) where {N,T,U}
49+
FixedVector{N,promote_type(T,U)}
50+
end

src/SOneTo.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
struct SOneTo{n} <: AbstractUnitRange{Int} end
3+
4+
@pure SOneTo(n::Int) = SOneTo{n}()
5+
function SOneTo{n}(r::AbstractUnitRange) where n
6+
((first(r) == 1) & (last(r) == n)) && return SOneTo{n}()
7+
8+
errmsg(r) = throw(DimensionMismatch("$r is inconsistent with SOneTo{$n}")) # avoid GC frame
9+
errmsg(r)
10+
end
11+
Base.Tuple(::SOneTo{N}) where N = ntuple(identity, Val(N))
12+
13+
@pure Base.axes(s::SOneTo) = (s,)
14+
@pure Base.size(s::SOneTo{n}) where n = (n,)
15+
@pure Base.length(s::SOneTo{n}) where n = n
16+
17+
# The axes of a Slice'd SOneTo use the SOneTo itself
18+
Base.axes(S::Base.Slice{<:SOneTo}) = (S.indices,)
19+
Base.unsafe_indices(S::Base.Slice{<:SOneTo}) = (S.indices,)
20+
Base.axes1(S::Base.Slice{<:SOneTo}) = S.indices
21+
22+
@propagate_inbounds function Base.getindex(s::SOneTo, i::Int)
23+
@boundscheck checkbounds(s, i)
24+
return i
25+
end
26+
@propagate_inbounds function Base.getindex(s::SOneTo, s2::SOneTo)
27+
@boundscheck checkbounds(s, s2)
28+
return s2
29+
end
30+
31+
@pure Base.first(::SOneTo) = 1
32+
@pure Base.last(::SOneTo{n}) where n = n::Int
33+
@pure Base.iterate(::SOneTo{n}) where n = n::Int < 1 ? nothing : (1, 1)
34+
@pure function Base.iterate(::SOneTo{n}, s::Int) where {n}
35+
if s < n::Int
36+
s2 = s + 1
37+
return (s2, s2)
38+
else
39+
return nothing
40+
end
41+
end
42+
43+
function Base.getproperty(::SOneTo{n}, s::Symbol) where {n}
44+
if s === :start
45+
return 1
46+
elseif s === :stop
47+
return n::Int
48+
else
49+
error("type SOneTo has no property $s")
50+
end
51+
end
52+
53+
Base.show(io::IO, ::SOneTo{n}) where {n} = print(io, "SOneTo(", n::Int, ")")
54+
Base.@pure function Base.checkindex(::Type{Bool}, ::SOneTo{n1}, ::SOneTo{n2}) where {n1, n2}
55+
return n1::Int >= n2::Int
56+
end
57+
58+
Base.promote_rule(a::Type{Base.OneTo{T}}, ::Type{SOneTo{n}}) where {T,n} =
59+
Base.OneTo{promote_type(T, Int)}

src/Values.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
# SArray.jl
3+
4+
struct Values{N,T} <: TupleVector{N,T}
5+
v::NTuple{N,T}
6+
Values{N,T}(x::NTuple{N,T}) where {N,T} = new{N,T}(x)
7+
Values{N,T}(x::NTuple{N,Any}) where {N,T} = new{N,T}(convert_ntuple(T, x))
8+
end
9+
10+
@pure @generated function (::Type{Values{N,T}})(x::Tuple) where {T, N}
11+
return quote
12+
@_inline_meta
13+
Values{N,T}(x)
14+
end
15+
end
16+
17+
@inline Values(a::TupleVector{N}) where N = Values{N}(Tuple(a))
18+
@propagate_inbounds Base.getindex(v::Values, i::Int) = v.v[i]
19+
@inline Tuple(v::Values) = v.v
20+
Base.dataids(::Values) = ()
21+
22+
# See #53
23+
Base.cconvert(::Type{Ptr{T}}, a::Values) where {T} = Base.RefValue(a)
24+
Base.unsafe_convert(::Type{Ptr{T}}, a::Base.RefValue{SA}) where {N,T,SA<:Values{N,T}} = Ptr{T}(Base.unsafe_convert(Ptr{Values{N,T}}, a))
25+
26+
# SVector.jl
27+
28+
@inline Values(x::NTuple{N,Any}) where N = Values{N}(x)
29+
@inline Values{N}(x::NTuple{N,T}) where {N,T} = Values{N,T}(x)
30+
@inline Values{N}(x::T) where {N,T<:Tuple} = Values{N,promote_tuple_eltype(T)}(x)
31+
32+
# Some more advanced constructor-like functions
33+
@pure @inline Base.zeros(::Type{Values{N}}) where N = zeros(Values{N,Float64})
34+
@pure @inline Base.ones(::Type{Values{N}}) where N = ones(Values{N,Float64})
35+
36+
# Converting a CartesianIndex to an SVector
37+
Base.convert(::Type{Values}, I::CartesianIndex) = Values(I.I)
38+
Base.convert(::Type{Values{N}}, I::CartesianIndex{N}) where {N} = Values{N}(I.I)
39+
Base.convert(::Type{Values{N,T}}, I::CartesianIndex{N}) where {N,T} = Values{N,T}(I.I)
40+
41+
@pure Base.promote_rule(::Type{Values{N,T}}, ::Type{CartesianIndex{N}}) where {N,T} = Values{N,promote_type(T,Int)}
42+

src/Variables.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
# MArray.jl
3+
4+
mutable struct Variables{N,T} <: TupleVector{N,T}
5+
v::NTuple{N,T}
6+
Variables{N,T}(x::NTuple{N,T}) where {N,T} = new{N,T}(x)
7+
Variables{N,T}(x::NTuple{N,Any}) where {N,T} = new{N,T}(convert_ntuple(T, x))
8+
Variables{N,T}(::UndefInitializer) where {N,T} = new{N,T}()
9+
end
10+
11+
@inline Variables(a::TupleVector{N}) where N = Variables{N}(Tuple(a))
12+
@generated function (::Type{Variables{N,T}})(x::Tuple) where {N,T}
13+
return quote
14+
$(Expr(:meta, :inline))
15+
Variables{N,T}(x)
16+
end
17+
end
18+
@generated function (::Type{Variables{N}})(x::T) where {N,T<:Tuple}
19+
return quote
20+
$(Expr(:meta, :inline))
21+
Variables{N,promote_tuple_eltype(T)}(x)
22+
end
23+
end
24+
25+
@propagate_inbounds function Base.getindex(v::Variables, i::Int)
26+
@boundscheck checkbounds(v,i)
27+
T = eltype(v)
28+
if isbitstype(T)
29+
return GC.@preserve v unsafe_load(Base.unsafe_convert(Ptr{T}, pointer_from_objref(v)), i)
30+
end
31+
v.v[i]
32+
end
33+
@propagate_inbounds function Base.setindex!(v::Variables, val, i::Int)
34+
@boundscheck checkbounds(v,i)
35+
T = eltype(v)
36+
if isbitstype(T)
37+
GC.@preserve v unsafe_store!(Base.unsafe_convert(Ptr{T}, pointer_from_objref(v)), convert(T, val), i)
38+
else
39+
# This one is unsafe (#27)
40+
# unsafe_store!(Base.unsafe_convert(Ptr{Ptr{Nothing}}, pointer_from_objref(v.data)), pointer_from_objref(val), i)
41+
error("setindex!() with non-isbitstype eltype is not supported by TupleVectors. Consider using FixedVector.")
42+
end
43+
return val
44+
end
45+
46+
@inline Base.Tuple(v::Variables) = v.v
47+
Base.dataids(ma::Variables) = (UInt(pointer(ma)),)
48+
49+
@inline function Base.unsafe_convert(::Type{Ptr{T}}, a::Variables{N,T}) where {N,T}
50+
Base.unsafe_convert(Ptr{T}, pointer_from_objref(a))
51+
end
52+
53+
function Base.promote_rule(::Type{<:Variables{N,T}}, ::Type{<:Variables{N,U}}) where {N,T,U}
54+
Variables{N,promote_type(T,U)}
55+
end
56+
57+
# MVector.jl
58+
59+
@inline Variables(x::NTuple{N,Any}) where N = Variables{N}(x)
60+
@inline Variables{N}(x::NTuple{N,T}) where {N,T} = Variables{N,T}(x)
61+
@inline Variables{N}(x::NTuple{N,Any}) where N = Variables{N, promote_tuple_eltype(typeof(x))}(x)
62+
63+
# Some more advanced constructor-like functions
64+
@inline Base.zeros(::Type{Variables{N}}) where N = zeros(Variables{N,Float64})
65+
@inline Base.ones(::Type{Variables{N}}) where N = ones(Variables{N,Float64})

src/abstractvector.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
2+
Base.axes(::TupleVector{N}) where N = _axes(Val(N))
3+
@pure function _axes(::Val{sizes}) where {sizes}
4+
map(SOneTo, (sizes,))
5+
end
6+
Base.axes(rv::LinearAlgebra.Adjoint{<:Any,<:Values}) = (SOneTo(1), axes(rv.parent)...)
7+
Base.axes(rv::LinearAlgebra.Transpose{<:Any,<:Values}) = (SOneTo(1), axes(rv.parent)...)
8+
9+
# Base.strides is intentionally not defined for SArray, see PR #658 for discussion
10+
Base.strides(a::Variables) = Base.size_to_strides(1, size(a)...)
11+
Base.strides(a::FixedVector) = strides(a.v)
12+
13+
Base.IndexStyle(::Type{T}) where {T<:TupleVector} = IndexLinear()
14+
15+
similar_type(::SA) where {SA<:TupleVector} = similar_type(SA,eltype(SA))
16+
similar_type(::Type{SA}) where {SA<:TupleVector} = similar_type(SA,eltype(SA))
17+
18+
similar_type(::SA,::Type{T}) where {SA<:TupleVector{N},T} where N = similar_type(SA,T,Val(N))
19+
similar_type(::Type{SA},::Type{T}) where {SA<:TupleVector{N},T} where N = similar_type(SA,T,Val(N))
20+
21+
similar_type(::A,n::Val) where {A<:AbstractArray} = similar_type(A,eltype(A),n)
22+
similar_type(::Type{A},n::Val) where {A<:AbstractArray} = similar_type(A,eltype(A),n)
23+
24+
similar_type(::A,::Type{T},n::Val) where {A<:AbstractArray,T} = similar_type(A,T,n)
25+
26+
# We should be able to deal with SOneTo axes
27+
similar_type(s::SOneTo) = similar_type(typeof(s))
28+
similar_type(::Type{SOneTo{n}}) where n = similar_type(SOneTo{n}, Int, Val(n))
29+
30+
# Default types
31+
# Generally, use SArray
32+
similar_type(::Type{A},::Type{T},n::Val) where {A<:AbstractArray,T} = default_similar_type(T,n)
33+
default_similar_type(::Type{T},::Val{N}) where {N,T} = Values{N,T}
34+
35+
similar_type(::Type{SA},::Type{T},n::Val) where {SA<:Variables,T} = mutable_similar_type(T,n)
36+
37+
mutable_similar_type(::Type{T},::Val{N}) where {N,T} = Variables{N,T}
38+
39+
similar_type(::Type{<:FixedVector},::Type{T},n::Val) where T = sizedarray_similar_type(T,n)
40+
# Should FixedVector also be used for normal Array?
41+
#similar_type(::Type{<:Array},::Type{T},n::Val) where T = sizedarray_similar_type(T,n)
42+
43+
sizedarray_similar_type(::Type{T},::Val{N}) where {N,T} = FixedVector{N,T}
44+
45+
Base.similar(::SA) where {SA<:TupleVector} = similar(SA,eltype(SA))
46+
Base.similar(::Type{SA}) where {SA<:TupleVector} = similar(SA,eltype(SA))
47+
48+
Base.similar(::SA,::Type{T}) where {SA<:TupleVector{N},T} where N = similar(SA,T,Val(N))
49+
Base.similar(::Type{SA},::Type{T}) where {SA<:TupleVector{N},T} where N = similar(SA,T,Val(N))
50+
51+
# Cases where a Val is given as the dimensions
52+
Base.similar(::A,n::Val) where A<:AbstractArray = similar(A,eltype(A),n)
53+
Base.similar(::Type{A},n::Val) where A<:AbstractArray = similar(A,eltype(A),n)
54+
55+
Base.similar(::A,::Type{T},n::Val) where {A<:AbstractArray,T} = similar(A,T,n)
56+
57+
# defaults to built-in mutable types
58+
Base.similar(::Type{A},::Type{T},n::Val) where {A<:AbstractArray,T} = mutable_similar_type(T,n)(undef)
59+
60+
# both FixedVector and Array return FixedVector
61+
Base.similar(::Type{SA},::Type{T},n::Val) where {SA<:FixedVector,T} = sizedarray_similar_type(T,n)(undef)
62+
Base.similar(::Type{A},::Type{T},n::Val) where {A<:Array,T} = sizedarray_similar_type(T,n)(undef)
63+
64+
# Support tuples of mixtures of `SOneTo`s alongside the normal `Integer` and `OneTo` options
65+
# by simply converting them to either a tuple of Ints or a Val, re-dispatching to either one
66+
# of the above methods (in the case of Val) or a base fallback (in the case of Ints).
67+
const HeterogeneousShape = Union{Integer, Base.OneTo, SOneTo}
68+
69+
Base.similar(A::AbstractArray, ::Type{T}, shape::Tuple{HeterogeneousShape, Vararg{HeterogeneousShape}}) where {T} = similar(A, T, homogenize_shape(shape))
70+
Base.similar(::Type{A}, shape::Tuple{HeterogeneousShape, Vararg{HeterogeneousShape}}) where {A<:AbstractArray} = similar(A, homogenize_shape(shape))
71+
# Use an Array for TupleVectors if we don't have a statically-known size
72+
Base.similar(::Type{A}, shape::Tuple{Int, Vararg{Int}}) where {A<:TupleVector} = Array{eltype(A)}(undef, shape)
73+
74+
homogenize_shape(::Tuple{}) = ()
75+
homogenize_shape(shape::Tuple{Vararg{SOneTo}}) = Val(prod(map(last, shape)))
76+
homogenize_shape(shape::Tuple{Vararg{HeterogeneousShape}}) = map(last, shape)
77+
78+
79+
@inline Base.copy(a::TupleVector) = typeof(a)(Tuple(a))
80+
@inline Base.copy(a::FixedVector) = typeof(a)(copy(a.v))
81+
82+
@inline Base.reverse(v::Values) = typeof(v)(_reverse(v))
83+
84+
@generated function _reverse(v::Values{N,T}) where {N,T}
85+
return Expr(:tuple, (:(v[$i]) for i = N:(-1):1)...)
86+
end
87+
88+
#--------------------------------------------------
89+
# Concatenation
90+
@inline Base.vcat(a::TupleVectorLike) = a
91+
@inline Base.vcat(a::TupleVectorLike{N}, b::TupleVectorLike{M}) where {N,M} = _vcat(Val(N), Val(M), a, b)
92+
@inline Base.vcat(a::TupleVectorLike, b::TupleVectorLike, c::TupleVectorLike...) = vcat(vcat(a,b), vcat(c...))
93+
94+
@generated function _vcat(::Val{Sa}, ::Val{Sb}, a::TupleVectorLike, b::TupleVectorLike) where {Sa, Sb}
95+
96+
# TODO cleanup?
97+
Snew = Sa + Sb
98+
exprs = vcat([:(a[$i]) for i = 1:Sa],
99+
[:(b[$i]) for i = 1:Sb])
100+
return quote
101+
@_inline_meta
102+
@inbounds return similar_type(a, promote_type(eltype(a), eltype(b)), Val($Snew))(tuple($(exprs...)))
103+
end
104+
end
105+
106+
#=@inline hcat(a::StaticVector) = similar_type(a, Size(Size(a)[1],1))(a)
107+
@inline hcat(a::StaticMatrixLike) = a
108+
@inline hcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike) = _hcat(Size(a), Size(b), a, b)
109+
@inline hcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike, c::StaticVecOrMatLike...) = hcat(hcat(a,b), hcat(c...))
110+
111+
@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMatLike, b::StaticVecOrMatLike) where {Sa, Sb}
112+
if Sa[1] != Sb[1]
113+
return :(throw(DimensionMismatch("Tried to hcat arrays of size $Sa and $Sb")))
114+
end
115+
116+
exprs = vcat([:(a[$i]) for i = 1:prod(Sa)],
117+
[:(b[$i]) for i = 1:prod(Sb)])
118+
119+
Snew = (Sa[1], Size(Sa)[2] + Size(Sb)[2])
120+
121+
return quote
122+
@_inline_meta
123+
@inbounds return similar_type(a, promote_type(eltype(a), eltype(b)), Size($Snew))(tuple($(exprs...)))
124+
end
125+
end=#

0 commit comments

Comments
 (0)