This project serves as an example ML / DL project for tabular data generation. It takes at the basis the approach of CTGAN and TVAE and improves the framework in follwoing context:
- The data transformer is saved and reused for the data, if not requested differently
- Framework rewritten in PyTorch Lightning -> more flexible and easily manageble
- Implemented the Tensorboard pluggin to visualize the training processes
An example dataset will be used for the data generation pipeline. The dataset is taken from one of the most popular kaggle challanges: Housing Prices competition. Download and put it into the folder data.
Create environment as follows:
conda env create -f environment.yml
After the environment is installed you can activate it:
conda activate gen-tab-env
You need to firstly setup the information in conf/generation_config.yaml considering the path tou your data.
Afterwards you can run a training of the variational autoencoder as follows
python -m scripts.train --model-type=[model]
while acceptable [model] is either vae (default) or gan. Or use the created setup for VSCode Debug (see the image below).
See --help for more information
Usage: python -m scripts.train [OPTIONS]
Options:
--model-type TEXT Type of model to train. Either 'vae' (default) or 'gan'.
--resume-from TEXT If you want to resume training from a checkpoint,
provide the path to the checkpoint.
--help Show this message and exit.
To run tesorboard visualizing your algorithms use Tensorboard Extension in VSCode or run
python -m tensorboard.main --logdir [logdir name] --samples_per_plugin "images=100"
If you want to specialize specific blocks of your architecture, including their histomgrams, weights and gradients please edit them in tb_settings.yaml. Add them to parameters_to_visualize.
The training workflow is implemented using PyTorch Lightning. Below is a description of the individual modules:
conf/conf/generation_config.yaml– Describes the model hyperparameters, the data path, and the checkpoint path.conf/tb_settings.yaml– Specifies the settings for logging with TensorBoard.
docs/– Contains important documents, such as the CTGAN License.scripts/scripts/train.py– Script for training the models.scripts/generate.py– Script for data generation. Set the path to the model, as well as the name of generated file inconf/generation_config.yaml.
src/src/[model]/– Implementation of the training procedures and architecture for each model.src/[model]/loop.py– PyTorch Lightning loop. Each loop inherits fromL.LightningModuleto leverage all the functionalities provided by Lightning for training and fromModelInterfacefor TensorBoard logging and checkpoint management.src/[model]/model/– Model-specific blocks and functions (e.g., Encoder, Decoder, Discriminator, Generator, etc.).
src/factory.py– Generalizes model creation using the factory design pattern. For more details, see this article.src/interface.py– Provides a model interface with general logging functions used across all models.src/visualizations.py– Containsmatplotlibvisualizations used in the repository.src/tools.py– Includes various utility functions used throughout the repository.
If you use this code in your research, consider citing the following abstract. This implementation of data generation forms the core of the data generation in this abstract.
@INPROCEEDINGS{10657766,
author={Lavronenko, K. and Paiva, L. Lopes de and Mueller, F. and Schulz, V. and Naunheim, S.},
booktitle={2024 IEEE Nuclear Science Symposium (NSS), Medical Imaging Conference (MIC) and Room Temperature Semiconductor Detector Conference (RTSD)},
title={Towards artificial data generation for accelerated PET detector ML-based timing calibration using GANs},
year={2024},
pages={1-2},
keywords={Temperature measurement;Semiconductor device measurement;Temperature distribution;Correlation;Semiconductor detectors;Supervised learning;Position measurement},
doi={10.1109/NSS/MIC/RTSD57108.2024.10657766}
}



