Skip to content

Commit c252e57

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6d87fa2 commit c252e57

File tree

4 files changed

+108
-92
lines changed

4 files changed

+108
-92
lines changed

source/lib/src/gpu/tabulate.cu

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -654,16 +654,19 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
654654
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
655655

656656
FPTYPE var[6];
657-
load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size);
657+
load_polynomial_params(var, table, table_idx, thread_idx,
658+
last_layer_size);
658659

659-
FPTYPE res = var[0] +
660-
(var[1] +
661-
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;
660+
FPTYPE res =
661+
var[0] +
662+
(var[1] +
663+
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
664+
xx;
662665

663666
// Store result preserving the nt_i x nt_j structure
664667
out[block_idx * nnei_i * nnei_j * last_layer_size +
665-
ii * nnei_j * last_layer_size +
666-
jj * last_layer_size + thread_idx] = res;
668+
ii * nnei_j * last_layer_size + jj * last_layer_size + thread_idx] =
669+
res;
667670
}
668671
}
669672
}
@@ -698,11 +701,12 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
698701
load_polynomial_params(var, table, table_idx, mm, last_layer_size);
699702

700703
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx +
701-
4.0 * var[4] * xx * xx * xx + 5.0 * var[5] * xx * xx * xx * xx;
704+
4.0 * var[4] * xx * xx * xx +
705+
5.0 * var[5] * xx * xx * xx * xx;
702706

703-
FPTYPE dy_val = dy[block_idx * nnei_i * nnei_j * last_layer_size +
704-
ii * nnei_j * last_layer_size +
705-
jj * last_layer_size + mm];
707+
FPTYPE dy_val =
708+
dy[block_idx * nnei_i * nnei_j * last_layer_size +
709+
ii * nnei_j * last_layer_size + jj * last_layer_size + mm];
706710
grad_sum += dy_val * dres_dxx;
707711
}
708712

@@ -734,21 +738,24 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial(
734738
for (int ii = 0; ii < nnei_i; ii++) {
735739
for (int jj = 0; jj < nnei_j; jj++) {
736740
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
737-
FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
741+
FPTYPE dz_dy_dem_x_val =
742+
dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
738743

739744
int table_idx = 0;
740745
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
741746

742747
FPTYPE var[6];
743-
load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size);
748+
load_polynomial_params(var, table, table_idx, thread_idx,
749+
last_layer_size);
744750

745751
FPTYPE dres_dxx = var[1] + 2.0 * var[2] * xx + 3.0 * var[3] * xx * xx +
746-
4.0 * var[4] * xx * xx * xx + 5.0 * var[5] * xx * xx * xx * xx;
752+
4.0 * var[4] * xx * xx * xx +
753+
5.0 * var[5] * xx * xx * xx * xx;
747754

748755
// Store result preserving the nt_i x nt_j structure
749756
dz_dy[block_idx * nnei_i * nnei_j * last_layer_size +
750-
ii * nnei_j * last_layer_size +
751-
jj * last_layer_size + thread_idx] = dz_dy_dem_x_val * dres_dxx;
757+
ii * nnei_j * last_layer_size + jj * last_layer_size + thread_idx] =
758+
dz_dy_dem_x_val * dres_dxx;
752759
}
753760
}
754761
}
@@ -1088,9 +1095,10 @@ void tabulate_fusion_se_t_tebd_grad_gpu(FPTYPE* dy_dem_x,
10881095
DPErrcheck(gpuDeviceSynchronize());
10891096
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j));
10901097
tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1091-
<<<nloc, KK * WARP_SIZE>>>(
1092-
dy_dem_x, table, em_x, em, dy, table_info[0], table_info[1],
1093-
table_info[2], table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
1098+
<<<nloc, KK * WARP_SIZE>>>(dy_dem_x, table, em_x, em, dy, table_info[0],
1099+
table_info[1], table_info[2], table_info[3],
1100+
table_info[4], nnei_i, nnei_j,
1101+
last_layer_size);
10941102
DPErrcheck(gpuGetLastError());
10951103
DPErrcheck(gpuDeviceSynchronize());
10961104
}
@@ -1111,13 +1119,14 @@ void tabulate_fusion_se_t_tebd_grad_grad_gpu(FPTYPE* dz_dy,
11111119
}
11121120
DPErrcheck(gpuGetLastError());
11131121
DPErrcheck(gpuDeviceSynchronize());
1114-
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
1122+
DPErrcheck(gpuMemset(
1123+
dz_dy, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size));
11151124

