8000 Add a utility function to sample using `scan` · Issue #80 · aesara-devs/aemcmc · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Add a utility function to sample using scan #80
Open
@rlouf

Description

@rlouf

Using the sampling steps built by AeMCMC in a scan loop is not straightforward:

import aesara
import aemcmc


sample_steps, sample_updates, initial_values = aemcmc.construct_sampler(
    {Y_rv: y_tt}, srng
)

to_sample_rvs: List[TensorVariable]
inputs = [initial_values[rv] for rv in to_sample_rvs]
outputs = [sample_steps[rv] for rv in to_sample_rvs]

def step_fn(*values):
    from aesara.compile.function.pfunc import rebuild_collect_shared

    vv_to_values = {inputs[i]: val for i, val in enumerate(values)}

    _, new_values, [_, new_updates, _, _] = rebuild_collect_shared(
        outputs, inputs=inputs, replace=vv_to_values, updates=sample_updates
    )

    return new_values, new_updates

n_samples = at.iscalar("n_samples")
outputs, updates = aesara.scan(step_fn, outputs_info=inputs, n_steps=n_samples)

sample_fn = aesara.function(inputs + [n_samples], outputs, updates=updates)

but easily generalizable. We should implement a utility function, e.g. aemcmc.sampling_loop which, given the outputs of construct_sampler and a number of iterations n_samples returns a graph that generate n_samples.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0