Skip to content

Commit d7488f4

Browse files
committed
Refactor standard_gamma and implement CUDA gamma sampling
1 parent d2f71cb commit d7488f4

File tree

15 files changed

+145
-94
lines changed

15 files changed

+145
-94
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3966,26 +3966,6 @@
39663966
kwarg_only: True
39673967
- THTensor* self
39683968
]]
3969-
[[
3970-
name: _standard_gamma
3971-
types:
3972-
- floating_point
3973-
backends:
3974-
- CPU
3975-
return: argument 0
3976-
variants:
3977-
- method
3978-
- function
3979-
options:
3980-
- cname: standard_gamma
3981-
arguments:
3982-
- arg: THTensor* output
3983-
output: True
3984-
- arg: THGenerator* generator
3985-
default: nullptr
3986-
kwarg_only: True
3987-
- THTensor* self
3988-
]]
39893969
[[
39903970
name: _dirichlet_grad
39913971
types:

aten/src/ATen/native/Distributions.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "ATen/CheckGenerator.h"
99
#include "ATen/Generator.h"
1010

11+
#include <ATen/native/Distributions.cuh>
12+
1113
#include "TH/THRandom.h"
1214

1315
namespace at {
@@ -155,6 +157,24 @@ namespace dist {
155157
return gen_->generator;
156158
}
157159

160+
template <typename scalar>
161+
struct GammaOp {
162+
static void apply(Tensor& ret, const Tensor& alpha, THGenerator *generator) {
163+
CPU_tensor_apply2<scalar, double>(ret, alpha,
164+
[generator](scalar& ret_val, const double& alpha){
165+
dist::baseSampler<float> standard_uniform([generator] () {
166+
return THRandom_standard_uniform(generator);
167+
});
168+
dist::baseSampler<float> standard_normal([generator] () {
169+
return THRandom_normal(generator, 0.0, 1.0);
170+
});
171+
auto sample = dist::sample_gamma<float>(alpha, standard_uniform, standard_normal);
172+
ret_val = std::max(std::numeric_limits<scalar>::min(), (scalar) sample);
173+
}
174+
);
175+
}
176+
};
177+
158178
template <typename scalar>
159179
struct PoissonOp {
160180
static int64_t sample_poisson(double lambda, THGenerator *generator) {
@@ -227,5 +247,12 @@ Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
227247
return ret;
228248
}
229249

250+
Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
251+
Tensor ret = alpha.type().zeros(alpha.sizes());
252+
auto alpha_ = alpha.toType(ScalarType::Double);
253+
dispatch_floating_types<void, dist::GammaOp>(ret.type(), "gamma", ret, alpha_, dist::get_generator(gen));
254+
return ret;
255+
}
256+
230257
} // at::native
231258
} // at
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "ATen/Config.h"
2+
#include <functional>
3+
#if AT_CUDA_ENABLED()
4+
#include <nvfunctional>
5+
#endif
6+
7+
namespace at {
8+
namespace native {
9+
namespace dist {
10+
11+
// this wraps sampling primitives to expose a common interface
12+
template<typename precision_t>
13+
struct baseSampler {
14+
#if AT_CUDA_ENABLED()
15+
nvstd::function<precision_t(void)> sampler;
16+
__device__ baseSampler(nvstd::function<precision_t(void)> sampler): sampler(sampler) {}
17+
__device__ precision_t sample() {
18+
return sampler();
19+
}
20+
#else
21+
std::function<precision_t(void)> sampler;
22+
baseSampler(std::function<precision_t(void)> sampler): sampler(sampler) {}
23+
precision_t sample() {
24+
return sampler();
25+
}
26+
#endif
27+
};
28+
29+
template<typename precision_t>
30+
#if AT_CUDA_ENABLED()
31+
__host__ __device__
32+
#endif
33+
precision_t sample_gamma(precision_t alpha, baseSampler<precision_t>& standard_uniform, baseSampler<precision_t>& standard_normal) {
34+
precision_t scale = 1.0;
35+
36+
// Boost alpha for higher acceptance probability.
37+
if (alpha < 1.0) {
38+
scale *= ::pow(1 - standard_uniform.sample(), 1.0 / alpha);
39+
alpha += 1.0;
40+
}
41+
42+
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
43+
// doi:10.1145/358407.358414
44+
const precision_t d = alpha - 1.0 / 3.0;
45+
const precision_t c = 1.0 / ::sqrt(9.0 * d);
46+
for (;;) {
47+
precision_t x, y;
48+
do {
49+
x = standard_normal.sample();
50+
y = 1.0 + c * x;
51+
} while (y <= 0);
52+
const precision_t v = y * y * y;
53+
const precision_t u = 1 - standard_uniform.sample();
54+
const precision_t xx = x * x;
55+
if (u < 1.0 - 0.0331 * xx * xx)
56+
return scale * d * v;
57+
if (::log(u) < 0.5 * xx + d * (1.0 - v + ::log(v)))
58+
return scale * d * v;
59+
}
60+
}
61+
} // dist
62+
} // native
63+
} // at

aten/src/ATen/native/cuda/Distributions.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#include <curand_kernel.h>
66
#include <curand_philox4x32_x.h>
77
#include <utility>
8+
#include <functional>
9+
#include <nvfunctional>
10+
11+
#include "ATen/native/Distributions.cuh"
812

913
#include <TH/THAtomic.h>
1014

@@ -26,6 +30,26 @@ namespace dist {
2630
return std::make_pair(gen_->initial_seed, offset);
2731
}
2832

33+
template <typename scalar>
34+
struct GammaOpCUDA {
35+
static void apply(Tensor& ret, const Tensor& alpha, std::pair<uint64_t, uint64_t> seeds) {
36+
at::cuda::CUDA_tensor_apply2<scalar, float>(ret, alpha,
37+
[seeds] __device__ (scalar& ret_val, const float& alpha, bool early_exit) {
38+
curandStatePhilox4_32_10_t state;
39+
curand_init(seeds.first, blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state);
40+
baseSampler<float> standard_uniform([&state] __device__ () {
41+
return curand_uniform(&state);
42+
});
43+
baseSampler<float> standard_normal([&state] __device__ () {
44+
return curand_normal(&state);
45+
});
46+
auto sample = scalar_cast<scalar>(sample_gamma<float>(alpha, standard_uniform, standard_normal));
47+
ret_val = ::max(THCNumerics<scalar>::min(), (scalar) sample);
48+
}
49+
);
50+
}
51+
};
52+
2953
template <typename scalar>
3054
struct PoissonOpCUDA {
3155
static void apply(Tensor& ret, const Tensor& lambda, std::pair<uint64_t, uint64_t> seeds) {
@@ -48,5 +72,12 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
4872
return ret;
4973
}
5074

75+
Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
76+
Tensor ret = alpha.type().tensor(alpha.sizes());
77+
auto alpha_ = alpha.toType(ScalarType::Float);
78+
dispatch_floating_types<void, dist::GammaOpCUDA>(ret.type(), "gamma", ret, alpha_, dist::next_philox_seed(gen));
79+
return ret;
80+
}
81+
5182
} // at::native
5283
} // at

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,12 @@
392392
CPU: _s_poisson_cpu
393393
CUDA: _s_poisson_cuda
394394

