diff --git a/.gitignore b/.gitignore index 0f84bed..86a3ec9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ +coverage/ *.jl.*.cov *.jl.cov *.jl.mem /Manifest.toml + +.DS_Store \ No newline at end of file diff --git a/Project.toml b/Project.toml index efa7a7b..921e065 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" [compat] diff --git a/README.md b/README.md index f650d21..e72d5dc 100644 --- a/README.md +++ b/README.md @@ -6,22 +6,78 @@ ## Design -A typical example of `Trajectory`: +The relationship of several concepts provided in this package: -![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png) +``` +┌───────────────────────────────────┐ +│ Trajectory │ +│ ┌───────────────────────────────┐ │ +│ │ AbstractTraces │ │ +│ │ ┌───────────────┐ │ │ +│ │ :trace_A => │ AbstractTrace │ │ │ +│ │ └───────────────┘ │ │ +│ │ │ │ +│ │ ┌───────────────┐ │ │ +│ │ :trace_B => │ AbstractTrace │ │ │ +│ │ └───────────────┘ │ │ +│ │ ... ... │ │ +│ └───────────────────────────────┘ │ +│ ┌───────────┐ │ +│ │ Sampler │ │ +│ └───────────┘ │ +│ ┌────────────┐ │ +│ │ Controller │ │ +│ └────────────┘ │ +└───────────────────────────────────┘ +``` + +## `Trajectory` + +A `Trajectory` contains 3 parts: -Exported APIs are: +- A `container` to store data. (Usually an `AbstractTraces`) +- A `sampler` to determine how to sample a batch from `container` +- A `controller` to decide when to sample a new batch from the `container` + +Typical usage: ```julia -push!(trajectory; [trace_name=value]...) -append!(trajectory; [trace_name=value]...) +julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3)); + +julia> for i in 1:5 + push!(t, (a=i, b=iseven(i))) + end -for sample in trajectory - # consume samples from the trajectory -end +julia> for batch in t + println(batch) + end +(a = [4, 5, 1], b = Bool[1, 0, 0]) +(a = [3, 2, 4], b = Bool[0, 1, 1]) +(a = [4, 1, 2], b = Bool[1, 0, 1]) ``` -A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc. +**Traces** + +- `Traces` +- `MultiplexTraces` +- `CircularSARTTraces` +- `Episode` +- `Episodes` + +**Samplers** + +- `BatchSampler` +- `MetaSampler` +- `MultiBatchSampler` + +**Controllers** + +- `InsertSampleRatioController` +- `InsertSampleController` +- `AsyncInsertSampleRatioController` + + +Please refer tests for common usage. (TODO: generate docs and add links to above data structures) ## Acknowledgement diff --git a/src/Trajectories.jl b/src/Trajectories.jl index be19971..9aef96d 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,11 +1,11 @@ module Trajectories +include("patch.jl") + +include("traces.jl") include("samplers.jl") include("controllers.jl") -include("traces.jl") -include("episodes.jl") include("trajectory.jl") -include("rendering.jl") include("common/common.jl") end diff --git a/src/common/CircularArraySARTTraces.jl b/src/common/CircularArraySARTTraces.jl index f140b00..f29093a 100644 --- a/src/common/CircularArraySARTTraces.jl +++ b/src/common/CircularArraySARTTraces.jl @@ -1,16 +1,15 @@ export CircularArraySARTTraces const CircularArraySARTTraces = Traces{ - SART, + SSAART, <:Tuple{ + <:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer} } } - function CircularArraySARTTraces(; capacity::Int, state=Int => (), @@ -23,32 +22,10 @@ function CircularArraySARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal + MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( - state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer - action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end - -function Random.rand(s::BatchSampler, t::CircularArraySARTTraces) - inds = rand(s.rng, 1:length(t), s.batch_size) - inds′ = inds .+ 1 - ( - state=t[:state][inds], - action=t[:action][inds], - reward=t[:reward][inds], - terminal=t[:terminal][inds], - next_state=t[:state][inds′], - next_action=t[:state][inds′] - ) |> s.transformer -end - -function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA}) - if length(t[:state]) == length(t[:terminal]) + 1 - pop!(t[:state]) - pop!(t[:action]) - end - push!(t[:state], x[:state]) - push!(t[:action], x[:action]) -end diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 83e5d0d..121168b 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -1,17 +1,16 @@ export CircularArraySLARTTraces const CircularArraySLARTTraces = Traces{ - SLART, + SSLLAART, <:Tuple{ + <:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer}, - <:Trace{<:CircularArrayBuffer} } } - function CircularArraySLARTTraces(; capacity::Int, state=Int => (), @@ -26,37 +25,11 @@ function CircularArraySLARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal + MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( - state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer - legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer - action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) -end - -function sample(s::BatchSampler, t::CircularArraySLARTTraces) - inds = rand(s.rng, 1:length(t), s.batch_size) - inds′ = inds .+ 1 - ( - state=t[:state][inds], - legal_actions_mask=t[:legal_actions_mask][inds], - action=t[:action][inds], - reward=t[:reward][inds], - terminal=t[:terminal][inds], - next_state=t[:state][inds′], - next_legal_actions_mask=t[:legal_actions_mask][inds′], - next_action=t[:state][inds′] - ) |> s.transformer -end - -function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA}) - if length(t[:state]) == length(t[:terminal]) + 1 - pop!(t[:state]) - pop!(t[:legal_actions_mask]) - pop!(t[:action]) - end - push!(t[:state], x[:state]) - push!(t[:legal_actions_mask], x[:legal_actions_mask]) - push!(t[:action], x[:action]) -end +end \ No newline at end of file diff --git a/src/common/common.jl b/src/common/common.jl index 271b149..ef647dc 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -1,12 +1,11 @@ using CircularArrayBuffers -const SA = (:state, :action) -const SLA = (:state, :legal_actions_mask, :action) +const SS = (:state, :next_state) +const LL = (:legal_actions_mask, :next_legal_actions_mask) +const AA = (:action, :next_action) const RT = (:reward, :terminal) -const SART = (:state, :action, :reward, :terminal) -const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) -const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal) -const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action) +const SSAART = (SS..., AA..., RT...) +const SSLLAART = (SS..., LL..., AA..., RT...) include("sum_tree.jl") include("CircularArraySARTTraces.jl") diff --git a/src/episodes.jl b/src/episodes.jl deleted file mode 100644 index dcd5712..0000000 --- a/src/episodes.jl +++ /dev/null @@ -1,107 +0,0 @@ -export Episode, Episodes - -using MLUtils: batch - -""" - Episode(traces) - -An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]` -to check/update whether the episode reaches a terminal or not. -""" -struct Episode{T} - traces::T - is_done::Ref{Bool} -end - -Base.getindex(e::Episode, s::Symbol) = getindex(e.traces, s) -Base.keys(e::Episode) = keys(e.traces) - -Base.getindex(e::Episode) = getindex(e.is_done) -Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_done, x) - -Base.length(e::Episode) = length(e.traces) - -Episode(t::Traces) = Episode(t, Ref(false)) - -function Base.push!(t::Episode, x) - if t.is_done[] - throw(ArgumentError("The episode is already flagged as done!")) - else - push!(t.traces, x) - end -end - -function Base.append!(t::Episode, x) - if t.is_done[] - throw(ArgumentError("The episode is already flagged as done!")) - else - append!(t.traces, x) - end -end - -function Base.pop!(t::Episode) - pop!(t.traces) - t.is_done[] = false -end - -Base.popfirst!(t::Episode) = popfirst!(t.traces) - -function Base.empty!(t::Episode) - empty!(t.traces) - t.is_done[] = false -end - -##### - -""" - Episodes(init) - -A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an [`Episode`](@ref). -""" -struct Episodes - init::Any - episodes::Vector{Episode} - inds::Vector{Tuple{Int,Int}} -end - -Base.length(e::Episodes) = length(e.inds) - -function Base.push!(e::Episodes, x::Episode) - push!(e.episodes, x) - for i in 1:length(x) - push!(e.inds, (length(e.episodes), i)) - end -end - -function Base.append!(e::Episodes, xs::AbstractVector{<:Episode}) - for x in xs - push!(e, x) - end -end - -function Base.push!(e::Episodes, x) - if isempty(e.episodes) || e.episodes[end][] - episode = e.init() - push!(episode, x) - push!(e.episodes, episode) - else - push!(e.episodes[end], x) - push!(e.inds, (length(e.episodes), length(e.episodes[end]))) - end -end - -function Base.append!(e::Episodes, x) - n_pre = length(e.episodes[end]) - append!(e.episodes[end], x) - n_post = length(e.episodes[end]) - for i in n_pre:n_post - push!(e.inds, (lengthe.episodes, i)) - end -end - -## - -function sample(s::BatchSampler, e::Episodes) - inds = rand(s.rng, 1:length(t), s.batch_size) - batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer -end \ No newline at end of file diff --git a/src/patch.jl b/src/patch.jl new file mode 100644 index 0000000..9b08b8f --- /dev/null +++ b/src/patch.jl @@ -0,0 +1,3 @@ +import MLUtils + +MLUtils.batch(x::AbstractArray{<:Number}) = x \ No newline at end of file diff --git a/src/samplers.jl b/src/samplers.jl index dab24b7..d47a0d2 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -1,5 +1,7 @@ export BatchSampler, MetaSampler, MultiBatchSampler +using MLUtils: batch + using Random abstract type AbstractSampler end @@ -17,7 +19,12 @@ Uniformly sample a batch of examples for each trace. See also [`sample`](@ref). """ -BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, identity) +BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=batch) = BatchSampler(batch_size, rng, transformer) + +function sample(s::BatchSampler, t::AbstractTraces) + inds = rand(s.rng, 1:length(t), s.batch_size) + map(s.transformer, t[inds]) +end """ MetaSampler(::NamedTuple) @@ -29,15 +36,13 @@ Used internally for algorithms that sample multiple times per epoch. MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100)) """ -struct MetaSampler{names, T} <: AbstractSampler - samplers::NamedTuple{names, T} +struct MetaSampler{names,T} <: AbstractSampler + samplers::NamedTuple{names,T} end MetaSampler(; kw...) = MetaSampler(NamedTuple(kw)) -function sample(s::MetaSampler, t) - (;[(k, sample(v, t)) for (k,v) in pairs(s.samplers)]...) -end +sample(s::MetaSampler, t) = map(x -> sample(x, t), s.samplers) """ @@ -49,7 +54,7 @@ Wraps a sampler. When sampled, will sample n batches using sampler. Useful in co MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), critic = MultiBatchSampler(BatchSampler(100), 5)) """ -struct MultiBatchSampler{S <: AbstractSampler} <: AbstractSampler +struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler sampler::S n::Int end diff --git a/src/traces.jl b/src/traces.jl index 02a96ab..fd9f20a 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -1,88 +1,321 @@ -export Trace, Traces, sample +export Trace, Traces, MultiplexTraces, Episode, Episodes + +import MacroTools: @forward +import StackViews: StackView + +##### + +abstract type AbstractTrace{E} <: AbstractVector{E} end + +Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x + +Base.summary(io::IO, t::AbstractTrace) = print(io, "$(length(t))-element $(nameof(typeof(t)))") + +##### +struct Trace{T,E} <: AbstractTrace{E} + parent::T +end + +Base.summary(io::IO, t::Trace{T}) where {T} = print(io, "$(length(t))-element $(nameof(typeof(t))){$T}") + +function Trace(x::T) where {T<:AbstractArray} + E = eltype(x) + N = ndims(x) - 1 + P = typeof(x) + I = Tuple{ntuple(_ -> Base.Slice{Base.OneTo{Int}}, Val(ndims(x) - 1))...,Int} + Trace{T,SubArray{E,N,P,I,true}}(x) +end + +Base.convert(::Type{AbstractTrace}, x::AbstractArray) = Trace(x) + +Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),) +Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) +Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) + +@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty! + +##### """ - Trace(data) +For each concrete `AbstractTraces`, we have the following assumption: -A wrapper of arbitrary container. Generally we assume the `data` is an -`AbstractVector` like object. When an `AbstractArray` is given, we view it as a -vector of sub-arrays along the last dimension. +1. Every inner trace is an `AbstractVector` +1. Support partial updating +1. Return *View* by default when getting elements. """ -struct Trace{T} - x::T +abstract type AbstractTraces{names,T} <: AbstractVector{NamedTuple{names,T}} end + +function Base.show(io::IO, ::MIME"text/plain", t::AbstractTraces{names,T}) where {names,T} + s = nameof(typeof(t)) + println(io, "$s with $(length(names)) entries:") + for n in names + println(io, " :$n => $(summary(t[n]))") + end end -Base.length(t::Trace) = length(t.x) -Base.length(t::Trace{<:AbstractArray}) = size(t.x, ndims(t.x)) +Base.keys(t::AbstractTraces{names}) where {names} = names +Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names -Base.lastindex(t::Trace) = length(t) -Base.firstindex(t::Trace) = 1 +##### -Base.convert(::Type{Trace}, x) = Trace(x) +""" + MultiplexTraces{names}(trace) -Base.getindex(t::Trace{<:AbstractVector}, I...) = getindex(t.x, I...) -Base.view(t::Trace{<:AbstractVector}, I...) = view(t.x, I...) +A special [`AbstractTraces`](@ref) which has exactly two traces of the same +length. And those two traces share the header and tail part. -Base.getindex(t::Trace{<:AbstractArray}, I...) = getindex(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) -Base.view(t::Trace{<:AbstractArray}, I...) = view(t.x, ntuple(_ -> :, ndims(t.x) - 1)..., I...) +For example, if a `trace` contains elements between 0 and 9, then the first +`trace_A` is a view of elements from 0 to 8 and the second one is a view from 1 +to 9. -Base.push!(t::Trace, x) = push!(t.x, x) -Base.append!(t::Trace, x) = append!(t.x, x) +``` + ┌─────trace_A───┐ +trace 0 1 2 3 4 5 6 7 8 9 + └────trace_B────┘ +``` -Base.pop!(t::Trace) = pop!(t.x) -Base.popfirst!(t::Trace) = popfirst!(t.x) -Base.empty!(t::Trace) = empty!(t.x) +This is quite common in RL to represent `states` and `next_states`. +""" +struct MultiplexTraces{names,T,E} <: AbstractTraces{names,Tuple{E,E}} + trace::T +end -## +function MultiplexTraces{names}(t) where {names} + if length(names) != 2 + throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $length(names) trace names")) + end + trace = convert(AbstractTrace, t) + MultiplexTraces{names,typeof(trace),eltype(trace)}(trace) +end -function sample(s::BatchSampler, t::Trace) - inds = rand(s.rng, 1:length(t), s.batch_size) - t[inds] |> s.transformer +function Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} + a, b = names + if k == a + convert(AbstractTrace, t.trace[1:end-1]) + elseif k == b + convert(AbstractTrace, t.trace[2:end]) + else + throw(ArgumentError("unknown trace name: $k")) + end +end + +Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = NamedTuple{names}((t.trace[I], t.trace[I+1])) +Base.getindex(t::MultiplexTraces{names}, I::AbstractArray{Int}) where {names} = NamedTuple{names}((t.trace[I], t.trace[I.+1])) + +Base.size(t::MultiplexTraces) = (max(0, length(t.trace) - 1),) + +@forward MultiplexTraces.trace Base.parent, Base.pop!, Base.popfirst!, Base.empty! + +for f in (:push!, :pushfirst!, :append!, :prepend!) + @eval function Base.$f(t::MultiplexTraces{names}, x::NamedTuple{ks,Tuple{Ts}}) where {names,ks,Ts} + k, v = first(ks), first(x) + if k in names + $f(t.trace, v) + else + throw(ArgumentError("unknown trace name: $k")) + end + end end ##### """ - Traces(;kw...) + Episode(traces) -A container of several named-[`Trace`](@ref)s. Each element in the `kw` will be converted into a `Trace`. +An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]` +to check/update whether the episode reaches a terminal or not. """ -struct Traces{names,T} - traces::NamedTuple{names,T} - function Traces(; kw...) - traces = map(x -> convert(Trace, x), values(kw)) - new{keys(traces),typeof(values(traces))}(traces) +struct Episode{T,names,E} <: AbstractTraces{names,E} + traces::T + is_terminated::Ref{Bool} +end + +Episode(t::AbstractTraces{names,T}) where {names,T} = Episode{typeof(t),names,T}(t, Ref(false)) + +@forward Episode.traces Base.getindex, Base.setindex!, Base.size + +Base.getindex(e::Episode) = getindex(e.is_terminated) +Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_terminated, x) + +for f in (:push!, :append!) + @eval function Base.$f(t::Episode, x) + if t.is_terminated[] + throw(ArgumentError("The episode is already flagged as done!")) + else + $f(t.traces, x) + end end end -Base.keys(t::Traces) = keys(t.traces) -Base.haskey(t::Traces, s::Symbol) = haskey(t.traces, s) -Base.getindex(t::Traces, x) = getindex(t.traces, x) -Base.length(t::Traces) = mapreduce(length, min, t.traces) +function Base.pop!(t::Episode) + pop!(t.traces) + t.is_terminated[] = false +end + +Base.pushfirst!(t::Episode, x) = pushfirst!(t.traces, x) +Base.prepend!(t::Episode, x) = prepend!(t.traces, x) +Base.popfirst!(t::Episode) = popfirst!(t.traces) -Base.push!(t::Traces; kw...) = push!(t, values(kw)) +function Base.empty!(t::Episode) + empty!(t.traces) + t.is_terminated[] = false +end -function Base.push!(t::Traces, x::NamedTuple) - for k in keys(x) - push!(t[k], x[k]) +##### + +""" + Episodes(init) + +A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an empty [`Episode`](@ref). +""" +struct Episodes{names,E,T} <: AbstractTraces{names,E} + init::Any + episodes::Vector{T} + inds::Vector{Tuple{Int,Int}} +end + +function Episodes(init) + x = init() + T = typeof(x) + @assert x isa Episode + @assert length(x) == 0 + names, E = eltype(x).parameters + Episodes{names,E,T}(init, [x], Tuple{Int,Int}[]) +end + +Base.size(e::Episodes) = size(e.inds) + +Base.setindex!(e::Episodes, is_terminated::Bool) = setindex!(e.episodes[end], is_terminated) + +Base.getindex(e::Episodes) = getindex(e.episodes[end]) + +function Base.getindex(e::Episodes, I::Int) + i, j = e.inds[I] + e.episodes[i][j] +end + +function Base.getindex(e::Episodes{names}, I) where {names} + NamedTuple{names}( + StackView( + map(I) do i + x, y = e.inds[i] + e.episodes[x][n][y] + end + ) + for n in names + ) +end + +function Base.getindex(e::Episodes, I::Symbol) + @warn "The returned trace is a vector of partitions instead of a continuous view" maxlog = 1 + map(x -> x[I], e.episodes) +end + +function Base.push!(e::Episodes, x::Episode) + # !!! note we do not check whether the last Episode is terminated or not here + push!(e.episodes, x) + for i in 1:length(x) + push!(e.inds, (length(e.episodes), i)) + end +end + +function Base.append!(e::Episodes, xs::AbstractVector{<:Episode}) + # !!! note we do not check whether each Episode is terminated or not here + for x in xs + push!(e, x) + end +end + +function Base.push!(e::Episodes, x::NamedTuple) + if isempty(e.episodes) || e.episodes[end][] + episode = e.init() + push!(episode, x) + push!(e, episode) + else + n_pre = length(e.episodes[end]) + push!(e.episodes[end], x) + n_post = length(e.episodes[end]) + # this is to support partial inserting + if n_post - n_pre == 1 + push!(e.inds, (length(e.episodes), length(e.episodes[end]))) + end end end -Base.append!(t::Traces; kw...) = append!(t, values(kw)) +##### +struct Traces{names,T,N,E} <: AbstractTraces{names,E} + traces::T + inds::NamedTuple{names,NTuple{N,Int}} +end + + +function Traces(; kw...) + data = map(x -> convert(AbstractTrace, x), values(kw)) + names = keys(data) + inds = NamedTuple(k => i for (i, k) in enumerate(names)) + Traces{names,typeof(data),length(names),typeof(values(data))}(data, inds) +end + -function Base.append!(t::Traces, x::NamedTuple) - for k in keys(x) - append!(t[k], x[k]) +function Base.getindex(ts::Traces, s::Symbol) + t = ts.traces[ts.inds[s]] + if t isa AbstractTrace + t + else + t[s] end end -Base.pop!(t::Traces) = map(pop!, t.traces) -Base.popfirst!(t::Traces) = map(popfirst!, t.traces) -Base.empty!(t::Traces) = map(empty!, t.traces) +Base.getindex(t::Traces{names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names)) + +function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::AbstractTraces{k2,T2}) where {k1,k2,T1,T2} + ks = (k1..., k2...) + ts = (t1, t2) + inds = (; (k => 1 for k in k1)..., (k => 2 for k in k2)...) + Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) +end + +function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::Traces{k2,T,N,T2}) where {k1,T1,k2,T,N,T2} + ks = (k1..., k2...) + ts = (t1, t2.traces...) + inds = merge(NamedTuple(k => 1 for k in k1), map(v -> v + 1, t2.inds)) + Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) +end + -## -function sample(s::BatchSampler, t::Traces) - inds = rand(s.rng, 1:length(t), s.batch_size) - map(t.traces) do x - x[inds] - end |> s.transformer -end \ No newline at end of file +function Base.:(+)(t1::Traces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where {k1,T,N,T1,k2,T2} + ks = (k1..., k2...) + ts = (t1.traces..., t2) + inds = merge(t1.inds, (; (k => length(ts) for k in k2)...)) + Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds) +end + +function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2} + ks = (k1..., k2...) + ts = (t1.traces..., t2.traces...) + inds = merge(t1.inds, map(x -> x + length(t1.traces), t2.inds)) + Traces{ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}}(ts, inds) +end + +Base.size(t::Traces) = (mapreduce(length, min, t.traces),) + +for f in (:push!, :pushfirst!, :append!, :prepend!) + @eval function Base.$f(ts::Traces, xs::NamedTuple) + for (k, v) in pairs(xs) + t = ts.traces[ts.inds[k]] + if t isa AbstractTrace + $f(t, v) + else + $f(t, (; k => v)) + end + end + end +end + +for f in (:pop!, :popfirst!, :empty!) + @eval function Base.$f(ts::Traces) + for t in ts.traces + $f(t) + end + end +end diff --git a/src/trajectory.jl b/src/trajectory.jl index fac1f13..890631a 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,7 +1,9 @@ -export Trajectory +export Trajectory, TrajectoryStyle, SyncTrajectoryStyle, AsyncTrajectoryStyle using Base.Threads +struct AsyncTrajectoryStyle end +struct SyncTrajectoryStyle end """ Trajectory(container, sampler, controller) @@ -53,8 +55,15 @@ Base.@kwdef struct Trajectory{C,S,T} end end +TrajectoryStyle(::Trajectory) = SyncTrajectoryStyle() +TrajectoryStyle(::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = AsyncTrajectoryStyle() -Base.push!(t::Trajectory; kw...) = push!(t, values(kw)) +Base.bind(::Trajectory, ::Task) = nothing + +function Base.bind(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, task) + bind(t.controler.ch_in, task) + bind(t.controler.ch_out, task) +end function Base.push!(t::Trajectory, x) n_pre = length(t.container) @@ -72,8 +81,6 @@ end Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.push!, args, kw)) Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.append!, args, kw)) -Base.append!(t::Trajectory; kw...) = append!(t, values(kw)) - function Base.append!(t::Trajectory, x) n_pre = length(t.container) append!(t.container, x) diff --git a/test/common.jl b/test/common.jl new file mode 100644 index 0000000..444c3fc --- /dev/null +++ b/test/common.jl @@ -0,0 +1,105 @@ +@testset "sum_tree" begin + t = SumTree(8) + + for i in 1:4 + push!(t, i) + end + + @test length(t) == 4 + @test size(t) == (4,) + + for i in 5:16 + push!(t, i) + end + + @test length(t) == 8 + @test size(t) == (8,) + @test t == 9:16 + + t[:] .= 1 + @test t == ones(8) + @test all([get(t, v)[1] == i for (i, v) in enumerate(0.5:1.0:8)]) + + empty!(t) + @test length(t) == 0 +end + +@testset "CircularArraySARTTraces" begin + t = CircularArraySARTTraces(; + capacity=3, + state=Float32 => (2, 3), + action=Float32 => (2,), + reward=Float32 => (), + terminal=Bool => () + ) + + @test t isa CircularArraySARTTraces + + push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2))) + @test length(t) == 0 + + push!(t, (reward=1.0f0, terminal=false)) + @test length(t) == 0 # next_state and next_action is still missing + + push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2)) + @test length(t) == 1 + + @test t[1] == ( + state=ones(Float32, 2, 3), + next_state=ones(Float32, 2, 3) * 2, + action=ones(Float32, 2), + next_action=ones(Float32, 2) * 2, + reward=1.0f0, + terminal=false, + ) + + push!(t, (reward=2.0f0, terminal=false)) + push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3)) + + @test length(t) == 2 + + push!(t, (reward=3.0f0, terminal=false)) + push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4)) + + @test length(t) == 3 + + push!(t, (reward=4.0f0, terminal=false)) + push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5)) + + @test length(t) == 3 + @test t[1] == ( + state=ones(Float32, 2, 3) * 2, + next_state=ones(Float32, 2, 3) * 3, + action=ones(Float32, 2) * 2, + next_action=ones(Float32, 2) * 3, + reward=2.0f0, + terminal=false, + ) + @test t[end] == ( + state=ones(Float32, 2, 3) * 4, + next_state=ones(Float32, 2, 3) * 5, + action=ones(Float32, 2) * 4, + next_action=ones(Float32, 2) * 5, + reward=4.0f0, + terminal=false, + ) + + batch = t[1:3] + @test size(batch.state) == (2, 3, 3) + @test size(batch.action) == (2, 3) + @test batch.reward == [2.0, 3.0, 4.0] + @test batch.terminal == Bool[0, 0, 0] +end + +@testset "CircularArraySLARTTraces" begin + t = CircularArraySLARTTraces(; + capacity=3, + state=Float32 => (2, 3), + legal_actions_mask=Bool => (5,), + action=Int => (), + reward=Float32 => (), + terminal=Bool => () + ) + + @test t isa CircularArraySLARTTraces +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index b587989..ec304ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,11 @@ using Trajectories +using CircularArrayBuffers using Test @testset "Trajectories.jl" begin include("traces.jl") + include("common.jl") + include("samplers.jl") include("trajectories.jl") include("samplers.jl") end diff --git a/test/samplers.jl b/test/samplers.jl index 4243373..5edb7f2 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -1,4 +1,28 @@ -using Trajectories, Test +@testset "BatchSampler" begin + sz = 32 + s = BatchSampler(sz) + t = Traces( + state=rand(3, 4, 5), + action=rand(1:4, 5), + ) + + b = Trajectories.sample(s, t) + + @test keys(b) == (:state, :action) + @test size(b.state) == (3, 4, sz) + @test size(b.action) == (sz,) + + e = Episodes() do + Episode(Traces(state=rand(2, 3, 0), action=rand(0))) + end + + push!(e, Episode(Traces(state=rand(2, 3, 2), action=rand(2)))) + push!(e, Episode(Traces(state=rand(2, 3, 3), action=rand(3)))) + + @test length(e) == 5 + @test size(e[2:4].state) == (2, 3, 3) + @test_broken size(e[2:4].action) == (3,) +end @testset "MetaSampler" begin t = Trajectory( @@ -6,11 +30,11 @@ using Trajectories, Test a=Int[], b=Bool[] ), - sampler = MetaSampler(policy = BatchSampler(3), critic = BatchSampler(5)), - controller = InsertSampleController(10, 0) + sampler=MetaSampler(policy=BatchSampler(3), critic=BatchSampler(5)), + controller=InsertSampleController(10, 0) ) - append!(t; a=[1, 2, 3, 4], b=[false, true, false, true]) + append!(t, (a=[1, 2, 3, 4], b=[false, true, false, true])) batches = [] @@ -19,7 +43,7 @@ using Trajectories, Test end @test length(batches) == 10 - @test length(batches[1][:policy][:a]) == 3 && length(batches[1][:critic][:b]) == 5 + @test length(batches[1][:policy][:a]) == 3 && length(batches[1][:critic][:b]) == 5 end @testset "MultiBatchSampler" begin @@ -28,11 +52,11 @@ end a=Int[], b=Bool[] ), - sampler = MetaSampler(policy = BatchSampler(3), critic = MultiBatchSampler(BatchSampler(5), 2)), - controller = InsertSampleController(10, 0) + sampler=MetaSampler(policy=BatchSampler(3), critic=MultiBatchSampler(BatchSampler(5), 2)), + controller=InsertSampleController(10, 0) ) - append!(t; a=[1, 2, 3, 4], b=[false, true, false, true]) + append!(t, (a=[1, 2, 3, 4], b=[false, true, false, true])) batches = [] @@ -41,7 +65,7 @@ end end @test length(batches) == 10 - @test length(batches[1][:policy][:a]) == 3 + @test length(batches[1][:policy][:a]) == 3 @test length(batches[1][:critic]) == 2 # we sampled 2 batches for critic @test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples end @@ -60,7 +84,7 @@ end n = 100 insert_task = @async for i in 1:n - append!(t; a=[i, i, i, i], b=[false, true, false, true]) + append!(t, (a=[i, i, i, i], b=[false, true, false, true])) end s = 0 diff --git a/test/traces.jl b/test/traces.jl index cfcd72e..82380b4 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -1,57 +1,183 @@ -@testset "Trace 1d" begin - t = Trace([]) +@testset "Traces" begin + t = Traces(; + a=[1, 2], + b=Bool[0, 1] + ) + + @test length(t) == 2 + + push!(t, (; a=3, b=true)) + + @test t[:a][end] == 3 + @test t[:b][end] == true + + append!(t, (a=[4, 5], b=[false, false])) + @test length(t[:a]) == 5 + @test t[:b][end-1:end] == [false, false] + + @test t[1] == (a=1, b=false) + + t_12 = t[1:2] + @test t_12.a == [1, 2] + @test t_12.b == [false, true] + + t_12_view = t[1:2] + t_12_view.a[1] = 0 + @test t[:a][1] == 0 + + pop!(t) + @test length(t) == 4 + + popfirst!(t) + @test length(t) == 3 + + empty!(t) + @test length(t) == 0 +end + +@testset "MultiplexTraces" begin + t = MultiplexTraces{(:state, :next_state)}(Int[]) + + @test length(t) == 0 + + push!(t, (; state=1)) + push!(t, (; next_state=2)) + + @test t[:state] == [1] + @test t[:next_state] == [2] + @test t[1] == (state=1, next_state=2) + + append!(t, (; state=[3, 4])) + + @test t[:state] == [1, 2, 3] + @test t[:next_state] == [2, 3, 4] + @test t[end] == (state=3, next_state=4) + + pop!(t) + t[end] == (state=2, next_state=3) + empty!(t) + @test length(t) == 0 +end + +@testset "MergedTraces" begin + t1 = Traces(a=Int[]) + t2 = Traces(b=Bool[]) + + t3 = t1 + t2 + @test t3[:a] === t1[:a] + @test t3[:b] === t2[:b] + + push!(t3, (; a=1, b=false)) + @test length(t3) == 1 + @test t3[1] == (a=1, b=false) + + append!(t3, (; a=[2, 3], b=[false, true])) + @test length(t3) == 3 + + @test t3[:a][1:3] == [1, 2, 3] + + t3_view = t3[1:3] + t3_view[:a][1] = 0 + @test t3[:a][1] == 0 + + pop!(t3) + @test length(t3) == 2 + + empty!(t3) + @test length(t3) == 0 + + t4 = MultiplexTraces{(:m, :n)}(Float64[]) + t5 = t4 + t2 + t1 + + push!(t5, (m=1.0, n=1.0, a=1, b=1)) + @test length(t5) == 1 + + push!(t5, (m=2.0, a=2, b=0)) + + @test t5[end] == (m=1.0, n=2.0, b=false, a=2) + + t6 = Traces(aa=Int[]) + t7 = Traces(bb=Bool[]) + t8 = (t1 + t2) + (t6 + t7) + + empty!(t8) + push!(t8, (a=1, b=false, aa=1, bb=false)) + append!(t8, (a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true])) + + @test length(t8) == 3 + + t8_view = t8[2:3] + t8_view.a[1] = 0 + @test t8[:a][2] == 0 +end + +@testset "Episode" begin + t = Episode( + Traces( + state=Int[], + action=Float64[] + ) + ) + @test length(t) == 0 - push!(t, 1) + push!(t, (state=1, action=1.0)) @test length(t) == 1 - @test t[1] == 1 - append!(t, [2, 3]) + append!(t, (state=[2, 3], action=[2.0, 3.0])) @test length(t) == 3 - @test @view(t[2:3]) == [2, 3] + + @test t[:state] == [1, 2, 3] + @test t[end-1:end] == ( + state=[2, 3], + action=[2.0, 3.0] + ) + + t[] = true # seal + @test_throws ArgumentError push!(t, (state=4, action=4.0)) pop!(t) @test length(t) == 2 - s = BatchSampler(2) - @test size(sample(s, t)) == (2,) + push!(t, (state=4, action=4.0)) + @test length(t) == 3 + t[] = true # seal empty!(t) - @test length(t) == 0 + @test length(t) == 0 end -@testset "Trace 2d" begin - t = Trace([ - 1 2 3 - 4 5 6 - ]) - @test length(t) == 3 - @test t[1] == [1, 4] - @test @view(t[2:3]) == [2 3; 5 6] +@testset "Episodes" begin + t = Episodes() do + Episode(Traces(state=Float64[], action=Int[])) + end - s = BatchSampler(5) - @test size(sample(s, t)) == (2, 5) -end + @test length(t) == 0 -@testset "Traces" begin - t = Traces(; - a=[1, 2], - b=Bool[0, 1] - ) + push!(t, (state=1.0, action=1)) + + @test length(t) == 1 + @test t[1] == (state=1.0, action=1) - @test keys(t) == (:a, :b) - @test haskey(t, :a) - @test t[:a] isa Trace + t[] = true # seal - push!(t; a=3, b=true) - @test t[:a][end] == 3 - @test t[:b][end] == true + push!(t, (state=2.0, action=2)) + @test length(t) == 2 - append!(t; a=[4, 5], b=[false, false]) - @test length(t[:a]) == 5 - @test t[:b][end-1:end] == [false, false] + @test t[end] == (state=2.0, action=2) + + # https://github.com/JuliaArrays/StackViews.jl/issues/3 + @test_broken t[1:2] == (state=[1.0, 2.0], action=[1, 2]) + + push!(t, (state=3.0, action=3)) + t[] = true # seal + + @test_broken size(t[:state]) == (3,) + + push!(t, Episode(Traces(state=[4.0, 5.0, 6.0], action=[4, 5, 6]))) + @test t[] == false - s = BatchSampler(5) - @test size(sample(s, t)[:a]) == (5,) + t[] = true + @test length(t) == 6 end \ No newline at end of file diff --git a/test/trajectories.jl b/test/trajectories.jl index b7a104c..9bbb82e 100644 --- a/test/trajectories.jl +++ b/test/trajectories.jl @@ -16,7 +16,7 @@ @test length(batches) == 0 # threshold not reached yet - append!(t; a=[1, 2, 3], b=[false, true, false]) + append!(t, (a=[1, 2, 3], b=[false, true, false])) for batch in t push!(batches, batch) @@ -24,7 +24,7 @@ @test length(batches) == 0 # threshold not reached yet - push!(t; a=4, b=true) + push!(t, (a=4, b=true)) for batch in t push!(batches, batch) @@ -32,7 +32,7 @@ @test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25 - append!(t; a=[5, 6, 7], b=[true, true, true]) + append!(t, (a=[5, 6, 7], b=[true, true, true])) for batch in t push!(batches, batch) @@ -40,7 +40,7 @@ @test length(batches) == 1 # 7 inserted, threshold is 4, ratio is 0.25 - push!(t; a=8, b=true) + push!(t, (a=8, b=true)) for batch in t push!(batches, batch) @@ -50,7 +50,7 @@ n = 100 for i in 1:n - append!(t; a=[i, i, i, i], b=[false, true, false, true]) + append!(t, (a=[i, i, i, i], b=[false, true, false, true])) end s = 0 @@ -74,7 +74,7 @@ end n = 100 insert_task = @async for i in 1:n - append!(t; a=[i, i, i, i], b=[false, true, false, true]) + append!(t, (a=[i, i, i, i], b=[false, true, false, true])) end s = 0