Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.1"
version = "0.31.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 6 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns)
end

vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn)

Check warning on line 182 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
return vector_getranges(vi.varinfo, vns)
end

function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
end
Expand Down
84 changes: 79 additions & 5 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@
end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)

"""
vector_length(varinfo::VarInfo)

Return the length of the vector representation of `varinfo`.
"""
vector_length(varinfo::VarInfo) = length(varinfo.metadata)
vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata)

Check warning on line 211 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L210-L211

Added lines #L210 - L211 were not covered by tests
vector_length(md::Metadata) = sum(length, md.ranges)

unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x)

# TODO: deprecate.
Expand Down Expand Up @@ -626,7 +635,72 @@
Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
"""
function getranges(vi::VarInfo, vns::Vector{<:VarName})
return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[])
return map(Base.Fix1(getrange, vi), vns)

Check warning on line 638 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L638

Added line #L638 was not covered by tests
end

"""
vector_getrange(varinfo::VarInfo, varname::VarName)

Return the range corresponding to `varname` in the vector representation of `varinfo`.
"""
vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn)
function vector_getrange(vi::TypedVarInfo, vn::VarName)
offset = 0
for md in values(vi.metadata)

Check warning on line 649 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L647-L649

Added lines #L647 - L649 were not covered by tests
# First, we need to check if `vn` is in `md`.
# In this case, we can just return the corresponding range + offset.
haskey(md, vn) && return getrange(md, vn) .+ offset

Check warning on line 652 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L652

Added line #L652 was not covered by tests
# Otherwise, we need to get the cumulative length of the ranges in `md`
# and add it to the offset.
offset += sum(length, md.ranges)
end

Check warning on line 656 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L655-L656

Added lines #L655 - L656 were not covered by tests
# If we reach this point, `vn` is not in `vi.metadata`.
throw(KeyError(vn))

Check warning on line 658 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L658

Added line #L658 was not covered by tests
end

"""
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})

Return the range corresponding to `varname` in the vector representation of `varinfo`.
"""
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
return map(Base.Fix1(vector_getrange, varinfo), varname)
end
# Specialized version for `TypedVarInfo`.
function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName})
# TODO: Does it help if we _don't_ convert to a vector here?
metadatas = collect(values(varinfo.metadata))
# Extract the offsets.
offsets = cumsum(map(vector_length, metadatas))
# Extract the ranges from each metadata.
ranges = Vector{UnitRange{Int}}(undef, length(vns))
# Need to keep track of which ones we've seen.
not_seen = fill(true, length(vns))
for (i, metadata) in enumerate(metadatas)
vns_metadata = filter(Base.Fix1(haskey, metadata), vns)
# If none of the variables exist in the metadata, we return an empty array.
isempty(vns_metadata) && continue
# Otherwise, we extract the ranges.
offset = i == 1 ? 0 : offsets[i - 1]
for vn in vns_metadata
r_vn = getrange(metadata, vn)
# Get the index, so we return in the same order as `vns`.
# NOTE: There might be duplicates in `vns`, so we need to handle that.
indices = findall(==(vn), vns)
for idx in indices
not_seen[idx] = false
ranges[idx] = r_vn .+ offset
end
end
end
# Raise key error if any of the variables were not found.
if any(not_seen)
inds = findall(not_seen)
# Just use a `convert` to get the same type as the input; don't want to confuse by overly
# specilizing the types in the error message.
throw(KeyError(convert(typeof(vns), vns[inds])))
end
return ranges
end

"""
Expand Down Expand Up @@ -1314,13 +1388,13 @@

function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f)
# TODO: Use inplace versions to avoid allocations
yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn))
# Determine the new range.
start = first(getrange(vi, vn))
start = first(getrange(md, vn))
# NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`.
setrange!(vi, vn, start:(start + length(yvec) - 1))
setrange!(md, vn, start:(start + length(yvec) - 1))
# Set the new value.
setval!(vi, yvec, vn)
setval!(md, yvec, vn)
acclogp!!(vi, -logjac)
return vi
end
Expand Down
2 changes: 2 additions & 0 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,8 @@ function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {spa
return replace_raw_storage(vnv, vals)
end

vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv)

"""
unflatten(vnv::VarNamedVector, vals::AbstractVector)

Expand Down
42 changes: 42 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -813,4 +813,46 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test DynamicPPL.istrans(varinfo2, vn)
end
end

# NOTE: It is not yet clear if this is something we want from all varinfo types.
# Hence, we only test the `VarInfo` types here.
@testset "vector_getranges for `VarInfo`" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
nt = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, nt, vns; include_threadsafe=true
)
# Only keep `VarInfo` types.
varinfos = filter(
Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos
)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
x = values_as(varinfo, Vector)

# Let's just check all the subsets of `vns`.
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
combinations(vns)
ranges = DynamicPPL.vector_getranges(varinfo, vns_subset)
@test length(ranges) == length(vns_subset)
for (r, vn) in zip(ranges, vns_subset)
@test x[r] == DynamicPPL.tovec(varinfo[vn])
end
end

# Let's try some failure cases.
@test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[]
# Non-existent variables.
@test_throws KeyError DynamicPPL.vector_getranges(
varinfo, [VarName{gensym("vn")}()]
)
@test_throws KeyError DynamicPPL.vector_getranges(
varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()]
)
# Duplicate variables.
ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2))
@test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2)
end
end
end
end
Loading