Skip to content

Commit dbe16f8

Browse files
committed
Refactor standard_gamma and implement CUDA gamma sampling
1 parent 5a5afa5 commit dbe16f8

File tree

15 files changed

+145
-93
lines changed

15 files changed

+145
-93
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3968,26 +3968,6 @@
39683968
kwarg_only: True
39693969
- THTensor* self
39703970
]]
3971-
[[
3972-
name: _standard_gamma
3973-
types:
3974-
- floating_point
3975-
backends:
3976-
- CPU
3977-
return: argument 0
3978-
variants:
3979-
- method
3980-
- function
3981-
options:
3982-
- cname: standard_gamma
3983-
arguments:
3984-
- arg: THTensor* output
3985-
output: True
3986-
- arg: THGenerator* generator
3987-
default: nullptr
3988-
kwarg_only: True
3989-
- THTensor* self
3990-
]]
39913971
[[
39923972
name: tensor
39933973
return: THTensor*

aten/src/ATen/native/Distributions.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "ATen/CheckGenerator.h"
99
#include "ATen/Generator.h"
1010

11+
#include <ATen/native/Distributions.cuh>
12+
1113
#include "TH/THRandom.h"
1214

1315
namespace at {
@@ -155,6 +157,24 @@ namespace dist {
155157
return gen_->generator;
156158
}
157159

160+
template <typename scalar>
161+
struct GammaOp {
162+
static void apply(Tensor& ret, const Tensor& alpha, THGenerator *generator) {
163+
CPU_tensor_apply2<scalar, double>(ret, alpha,
164+
[generator](scalar& ret_val, const double& alpha){
165+
dist::baseSampler<float> standard_uniform([generator] () {
166+
return THRandom_standard_uniform(generator);
167+
});
168+
dist::baseSampler<float> standard_normal([generator] () {
169+
return THRandom_normal(generator, 0.0, 1.0);
170+
});
171+
auto sample = dist::sample_gamma<float>(alpha, standard_uniform, standard_normal);
172+
ret_val = std::max(std::numeric_limits<scalar>::min(), (scalar) sample);
173+
}
174+
);
175+
}
176+
};
177+
158178
template <typename scalar>
159179
struct PoissonOp {
160180
static int64_t sample_poisson(double lambda, THGenerator *generator) {
@@ -227,5 +247,12 @@ Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
227247
return ret;
228248
}
229249

250+
Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
251+
Tensor ret = alpha.type().zeros(alpha.sizes());
252+
auto alpha_ = alpha.toType(ScalarType::Double);
253+
dispatch_floating_types<void, dist::GammaOp>(ret.type(), "gamma", ret, alpha_, dist::get_generator(gen));
254+
return ret;
255+
}
256+
230257
} // at::native
231258
} // at
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "ATen/Config.h"
2+
#include <functional>
3+
#if AT_CUDA_ENABLED()
4+
#include <nvfunctional>
5+
#endif
6+
7+
namespace at {
8+
namespace native {
9+
namespace dist {
10+
11+
// this wraps sampling primitives to expose a common interface
12+
template<typename precision_t>
13+
struct baseSampler {
14+
#if AT_CUDA_ENABLED()
15+
nvstd::function<precision_t(void)> sampler;
16+
__device__ baseSampler(nvstd::function<precision_t(void)> sampler): sampler(sampler) {}
17+
__device__ precision_t sample() {
18+
return sampler();
19+
}
20+
#else
21+
std::function<precision_t(void)> sampler;
22+
baseSampler(std::function<precision_t(void)> sampler): sampler(sampler) {}
23+
precision_t sample() {
24+
return sampler();
25+
}
26+
#endif
27+
};
28+
29+
template<typename precision_t>
30+
#if AT_CUDA_ENABLED()
31+
__host__ __device__
32+
#endif
33+
precision_t sample_gamma(precision_t alpha, baseSampler<precision_t>& standard_uniform, baseSampler<precision_t>& standard_normal) {
34+
precision_t scale = 1.0;
35+
36+
// Boost alpha for higher acceptance probability.
37+
if (alpha < 1.0) {
38+
scale *= ::pow(1 - standard_uniform.sample(), 1.0 / alpha);
39+
alpha += 1.0;
40+
}
41+
42+
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
43+
// doi:10.1145/358407.358414
44+
const precision_t d = alpha - 1.0 / 3.0;
45+
const precision_t c = 1.0 / ::sqrt(9.0 * d);
46+
for (;;) {
47+
precision_t x, y;
48+
do {
49+
x = standard_normal.sample();
50+
y = 1.0 + c * x;
51+
} while (y <= 0);
52+
const precision_t v = y * y * y;
53+
const precision_t u = 1 - standard_uniform.sample();
54+
const precision_t xx = x * x;
55+
if (u < 1.0 - 0.0331 * xx * xx)
56+
return scale * d * v;
57+
if (::log(u) < 0.5 * xx + d * (1.0 - v + ::log(v)))
58+
return scale * d * v;
59+
}
60+
}
61+
} // dist
62+
} // native
63+
} // at

aten/src/ATen/native/cuda/Distributions.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#include <curand_kernel.h>
66
#include <curand_philox4x32_x.h>
77
#include <utility>
8+
#include <functional>
9+
#include <nvfunctional>
10+
11+
#include "ATen/native/Distributions.cuh"
812

913
#include <TH/THAtomic.h>
1014

@@ -26,6 +30,26 @@ namespace dist {
2630
return std::make_pair(gen_->initial_seed, offset);
2731
}
2832

33+
template <typename scalar>
34+
struct GammaOpCUDA {
35+
static void apply(Tensor& ret, const Tensor& alpha, std::pair<uint64_t, uint64_t> seeds) {
36+
at::cuda::CUDA_tensor_apply2<scalar, float>(ret, alpha,
37+
[seeds] __device__ (scalar& ret_val, const float& alpha, bool early_exit) {
38+
curandStatePhilox4_32_10_t state;
39+
curand_init(seeds.first, blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state);
40+
baseSampler<float> standard_uniform([&state] __device__ () {
41+
return curand_uniform(&state);
42+
});
43+
baseSampler<float> standard_normal([&state] __device__ () {
44+
return curand_normal(&state);
45+
});
46+
auto sample = scalar_cast<scalar>(sample_gamma<float>(alpha, standard_uniform, standard_normal));
47+
ret_val = ::max(THCNumerics<scalar>::min(), (scalar) sample);
48+
}
49+
);
50+
}
51+
};
52+
2953
template <typename scalar>
3054
struct PoissonOpCUDA {
3155
static void apply(Tensor& ret, const Tensor& lambda, std::pair<uint64_t, uint64_t> seeds) {
@@ -48,5 +72,12 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
4872
return ret;
4973
}
5074

75+
Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
76+
Tensor ret = alpha.type().tensor(alpha.sizes());
77+
auto alpha_ = alpha.toType(ScalarType::Float);
78+
dispatch_floating_types<void, dist::GammaOpCUDA>(ret.type(), "gamma", ret, alpha_, dist::next_philox_seed(gen));
79+
return ret;
80+
}
81+
5182
} // at::native
5283
} // at

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,9 @@
318318
dispatch:
319319
CPU: _s_poisson_cpu
320320
CUDA: _s_poisson_cuda
321+
322+
- func: standard_gamma(Tensor self, Generator* generator=nullptr) -> Tensor
323+
variants: function
324+
dispatch:
325+
CPU: _s_gamma_cpu
326+
CUDA: _s_gamma_cuda

