Skip to content

Commit f20772b

Browse files
authored
Release 0.3.6 (#76)
1 parent e97c5c8 commit f20772b

File tree

21 files changed

+1701
-864
lines changed

21 files changed

+1701
-864
lines changed

.github/workflows/pypi-publish.yml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
on:
2-
push:
3-
branches: [ "main" ]
4-
paths-ignore:
5-
- ".github/workflows/*"
6-
- ".devcontainer/*"
7-
- ".gitignore"
8-
- ".pre-commit-config.yaml"
2+
release:
3+
types: [released]
94
jobs:
105
pypi-publish:
116
name: Upload release to PyPI

poetry.lock

Lines changed: 678 additions & 757 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "causica"
3-
version = "0.3.5"
3+
version = "0.3.6"
44
description = ""
55
readme = "README.md"
66
authors = []

src/causica/data_generation/samplers/noise_dist_sampler.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from causica.distributions import JointNoiseModule
99
from causica.distributions.noise import NoiseModule, UnivariateNormalNoiseModule
1010
from causica.distributions.noise.bernoulli import BernoulliNoiseModule
11+
from causica.distributions.noise.categorical import CategoricalNoiseModule
1112

1213

1314
class NoiseModuleSampler(Sampler[NoiseModule]):
@@ -64,5 +65,21 @@ def __init__(self, base_logits_dist: td.Distribution, dim: int = 1):
6465
def sample(
6566
self,
6667
) -> NoiseModule:
67-
base_logits = self.base_logits_dist.sample().item()
68+
base_logits = self.base_logits_dist.sample()
6869
return BernoulliNoiseModule(dim=self.dim, init_base_logits=base_logits)
70+
71+
72+
class CategoricalNoiseModuleSampler(NoiseModuleSampler):
73+
"""Sample a CategoricalNoiseModule, with num_classes classes. This does not actually sample but returns the noise."""
74+
75+
def __init__(self, base_logits_dist: td.Distribution | None, num_classes: int = 2):
76+
super().__init__()
77+
assert num_classes >= 2
78+
self.num_classes = num_classes
79+
self.base_logits_dist = base_logits_dist
80+
81+
def sample(
82+
self,
83+
) -> NoiseModule:
84+
init_base_logits = self.base_logits_dist.sample() if self.base_logits_dist else None
85+
return CategoricalNoiseModule(num_classes=self.num_classes, init_base_logits=init_base_logits)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from causica.datasets.causica_dataset_format.load import (
2+
CAUSICA_DATASETS_PATH,
3+
CounterfactualWithEffects,
4+
DataEnum,
5+
InterventionWithEffects,
6+
Variable,
7+
VariablesMetadata,
8+
get_group_idxs,
9+
get_group_names,
10+
get_group_variable_names,
11+
get_name_to_idx,
12+
load_data,
13+
tensordict_from_variables_metadata,
14+
tensordict_to_tensor,
15+
)
16+
from causica.datasets.causica_dataset_format.save import save_data, save_dataset

0 commit comments

Comments
 (0)