11161125
tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
1117-
<<<nloc, last_layer_size>>>(
1118-
dz_dy, table, em_x, em, dz_dy_dem_x,
1119-
table_info[0], table_info[1], table_info[2], table_info[3], table_info[4],
1120-
nnei_i, nnei_j, last_layer_size);
1126+
<<<nloc, last_layer_size>>>(dz_dy, table, em_x, em, dz_dy_dem_x,
1127+
table_info[0], table_info[1], table_info[2],
1128+
table_info[3], table_info[4], nnei_i, nnei_j,
1129+
last_layer_size);
11211130
DPErrcheck(gpuGetLastError());
11221131
DPErrcheck(gpuDeviceSynchronize());
11231132
}
@@ -1381,27 +1390,25 @@ template void tabulate_fusion_se_r_grad_grad_gpu<double>(
13811390
const int last_layer_size);
13821391

13831392
// Template instantiations for SE_T_TEBD GPU functions
1384-
template void tabulate_fusion_se_t_tebd_gpu<float>(
1385-
float* out,
1386-
const float* table,
1387-
const float* table_info,
1388-
const float* em_x,
1389-
const float* em,
1390-
const int nloc,
1391-
const int nnei_i,
1392-
const int nnei_j,
1393-
const int last_layer_size);
1393+
template void tabulate_fusion_se_t_tebd_gpu<float>(float* out,
1394+
const float* table,
1395+
const float* table_info,
1396+
const float* em_x,
1397+
const float* em,
1398+
const int nloc,
1399+
const int nnei_i,
1400+
const int nnei_j,
1401+
const int last_layer_size);
13941402

1395-
template void tabulate_fusion_se_t_tebd_gpu<double>(
1396-
double* out,
1397-
const double* table,
1398-
const double* table_info,
1399-
const double* em_x,
1400-
const double* em,
1401-
const int nloc,
1402-
const int nnei_i,
1403-
const int nnei_j,
1404-
const int last_layer_size);
1403+
template void tabulate_fusion_se_t_tebd_gpu<double>(double* out,
1404+
const double* table,
1405+
const double* table_info,
1406+
const double* em_x,
1407+
const double* em,
1408+
const int nloc,
1409+
const int nnei_i,
1410+
const int nnei_j,
1411+
const int last_layer_size);
14051412

14061413
template void tabulate_fusion_se_t_tebd_grad_gpu<float>(
14071414
float* dy_dem_x,

source/lib/src/tabulate.cc

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,7 @@ void deepmd::tabulate_fusion_se_t_tebd_cpu(FPTYPE* out,
582582

583583
// Store result preserving the nt_i x nt_j structure
584584
out[ii * nnei_i * nnei_j * last_layer_size +
585-
jj * nnei_j * last_layer_size +
586-
kk * last_layer_size + mm] = res;
585+
jj * nnei_j * last_layer_size + kk * last_layer_size + mm] = res;
587586
}
588587
}
589588
}
@@ -626,11 +625,12 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_cpu(FPTYPE* dy_dem_x,
626625
FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5];
627626

628627
FPTYPE dres_dxx = a1 + 2.0 * a2 * xx + 3.0 * a3 * xx * xx +
629-
4.0 * a4 * xx * xx * xx + 5.0 * a5 * xx * xx * xx * xx;
628+
4.0 * a4 * xx * xx * xx +
629+
5.0 * a5 * xx * xx * xx * xx;
630630

631-
FPTYPE dy_val = dy[ii * nnei_i * nnei_j * last_layer_size +
632-
jj * nnei_j * last_layer_size +
633-
kk * last_layer_size + mm];
631+
FPTYPE dy_val =
632+
dy[ii * nnei_i * nnei_j * last_layer_size +
633+
jj * nnei_j * last_layer_size + kk * last_layer_size + mm];
634634
grad_sum += dy_val * dres_dxx;
635635
}
636636

@@ -641,16 +641,17 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_cpu(FPTYPE* dy_dem_x,
641641
}
642642

