Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![ci](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml/badge.svg)](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml)
[![coverage](https://codecov.io/gh/ramsey-devs/ramsey/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/ramsey-devs/ramsey)
[![quality](https://app.codacy.com/project/badge/Grade/ed13460537fd4ac099c8534b1d9a0202)](https://app.codacy.com/gh/ramsey-devs/ramsey/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
[![documentation](https://readthedocs.org/projects/ramsey/badge/?version=latest)](https://ramsey.readthedocs.io/en/latest/?badge=latest)
[![version](https://img.shields.io/pypi/v/ramsey.svg?colorB=black&style=flat)](https://pypi.org/project/ramsey/)

Expand All @@ -12,13 +11,10 @@
## About

Ramsey is a library for probabilistic deep learning using [JAX](https://github.com/google/jax),
[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro).

Ramsey's scope covers
[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro). Its scope covers

- neural processes (vanilla, attentive, Markovian, convolutional, ...),
- neural Laplace and Fourier operator models,
- flow matching and denoising diffusion models,
- etc.

## Example usage
Expand All @@ -29,35 +25,44 @@ You can, for instance, construct a simple neural process like this:
from flax import nnx

from ramsey import NP
from ramsey.nn import MLP
from ramsey.nn import MLP # just a flax.nnx module

def get_neural_process(in_features, out_features):
dim = 128
np = NP(
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(0)),\
latent_encoder=(
MLP(in_features, [dim, dim], rngs=nnx.Rngs(1)),
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(2))
)
MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
),
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2))
)
return np

neural_process = get_neural_process(1, 1)
```

The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train
it by accessing the ELBO given input-output pairs via
The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.

Ramsey provides a unified interface where each method implements (at least) `__call__` and `loss`
functions to transform a set of inputs and compute a training loss, respectively:

```python
from jax import random as jr
from ramsey.data import sample_from_sine_function

key = jr.PRNGKey(0)
data = sample_from_sine_function(key)

data = sample_from_sine_function(jr.key(0))
x_context, y_context = data.x[:, :20, :], data.y[:, :20, :]
x_target, y_target = data.x, data.y

# make a prediction
pred = neural_process(
x_context=x_context,
y_context=y_context,
x_target=x_target,
)

# compute the loss
loss = neural_process.loss(
x_context=x_context,
y_context=y_context,
Expand All @@ -66,11 +71,6 @@ loss = neural_process.loss(
)
```

Making predictions can be done like this:
```python
pred = neural_process(x_context=x_context, y_context=y_context, x_target=x_target)
```

## Installation

To install from PyPI, call:
Expand Down
6 changes: 4 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@

autosummary_generate = True
autodoc_typehints = 'none'
typehints_fully_qualified = True
always_document_param_types = True
# typehints_fully_qualified = True
# always_document_param_types = True
# autodoc_inherit_docstrings = False
# typehints_document_rtype= False

html_theme = "sphinx_book_theme"
html_theme_options = {
Expand Down
41 changes: 21 additions & 20 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@

Ramsey is a library for probabilistic modelling using `JAX <https://github.com/google/jax>`_ ,
`Flax <https://github.com/google/flax>`_ and `NumPyro <https://github.com/pyro-ppl/numpyro>`_.

Ramsey's scope covers

- neural processes (vanilla, attentive, Markovian, convolutional, ...),
- neural Laplace and Fourier operator models,
- flow matching and denoising diffusion models,
- etc.

Example
Expand All @@ -25,48 +23,52 @@ You can, for instance, construct a simple neural process like this:
from flax import nnx

from ramsey import NP
from ramsey.nn import MLP
from ramsey.nn import MLP # just a flax.nnx module

def get_neural_process(in_features, out_features):
dim = 128
np = NP(
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(0)),
latent_encoder=(
MLP(in_features, [dim, dim], rngs=nnx.Rngs(1)),
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(2))
)
MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
),
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2)),
)
return np

neural_process = get_neural_process(1, 1)

The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train
the model by accessing the ELBO given input-output pairs via
The neural process above takes a decoder and a set of two latent encoders as arguments.
All of these are typically ``flax.nnx`` MLPs, but Ramsey is flexible enough that you can
change them, for instance, to CNNs or RNNs.

Ramsey provides a unified interface where each method implements (at least) ``__call__`` and ``loss``
functions to transform a set of inputs and compute a training loss, respectively:

.. code-block:: python

from jax import random as jr
from ramsey.data import sample_from_sine_function

key = jr.PRNGKey(0)
data = sample_from_sine_function(key)

data = sample_from_sine_function(jr.key(0))
x_context, y_context = data.x[:, :20, :], data.y[:, :20, :]
x_target, y_target = data.x, data.y

# make a prediction
pred = neural_process(
x_context=x_context,
y_context=y_context,
x_target=x_target,
)

# compute the loss
loss = neural_process.loss(
x_context=x_context,
y_context=y_context,
x_target=x_target,
y_target=y_target
)

Making predictions can be done like this:

.. code-block:: python

pred = neural_process(x_context=x_context, y_context=y_context, x_target=x_target)


Why Ramsey
----------
Expand Down Expand Up @@ -119,7 +121,6 @@ Ramsey is licensed under the Apache 2.0 License.
:hidden:

🏠 Home <self>
📰 News <news>
📚 References <references>

.. toctree::
Expand Down
16 changes: 0 additions & 16 deletions docs/news.rst

This file was deleted.

114 changes: 59 additions & 55 deletions docs/notebooks/neural_processes.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions docs/ramsey.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ Covariance functions
ExponentiatedQuadratic
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ExponentiatedQuadratic
:members: __call__
.. autoclass:: ExponentiatedQuadratic
:members: __call__

.. autofunction:: exponentiated_quadratic

Linear
~~~~~~

.. autoclass:: Linear
:members: __call__
.. autoclass:: Linear
:members: __call__

.. autofunction:: linear

Periodic
~~~~~~~~~

.. autoclass:: Periodic
:members: __call__
.. autoclass:: Periodic
:members: __call__

.. autofunction:: periodic
6 changes: 3 additions & 3 deletions docs/ramsey.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ Neural processes
~~~~~~~~~~~~~~~~

.. autoclass:: NP
:members: __call__
:members: __call__, loss

.. autoclass:: ANP
:members: __call__
:members: __call__, loss

.. autoclass:: DANP
:members: __call__
:members: __call__, loss

Train functions
---------------
Expand Down
Loading