Skip to content

Reducing reducndancy for primitive functions #250

@cscherrer

Description

@cscherrer

"Primitive" here is a term I've been using for functions that use GeneralizedGenerated.jl to generated a function based on a Model and usually some other values. For each of these, there's a source____ function that builds the AST, for example sourceLogdensity and sourceRand.

For example, logdensity is built from

function sourceLogdensity()
    function(_m::Model)
        proc(_m, st :: Assign)     = :($(st.x) = $(st.rhs))
        proc(_m, st :: Return)     = nothing
        proc(_m, st :: LineNumber) = nothing
        function proc(_m, st :: Sample)
            x = st.x
            rhs = st.rhs
            @q begin
                _ℓ += logdensity($rhs, $x)
                $x = Soss.predict($rhs, $x)
            end
        end

        wrap(kernel) = @q begin
            _ℓ = 0.0
            $kernel
            return _ℓ
        end

        buildSource(_m, proc, wrap) |> MacroTools.flatten
    end
end

and rand is built from

function sourceRand() 
    function(_m::Model)
        proc(_m, st::Assign)  = :($(st.x) = $(st.rhs))
        proc(_m, st::Sample)  = :($(st.x) = rand(_rng, $(st.rhs)))
        proc(_m, st::Return)  = :(return $(st.rhs))
        proc(_m, st::LineNumber) = nothing

        vals = map(x -> Expr(:(=), x,x),parameters(_m)) 

        wrap(kernel) = @q begin
            _rng -> begin
                $kernel
                $(Expr(:tuple, vals...))
            end
        end

        buildSource(_m, proc, wrap) |> MacroTools.flatten
    end
end

There's clearly a lot of commonality between these, and also between the many calls to @gg:

chad@albatross ~/g/Soss.jl (dev)> rg @gg
src/importance.jl
167:@gg M function _importanceSample(_::Type{M}, p::Model, _pargs, q::Model, _qargs, _data) where M <: TypeLevel{Module}

src/simulate.jl
124:@gg M function _simulate(_::Type{M}, _m::Model, _args, trace_assignments::Val{V}) where {V, M <: TypeLevel{Module}}
131:@gg M function _simulate(_::Type{M}, _m::Model, _args::NamedTuple{()}, trace_assignments::Val{V}) where {V, M <: TypeLevel{Module}}

src/particles.jl
150:@gg M function _particles(_::Type{M}, _m::Model, _args, _n::Val{_N}) where {M <: TypeLevel{Module},_N}
156:@gg M function _particles(_::Type{M}, _m::Model, _args::NamedTuple{()}, _n::Val{_N}) where {M <: TypeLevel{Module},_N}

src/primitives/likelihood-weighting.jl
38:@gg M function _weightedSample(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module}

src/primitives/rand.jl
59:@gg M function _rand(_::Type{M}, _m::Model, _args) where M <: TypeLevel{Module}
65:@gg M function _rand(_::Type{M}, _m::Model, _args::NamedTuple{()}) where M <: TypeLevel{Module}

src/primitives/logdensity.jl
43:@gg M function _logdensity(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}

src/primitives/xform.jl
148:@gg M function _xform(_::Type{M}, _m::Model{Asub,B}, _args::A, _data) where {M <: TypeLevel{Module}, Asub, A,B}

src/primitives/entropy.jl
55:@gg M function _entropy(_::Type{M}, _m::Model, _args, _n::Val{_N}) where {M <: TypeLevel{Module},_N}
61:@gg M function _entropy(_::Type{M}, _m::Model, _args::NamedTuple{()}, _n::Val{_N}) where {M <: TypeLevel{Module},_N}

src/symbolic/symbolic.jl
143:@gg M function _symlogdensity(_::Type{M}, _m::Model, ::Type{T}) where {T, M <: TypeLevel{Module}}

src/primitives/basemeasure.jl
40:@gg M function _basemeasure(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}
chad@albatross ~/g/Soss.jl (dev)> 

This makes me wonder, can we put all of this under a common higher-order function? Maybe something like

@gg M function makeprimitive(::Type{M}, _m::Model, f, post, args...)

where f takes the place of proc (since that name's not so descriptive anyway), and args... can hold whatever other arguments are passed. post is a function Expr -> Expr, which in many cases might just add some surrounding context.

Some challenges:

  • The way args is used can change a lot across functions
  • In the past I've found it very tricky to manage what exactly is known at what time. In some cases we need to know values at AST generation time, in other cases just types.

If it can become easier to build new primitives, this will encourage people to use this functionality. I think there's a really great potential if we can do this. Things do get tricky at this degree of abstraction, so we nede to be sure we can completely represent what we have already without losing performance.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions