|
1 | 1 | {
|
2 | 2 | "cells": [
|
3 | 3 | {
|
| 4 | + "attachments": {}, |
4 | 5 | "cell_type": "markdown",
|
5 | 6 | "metadata": {},
|
6 | 7 | "source": [
|
|
20 | 21 | ]
|
21 | 22 | },
|
22 | 23 | {
|
| 24 | + "attachments": {}, |
23 | 25 | "cell_type": "markdown",
|
24 | 26 | "metadata": {},
|
25 | 27 | "source": [
|
|
55 | 57 | "from pytorch_lightning.callbacks import TQDMProgressBar\n",
|
56 | 58 | "from tensordict import TensorDict\n",
|
57 | 59 | "\n",
|
58 |
| - "from causica.distributions import (\n", |
59 |
| - " ContinuousNoiseDist,\n", |
60 |
| - " SEMDistributionModule,\n", |
61 |
| - ")\n", |
| 60 | + "from causica.distributions import ContinuousNoiseDist\n", |
62 | 61 | "from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule\n",
|
63 | 62 | "from causica.lightning.modules.deci_module import DECIModule\n",
|
64 |
| - "from causica.sem.distribution_parameters_sem import DistributionParametersSEM\n", |
| 63 | + "from causica.sem.sem_distribution import SEMDistributionModule\n", |
65 | 64 | "from causica.sem.structural_equation_model import ite\n",
|
66 | 65 | "from causica.training.auglag import AugLagLRConfig\n",
|
67 | 66 | "\n",
|
68 | 67 | "warnings.filterwarnings(\"ignore\")\n",
|
69 |
| - "%matplotlib inline" |
| 68 | + "test_run = bool(os.environ.get(\"TEST_RUN\", False)) # used by testing to run the notebook as a script" |
70 | 69 | ]
|
71 | 70 | },
|
72 | 71 | {
|
| 72 | + "attachments": {}, |
73 | 73 | "cell_type": "markdown",
|
74 | 74 | "metadata": {},
|
75 | 75 | "source": [
|
|
110 | 110 | ]
|
111 | 111 | },
|
112 | 112 | {
|
| 113 | + "attachments": {}, |
113 | 114 | "cell_type": "markdown",
|
114 | 115 | "metadata": {},
|
115 | 116 | "source": [
|
|
384 | 385 | ]
|
385 | 386 | },
|
386 | 387 | {
|
| 388 | + "attachments": {}, |
387 | 389 | "cell_type": "markdown",
|
388 | 390 | "metadata": {},
|
389 | 391 | "source": [
|
|
439 | 441 | "fig, axis = plt.subplots(1, 1, figsize=(8, 8))\n",
|
440 | 442 | "labels = {node: i for i, node in enumerate(true_adj.nodes)}\n",
|
441 | 443 | "\n",
|
442 |
| - "layout = nx.nx_agraph.graphviz_layout(true_adj, prog=\"dot\")\n", |
| 444 | + "try:\n", |
| 445 | + " layout = nx.nx_agraph.graphviz_layout(true_adj, prog=\"dot\")\n", |
| 446 | + "except (ModuleNotFoundError, ImportError):\n", |
| 447 | + " layout = nx.layout.spring_layout(true_adj)\n", |
| 448 | + "\n", |
443 | 449 | "for node, i in labels.items():\n",
|
444 | 450 | " axis.scatter(layout[node][0], layout[node][1], label=f\"{i}: {node}\")\n",
|
445 | 451 | "axis.legend()\n",
|
446 | 452 | "nx.draw_networkx(true_adj, pos=layout, with_labels=True, arrows=True, labels=labels, ax=axis)"
|
447 | 453 | ]
|
448 | 454 | },
|
449 | 455 | {
|
| 456 | + "attachments": {}, |
450 | 457 | "cell_type": "markdown",
|
451 | 458 | "metadata": {},
|
452 | 459 | "source": [
|
453 | 460 | "# Discover the Causal Graph"
|
454 | 461 | ]
|
455 | 462 | },
|
456 | 463 | {
|
| 464 | + "attachments": {}, |
457 | 465 | "cell_type": "markdown",
|
458 | 466 | "metadata": {},
|
459 | 467 | "source": [
|
|
478 | 486 | ]
|
479 | 487 | },
|
480 | 488 | {
|
| 489 | + "attachments": {}, |
481 | 490 | "cell_type": "markdown",
|
482 | 491 | "metadata": {},
|
483 | 492 | "source": [
|
|
497 | 506 | ]
|
498 | 507 | },
|
499 | 508 | {
|
| 509 | + "attachments": {}, |
500 | 510 | "cell_type": "markdown",
|
501 | 511 | "metadata": {},
|
502 | 512 | "source": [
|
|
523 | 533 | ]
|
524 | 534 | },
|
525 | 535 | {
|
| 536 | + "attachments": {}, |
526 | 537 | "cell_type": "markdown",
|
527 | 538 | "metadata": {},
|
528 | 539 | "source": [
|
|
542 | 553 | ]
|
543 | 554 | },
|
544 | 555 | {
|
| 556 | + "attachments": {}, |
545 | 557 | "cell_type": "markdown",
|
546 | 558 | "metadata": {},
|
547 | 559 | "source": [
|
|
553 | 565 | ]
|
554 | 566 | },
|
555 | 567 | {
|
| 568 | + "attachments": {}, |
556 | 569 | "cell_type": "markdown",
|
557 | 570 | "metadata": {},
|
558 | 571 | "source": [
|
|
596 | 609 | "\n",
|
597 | 610 | "trainer = pl.Trainer(\n",
|
598 | 611 | " accelerator=\"auto\",\n",
|
599 |
| - " max_epochs=int(os.environ.get(\"MAX_EPOCH\", 2000)), # used by testing to run the notebook as a script\n", |
| 612 | + " max_epochs=2000,\n", |
| 613 | + " fast_dev_run=test_run,\n", |
600 | 614 | " callbacks=[TQDMProgressBar(refresh_rate=19)],\n",
|
601 | 615 | " enable_checkpointing=False,\n",
|
602 | 616 | ")"
|
|
668 | 682 | ]
|
669 | 683 | },
|
670 | 684 | {
|
| 685 | + "attachments": {}, |
671 | 686 | "cell_type": "markdown",
|
672 | 687 | "metadata": {},
|
673 | 688 | "source": [
|
|
725 | 740 | ]
|
726 | 741 | },
|
727 | 742 | {
|
| 743 | + "attachments": {}, |
728 | 744 | "cell_type": "markdown",
|
729 | 745 | "metadata": {},
|
730 | 746 | "source": [
|
|
733 | 749 | ]
|
734 | 750 | },
|
735 | 751 | {
|
| 752 | + "attachments": {}, |
736 | 753 | "cell_type": "markdown",
|
737 | 754 | "metadata": {},
|
738 | 755 | "source": [
|
|
742 | 759 | ]
|
743 | 760 | },
|
744 | 761 | {
|
| 762 | + "attachments": {}, |
745 | 763 | "cell_type": "markdown",
|
746 | 764 | "metadata": {},
|
747 | 765 | "source": [
|
|
769 | 787 | ],
|
770 | 788 | "source": [
|
771 | 789 | "revenue_estimated_ate = {}\n",
|
772 |
| - "num_samples = 20000\n", |
| 790 | + "num_samples = 10 if test_run else 20000\n", |
773 | 791 | "sample_shape = torch.Size([num_samples])\n",
|
774 | 792 | "transform = data_module.normalizer.transform_modules[outcome]().inv\n",
|
775 | 793 | "\n",
|
|
791 | 809 | ]
|
792 | 810 | },
|
793 | 811 | {
|
| 812 | + "attachments": {}, |
794 | 813 | "cell_type": "markdown",
|
795 | 814 | "metadata": {},
|
796 | 815 | "source": [
|
|
825 | 844 | ]
|
826 | 845 | },
|
827 | 846 | {
|
| 847 | + "attachments": {}, |
828 | 848 | "cell_type": "markdown",
|
829 | 849 | "metadata": {},
|
830 | 850 | "source": [
|
|
857 | 877 | "source": [
|
858 | 878 | "revenue_estimated_ite = {}\n",
|
859 | 879 | "\n",
|
| 880 | + "base_noise = sem.sample_to_noise(data_module.dataset_train)\n", |
| 881 | + "\n", |
860 | 882 | "for treatment in treatment_columns:\n",
|
861 |
| - " base_noise = sem.sample_to_noise(data_module.dataset_train)\n", |
862 |
| - " intervention_a = TensorDict({treatment: torch.tensor([1.0])}, batch_size=tuple())\n", |
863 |
| - " do_a_cfs = transform(sem.do(interventions=intervention_a).noise_to_sample(base_noise)[outcome])\n", |
864 |
| - " intervention_b = TensorDict({treatment: torch.tensor([0.0])}, batch_size=tuple())\n", |
865 |
| - " do_b_cfs = transform(sem.do(interventions=intervention_b).noise_to_sample(base_noise)[outcome])\n", |
866 |
| - " revenue_estimated_ite[treatment] = (do_a_cfs - do_b_cfs).cpu().detach().numpy()[:, 0]\n", |
| 883 | + " do_sem = sem.do(interventions=TensorDict({treatment: torch.tensor([1.0])}, batch_size=tuple()))\n", |
| 884 | + " do_a_cfs = transform(do_sem.noise_to_sample(base_noise)[outcome]).cpu().detach().numpy()[:, 0]\n", |
| 885 | + " do_sem = sem.do(interventions=TensorDict({treatment: torch.tensor([0.0])}, batch_size=tuple()))\n", |
| 886 | + " do_b_cfs = transform(do_sem.noise_to_sample(base_noise)[outcome]).cpu().detach().numpy()[:, 0]\n", |
| 887 | + " revenue_estimated_ite[treatment] = do_a_cfs - do_b_cfs\n", |
867 | 888 | "\n",
|
868 | 889 | "revenue_estimated_ite"
|
869 | 890 | ]
|
|
0 commit comments