643643
template <typename FPTYPE>
644-
void deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(FPTYPE* dz_dy,
645-
const FPTYPE* table,
646-
const FPTYPE* table_info,
647-
const FPTYPE* em_x,
648-
const FPTYPE* em,
649-
const FPTYPE* dz_dy_dem_x,
650-
const int nloc,
651-
const int nnei_i,
652-
const int nnei_j,
653-
const int last_layer_size) {
644+
void deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(
645+
FPTYPE* dz_dy,
646+
const FPTYPE* table,
647+
const FPTYPE* table_info,
648+
const FPTYPE* em_x,
649+
const FPTYPE* em,
650+
const FPTYPE* dz_dy_dem_x,
651+
const int nloc,
652+
const int nnei_i,
653+
const int nnei_j,
654+
const int last_layer_size) {
654655
memset(dz_dy, 0, sizeof(FPTYPE) * nloc * nnei_i * nnei_j * last_layer_size);
655656
const FPTYPE lower = table_info[0];
656657
const FPTYPE upper = table_info[1];
@@ -667,7 +668,8 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(FPTYPE* dz_dy,
667668
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
668669
table_idx);
669670

670-
FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
671+
FPTYPE dz_dy_dem_x_val =
672+
dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
671673

672674
for (int mm = 0; mm < last_layer_size; mm++) {
673675
FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * mm + 1];
@@ -677,11 +679,12 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(FPTYPE* dz_dy,
677679
FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5];
678680

679681
FPTYPE dres_dxx = a1 + 2.0 * a2 * xx + 3.0 * a3 * xx * xx +
680-
4.0 * a4 * xx * xx * xx + 5.0 * a5 * xx * xx * xx * xx;
682+
4.0 * a4 * xx * xx * xx +
683+
5.0 * a5 * xx * xx * xx * xx;
681684

682685
dz_dy[ii * nnei_i * nnei_j * last_layer_size +
683-
jj * nnei_j * last_layer_size +
684-
kk * last_layer_size + mm] = dz_dy_dem_x_val * dres_dxx;
686+
jj * nnei_j * last_layer_size + kk * last_layer_size + mm] =
687+
dz_dy_dem_x_val * dres_dxx;
685688
}
686689
}
687690
}

source/op/pt/tabulate_multi_device.cc

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -369,16 +369,18 @@ void TabulateFusionSeTTebdForward(const torch::Tensor& table_tensor,
369369
// compute
370370
if (device == "GPU") {
371371
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
372-
deepmd::tabulate_fusion_se_t_tebd_gpu(descriptor, table, table_info, em_x, em,
373-
nloc, nnei_i, nnei_j, last_layer_size);
372+
deepmd::tabulate_fusion_se_t_tebd_gpu(descriptor, table, table_info, em_x,
373+
em, nloc, nnei_i, nnei_j,
374+
last_layer_size);
374375
#else
375376
throw std::runtime_error(
376377
"The input tensor is on the GPU, but the GPU support for the "
377378
"customized OP library is not enabled.");
378379
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
379380
} else if (device == "CPU") {
380-
deepmd::tabulate_fusion_se_t_tebd_cpu(descriptor, table, table_info, em_x, em,
381-
nloc, nnei_i, nnei_j, last_layer_size);
381+
deepmd::tabulate_fusion_se_t_tebd_cpu(descriptor, table, table_info, em_x,
382+
em, nloc, nnei_i, nnei_j,
383+
last_layer_size);
382384
}
383385
}
384386

@@ -414,28 +416,29 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
414416
if (device == "GPU") {
415417
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
416418
deepmd::tabulate_fusion_se_t_tebd_grad_gpu(dy_dem_x, table, table_info,
417-
em_x, em, dy, nloc, nnei_i, nnei_j,
418-
last_layer_size);
419+
em_x, em, dy, nloc, nnei_i,
420+
nnei_j, last_layer_size);
419421
#else
420422
throw std::runtime_error(
421423
"The input tensor is on the GPU, but the GPU support for the "
422424
"customized OP library is not enabled.");
423425
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
424426
} else if (device == "CPU") {
425427
deepmd::tabulate_fusion_se_t_tebd_grad_cpu(dy_dem_x, table, table_info,
426-
em_x, em, dy, nloc, nnei_i, nnei_j,
427-
last_layer_size);
428+
em_x, em, dy, nloc, nnei_i,
429+
nnei_j, last_layer_size);
428430
}
429431
}
430432

