我正在尝试编写一个结构来计算梯度(遵循https://www.youtube.com/watch?v=rZS2LGiurKY)这是我到目前为止所拥有的:
struct GRAD{F <: Array{Float64,2}, ∇F <:Array{Float64,2}}
f::F
∇f::∇F
end
begin
import Base: +,*,-,^,/,convert,promote_rule,size,reshape,promote
# addition rule
+(x::GRAD,y::GRAD) = GRAD(x.f+y.f,x.∇f+y.∇f)
-(x::GRAD,y::GRAD) = GRAD(x.f-y.f,x.∇f-y.∇f)
# multiplying by scalar
*(y::Real,x::GRAD) = GRAD(x.f.*y,x.∇f.*y)
*(x::GRAD,y::Real) = *(y::Real,x::GRAD)
# product rule
*(x::GRAD,y::GRAD) = GRAD(x.f.*y.f,x.f.*y.∇f+ x.∇f.*y.f)
convert(::Type{GRAD},x::Array) = GRAD(x,zero(x))
size(x::GRAD) = size(x.f)
Base.promote_rule(::Type{GRAD{F,∇F}}, x::Type{<:Array}) = GRAD # bug is here!!
end
A = rand(5,5)
r = rand(5,1)
b = rand(5,1)
g = GRAD(r, zeros(5,1) + [1 for i=1:5])
我想计算A*g
(应该是A*ones()
)的梯度,但是当我这样做时
> A*g
MethodError: no method matching *(::Array{Float64,2}, ::Main.workspace2861.GRAD{Array{Float64,2},Array{Float64,2}})
Closest candidates are:
*(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:538
*(!Matched::Real, ::Main.workspace2861.GRAD) at /var/folders/2s/p1vy6rx91lsfh9ltgzz6j_lmb6r7gr/T/Unexpected invention.jl#==#c23631c4-0646-11eb-13be-3b5fa3514823:6
*(::Union{StridedArray{T, 2}, LinearAlgebra.Adjoint{var"#s828",var"#s827"} where var"#s827"<:Union{StridedArray{T, 2}, LinearAlgebra.LowerTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UnitLowerTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UnitUpperTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UpperTriangular{T,S} where S<:AbstractArray{T,2}} where var"#s828", LinearAlgebra.LowerTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.Transpose{var"#s826",var"#s825"} where var"#s825"<:Union{StridedArray{T, 2}, LinearAlgebra.LowerTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UnitLowerTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UnitUpperTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UpperTriangular{T,S} where S<:AbstractArray{T,2}} where var"#s826", LinearAlgebra.UnitLowerTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UnitUpperTriangular{T,S} where S<:AbstractArray{T,2}, LinearAlgebra.UpperTriangular{T,S} where S<:AbstractArray{T,2}} where T, !Matched::LinearAlgebra.Adjoint{var"#s828",var"#s827"} where var"#s827"<:SparseArrays.AbstractSparseMatrixCSC where var"#s828") at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/SparseArrays/src/linalg.jl:147
但convert(GRAD, A) * g
我得到了正确的结果。
我究竟做错了什么?