diff --git a/examples/fib_prim.py b/examples/fib_prim.py index 6111a5f..ed8ac4b 100644 --- a/examples/fib_prim.py +++ b/examples/fib_prim.py @@ -1,7 +1,6 @@ -from lyrt import native, from_prim +from lyrt import from_prim, native from lyrt.prim import Int - p1 = Int[32](1) p2 = Int[32](2) p35 = Int[32](35) diff --git a/examples/tensor_prim.py b/examples/tensor_prim.py index 7a6c2ad..a311dcc 100644 --- a/examples/tensor_prim.py +++ b/examples/tensor_prim.py @@ -1,11 +1,10 @@ from lyrt import from_prim -from lyrt.prim import Float, Vector, Matrix, Tensor +from lyrt.prim import Float, Matrix, Tensor, Vector # ゼロ初期化 v = Vector[Float[32], 4].zeros() m = Matrix[Float[32], 3, 4].zeros() - t = Tensor[Float[32], 2, 3, 4].zeros() print(from_prim(v)) @@ -16,24 +15,28 @@ v2 = Vector[Float[32], 4]([0.3, -1.2, 4.5, 2.1]) -m2 = Matrix[Float[32], 3, 4]([ - [1.1, -0.7, 3.3, 0.0], - [2.4, 5.6, -2.2, 1.9], - [0.5, 4.8, -1.1, 3.7], -]) - -t2 = Tensor[Float[32], 2, 3, 4]([ +m2 = Matrix[Float[32], 3, 4]( [ - [0.9, -1.3, 2.2, 4.4], - [3.1, 0.8, -0.5, 6.6], - [7.7, -2.4, 1.2, 0.3], - ], + [1.1, -0.7, 3.3, 0.0], + [2.4, 5.6, -2.2, 1.9], + [0.5, 4.8, -1.1, 3.7], + ] +) + +t2 = Tensor[Float[32], 2, 3, 4]( [ - [5.5, 2.6, -3.3, 1.4], - [8.8, -0.9, 4.0, 2.2], - [6.1, 3.3, -1.7, 9.9], - ], -]) + [ + [0.9, -1.3, 2.2, 4.4], + [3.1, 0.8, -0.5, 6.6], + [7.7, -2.4, 1.2, 0.3], + ], + [ + [5.5, 2.6, -3.3, 1.4], + [8.8, -0.9, 4.0, 2.2], + [6.1, 3.3, -1.7, 9.9], + ], + ] +) print(from_prim(v2)) print(from_prim(m2)) @@ -43,17 +46,21 @@ v3 = Vector[Float[32], ...]([1.0, 2.0, 3.0, 4.0, 5.0]) -m3 = Matrix[Float[32], ...]([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], -]) +m3 = Matrix[Float[32], ...]( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] +) -t3 = Tensor[Float[32], ...]([ +t3 = Tensor[Float[32], ...]( [ - [1.0, 2.0], - [3.0, 4.0], - ], -]) + [ + [1.0, 2.0], + [3.0, 4.0], + ], + ] +) print(from_prim(v3)) print(from_prim(m3)) @@ -62,4 +69,5 @@ # 0階テンソルはスカラーと同値 t4 = Tensor[Float[32], ...](3.14) assert from_prim(t4) == Float[32](3.14) + print(from_prim(t4)) diff --git a/src/lyrt/__init__.pyi b/src/lyrt/__init__.pyi index 307845d..826954a 100644 --- a/src/lyrt/__init__.pyi +++ b/src/lyrt/__init__.pyi @@ -6,11 +6,22 @@ __all__ = ["prim", "native", "to_prim", "from_prim"] T = TypeVar("T") PrimT = TypeVar("PrimT", bound=prim.Prim[prim.Int] | prim.Prim[prim.Float]) -PrimFunc = TypeVar("PrimFunc", bound=Callable[..., prim.Prim[prim.Int]] | Callable[..., prim.Prim[prim.Float]]) +PrimFunc = TypeVar( + "PrimFunc", + bound=Callable[..., prim.Prim[prim.Int]] | Callable[..., prim.Prim[prim.Float]], +) def native( *, gc: Literal["none", "shadow-stack", "rc"] = "none", ) -> Callable[[PrimFunc], PrimFunc]: ... def to_prim(value: object, prim_type: type[PrimT]) -> PrimT: ... -def from_prim(prim_value: prim.Prim[prim.Int] | prim.Prim[prim.Float] | prim.Vector | prim.Matrix | prim.Tensor) -> object: ... +def from_prim( + prim_value: ( + prim.Prim[prim.Int] + | prim.Prim[prim.Float] + | prim.Vector + | prim.Matrix + | prim.Tensor + ), +) -> object: ... diff --git a/src/lyrt/prim/__init__.pyi b/src/lyrt/prim/__init__.pyi index 9962f2d..a237da0 100644 --- a/src/lyrt/prim/__init__.pyi +++ b/src/lyrt/prim/__init__.pyi @@ -23,8 +23,8 @@ class Prim(Generic[PrimT]): def __neg__(self: PrimT) -> PrimT: ... def __lt__(self: PrimT, other: PrimT) -> Int[1]: ... def __le__(self: PrimT, other: PrimT) -> Int[1]: ... - def __eq__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore - def __ne__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore + def __eq__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore + def __ne__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore def __gt__(self: PrimT, other: PrimT) -> Int[1]: ... def __ge__(self: PrimT, other: PrimT) -> Int[1]: ... @@ -46,14 +46,12 @@ class Int(PrimInt[Int]): """Integer primitive type""" def __class_getitem__(cls, key: int) -> Type[Int]: ... - def __init__(self, value: int) -> None: ... class Float(PrimFloat[Float]): """Floating-point primitive type""" def __class_getitem__(cls, key: int) -> Type[Float]: ... - def __init__(self, value: int | float) -> None: ... type NumberLike = Prim[Int | Float] | int | float @@ -62,9 +60,7 @@ Shape = TypeVarTuple("Shape") class Vector: def __class_getitem__(cls, key: tuple[Prim[Int | Float], int]) -> Type[Vector]: ... - def __init__(self, value: List[NumberLike]) -> None: ... - @classmethod def zeros(cls) -> Vector: ... @classmethod @@ -73,10 +69,10 @@ class Vector: def full(cls, value: NumberLike) -> Vector: ... class Matrix: - def __class_getitem__(cls, key: tuple[Prim[Int | Float], int, int]) -> Type[Matrix]: ... - + def __class_getitem__( + cls, key: tuple[Prim[Int | Float], int, int] + ) -> Type[Matrix]: ... def __init__(self, value: List[List[NumberLike]]) -> None: ... - @classmethod def zeros(cls) -> Matrix: ... @classmethod @@ -85,11 +81,12 @@ class Matrix: def full(cls, value: NumberLike) -> Matrix: ... class Tensor: - def __class_getitem__(cls, key: tuple[Prim[Int | Float], Unpack[Shape]]) -> Type[Tensor]: ... + def __class_getitem__( + cls, key: tuple[Prim[Int | Float], Unpack[Shape]] + ) -> Type[Tensor]: ... type _NestedNumberLike = NumberLike | List[_NestedNumberLike] def __init__(self, value: _NestedNumberLike) -> None: ... - @classmethod def zeros(cls) -> Tensor: ... @classmethod diff --git a/src/lython/visitors/_base.py b/src/lython/visitors/_base.py index b927036..3438d81 100644 --- a/src/lython/visitors/_base.py +++ b/src/lython/visitors/_base.py @@ -359,7 +359,9 @@ def get_primitive_type_from_spec(self, base_type: str, bits: int) -> ir.Type: return ir.IntegerType.get_signless(bits, context=self.ctx) elif base_type == "Float": if bits not in FLOAT_VALID_BITS: - raise ValueError(f"Float bit width must be one of {sorted(FLOAT_VALID_BITS)}, got {bits}") + raise ValueError( + f"Float bit width must be one of {sorted(FLOAT_VALID_BITS)}, got {bits}" + ) if bits == 16: return ir.F16Type.get(context=self.ctx) elif bits == 32: @@ -400,11 +402,15 @@ def annotation_to_primitive_type( if base_type in PRIMITIVE_BASE_TYPES: # Get the bit width from the slice - if isinstance(annotation.slice, ast.Constant) and isinstance(annotation.slice.value, int): + if isinstance(annotation.slice, ast.Constant) and isinstance( + annotation.slice.value, int + ): bits = annotation.slice.value return self.get_primitive_type_from_spec(base_type, bits) else: - raise ValueError(f"Primitive type {base_type} requires an integer bit width") + raise ValueError( + f"Primitive type {base_type} requires an integer bit width" + ) return None diff --git a/src/lython/visitors/expr.py b/src/lython/visitors/expr.py index 64062ba..f962acb 100644 --- a/src/lython/visitors/expr.py +++ b/src/lython/visitors/expr.py @@ -602,7 +602,9 @@ def _handle_prim_constructor( # Get the bit width from the subscript assert isinstance(node.func, ast.Subscript) - if not isinstance(node.func.slice, ast.Constant) or not isinstance(node.func.slice.value, int): + if not isinstance(node.func.slice, ast.Constant) or not isinstance( + node.func.slice.value, int + ): raise ValueError(f"{base_type} requires an integer bit width") bits = node.func.slice.value diff --git a/src/lython/visitors/stmt.py b/src/lython/visitors/stmt.py index ac20fdf..c3387f6 100644 --- a/src/lython/visitors/stmt.py +++ b/src/lython/visitors/stmt.py @@ -136,7 +136,9 @@ def hoge(n: int) -> int: ir.StringAttr.get(arg.arg, self.ctx) for arg in node.args.args ] arg_names_attr = ( - ir.ArrayAttr.get(arg_name_attrs, context=self.ctx) # pyright: ignore[reportUnknownMemberType] + ir.ArrayAttr.get( + arg_name_attrs, context=self.ctx + ) # pyright: ignore[reportUnknownMemberType] if arg_name_attrs else None ) @@ -403,7 +405,9 @@ def _visit_method_def( ir.StringAttr.get(arg.arg, self.ctx) for arg in node.args.args ] arg_names_attr = ( - ir.ArrayAttr.get(arg_name_attrs, context=self.ctx) # pyright: ignore[reportUnknownMemberType] + ir.ArrayAttr.get( + arg_name_attrs, context=self.ctx + ) # pyright: ignore[reportUnknownMemberType] if arg_name_attrs else None ) @@ -692,9 +696,15 @@ def visit_If(self, node: ast.If) -> None: assert self.current_block is not None parent_region = self.current_block.region - true_block = parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType] - false_block = parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType] - merge_block = parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType] + true_block = ( + parent_region.blocks.append() + ) # pyright: ignore[reportUnknownMemberType] + false_block = ( + parent_region.blocks.append() + ) # pyright: ignore[reportUnknownMemberType] + merge_block = ( + parent_region.blocks.append() + ) # pyright: ignore[reportUnknownMemberType] with self._loc(node), self.insertion_point(): cf_ops.CondBranchOp(cond, [], [], true_block, false_block) diff --git a/tools/CLI.cpp b/tools/CLI.cpp index ece58fc..515f2c3 100644 --- a/tools/CLI.cpp +++ b/tools/CLI.cpp @@ -111,24 +111,82 @@ if __name__ == "__main__": } LogicalResult runPipeline(ModuleOp module, MLIRContext& context) { - PassManager pm(&context); - // Verify @native functions before any transformations - // This enforces the modal logic separation (Primitive World vs Object World) - pm.addPass(py::createNativeVerificationPass()); - // Insert reference counting operations using Affine SSA (Linear Type) logic - pm.addPass(py::createRefCountInsertionPass()); - // Early canonicalization and CSE for arith/func ops - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(py::createRuntimeLoweringPass()); - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createConvertControlFlowToLLVMPass()); - pm.addPass(mlir::createConvertToLLVMPass()); - pm.addPass(mlir::createReconcileUnrealizedCastsPass()); - pm.addNestedPass(mlir::createReconcileUnrealizedCastsPass()); - pm.addPass(mlir::createCanonicalizerPass()); - return pm.run(module); + bool dumpIR = static_cast( + llvm::sys::Process::GetEnv("LYTHON_DUMP_LOWERING_IR")); + + if (dumpIR) { + llvm::errs() << "=== [Frontend Output (before any passes)] ===\n"; + module.dump(); + } + + // Phase 1: Native verification + { + PassManager pm(&context); + pm.addPass(py::createNativeVerificationPass()); + if (failed(pm.run(module))) + return failure(); + } + + if (dumpIR) { + llvm::errs() << "\n=== [After NativeVerificationPass] ===\n"; + module.dump(); + } + + // Phase 2: Reference counting insertion + { + PassManager pm(&context); + pm.addPass(py::createRefCountInsertionPass()); + if (failed(pm.run(module))) + return failure(); + } + + if (dumpIR) { + llvm::errs() << "\n=== [After RefCountInsertionPass] ===\n"; + module.dump(); + } + + // Phase 3: Early canonicalization and CSE + { + PassManager pm(&context); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + if (failed(pm.run(module))) + return failure(); + } + + if (dumpIR) { + llvm::errs() << "\n=== [After Canonicalizer + CSE] ===\n"; + module.dump(); + } + + // Phase 4: Runtime lowering (Py dialect -> func/LLVM) + { + PassManager pm(&context); + pm.addPass(py::createRuntimeLoweringPass()); + if (failed(pm.run(module))) + return failure(); + } + + // Phase 5: Final lowering to LLVM + { + PassManager pm(&context); + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createConvertControlFlowToLLVMPass()); + pm.addPass(mlir::createConvertToLLVMPass()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + pm.addNestedPass(mlir::createReconcileUnrealizedCastsPass()); + pm.addPass(mlir::createCanonicalizerPass()); + if (failed(pm.run(module))) + return failure(); + } + + if (dumpIR) { + llvm::errs() << "\n=== [Final LLVM IR] ===\n"; + module.dump(); + } + + return success(); } void registerRuntimeSymbols(ExecutionEngine& engine) {