395+
- func: standard_gamma(Tensor self, Generator* generator=nullptr) -> Tensor
396+
variants: function
397+
dispatch:
398+
CPU: _s_gamma_cpu
399+
CUDA: _s_gamma_cuda
400+
395401
- func: _cudnn_rnn_flatten_weight(TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, bool bidirectional) -> Tensor
396402
variants: function
397403

aten/src/TH/THRandom.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -290,35 +290,6 @@ double THRandom_exponential(THGenerator *_generator, double lambda)
290290
return(-1. / lambda * log(1-uniform_double(_generator)));
291291
}
292292

293-
double THRandom_standard_gamma(THGenerator *_generator, double alpha) {
294-
double scale = 1.0;
295-
296-
// Boost alpha for higher acceptance probability.
297-
if(alpha < 1.0) {
298-
scale *= pow(1 - uniform_double(_generator), 1.0 / alpha);
299-
alpha += 1.0;
300-
}
301-
302-
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
303-
// doi:10.1145/358407.358414
304-
const double d = alpha - 1.0 / 3.0;
305-
const double c = 1.0 / sqrt(9.0 * d);
306-
for(;;) {
307-
double x, y;
308-
do {
309-
x = THRandom_normal(_generator, 0.0, 1.0);
310-
y = 1.0 + c * x;
311-
} while(y <= 0);
312-
const double v = y * y * y;
313-
const double u = 1 - uniform_double(_generator);
314-
const double xx = x * x;
315-
if(u < 1.0 - 0.0331 * xx * xx)
316-
return scale * d * v;
317-
if(log(u) < 0.5 * xx + d * (1.0 - v + log(v)))
318-
return scale * d * v;
319-
}
320-
}
321-
322293
double THRandom_cauchy(THGenerator *_generator, double median, double sigma)
323294
{
324295
return(median + sigma * tan(M_PI*(uniform_double(_generator)-0.5)));

aten/src/TH/THRandom.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ TH_API double THRandom_normal(THGenerator *_generator, double mean, double stdv)
6363
*/
6464
TH_API double THRandom_exponential(THGenerator *_generator, double lambda);
6565

66-
/** Generates a random number from a standard Gamma distribution.
67-
The Gamma density is proportional to $x^{alpha-1} exp(-x)$
68-
The shape parameter alpha (a.k.a. k) is a positive real number.
69-
*/
70-
TH_API double THRandom_standard_gamma(THGenerator *_generator, double alpha);
71-
7266
/** Returns a random number from a Cauchy distribution.
7367
The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$
7468
*/

aten/src/TH/generic/THTensorRandom.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,6 @@ void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lamb
138138
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(_generator, lambda););
139139
}
140140

141-
void THTensor_(standard_gamma)(THTensor *self, THGenerator *_generator, THTensor *alpha)
142-
{
143-
std::lock_guard<std::mutex> lock(_generator->mutex);
144-
THTensor_(resizeAs)(self, alpha);
145-
TH_TENSOR_APPLY2(real, self, real, alpha, {
146-
const real sample = THRandom_standard_gamma(_generator, *alpha_data);
147-
*self_data = sample > 0 ? sample : TH_REAL_MIN;
148-
});
149-
}
150-
151141
#undef TH_REAL_MIN
152142

153143
void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma)

aten/src/TH/generic/THTensorRandom.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ TH_API void THTensor_(normal_means)(THTensor *self, THGenerator *gen, THTensor *
1818
TH_API void THTensor_(normal_stddevs)(THTensor *self, THGenerator *gen, double mean, THTensor *stddevs);
1919
TH_API void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, THTensor *means, THTensor *stddevs);
2020
TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda);
21-
TH_API void THTensor_(standard_gamma)(THTensor *self, THGenerator *_generator, THTensor *alpha);
2221
TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma);
2322
TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv);
2423
TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement);

test/test_distributions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def test_poisson_sample(self):
712712
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
713713
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
714714
def test_poisson_gpu_sample(self):
715-
set_rng_seed(0)
715+
set_rng_seed(1)
716716
for rate in [0.12, 0.9, 4.0]:
717717
self._check_sampler_discrete(Poisson(torch.Tensor([rate]).cuda()),
718718
scipy.stats.poisson(rate),
@@ -1089,6 +1089,17 @@ def test_gamma_sample(self):
10891089
scipy.stats.gamma(alpha, scale=1.0 / beta),
10901090
'Gamma(concentration={}, rate={})'.format(alpha, beta))
10911091

1092+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
1093+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1094+
def test_gamma_gpu_sample(self):
1095+
set_rng_seed(0)
1096+
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
1097+
a, b = torch.Tensor([alpha]).cuda(), torch.Tensor([beta]).cuda()
1098+
self._check_sampler_sampler(Gamma(a, b),
1099+
scipy.stats.gamma(alpha, scale=1.0 / beta),
1100+
'Gamma(alpha={}, beta={})'.format(alpha, beta),
1101+
failure_rate=1e-4)
1102+
10921103
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
10931104
def test_pareto(self):
10941105
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)

0 commit comments

Comments
 (0)