431433
template <typename FPTYPE>
432-
void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
433-
const torch::Tensor& table_info_tensor,
434-
const torch::Tensor& em_x_tensor,
435-
const torch::Tensor& em_tensor,
436-
const torch::Tensor& dz_dy_dem_x_tensor,
437-
const torch::Tensor& descriptor_tensor,
438-
torch::Tensor& dz_dy_tensor) {
434+
void TabulateFusionSeTTebdGradGradForward(
435+
const torch::Tensor& table_tensor,
436+
const torch::Tensor& table_info_tensor,
437+
const torch::Tensor& em_x_tensor,
438+
const torch::Tensor& em_tensor,
439+
const torch::Tensor& dz_dy_dem_x_tensor,
440+
const torch::Tensor& descriptor_tensor,
441+
torch::Tensor& dz_dy_tensor) {
439442
// Check input shape
440443
if (dz_dy_dem_x_tensor.dim() != 3) {
441444
throw std::invalid_argument("Dim of dz_dy_dem_x should be 3");
@@ -458,9 +461,9 @@ void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
458461
// compute
459462
if (device == "GPU") {
460463
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
461-
deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu(dz_dy, table, table_info, em_x,
462-
em, dz_dy_dem_x, nloc,
463-
nnei_i, nnei_j, last_layer_size);
464+
deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu(
465+
dz_dy, table, table_info, em_x, em, dz_dy_dem_x, nloc, nnei_i, nnei_j,
466+
last_layer_size);
464467
#else
465468
throw std::runtime_error(
466469
"The input tensor is on the GPU, but the GPU support for the "
@@ -470,9 +473,9 @@ void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
470473
"In the process of model compression, the size of the "
471474
"last layer of embedding net must be less than 1024!");
472475
} else if (device == "CPU") {
473-
deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(dz_dy, table, table_info, em_x,
474-
em, dz_dy_dem_x, nloc,
475-
nnei_i, nnei_j, last_layer_size);
476+
deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(
477+
dz_dy, table, table_info, em_x, em, dz_dy_dem_x, nloc, nnei_i, nnei_j,
478+
last_layer_size);
476479
}
477480
}
478481

@@ -1112,13 +1115,14 @@ class TabulateFusionSeTTebdOp
11121115
auto options = torch::TensorOptions()
11131116
.dtype(table_tensor.dtype())
11141117
.device(table_tensor.device());
1115-
torch::Tensor descriptor_tensor = torch::empty(
1116-
{em_tensor.size(0), em_tensor.size(1), em_tensor.size(2), last_layer_size},
1117-
options);
1118+
torch::Tensor descriptor_tensor =
1119+
torch::empty({em_tensor.size(0), em_tensor.size(1), em_tensor.size(2),
1120+
last_layer_size},
1121+
options);
11181122
// compute
11191123
TabulateFusionSeTTebdForward<FPTYPE>(table_tensor, table_info_tensor,
1120-
em_x_tensor, em_tensor, last_layer_size,
1121-
descriptor_tensor);
1124+
em_x_tensor, em_tensor,
1125+
last_layer_size, descriptor_tensor);
11221126
// save data
11231127
ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor,
11241128
em_tensor, descriptor_tensor});
@@ -1202,8 +1206,8 @@ std::vector<torch::Tensor> tabulate_fusion_se_t_tebd(
12021206
const torch::Tensor& em_x_tensor,
12031207
const torch::Tensor& em_tensor,
12041208
int64_t last_layer_size) {
1205-
return TabulateFusionSeTTebdOp::apply(table_tensor, table_info_tensor,
1206-
em_x_tensor, em_tensor, last_layer_size);
1209+
return TabulateFusionSeTTebdOp::apply(
1210+
table_tensor, table_info_tensor, em_x_tensor, em_tensor, last_layer_size);
12071211
}
12081212

12091213
std::vector<torch::Tensor> tabulate_fusion_se_r(

source/tests/pt/test_tabulate_fusion_se_t_tebd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,9 @@ def setUp(self) -> None:
233233
dtype=dtype,
234234
device=env.DEVICE,
235235
).reshape(4, 4)
236-
self.em_tensor = self.em_x_tensor.reshape(4, 4, 1) # SE_T_TEBD uses angular information, so 1D
236+
self.em_tensor = self.em_x_tensor.reshape(
237+
4, 4, 1
238+
) # SE_T_TEBD uses angular information, so 1D
237239
self.table_info_tensor.requires_grad = False
238240
self.table_tensor.requires_grad = False
239241
self.em_x_tensor.requires_grad = True

0 commit comments

Comments
 (0)