Skip to content

Commit 9fcb726

Browse files
authored
Release 0.3.3 (#49)
1 parent 8e9cc2d commit 9fcb726

38 files changed

+1191
-1050
lines changed

examples/__init__.py

Whitespace-only changes.

examples/csuite_example.ipynb

Lines changed: 229 additions & 229 deletions
Large diffs are not rendered by default.

examples/multi_investment_sales_attribution.ipynb

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"cells": [
33
{
4+
"attachments": {},
45
"cell_type": "markdown",
56
"metadata": {},
67
"source": [
@@ -20,6 +21,7 @@
2021
]
2122
},
2223
{
24+
"attachments": {},
2325
"cell_type": "markdown",
2426
"metadata": {},
2527
"source": [
@@ -55,21 +57,19 @@
5557
"from pytorch_lightning.callbacks import TQDMProgressBar\n",
5658
"from tensordict import TensorDict\n",
5759
"\n",
58-
"from causica.distributions import (\n",
59-
" ContinuousNoiseDist,\n",
60-
" SEMDistributionModule,\n",
61-
")\n",
60+
"from causica.distributions import ContinuousNoiseDist\n",
6261
"from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule\n",
6362
"from causica.lightning.modules.deci_module import DECIModule\n",
64-
"from causica.sem.distribution_parameters_sem import DistributionParametersSEM\n",
63+
"from causica.sem.sem_distribution import SEMDistributionModule\n",
6564
"from causica.sem.structural_equation_model import ite\n",
6665
"from causica.training.auglag import AugLagLRConfig\n",
6766
"\n",
6867
"warnings.filterwarnings(\"ignore\")\n",
69-
"%matplotlib inline"
68+
"test_run = bool(os.environ.get(\"TEST_RUN\", False)) # used by testing to run the notebook as a script"
7069
]
7170
},
7271
{
72+
"attachments": {},
7373
"cell_type": "markdown",
7474
"metadata": {},
7575
"source": [
@@ -110,6 +110,7 @@
110110
]
111111
},
112112
{
113+
"attachments": {},
113114
"cell_type": "markdown",
114115
"metadata": {},
115116
"source": [
@@ -384,6 +385,7 @@
384385
]
385386
},
386387
{
388+
"attachments": {},
387389
"cell_type": "markdown",
388390
"metadata": {},
389391
"source": [
@@ -439,21 +441,27 @@
439441
"fig, axis = plt.subplots(1, 1, figsize=(8, 8))\n",
440442
"labels = {node: i for i, node in enumerate(true_adj.nodes)}\n",
441443
"\n",
442-
"layout = nx.nx_agraph.graphviz_layout(true_adj, prog=\"dot\")\n",
444+
"try:\n",
445+
" layout = nx.nx_agraph.graphviz_layout(true_adj, prog=\"dot\")\n",
446+
"except (ModuleNotFoundError, ImportError):\n",
447+
" layout = nx.layout.spring_layout(true_adj)\n",
448+
"\n",
443449
"for node, i in labels.items():\n",
444450
" axis.scatter(layout[node][0], layout[node][1], label=f\"{i}: {node}\")\n",
445451
"axis.legend()\n",
446452
"nx.draw_networkx(true_adj, pos=layout, with_labels=True, arrows=True, labels=labels, ax=axis)"
447453
]
448454
},
449455
{
456+
"attachments": {},
450457
"cell_type": "markdown",
451458
"metadata": {},
452459
"source": [
453460
"# Discover the Causal Graph"
454461
]
455462
},
456463
{
464+
"attachments": {},
457465
"cell_type": "markdown",
458466
"metadata": {},
459467
"source": [
@@ -478,6 +486,7 @@
478486
]
479487
},
480488
{
489+
"attachments": {},
481490
"cell_type": "markdown",
482491
"metadata": {},
483492
"source": [
@@ -497,6 +506,7 @@
497506
]
498507
},
499508
{
509+
"attachments": {},
500510
"cell_type": "markdown",
501511
"metadata": {},
502512
"source": [
@@ -523,6 +533,7 @@
523533
]
524534
},
525535
{
536+
"attachments": {},
526537
"cell_type": "markdown",
527538
"metadata": {},
528539
"source": [
@@ -542,6 +553,7 @@
542553
]
543554
},
544555
{
556+
"attachments": {},
545557
"cell_type": "markdown",
546558
"metadata": {},
547559
"source": [
@@ -553,6 +565,7 @@
553565
]
554566
},
555567
{
568+
"attachments": {},
556569
"cell_type": "markdown",
557570
"metadata": {},
558571
"source": [
@@ -596,7 +609,8 @@
596609
"\n",
597610
"trainer = pl.Trainer(\n",
598611
" accelerator=\"auto\",\n",
599-
" max_epochs=int(os.environ.get(\"MAX_EPOCH\", 2000)), # used by testing to run the notebook as a script\n",
612+
" max_epochs=2000,\n",
613+
" fast_dev_run=test_run,\n",
600614
" callbacks=[TQDMProgressBar(refresh_rate=19)],\n",
601615
" enable_checkpointing=False,\n",
602616
")"
@@ -668,6 +682,7 @@
668682
]
669683
},
670684
{
685+
"attachments": {},
671686
"cell_type": "markdown",
672687
"metadata": {},
673688
"source": [
@@ -725,6 +740,7 @@
725740
]
726741
},
727742
{
743+
"attachments": {},
728744
"cell_type": "markdown",
729745
"metadata": {},
730746
"source": [
@@ -733,6 +749,7 @@
733749
]
734750
},
735751
{
752+
"attachments": {},
736753
"cell_type": "markdown",
737754
"metadata": {},
738755
"source": [
@@ -742,6 +759,7 @@
742759
]
743760
},
744761
{
762+
"attachments": {},
745763
"cell_type": "markdown",
746764
"metadata": {},
747765
"source": [
@@ -769,7 +787,7 @@
769787
],
770788
"source": [
771789
"revenue_estimated_ate = {}\n",
772-
"num_samples = 20000\n",
790+
"num_samples = 10 if test_run else 20000\n",
773791
"sample_shape = torch.Size([num_samples])\n",
774792
"transform = data_module.normalizer.transform_modules[outcome]().inv\n",
775793
"\n",
@@ -791,6 +809,7 @@
791809
]
792810
},
793811
{
812+
"attachments": {},
794813
"cell_type": "markdown",
795814
"metadata": {},
796815
"source": [
@@ -825,6 +844,7 @@
825844
]
826845
},
827846
{
847+
"attachments": {},
828848
"cell_type": "markdown",
829849
"metadata": {},
830850
"source": [
@@ -857,13 +877,14 @@
857877
"source": [
858878
"revenue_estimated_ite = {}\n",
859879
"\n",
880+
"base_noise = sem.sample_to_noise(data_module.dataset_train)\n",
881+
"\n",
860882
"for treatment in treatment_columns:\n",
861-
" base_noise = sem.sample_to_noise(data_module.dataset_train)\n",
862-
" intervention_a = TensorDict({treatment: torch.tensor([1.0])}, batch_size=tuple())\n",
863-
" do_a_cfs = transform(sem.do(interventions=intervention_a).noise_to_sample(base_noise)[outcome])\n",
864-
" intervention_b = TensorDict({treatment: torch.tensor([0.0])}, batch_size=tuple())\n",
865-
" do_b_cfs = transform(sem.do(interventions=intervention_b).noise_to_sample(base_noise)[outcome])\n",
866-
" revenue_estimated_ite[treatment] = (do_a_cfs - do_b_cfs).cpu().detach().numpy()[:, 0]\n",
883+
" do_sem = sem.do(interventions=TensorDict({treatment: torch.tensor([1.0])}, batch_size=tuple()))\n",
884+
" do_a_cfs = transform(do_sem.noise_to_sample(base_noise)[outcome]).cpu().detach().numpy()[:, 0]\n",
885+
" do_sem = sem.do(interventions=TensorDict({treatment: torch.tensor([0.0])}, batch_size=tuple()))\n",
886+
" do_b_cfs = transform(do_sem.noise_to_sample(base_noise)[outcome]).cpu().detach().numpy()[:, 0]\n",
887+
" revenue_estimated_ite[treatment] = do_a_cfs - do_b_cfs\n",
867888
"\n",
868889
"revenue_estimated_ite"
869890
]

0 commit comments

Comments
 (0)