aten/src/TH/THRandom.c

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -281,35 +281,6 @@ double THRandom_exponential(THGenerator *_generator, double lambda)
281281
return(-1. / lambda * log(1-uniform_double(_generator)));
282282
}
283283

284-
double THRandom_standard_gamma(THGenerator *_generator, double alpha) {
285-
double scale = 1.0;
286-
287-
// Boost alpha for higher acceptance probability.
288-
if(alpha < 1.0) {
289-
scale *= pow(1 - uniform_double(_generator), 1.0 / alpha);
290-
alpha += 1.0;
291-
}
292-
293-
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
294-
// doi:10.1145/358407.358414
295-
const double d = alpha - 1.0 / 3.0;
296-
const double c = 1.0 / sqrt(9.0 * d);
297-
for(;;) {
298-
double x, y;
299-
do {
300-
x = THRandom_normal(_generator, 0.0, 1.0);
301-
y = 1.0 + c * x;
302-
} while(y <= 0);
303-
const double v = y * y * y;
304-
const double u = 1 - uniform_double(_generator);
305-
const double xx = x * x;
306-
if(u < 1.0 - 0.0331 * xx * xx)
307-
return scale * d * v;
308-
if(log(u) < 0.5 * xx + d * (1.0 - v + log(v)))
309-
return scale * d * v;
310-
}
311-
}
312-
313284
double THRandom_cauchy(THGenerator *_generator, double median, double sigma)
314285
{
315286
return(median + sigma * tan(M_PI*(uniform_double(_generator)-0.5)));

aten/src/TH/THRandom.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ TH_API double THRandom_normal(THGenerator *_generator, double mean, double stdv)
6868
*/
6969
TH_API double THRandom_exponential(THGenerator *_generator, double lambda);
7070

71-
/** Generates a random number from a standard Gamma distribution.
72-
The Gamma density is proportional to $x^{alpha-1} exp(-x)$
73-
The shape parameter alpha (a.k.a. k) is a positive real number.
74-
*/
75-
TH_API double THRandom_standard_gamma(THGenerator *_generator, double alpha);
76-
7771
/** Returns a random number from a Cauchy distribution.
7872
The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$
7973
*/

aten/src/TH/generic/THTensorRandom.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,6 @@ void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lamb
128128
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(_generator, lambda););
129129
}
130130

