@@ -356,14 +356,43 @@ struct Dot{StyleA,StyleB,ATyp,BTyp}
356
356
B:: BTyp
357
357
end
358
358
359
- @inline Dot (A:: ATyp ,B:: BTyp ) where {ATyp,BTyp} = Dot {typeof(MemoryLayout(ATyp)), typeof(MemoryLayout(BTyp)), ATyp, BTyp} (A, B)
359
+ """
360
+ Dotu(A, B)
361
+
362
+ is a lazy version of `BLAS.dotu(A, B)`, designed to support
363
+ materializing based on `MemoryLayout`.
364
+ """
365
+ struct Dotu{StyleA,StyleB,ATyp,BTyp}
366
+ A:: ATyp
367
+ B:: BTyp
368
+ end
369
+
370
+
371
+
372
+ for Dt in (:Dot , :Dotu )
373
+ @eval begin
374
+ @inline $ Dt (A:: ATyp ,B:: BTyp ) where {ATyp,BTyp} = $ Dt {typeof(MemoryLayout(ATyp)), typeof(MemoryLayout(BTyp)), ATyp, BTyp} (A, B)
375
+ @inline materialize (d:: $Dt ) = copy (instantiate (d))
376
+ @inline eltype (D:: $Dt ) = promote_type (eltype (D. A), eltype (D. B))
377
+ end
378
+ end
360
379
@inline copy (d:: Dot ) = invoke (LinearAlgebra. dot, Tuple{AbstractArray,AbstractArray}, d. A, d. B)
361
- @inline materialize (d:: Dot ) = copy (instantiate (d))
380
+ @inline copy (d:: Dotu{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractVector{T},<:AbstractVector{T}} ) where T <: BlasComplex = BLAS. dotu (d. A, d. B)
381
+ @inline copy (d:: Dotu{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractVector{T},<:AbstractVector{T}} ) where T <: BlasReal = BLAS. dot (d. A, d. B)
382
+ @inline copy (d:: Dotu ) = LinearAlgebra. _dot_nonrecursive (d. A, d. B)
383
+
362
384
@inline Dot (M:: Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector} ) = Dot (M. A' , M. B)
363
- @inline mulreduce (M:: Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector} ) = Dot (M)
364
- @inline eltype (D:: Dot ) = promote_type (eltype (D. A), eltype (D. B))
385
+ @inline Dotu (M:: Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector} ) = Dotu (transpose (M. A), M. B)
386
+
387
+ @inline _dot_or_dotu (:: typeof (transpose), :: Type{<:Complex} , M) = Dotu (M)
388
+ @inline _dot_or_dotu (_, _, M) = Dot (M)
389
+ @inline mulreduce (M:: Mul{<:DualLayout,<:Any,<:AbstractMatrix,<:AbstractVector} ) = _dot_or_dotu (dualadjoint (M. A), eltype (M. A), M)
390
+
391
+
365
392
366
393
dot (a, b) = materialize (Dot (a, b))
394
+ dotu (a, b) = materialize (Dotu (a, b))
395
+
367
396
@inline LinearAlgebra. dot (a:: LayoutArray , b:: LayoutArray ) = dot (a,b)
368
397
@inline LinearAlgebra. dot (a:: LayoutArray , b:: AbstractArray ) = dot (a,b)
369
398
@inline LinearAlgebra. dot (a:: AbstractArray , b:: LayoutArray ) = dot (a,b)
0 commit comments