Skip to content

Commit e34ec25

Browse files
committed
Added logit support
1 parent 4ae92a0 commit e34ec25

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/pointwise_ops.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,41 @@ using c10::DeviceType;
13871387
return grad_input;
13881388
}
13891389

1390+
// {"schema": "aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
1391+
Tensor & logit_out(const Tensor & self, ::std::optional<double> eps, Tensor & out)
1392+
{
1393+
GUARD;
1394+
Tensor self_c = self.contiguous(), out_c = out.contiguous();
1395+
dlprim::Tensor X = todp(self_c);
1396+
dlprim::Tensor Y = todp(out_c);
1397+
auto q = getExecutionContext(self);
1398+
if(eps) {
1399+
double e = *eps;
1400+
dlprim::core::pointwise_operation({X},{Y},{e},
1401+
"dtype z = min(1-w0,max(w0,x0)); "
1402+
"y0 = log(z / (z-1)); ",q);
1403+
}
1404+
else {
1405+
dlprim::core::pointwise_operation({X},{Y},{},
1406+
"y0 = log(x0 / (x0-1));",q);
1407+
}
1408+
if(!out.is_contiguous())
1409+
out.copy_(out_c);
1410+
1411+
sync_if_needed(self.device());
1412+
return out;
1413+
}
1414+
// {"schema": "aten::logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "False"}
1415+
Tensor logit(const Tensor & self, ::std::optional<double> eps)
1416+
{
1417+
Tensor self_c = self.contiguous();
1418+
dlprim::Tensor X = todp(self_c);
1419+
1420+
torch::Tensor result = new_tensor_as(X.shape(),self);
1421+
logit_out(self_c,eps,result);
1422+
return result;
1423+
}
1424+
13901425

13911426

13921427
#if 0
@@ -1472,6 +1507,8 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
14721507
m.impl("aten::gelu.out",&ptdlprim::gelu_out);
14731508
m.impl("aten::gelu_backward.grad_input",&ptdlprim::gelu_backward_out);
14741509
m.impl("aten::lerp.Scalar_out",&ptdlprim::lerp_out);
1510+
m.impl("aten::logit.out",&ptdlprim::logit_out);
1511+
m.impl("aten::logit",&ptdlprim::logit);
14751512

14761513
}
14771514

tests/test_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def test_all(device):
267267
print("LayerNorm Aff")
268268
test_fwd_bwd_op([([2,3,4,30],-1)],torch.nn.LayerNorm((4,30),elementwise_affine=True),device,paramgen = torch.randn)
269269

270+
print("Test logit eps")
271+
test_fwd([([4,3,5],-1)],lambda x:torch.logit(x,eps=0.1),device)
272+
273+
print("Test logit")
274+
test_fwd([([4,3,5],-1)],lambda x:torch.logit(torch.clamp(x,min=0.1,max=0.9)),device)
270275

271276
def test_concat(dev):
272277
print("Test concat")

0 commit comments

Comments
 (0)