131-
void THTensor_(standard_gamma)(THTensor *self, THGenerator *gen, THTensor *alpha)
132-
{
133-
THTensor_(resizeAs)(self, alpha);
134-
TH_TENSOR_APPLY2(real, self, real, alpha, {
135-
const real sample = THRandom_standard_gamma(gen, *alpha_data);
136-
*self_data = sample > 0 ? sample : TH_REAL_MIN;
137-
});
138-
}
139-
140131
#undef TH_REAL_MIN
141132

142133
void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma)

aten/src/TH/generic/THTensorRandom.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ TH_API void THTensor_(normal_means)(THTensor *self, THGenerator *gen, THTensor *
1818
TH_API void THTensor_(normal_stddevs)(THTensor *self, THGenerator *gen, double mean, THTensor *stddevs);
1919
TH_API void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, THTensor *means, THTensor *stddevs);
2020
TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda);
21-
TH_API void THTensor_(standard_gamma)(THTensor *self, THGenerator *_generator, THTensor *alpha);
2221
TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma);
2322
TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv);
2423
TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement);

test/test_distributions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def test_poisson_sample(self):
599599
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
600600
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
601601
def test_poisson_gpu_sample(self):
602-
set_rng_seed(0)
602+
set_rng_seed(1)
603603
for rate in [0.12, 0.9, 4.0]:
604604
self._check_sampler_discrete(Poisson(torch.Tensor([rate]).cuda()),
605605
scipy.stats.poisson(rate),
@@ -832,6 +832,17 @@ def test_gamma_sample(self):
832832
scipy.stats.gamma(alpha, scale=1.0 / beta),
833833
'Gamma(concentration={}, rate={})'.format(alpha, beta))
834834

