@@ -1387,6 +1387,41 @@ using c10::DeviceType;
13871387 return grad_input;
13881388 }
13891389
1390+ // {"schema": "aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
1391+ Tensor & logit_out (const Tensor & self, ::std::optional<double > eps, Tensor & out)
1392+ {
1393+ GUARD;
1394+ Tensor self_c = self.contiguous (), out_c = out.contiguous ();
1395+ dlprim::Tensor X = todp (self_c);
1396+ dlprim::Tensor Y = todp (out_c);
1397+ auto q = getExecutionContext (self);
1398+ if (eps) {
1399+ double e = *eps;
1400+ dlprim::core::pointwise_operation ({X},{Y},{e},
1401+ " dtype z = min(1-w0,max(w0,x0)); "
1402+ " y0 = log(z / (z-1)); " ,q);
1403+ }
1404+ else {
1405+ dlprim::core::pointwise_operation ({X},{Y},{},
1406+ " y0 = log(x0 / (x0-1));" ,q);
1407+ }
1408+ if (!out.is_contiguous ())
1409+ out.copy_ (out_c);
1410+
1411+ sync_if_needed (self.device ());
1412+ return out;
1413+ }
1414+ // {"schema": "aten::logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "False"}
1415+ Tensor logit (const Tensor & self, ::std::optional<double > eps)
1416+ {
1417+ Tensor self_c = self.contiguous ();
1418+ dlprim::Tensor X = todp (self_c);
1419+
1420+ torch::Tensor result = new_tensor_as (X.shape (),self);
1421+ logit_out (self_c,eps,result);
1422+ return result;
1423+ }
1424+
13901425
13911426
13921427#if 0
@@ -1472,6 +1507,8 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
14721507 m.impl (" aten::gelu.out" ,&ptdlprim::gelu_out);
14731508 m.impl (" aten::gelu_backward.grad_input" ,&ptdlprim::gelu_backward_out);
14741509 m.impl (" aten::lerp.Scalar_out" ,&ptdlprim::lerp_out);
1510+ m.impl (" aten::logit.out" ,&ptdlprim::logit_out);
1511+ m.impl (" aten::logit" ,&ptdlprim::logit);
14751512
14761513}
14771514
0 commit comments