Skip to content

Commit 424eff9

Browse files
authored
Add Dotu (#174)
* Add Dotu * Update mul.jl * Update test_muladd.jl
1 parent d37a626 commit 424eff9

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.3.1"
4+
version = "1.4"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/mul.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,43 @@ struct Dot{StyleA,StyleB,ATyp,BTyp}
356356
B::BTyp
357357
end
358358

359-
@inline Dot(A::ATyp,B::BTyp) where {ATyp,BTyp} = Dot{typeof(MemoryLayout(ATyp)), typeof(MemoryLayout(BTyp)), ATyp, BTyp}(A, B)
359+
"""
360+
Dotu(A, B)
361+
362+
is a lazy version of `BLAS.dotu(A, B)`, designed to support
363+
materializing based on `MemoryLayout`.
364+
"""
365+
struct Dotu{StyleA,StyleB,ATyp,BTyp}
366+
A::ATyp
367+
B::BTyp
368+
end
369+
370+
371+
372+
for Dt in (:Dot, :Dotu)
373+
@eval begin
374+
@inline $Dt(A::ATyp,B::BTyp) where {ATyp,BTyp} = $Dt{typeof(MemoryLayout(ATyp)), typeof(MemoryLayout(BTyp)), ATyp, BTyp}(A, B)
375+
@inline materialize(d::$Dt) = copy(instantiate(d))
376+
@inline eltype(D::$Dt) = promote_type(eltype(D.A), eltype(D.B))
377+
end
378+
end
360379
@inline copy(d::Dot) = invoke(LinearAlgebra.dot, Tuple{AbstractArray,AbstractArray}, d.A, d.B)
361-
@inline materialize(d::Dot) = copy(instantiate(d))
380+
@inline copy(d::Dotu{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractVector{T},<:AbstractVector{T}}) where T <: BlasComplex = BLAS.dotu(d.A, d.B)
381+
@inline copy(d::Dotu{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractVector{T},<:AbstractVector{T}}) where T <: BlasReal = BLAS.dot(d.A, d.B)
382+
@inline copy(d::Dotu) = LinearAlgebra._dot_nonrecursive(d.A, d.B)
383+
362384
@inline Dot(M::Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector}) = Dot(M.A', M.B)
363-
@inline mulreduce(M::Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector}) = Dot(M)
364-
@inline eltype(D::Dot) = promote_type(eltype(D.A), eltype(D.B))
385+
@inline Dotu(M::Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector}) = Dotu(transpose(M.A), M.B)
386+
387+
@inline _dot_or_dotu(::typeof(transpose), ::Type{<:Complex}, M) = Dotu(M)
388+
@inline _dot_or_dotu(_, _, M) = Dot(M)
389+
@inline mulreduce(M::Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector}) = _dot_or_dotu(dualadjoint(M.A), eltype(M.A), M)
390+
391+
365392

366393
dot(a, b) = materialize(Dot(a, b))
394+
dotu(a, b) = materialize(Dotu(a, b))
395+
367396
@inline LinearAlgebra.dot(a::LayoutArray, b::LayoutArray) = dot(a,b)
368397
@inline LinearAlgebra.dot(a::LayoutArray, b::AbstractArray) = dot(a,b)
369398
@inline LinearAlgebra.dot(a::AbstractArray, b::LayoutArray) = dot(a,b)

test/test_muladd.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,12 +675,28 @@ Random.seed!(0)
675675
@test M[1,1] M[CartesianIndex(1,1)] M[1] (b*b')[1,1]
676676
end
677677

678-
@testset "Dot" begin
678+
@testset "Dot/Dotu" begin
679679
a = randn(5)
680680
b = randn(5)
681-
@test ArrayLayouts.dot(a,b) == mul(a',b)
681+
c = randn(5) + im*randn(5)
682+
d = randn(5) + im*randn(5)
683+
684+
@test ArrayLayouts.dot(a,b) == ArrayLayouts.dotu(a,b) == mul(a',b)
682685
@test ArrayLayouts.dot(a,b) dot(a,b)
683686
@test eltype(Dot(a,1:5)) == Float64
687+
688+
@test ArrayLayouts.dot(c,d) == mul(c',d)
689+
@test ArrayLayouts.dotu(c,d) == mul(transpose(c),d)
690+
@test ArrayLayouts.dot(c,d) dot(c,d)
691+
@test ArrayLayouts.dotu(c,d) BLAS.dotu(c,d)
692+
693+
@test ArrayLayouts.dot(c,b) == mul(c',b)
694+
@test ArrayLayouts.dotu(c,b) == mul(transpose(c),b)
695+
@test ArrayLayouts.dot(c,b) dot(c,b)
696+
697+
@test ArrayLayouts.dot(a,d) == mul(a',d)
698+
@test ArrayLayouts.dotu(a,d) == mul(transpose(a),d)
699+
@test ArrayLayouts.dot(a,d) dot(a,d)
684700
end
685701

686702
@testset "adjtrans muladd" begin

0 commit comments

Comments
 (0)