835+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
836+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
837+
def test_gamma_gpu_sample(self):
838+
set_rng_seed(0)
839+
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
840+
a, b = torch.Tensor([alpha]).cuda(), torch.Tensor([beta]).cuda()
841+
self._check_sampler_sampler(Gamma(a, b),
842+
scipy.stats.gamma(alpha, scale=1.0 / beta),
843+
'Gamma(alpha={}, beta={})'.format(alpha, beta),
844+
failure_rate=1e-4)
845+
835846
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
836847
def test_pareto(self):
837848
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,8 @@
644644
self: not_implemented("_sparse_mask")
645645
mask: not_implemented("_sparse_mask")
646646

647-
- name: _standard_gamma(Tensor self, Generator generator)
648-
self: grad * self._standard_gamma_grad(output)
647+
- name: standard_gamma(Tensor self, Generator generator)
648+
self: grad * self._standard_gamma_grad(result)
649649

650650
- name: _standard_gamma_grad(Tensor self, Tensor output)
651651
self: not_implemented("_standard_gamma_grad")

torch/csrc/Module.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ IMPLEMENT_STATELESS(bmm)
298298
// TODO: this doesn't implement options that return numbers!
299299
IMPLEMENT_STATELESS(multinomial)
300300
IMPLEMENT_STATELESS(normal)
301-
IMPLEMENT_STATELESS(_standard_gamma)
302301
IMPLEMENT_STATELESS(_dirichlet_grad)
303302
IMPLEMENT_STATELESS(bernoulli)
304303
IMPLEMENT_STATELESS(range)
@@ -719,7 +718,6 @@ static PyMethodDef TorchMethods[] = {
719718
{"bmm", (PyCFunction)THPModule_bmm, METH_VARARGS | METH_KEYWORDS, NULL},
720719
{"multinomial", (PyCFunction)THPModule_multinomial, METH_VARARGS | METH_KEYWORDS, NULL},
721720
{"normal", (PyCFunction)THPModule_normal, METH_VARARGS | METH_KEYWORDS, NULL},
722-
{"_standard_gamma", (PyCFunction)THPModule__standard_gamma, METH_VARARGS | METH_KEYWORDS, NULL},
723721
{"_dirichlet_grad", (PyCFunction)THPModule__dirichlet_grad, METH_VARARGS | METH_KEYWORDS, NULL},
724722
{"bernoulli", (PyCFunction)THPModule_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},
725723
{"rand", (PyCFunction)THPModule_rand, METH_VARARGS | METH_KEYWORDS, NULL},

torch/csrc/generic/methods/TensorRandom.cwrap

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,26 +210,6 @@
210210
default: 1
211211
]]
212212

213-
[[
214-
name: _standard_gamma
215-
types:
216-
- floating_point
217-
backends:
218-
- CPU
219-
return: argument 0
220-
variants:
221-
- function
222-
options:
223-
- cname: standard_gamma
224-
arguments:
225-
- arg: THTensor* output
226-
output: True
227-
- arg: THGenerator* generator
228-
default: THPGenerator_TH_CData(THPDefaultGenerator)
229-
kwarg_only: True
230-
- THTensor* alpha
231-
]]
232-
233213
[[
234214
name: _dirichlet_grad
235215
types:

torch/distributions/dirichlet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from torch.autograd.function import once_differentiable
66
from torch.distributions import constraints
77
from torch.distributions.distribution import Distribution
8+
from torch.distributions.gamma import _standard_gamma
89
from torch.distributions.utils import _finfo, broadcast_all
910

1011

1112
def _dirichlet_sample_nograd(concentration):
12-
probs = torch._C._standard_gamma(concentration)
13+
probs = _standard_gamma(concentration)
1314
probs /= probs.sum(-1, True)
1415
eps = _finfo(probs).eps
1516
return probs.clamp_(min=eps, max=1 - eps)

torch/distributions/gamma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
def _standard_gamma(concentration):
1212
if not isinstance(concentration, Variable):
13-
return torch._C._standard_gamma(concentration)
14-
return concentration._standard_gamma()
13+
return torch._C._VariableFunctions.standard_gamma(Variable(concentration)).data
14+
return torch._C._VariableFunctions.standard_gamma(concentration)
1515

1616

1717
class Gamma(Distribution):

0 commit comments

Comments
 (0)