From e4e58e2b5fe0068ce39a762d317cab870eb04ef9 Mon Sep 17 00:00:00 2001 From: Alican Bozkurt Date: Wed, 31 Jan 2018 09:38:32 -0500 Subject: [PATCH 1/8] add cdf and icdf to normal --- test/test_distributions.py | 22 ++++++++++++++++++++++ torch/distributions/distribution.py | 20 ++++++++++++++++++++ torch/distributions/normal.py | 8 ++++++++ 3 files changed, 50 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index 8cbfe66419a94b..a2fd5e2f091e40 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2309,6 +2309,28 @@ def test_variance_stddev(self): self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist) self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist) + def test_cdf(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + for pytorch_dist, scipy_dist in self.distribution_pairs: + samples = pytorch_dist.sample((5,)) + try: + self.assertEqual(pytorch_dist.cdf(samples), + scipy_dist.cdf(samples), + message=pytorch_dist) + except NotImplementedError: + pass + + def test_icdf(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + for pytorch_dist, scipy_dist in self.distribution_pairs: + samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) + try: + self.assertEqual(pytorch_dist.icdf(samples), + scipy_dist.ppf(samples), + message=pytorch_dist) + except NotImplementedError: + pass + class TestTransforms(TestCase): def setUp(self): diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 4201def601bd42..ffa6dfbfe2607c 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -101,6 +101,26 @@ def log_prob(self, value): """ raise NotImplementedError + def cdf(self, value): + """ + Returns the cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor or Variable): + """ + raise NotImplementedError + + def icdf(self, value): + """ + Returns the inverse cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor or Variable): + """ + raise NotImplementedError + def enumerate_support(self): """ Returns tensor containing all values supported by a discrete diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 1ddb10e86f8da3..53d7b2ebcd5d18 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -64,5 +64,13 @@ def log_prob(self, value): log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log() return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + def cdf(self, value): + self._validate_log_prob_arg(value) + return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))) + + def icdf(self, value): + self._validate_log_prob_arg(value) + return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) + def entropy(self): return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) From d9263a69a144cc06d512e28270ed46a7abf551db Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sat, 3 Feb 2018 10:23:14 +0530 Subject: [PATCH 2/8] New CDF and ICDF implementations 1. Cauchy 2. Exponential 3. Laplace (Only CDF) 4. Pareto --- torch/distributions/cauchy.py | 8 ++++++++ torch/distributions/exponential.py | 8 ++++++++ torch/distributions/laplace.py | 8 ++++++++ torch/distributions/pareto.py | 8 ++++++++ 4 files changed, 32 insertions(+) diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 6a3600b637df9e..b3b7f06e4fdfb8 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -53,5 +53,13 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return -math.log(math.pi) - self.scale.log() - (1 + ((value - self.loc) / self.scale)**2).log() + def cdf(self, value): + self._validate_log_prob_arg(value) + return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5 + + def icdf(self, value): + self._validate_log_prob_arg(value) + return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc + def entropy(self): return math.log(4 * math.pi) + self.scale.log() diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index d22cfc629a52cb..06a5b33c9aea52 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -49,5 +49,13 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return self.rate.log() - self.rate * value + def cdf(self, value): + self._validate_log_prob_arg(value) + return 1 - torch.exp(-self.rate * value) + + def icdf(self, value): + self._validate_log_prob_arg(value) + return -torch.log(1 - value) / self.rate + def entropy(self): return 1.0 - torch.log(self.rate) diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index ea96dbbfb722b3..398bc141a5201f 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -55,5 +55,13 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale + def cdf(self, value): + self._validate_log_prob_arg(value) + term = torch.exp((value - self.loc) / self.scale) + result = value.new() + result[value < self.loc] = 0.5 * term + result[value >= self.loc] = 1 - 0.5 / term + return result + def entropy(self): return 1 + torch.log(2 * self.scale) diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 15854b9da0636e..3a466f1f8d34fe 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -59,5 +59,13 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return torch.log(self.alpha / value) + self.alpha * (self.scale / value).log() + def cdf(self, value): + self._validate_log_prob_arg(value) + return 1 - (self.scale / value).pow(self.alpha) + + def icdf(self, value): + self._validate_log_prob_arg(value) + return self.scale / (1 - value).pow(self.alpha.reciprocal()) + def entropy(self): return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())) From ee55d13606655d04e40f8fc2c1614e927e64eaca Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Mon, 5 Feb 2018 00:09:27 +0530 Subject: [PATCH 3/8] Major: Add .cdf and .icdf methods for TransformedDistributions Minor: 1. Convert Pareto and Gumbel to TransformedDistribution 2. Add .cdf and .icdf for Uniform 3. Temporarily remove .cdf from Laplace --- test/test_distributions.py | 36 +++++++++++-------- torch/distributions/gumbel.py | 24 +++++-------- torch/distributions/laplace.py | 8 ----- torch/distributions/pareto.py | 28 ++++----------- .../distributions/transformed_distribution.py | 21 +++++++++++ torch/distributions/uniform.py | 10 ++++++ 6 files changed, 68 insertions(+), 59 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index a2fd5e2f091e40..7a82852466faef 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -939,10 +939,10 @@ def test_gamma_sample(self): def test_pareto(self): scale = Variable(torch.randn(2, 3).abs(), requires_grad=True) alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True) - scale_1d = torch.randn(1).abs() - alpha_1d = torch.randn(1).abs() - self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).mean, float('inf'), allow_inf=True) - self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).variance, float('inf'), allow_inf=True) + scale_1d = Variable(torch.randn(1).abs(), requires_grad=True) + alpha_1d = Variable(torch.randn(1).abs(), requires_grad=True) + self.assertEqual(Pareto(scale_1d, 0.5).mean, float('inf'), allow_inf=True) + self.assertEqual(Pareto(scale_1d, 0.5).variance, float('inf'), allow_inf=True) self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3)) self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3)) self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1)) @@ -970,8 +970,8 @@ def test_pareto_sample(self): def test_gumbel(self): loc = Variable(torch.randn(2, 3), requires_grad=True) scale = Variable(torch.randn(2, 3).abs(), requires_grad=True) - loc_1d = torch.randn(1) - scale_1d = torch.randn(1).abs() + loc_1d = Variable(torch.randn(1), requires_grad=True) + scale_1d = Variable(torch.randn(1).abs(), requires_grad=True) self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3)) self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3)) self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,)) @@ -2233,6 +2233,10 @@ def setUp(self): Binomial(10, simplex_tensor), scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor) ), + ( + Cauchy(random_var, positive_var), + scipy.stats.cauchy(loc=random_var, scale=positive_var) + ), ( Dirichlet(positive_var), scipy.stats.dirichlet(positive_var) @@ -2298,10 +2302,14 @@ def setUp(self): def test_mean(self): for pytorch_dist, scipy_dist in self.distribution_pairs: + if isinstance(pytorch_dist, Cauchy): + continue self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist) def test_variance_stddev(self): for pytorch_dist, scipy_dist in self.distribution_pairs: + if isinstance(pytorch_dist, Cauchy): + continue if isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist) self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist) @@ -2314,22 +2322,22 @@ def test_cdf(self): for pytorch_dist, scipy_dist in self.distribution_pairs: samples = pytorch_dist.sample((5,)) try: - self.assertEqual(pytorch_dist.cdf(samples), - scipy_dist.cdf(samples), - message=pytorch_dist) + cdf = pytorch_dist.cdf(samples) except NotImplementedError: - pass + continue + print("Testing {}.cdf()".format(type(pytorch_dist).__name__)) + self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist) def test_icdf(self): set_rng_seed(0) # see Note [Randomized statistical tests] for pytorch_dist, scipy_dist in self.distribution_pairs: samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) try: - self.assertEqual(pytorch_dist.icdf(samples), - scipy_dist.ppf(samples), - message=pytorch_dist) + icdf = pytorch_dist.icdf(samples) except NotImplementedError: - pass + continue + print("Testing {}.icdf()".format(type(pytorch_dist).__name__)) + self.assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist) class TestTransforms(TestCase): diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index c4f8fcc03d617b..4e0f950ee4dfe4 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -2,13 +2,15 @@ import math import torch from torch.distributions import constraints -from torch.distributions.distribution import Distribution +from torch.distributions.uniform import Uniform +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import _finfo, broadcast_all euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant -class Gumbel(Distribution): +class Gumbel(TransformedDistribution): r""" Samples from a Gumbel Distribution. @@ -23,7 +25,6 @@ class Gumbel(Distribution): loc (float or Tensor or Variable): Location parameter of the distribution scale (float or Tensor or Variable): Scale parameter of the distribution """ - has_rsample = True params = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.real @@ -33,19 +34,10 @@ def __init__(self, loc, scale): batch_shape = torch.Size() else: batch_shape = self.scale.size() - super(Gumbel, self).__init__(batch_shape) - - def rsample(self, sample_shape=torch.Size()): - shape = self._extended_shape(sample_shape) - uni_dist = self.scale.new(shape).uniform_(_finfo(self.scale).eps, 1) - # X ~ Uniform(0, 1) - # Y = loc - scale * ln (-ln (X)) ~ Gumbel(loc, scale) - return self.loc - self.scale * torch.log(-uni_dist.log()) - - def log_prob(self, value): - self._validate_log_prob_arg(value) - z = (value - self.loc) / self.scale - return -(self.scale.log() + z + torch.exp(-z)) + base_dist = Uniform(torch.zeros_like(self.loc), 1) + transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-1), + ExpTransform().inv, AffineTransform(loc=loc, scale=-scale)] + super(Gumbel, self).__init__(base_dist, transforms) @property def mean(self): diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 398bc141a5201f..ea96dbbfb722b3 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -55,13 +55,5 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale - def cdf(self, value): - self._validate_log_prob_arg(value) - term = torch.exp((value - self.loc) / self.scale) - result = value.new() - result[value < self.loc] = 0.5 * term - result[value >= self.loc] = 1 - 0.5 / term - return result - def entropy(self): return 1 + torch.log(2 * self.scale) diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 3a466f1f8d34fe..108d1ba2729a25 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -4,11 +4,13 @@ import torch from torch.distributions import constraints -from torch.distributions.distribution import Distribution +from torch.distributions.exponential import Exponential +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import broadcast_all -class Pareto(Distribution): +class Pareto(TransformedDistribution): r""" Samples from a Pareto Type 1 distribution. @@ -23,7 +25,6 @@ class Pareto(Distribution): scale (float or Tensor or Variable): Scale parameter of the distribution alpha (float or Tensor or Variable): Shape parameter of the distribution """ - has_rsample = True params = {'alpha': constraints.positive, 'scale': constraints.positive} def __init__(self, scale, alpha): @@ -32,7 +33,9 @@ def __init__(self, scale, alpha): batch_shape = torch.Size() else: batch_shape = self.scale.size() - super(Pareto, self).__init__(batch_shape) + base_dist = Exponential(alpha) + transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)] + super(Pareto, self).__init__(base_dist, transforms) @property def mean(self): @@ -50,22 +53,5 @@ def variance(self): def support(self): return constraints.greater_than(self.scale) - def rsample(self, sample_shape=torch.Size()): - shape = self._extended_shape(sample_shape) - exp_dist = self.alpha.new(shape).exponential_() - return self.scale * torch.exp(exp_dist / self.alpha) - - def log_prob(self, value): - self._validate_log_prob_arg(value) - return torch.log(self.alpha / value) + self.alpha * (self.scale / value).log() - - def cdf(self, value): - self._validate_log_prob_arg(value) - return 1 - (self.scale / value).pow(self.alpha) - - def icdf(self, value): - self._validate_log_prob_arg(value) - return self.scale / (1 - value).pow(self.alpha.reciprocal()) - def entropy(self): return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 2db3b568c21880..ccc0ba4c5f2d14 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -85,3 +85,24 @@ def log_prob(self, value): log_prob += _sum_rightmost(self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)) return log_prob + + def cdf(self, value): + """ + Computes the cumulative distribution function by inverting the transform(s) and computing + the score of the base distribution + """ + self.base_dist._validate_log_prob_arg(value) + for transform in self.transforms[::-1]: + value = transform.inv(value) + return self.base_dist.cdf(value) + + def icdf(self, value): + """ + Computes the inverse cumulative distribution function using transform(s) and computing + the score of the base distribution + """ + self.base_dist._validate_log_prob_arg(value) + value = self.base_dist.icdf(value) + for transform in self.transforms: + value = transform(value) + return value diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 1d233a3ad5d644..9750511af02ec3 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -63,5 +63,15 @@ def log_prob(self, value): ub = value.lt(self.high).type_as(self.low) return torch.log(lb.mul(ub)) - torch.log(self.high - self.low) + def cdf(self, value): + self._validate_log_prob_arg(value) + result = (value - self.low) / (self.high - self.low) + return result + + def icdf(self, value): + self._validate_log_prob_arg(value) + result = value * (self.high - self.low) + self.low + return result + def entropy(self): return torch.log(self.high - self.low) From f76114a25a4e6c894d86dd933bb16caf248782c3 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Mon, 5 Feb 2018 11:44:19 +0530 Subject: [PATCH 4/8] Add SciPy / NumPy independent tests for cdf and icdf invertibility --- test/test_distributions.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index 7a82852466faef..2a711eab8b47e1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1158,6 +1158,20 @@ def test_beta_sample(self): x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0] self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x)) + def test_cdf_icdf_inverse(self): + for Dist, params in EXAMPLES: + for i, param in enumerate(params): + dist = Dist(**param) + samples = dist.sample(sample_shape=(20,)) + try: + cdf = dist.cdf(samples) + actual = dist.icdf(cdf) + except NotImplementedError: + continue + self.assertEqual(actual, samples, + message='{} example {}/{},\ + icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params))) + def test_valid_parameter_broadcasting(self): # Test correct broadcasting of parameter sizes for distributions that have multiple # parameters. From d0bc72c6bd1176f6a620b1b91612f06ce434076c Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Mon, 5 Feb 2018 12:39:14 +0530 Subject: [PATCH 5/8] Addition of NumPy / SciPy independent tests for .log_prob using derivative of .cdf --- test/test_distributions.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 2a711eab8b47e1..b1deeb20127cbb 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -181,8 +181,8 @@ def is_all_nan(tensor): 'scale': Variable(torch.randn(1).abs(), requires_grad=True), }, { - 'loc': torch.Tensor([1.0, 0.0]), - 'scale': torch.Tensor([1e-5, 1e-5]), + 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), + 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), }, ]), Example(LogNormal, [ @@ -195,8 +195,8 @@ def is_all_nan(tensor): 'scale': Variable(torch.randn(1).abs(), requires_grad=True), }, { - 'loc': torch.Tensor([1.0, 0.0]), - 'scale': torch.Tensor([1e-5, 1e-5]), + 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), + 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), }, ]), Example(Normal, [ @@ -209,8 +209,8 @@ def is_all_nan(tensor): 'scale': Variable(torch.randn(1).abs(), requires_grad=True), }, { - 'loc': torch.Tensor([1.0, 0.0]), - 'scale': torch.Tensor([1e-5, 1e-5]), + 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), + 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), }, ]), Example(OneHotCategorical, [ @@ -270,8 +270,8 @@ def is_all_nan(tensor): 'high': Variable(torch.ones(1), requires_grad=True), }, { - 'low': torch.Tensor([1.0, 1.0]), - 'high': torch.Tensor([2.0, 3.0]), + 'low': variable([1.0, 1.0]), + 'high': variable([2.0, 3.0]), }, ]), ] @@ -1159,6 +1159,7 @@ def test_beta_sample(self): self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x)) def test_cdf_icdf_inverse(self): + # Tests the invertibility property on the distributions for Dist, params in EXAMPLES: for i, param in enumerate(params): dist = Dist(**param) @@ -1172,6 +1173,24 @@ def test_cdf_icdf_inverse(self): message='{} example {}/{},\ icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params))) + def test_cdf_log_prob(self): + # Tests if the differentiation of the CDF gives the PDF at a given value + for Dist, params in EXAMPLES: + for i, param in enumerate(params): + dist = Dist(**param) + samples = dist.sample(sample_shape=(20,)) + if not samples.requires_grad: + continue + try: + cdfs = dist.cdf(samples) + pdfs = dist.log_prob(samples).exp() + except NotImplementedError: + continue + cdfs_derivative = grad(cdfs.sum(), [samples])[0] + self.assertEqual(cdfs_derivative, pdfs, + message='{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, + len(params))) + def test_valid_parameter_broadcasting(self): # Test correct broadcasting of parameter sizes for distributions that have multiple # parameters. @@ -2339,7 +2358,6 @@ def test_cdf(self): cdf = pytorch_dist.cdf(samples) except NotImplementedError: continue - print("Testing {}.cdf()".format(type(pytorch_dist).__name__)) self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist) def test_icdf(self): @@ -2350,7 +2368,6 @@ def test_icdf(self): icdf = pytorch_dist.icdf(samples) except NotImplementedError: continue - print("Testing {}.icdf()".format(type(pytorch_dist).__name__)) self.assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist) From a9d74ce69be1cbc99353d2a421b7b36faeb99c1f Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Tue, 6 Feb 2018 09:50:13 +0530 Subject: [PATCH 6/8] Bug fixes 1. Fix the size issue with Gumbel as a transformed distribution 2. Add the scalar params test for Gumbel --- test/test_distributions.py | 22 ++++++++++++++----- torch/distributions/gumbel.py | 8 ++++--- .../distributions/transformed_distribution.py | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index adc2145c9e9a91..a8542b0a4278df 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -51,7 +51,7 @@ SigmoidTransform, StickBreakingTransform, identity_transform) -from torch.distributions.utils import _finfo, probs_to_logits +from torch.distributions.utils import _finfo, probs_to_logits, softmax TEST_NUMPY = True try: @@ -1173,8 +1173,8 @@ def test_cdf_icdf_inverse(self): except NotImplementedError: continue self.assertEqual(actual, samples, - message='{} example {}/{},\ - icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params))) + message='{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, + len(params))) def test_cdf_log_prob(self): # Tests if the differentiation of the CDF gives the PDF at a given value @@ -1774,6 +1774,16 @@ def test_pareto_shape_scalar_params(self): self.assertEqual(pareto.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(pareto.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_gumbel_shape_scalar_params(self): + gumbel = Gumbel(1, 1) + self.assertEqual(gumbel._batch_shape, torch.Size()) + self.assertEqual(gumbel._event_shape, torch.Size()) + self.assertEqual(gumbel.sample().size(), torch.Size(SCALAR_SHAPE)) + self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2))) + self.assertRaises(ValueError, gumbel.log_prob, self.scalar_sample) + self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) + self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_normal_shape_scalar_params(self): normal = Normal(0, 1) self.assertEqual(normal._batch_shape, torch.Size()) @@ -2312,7 +2322,7 @@ def setUp(self): positive_var2 = Variable(torch.Tensor(20,).normal_()).exp() random_var = Variable(torch.Tensor(20,).normal_()) random_tensor = torch.Tensor(20,).normal_() - simplex_tensor = random_tensor.exp() / random_tensor.exp().sum() + simplex_tensor = softmax(random_tensor) self.distribution_pairs = [ ( Bernoulli(simplex_tensor), @@ -2336,7 +2346,7 @@ def setUp(self): ), ( Exponential(positive_var), - scipy.stats.expon(scale=1. / positive_var) + scipy.stats.expon(scale=positive_var.reciprocal()) ), ( FisherSnedecor(positive_var, 4 + positive_var2), # var for df2<=4 is undefined @@ -2344,7 +2354,7 @@ def setUp(self): ), ( Gamma(positive_var, positive_var2), - scipy.stats.gamma(positive_var, scale=1 / positive_var2) + scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()) ), ( Geometric(simplex_tensor), diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 4e0f950ee4dfe4..8b8dc0ecee7f2e 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -30,13 +30,15 @@ class Gumbel(TransformedDistribution): def __init__(self, loc, scale): self.loc, self.scale = broadcast_all(loc, scale) + finfo = _finfo(self.loc) if isinstance(loc, Number) and isinstance(scale, Number): batch_shape = torch.Size() + base_dist = Uniform(finfo.tiny, 1 - finfo.eps) else: batch_shape = self.scale.size() - base_dist = Uniform(torch.zeros_like(self.loc), 1) - transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-1), - ExpTransform().inv, AffineTransform(loc=loc, scale=-scale)] + base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps) + transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), + ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)] super(Gumbel, self).__init__(base_dist, transforms) @property diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index ccc0ba4c5f2d14..2b463323d0f296 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -74,6 +74,7 @@ def log_prob(self, value): Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian """ + self.base_dist._validate_log_prob_arg(value) event_dim = len(self.event_shape) log_prob = 0.0 y = value @@ -101,7 +102,6 @@ def icdf(self, value): Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution """ - self.base_dist._validate_log_prob_arg(value) value = self.base_dist.icdf(value) for transform in self.transforms: value = transform(value) From 24c24616a917d34938866607fd937cf980e004f3 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Tue, 6 Feb 2018 21:15:29 +0530 Subject: [PATCH 7/8] Remove batch_shape calculation after making Pareto distribution a Transformed Distribution --- torch/distributions/pareto.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 108d1ba2729a25..56dd2a7ca443d5 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -29,10 +29,6 @@ class Pareto(TransformedDistribution): def __init__(self, scale, alpha): self.scale, self.alpha = broadcast_all(scale, alpha) - if isinstance(scale, Number) and isinstance(alpha, Number): - batch_shape = torch.Size() - else: - batch_shape = self.scale.size() base_dist = Exponential(alpha) transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)] super(Pareto, self).__init__(base_dist, transforms) From 2932ffe51f0c59eaa09c478bfc811be74c8a2403 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 6 Feb 2018 09:03:51 -0800 Subject: [PATCH 8/8] Fix out-of-range test parameters --- test/test_distributions.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index a8542b0a4278df..1cc58abcb519f8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -247,17 +247,17 @@ def is_all_nan(tensor): Example(TransformedDistribution, [ { 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), - Variable(torch.randn(2, 3), requires_grad=True)), + Variable(torch.randn(2, 3).abs(), requires_grad=True)), 'transforms': [], }, { 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), - Variable(torch.randn(2, 3), requires_grad=True)), + Variable(torch.randn(2, 3).abs(), requires_grad=True)), 'transforms': ExpTransform(), }, { 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), - Variable(torch.randn(2, 3), requires_grad=True)), + Variable(torch.randn(2, 3).abs(), requires_grad=True)), 'transforms': [AffineTransform(Variable(torch.randn(1)), Variable(torch.randn(1))), ExpTransform()], }, @@ -1172,16 +1172,20 @@ def test_cdf_icdf_inverse(self): actual = dist.icdf(cdf) except NotImplementedError: continue - self.assertEqual(actual, samples, - message='{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, - len(params))) + rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples)) + self.assertLess(rel_error.max(), 1e-4, msg='\n'.join([ + '{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)), + 'x = {}'.format(samples), + 'cdf(x) = {}'.format(cdf), + 'icdf(cdf(x)) = {}'.format(actual), + ])) def test_cdf_log_prob(self): # Tests if the differentiation of the CDF gives the PDF at a given value for Dist, params in EXAMPLES: for i, param in enumerate(params): dist = Dist(**param) - samples = dist.sample(sample_shape=(20,)) + samples = dist.sample() if not samples.requires_grad: continue try: @@ -1190,9 +1194,13 @@ def test_cdf_log_prob(self): except NotImplementedError: continue cdfs_derivative = grad(cdfs.sum(), [samples])[0] - self.assertEqual(cdfs_derivative, pdfs, - message='{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, - len(params))) + self.assertEqual(cdfs_derivative, pdfs, message='\n'.join([ + '{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)), + 'x = {}'.format(samples), + 'cdf = {}'.format(cdfs), + 'pdf = {}'.format(pdfs), + 'grad(cdf) = {}'.format(cdfs_derivative), + ])) def test_valid_parameter_broadcasting(self): # Test correct broadcasting of parameter sizes for distributions that have multiple