@@ -47,9 +47,10 @@ from tripy.flat_ir.ops.base import BaseFlatIROp
4747class ThetaOp (BaseFlatIROp ):
4848 dim: int
4949
50- # `to_mlir()` is the trickiest bit. As the name implies, the method is meant to lower the
51- # `FlatIR` operator into MLIR. To figure out which MLIR operators to use, refer to
52- # the 'MLIR Python API Guide' (linked below).
50+ # `to_mlir()` is the trickiest bit. As the name implies, the method is
51+ # meant to lower the `FlatIR` operator into MLIR. To figure out which
52+ # MLIR operators to use, refer to the 'MLIR Python API Guide'
53+ # (linked below).
5354 def to_mlir (self , operands ):
5455 out_type = self .outputs[0 ].to_mlir()
5556 theta_dim = ir.IntegerAttr.get(type = ir.IntegerType.get_signless(64 ), value = self .dim)
@@ -116,29 +117,31 @@ from tripy.frontend.trace.ops.base import BaseTraceOp
116117import tripy.frontend.trace.ops.utils as op_utils
117118
118119
119- # Just like with `FlatIR` operators, all `Trace` operators are implemented as `dataclass`es.
120- # As before, we want `repr=False` here.
120+ # Just like with `FlatIR` operators, all `Trace` operators are implemented
121+ # as `dataclass`es. As before, we want `repr=False` here.
121122@dataclass (repr = False )
122123class Theta (BaseTraceOp ):
123- # Notice that we do *not* need to define a constructor and can rely on the default
124- # implementation provided by `dataclass`.
124+ # Notice that we do *not* need to define a constructor and can rely on
125+ # the default implementation provided by `dataclass`.
125126 dim: int
126127 dtype: datatype.dtype
127128
128129 # `infer_rank()` populates the rank of the output `TraceTensor`s.
129- # Here we use one of the predefined policies to set the output rank to the same as the shape (i.e. the length)
130- # of the shape operand.
130+ # Here we use one of the predefined policies to set the output rank
131+ # to the same as the shape (i.e. the length) of the shape operand.
131132 infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input()
132133
133134 # *Optional* `infer_dtypes()` populates the data types of the
134135 # output `TraceTensor`s. The default implementation copies the input
135- # data types if they are all the same, so you may not need to implement this.
136+ # data types if they are all the same, so you may not need to implement
137+ # this.
136138 def infer_dtypes (self ):
137139 self .outputs[0 ].dtype = self .dtype
138140
139141 # *Optional* `infer_devices()` populates the devices of the
140142 # output `TraceTensor`s. The default implementation copies the input
141- # devices if they are all the same, so you may not need to implement this either.
143+ # devices if they are all the same, so you may not need to implement
144+ # this either.
142145 def infer_devices (self ):
143146 self .outputs[0 ].device = device(" gpu" )
144147
@@ -177,30 +180,35 @@ from tripy import export
177180import tripy.frontend.utils as frontend_utils
178181from tripy.types import ShapeLike
179182
180- # We can use the `export.public_api()` decorator to automatically export this function into the
181- # top-level module. This means it will be accessible as `tripy.theta`.
183+ # We can use the `export.public_api()` decorator to automatically export this
184+ # function into the top-level module. This means it will be accessible as
185+ # `tripy.theta`.
182186#
183- # This decorator also controls how the API is exposed in the documentation - the `document_under`
184- # option determines where in the documentation hierarchy this API will show up.
187+ # This decorator also controls how the API is exposed in the documentation -
188+ # the `document_under` option determines where in the documentation hierarchy
189+ # this API will show up.
185190#
186- # If we needed to provide any special autodoc options, we could use the `autodoc_options` parameter.
191+ # If we needed to provide any special autodoc options, we could use the
192+ # `autodoc_options` parameter.
187193@export.public_api (document_under = " tensor_operations" )
188194
189- # The `convert_to_tensors` decorator automatically converts compatible arguments,
190- # like `TensorLike` or `ShapeLike`s, into tensors.
195+ # The `convert_to_tensors` decorator automatically converts compatible
196+ # arguments, like `TensorLike` or `ShapeLike`s, into tensors.
191197@frontend_utils.convert_to_tensors ()
192198def theta (shape : ShapeLike, dim : int = 0 , dtype : datatype.dtype = datatype.float32) -> " tripy.Tensor" :
193- # For any public facing interfaces, we have documentation requirements which you can read
194- # about in the 'Docs README' (linked below). The docstring we've implemented here
195- # adheres to all of these requirements. Non-compliant docstrings will, in most cases,
196- # cause test failures; however, you should still manually ensure you're writing high-quality
197- # docstrings.
199+ # For any public facing interfaces, we have documentation requirements which
200+ # you can read about in the 'Docs README' (linked below). The docstring
201+ # we've implemented here adheres to all of these requirements. Non-compliant
202+ # docstrings will, in most cases, cause test failures; however, you should
203+ # still manually ensure you're writing high-quality docstrings.
198204 #
199- # The examples in docstrings are run as part of our tests, so you should also add
200- # assertions to make sure things are functionally correct. In this case, we check
201- # that the `output` we create in the code example is what we expect.
205+ # The examples in docstrings are run as part of our tests, so you should
206+ # also add assertions to make sure things are functionally correct. In this
207+ # case, we check that the `output` we create in the code example is what we
208+ # expect.
202209 """
203- Fills an output tensor with consecutive values starting from zero along the given dimension.
210+ Fills an output tensor with consecutive values starting from zero
211+ along the given dimension.
204212
205213 Args:
206214 shape: The desired shape.
@@ -217,12 +225,15 @@ def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float
217225
218226 output = tp.theta([3])
219227
220- assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32))
228+ assert np.array_equal(
229+ cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)
230+ )
221231 """
222232
223- # Next we build the trace operator. The `build()` function is also responsible for constructing
224- # the output frontend Tensors. All of the arguments that follow the inputs
225- # are forwarded directly to the constructor of the `Trace` operator.
233+ # Next we build the trace operator. The `build()` function is also
234+ # responsible for constructing the output frontend Tensors. All of the
235+ # arguments that follow the inputs are forwarded directly to the
236+ # constructor of the `Trace` operator.
226237 return Theta.build([shape], dim, dtype)
227238
228239```
0 commit comments