Jax and Flax Time Series Prediction Transformer. The goal of this repo is to learn Jax and Flax by implementing a deep learning model. In this case, a transformer for time series prediction. I have decided to predict cryptocurrency due to the availability of the data. However, this could be used for any other time series data, such as market stocks or energy demand.
You can check out the project documentation for more information about the project and also for results and conclusions.
I highly recommend to clone and run the repo on linux device because the installation of jax and related libs is easier, but with a bit of time you may end up running on other compatible platform. To start, just clone the repo:
git clone https://github.com/rsanchezmo/jaxer.git
cd jaxer
Create a python venv and source it:
python3 -m venv venv
source venv/bin/activate
Install your desired Jax version (more info at https://jax.readthedocs.io/en/latest/installation.html). You must notice that jax only provides 2 distributions with cuda (12.3 and 11.8). As this repo also depends on torch, you could install torch in cpu and jax in gpu, torch is only used for the dataloaders. However, I decided to better install torch+cuda11.8 and jax+cuda11.8. For example, if already installed CUDA 11.8 on Linux (make sure to have exported to PATH your CUDA version):
(venv) $ pip install --upgrade jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
(venv) $ pip install torch --index-url https://download.pytorch.org/whl/cu118
If cuda is not installed, you may use pip installation and let the pip install the right version for your gpu. Make sure that when doing nvcc --version, it says that you should install cuda toolkit (just remove if you have a release already reported on .bashrc:
(venv) $ pip install --upgrade jax[cuda11_pip] -f https://storage.googleapis.com/jax-releases/jax_releases.html
(venv) $ pip install torch --index-url https://download.pytorch.org/whl/cu118
Then install the rest of the dependencies (which are in the requirements.txt
file):
(venv) $ pip install .
You should first unzip the dataset in the dataset
folder.
unzip data.zip
The training of the model is made really easy, as simple as creating a trainer with the experiment configuration and calling the train_and_evaluate
method:
import jaxer
from training_config import config
if __name__ == '__main__':
trainer = jaxer.run.FlaxTrainer(config=config)
trainer.train_and_evaluate()
An example of the configuration file is in the training_config.py
file. Here is an example of the configuration file for the training of the model:
import jaxer
output_mode = 'mean' # 'mean' or 'distribution' or 'discrete_grid
seq_len = 100
d_model = 128
model_config = jaxer.config.ModelConfig(
d_model=d_model,
num_layers=4,
head_layers=2,
n_heads=2,
dim_feedforward=4 * d_model, # 4 * d_model
dropout=0.05,
max_seq_len=seq_len,
flatten_encoder_output=False,
fe_blocks=0, # feature extractor is incremental, for instance input_shape, 128/2, 128 (d_model)
use_time2vec=False,
output_mode=output_mode, # 'mean' or 'distribution' or 'discrete_grid'
use_resblocks_in_head=False,
use_resblocks_in_fe=True,
average_encoder_output=False,
norm_encoder_prev=True
)
dataset_config = jaxer.config.DatasetConfig(
datapath='./data/datasets/data/',
output_mode=output_mode, # 'mean' or 'distribution' or 'discrete_grid
discrete_grid_levels=[-9e6, 0.0, 9e6],
initial_date='2018-01-01',
norm_mode="window_minmax",
resolution='30m',
tickers=['btc_usd', 'eth_usd', 'sol_usd'],
indicators=None,
seq_len=seq_len
)
synthetic_dataset_config = jaxer.config.SyntheticDatasetConfig(
window_size=seq_len,
output_mode=output_mode, # 'mean' or 'distribution' or 'discrete_grid
normalizer_mode='window_minmax', # 'window_meanstd' or 'window_minmax' or 'window_mean'
add_noise=False,
min_amplitude=0.1,
max_amplitude=1.0,
min_frequency=0.5,
max_frequency=30,
num_sinusoids=5
)
pretrained_folder = "results/exp_synthetic_context"
pretrained_path_subfolder, pretrained_path_ckpt = jaxer.utils.get_best_model(pretrained_folder)
pretrained_model = (pretrained_folder, pretrained_path_subfolder, pretrained_path_ckpt)
config = jaxer.config.ExperimentConfig(
model_config=model_config,
pretrained_model=pretrained_model,
log_dir="results",
experiment_name="exp_both_pretrained_synthetic",
num_epochs=1000,
steps_per_epoch=500, # for synthetic dataset only
learning_rate=5e-4,
lr_mode='cosine', # 'cosine'
warmup_epochs=15,
dataset_mode='both', # 'real' or 'synthetic' or 'both'
dataset_config=dataset_config,
synthetic_dataset_config=synthetic_dataset_config,
batch_size=256,
test_split=0.1,
test_tickers=['btc_usd'],
seed=0,
save_weights=True,
early_stopper=100
)
An agent class has been created, so you can load a trained model and use it to predict any data you want:
import jaxer
from torch.utils.data import DataLoader
if __name__ == '__main__':
# load the agent with best model weights
experiment = "exp_1"
agent = jaxer.run.Agent(experiment=experiment, model_name=jaxer.utils.get_best_model(experiment))
# creater dataloaders
if agent.config.dataset_mode == 'synthetic':
dataset = jaxer.utils.SyntheticDataset(config=jaxer.config.SyntheticDatasetConfig.from_dict(agent.config.synthetic_dataset_config))
test_dataloader = dataset.generator(batch_size=1, seed=100)
train_dataloader = dataset.generator(batch_size=1, seed=200)
else:
dataset = jaxer.utils.Dataset(dataset_config=jaxer.config.DatasetConfig.from_dict(agent.config.dataset_config))
train_ds, test_ds = dataset.get_train_test_split(test_size=agent.config.test_split,
test_tickers=agent.config.test_tickers)
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=jaxer.utils.jax_collate_fn)
test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=jaxer.utils.jax_collate_fn)
infer_test = False
if infer_test:
dataloader = test_dataloader
else:
dataloader = train_dataloader
if agent.config.dataset_mode == 'synthetic':
for i in range(30):
x, y_true, normalizer, window_info = next(dataloader)
y_pred = agent(x)
jaxer.utils.plot_predictions(x=x, y_true=y_true, y_pred=y_pred, normalizer=normalizer,
window_info=window_info, denormalize_values=True)
else:
for batch in dataloader:
x, y_true, normalizer, window_info = batch
y_pred = agent(x)
jaxer.utils.plot_predictions(x=x, y_true=y_true, y_pred=y_pred, normalizer=normalizer,
window_info=window_info[0], denormalize_values=True)
Rodrigo Sánchez Molina
- Email: [email protected]
- Linkedin: rsanchezm98
- Github: rsanchezmo
If you find Jaxer useful, please consider citing:
@misc{2024rsanchezmo,
title = {Jaxer: Jax and Flax Time Series Prediction Transformer},
author = {Rodrigo Sánchez Molina},
year = {2024},
howpublished = {https://github.com/rsanchezmo/jaxer}
}