Skip to content
Merged

fmt #35

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/fib_prim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from lyrt import native, from_prim
from lyrt import from_prim, native
from lyrt.prim import Int


p1 = Int[32](1)
p2 = Int[32](2)
p35 = Int[32](35)
Expand Down
62 changes: 35 additions & 27 deletions examples/tensor_prim.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from lyrt import from_prim
from lyrt.prim import Float, Vector, Matrix, Tensor
from lyrt.prim import Float, Matrix, Tensor, Vector

# ゼロ初期化

v = Vector[Float[32], 4].zeros()
m = Matrix[Float[32], 3, 4].zeros()

t = Tensor[Float[32], 2, 3, 4].zeros()

print(from_prim(v))
Expand All @@ -16,24 +15,28 @@

v2 = Vector[Float[32], 4]([0.3, -1.2, 4.5, 2.1])

m2 = Matrix[Float[32], 3, 4]([
[1.1, -0.7, 3.3, 0.0],
[2.4, 5.6, -2.2, 1.9],
[0.5, 4.8, -1.1, 3.7],
])

t2 = Tensor[Float[32], 2, 3, 4]([
m2 = Matrix[Float[32], 3, 4](
[
[0.9, -1.3, 2.2, 4.4],
[3.1, 0.8, -0.5, 6.6],
[7.7, -2.4, 1.2, 0.3],
],
[1.1, -0.7, 3.3, 0.0],
[2.4, 5.6, -2.2, 1.9],
[0.5, 4.8, -1.1, 3.7],
]
)

t2 = Tensor[Float[32], 2, 3, 4](
[
[5.5, 2.6, -3.3, 1.4],
[8.8, -0.9, 4.0, 2.2],
[6.1, 3.3, -1.7, 9.9],
],
])
[
[0.9, -1.3, 2.2, 4.4],
[3.1, 0.8, -0.5, 6.6],
[7.7, -2.4, 1.2, 0.3],
],
[
[5.5, 2.6, -3.3, 1.4],
[8.8, -0.9, 4.0, 2.2],
[6.1, 3.3, -1.7, 9.9],
],
]
)

print(from_prim(v2))
print(from_prim(m2))
Expand All @@ -43,17 +46,21 @@

v3 = Vector[Float[32], ...]([1.0, 2.0, 3.0, 4.0, 5.0])

m3 = Matrix[Float[32], ...]([
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
])
m3 = Matrix[Float[32], ...](
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
]
)

t3 = Tensor[Float[32], ...]([
t3 = Tensor[Float[32], ...](
[
[1.0, 2.0],
[3.0, 4.0],
],
])
[
[1.0, 2.0],
[3.0, 4.0],
],
]
)

print(from_prim(v3))
print(from_prim(m3))
Expand All @@ -62,4 +69,5 @@
# 0階テンソルはスカラーと同値
t4 = Tensor[Float[32], ...](3.14)
assert from_prim(t4) == Float[32](3.14)

print(from_prim(t4))
15 changes: 13 additions & 2 deletions src/lyrt/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@ __all__ = ["prim", "native", "to_prim", "from_prim"]

T = TypeVar("T")
PrimT = TypeVar("PrimT", bound=prim.Prim[prim.Int] | prim.Prim[prim.Float])
PrimFunc = TypeVar("PrimFunc", bound=Callable[..., prim.Prim[prim.Int]] | Callable[..., prim.Prim[prim.Float]])
PrimFunc = TypeVar(
"PrimFunc",
bound=Callable[..., prim.Prim[prim.Int]] | Callable[..., prim.Prim[prim.Float]],
)

def native(
*,
gc: Literal["none", "shadow-stack", "rc"] = "none",
) -> Callable[[PrimFunc], PrimFunc]: ...
def to_prim(value: object, prim_type: type[PrimT]) -> PrimT: ...
def from_prim(prim_value: prim.Prim[prim.Int] | prim.Prim[prim.Float] | prim.Vector | prim.Matrix | prim.Tensor) -> object: ...
def from_prim(
prim_value: (
prim.Prim[prim.Int]
| prim.Prim[prim.Float]
| prim.Vector
| prim.Matrix
| prim.Tensor
),
) -> object: ...
19 changes: 8 additions & 11 deletions src/lyrt/prim/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class Prim(Generic[PrimT]):
def __neg__(self: PrimT) -> PrimT: ...
def __lt__(self: PrimT, other: PrimT) -> Int[1]: ...
def __le__(self: PrimT, other: PrimT) -> Int[1]: ...
def __eq__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore
def __ne__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore
def __eq__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore
def __ne__(self: PrimT, other: PrimT) -> Int[1]: ... # type: ignore
def __gt__(self: PrimT, other: PrimT) -> Int[1]: ...
def __ge__(self: PrimT, other: PrimT) -> Int[1]: ...

Expand All @@ -46,14 +46,12 @@ class Int(PrimInt[Int]):
"""Integer primitive type"""

def __class_getitem__(cls, key: int) -> Type[Int]: ...

def __init__(self, value: int) -> None: ...

class Float(PrimFloat[Float]):
"""Floating-point primitive type"""

def __class_getitem__(cls, key: int) -> Type[Float]: ...

def __init__(self, value: int | float) -> None: ...

type NumberLike = Prim[Int | Float] | int | float
Expand All @@ -62,9 +60,7 @@ Shape = TypeVarTuple("Shape")

class Vector:
def __class_getitem__(cls, key: tuple[Prim[Int | Float], int]) -> Type[Vector]: ...

def __init__(self, value: List[NumberLike]) -> None: ...

@classmethod
def zeros(cls) -> Vector: ...
@classmethod
Expand All @@ -73,10 +69,10 @@ class Vector:
def full(cls, value: NumberLike) -> Vector: ...

class Matrix:
def __class_getitem__(cls, key: tuple[Prim[Int | Float], int, int]) -> Type[Matrix]: ...

def __class_getitem__(
cls, key: tuple[Prim[Int | Float], int, int]
) -> Type[Matrix]: ...
def __init__(self, value: List[List[NumberLike]]) -> None: ...

@classmethod
def zeros(cls) -> Matrix: ...
@classmethod
Expand All @@ -85,11 +81,12 @@ class Matrix:
def full(cls, value: NumberLike) -> Matrix: ...

class Tensor:
def __class_getitem__(cls, key: tuple[Prim[Int | Float], Unpack[Shape]]) -> Type[Tensor]: ...
def __class_getitem__(
cls, key: tuple[Prim[Int | Float], Unpack[Shape]]
) -> Type[Tensor]: ...

type _NestedNumberLike = NumberLike | List[_NestedNumberLike]
def __init__(self, value: _NestedNumberLike) -> None: ...

@classmethod
def zeros(cls) -> Tensor: ...
@classmethod
Expand Down
12 changes: 9 additions & 3 deletions src/lython/visitors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def get_primitive_type_from_spec(self, base_type: str, bits: int) -> ir.Type:
return ir.IntegerType.get_signless(bits, context=self.ctx)
elif base_type == "Float":
if bits not in FLOAT_VALID_BITS:
raise ValueError(f"Float bit width must be one of {sorted(FLOAT_VALID_BITS)}, got {bits}")
raise ValueError(
f"Float bit width must be one of {sorted(FLOAT_VALID_BITS)}, got {bits}"
)
if bits == 16:
return ir.F16Type.get(context=self.ctx)
elif bits == 32:
Expand Down Expand Up @@ -400,11 +402,15 @@ def annotation_to_primitive_type(

if base_type in PRIMITIVE_BASE_TYPES:
# Get the bit width from the slice
if isinstance(annotation.slice, ast.Constant) and isinstance(annotation.slice.value, int):
if isinstance(annotation.slice, ast.Constant) and isinstance(
annotation.slice.value, int
):
bits = annotation.slice.value
return self.get_primitive_type_from_spec(base_type, bits)
else:
raise ValueError(f"Primitive type {base_type} requires an integer bit width")
raise ValueError(
f"Primitive type {base_type} requires an integer bit width"
)

return None

Expand Down
4 changes: 3 additions & 1 deletion src/lython/visitors/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ def _handle_prim_constructor(

# Get the bit width from the subscript
assert isinstance(node.func, ast.Subscript)
if not isinstance(node.func.slice, ast.Constant) or not isinstance(node.func.slice.value, int):
if not isinstance(node.func.slice, ast.Constant) or not isinstance(
node.func.slice.value, int
):
raise ValueError(f"{base_type} requires an integer bit width")
bits = node.func.slice.value

Expand Down
20 changes: 15 additions & 5 deletions src/lython/visitors/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def hoge(n: int) -> int:
ir.StringAttr.get(arg.arg, self.ctx) for arg in node.args.args
]
arg_names_attr = (
ir.ArrayAttr.get(arg_name_attrs, context=self.ctx) # pyright: ignore[reportUnknownMemberType]
ir.ArrayAttr.get(
arg_name_attrs, context=self.ctx
) # pyright: ignore[reportUnknownMemberType]
if arg_name_attrs
else None
)
Expand Down Expand Up @@ -403,7 +405,9 @@ def _visit_method_def(
ir.StringAttr.get(arg.arg, self.ctx) for arg in node.args.args
]
arg_names_attr = (
ir.ArrayAttr.get(arg_name_attrs, context=self.ctx) # pyright: ignore[reportUnknownMemberType]
ir.ArrayAttr.get(
arg_name_attrs, context=self.ctx
) # pyright: ignore[reportUnknownMemberType]
if arg_name_attrs
else None
)
Expand Down Expand Up @@ -692,9 +696,15 @@ def visit_If(self, node: ast.If) -> None:

assert self.current_block is not None
parent_region = self.current_block.region
true_block = parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType]
false_block = parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType]
merge_block = parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType]
true_block = (
parent_region.blocks.append()
) # pyright: ignore[reportUnknownMemberType]
false_block = (
parent_region.blocks.append()
) # pyright: ignore[reportUnknownMemberType]
merge_block = (
parent_region.blocks.append()
) # pyright: ignore[reportUnknownMemberType]
with self._loc(node), self.insertion_point():
cf_ops.CondBranchOp(cond, [], [], true_block, false_block)

Expand Down
94 changes: 76 additions & 18 deletions tools/CLI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,82 @@ if __name__ == "__main__":
}

LogicalResult runPipeline(ModuleOp module, MLIRContext& context) {
PassManager pm(&context);
// Verify @native functions before any transformations
// This enforces the modal logic separation (Primitive World vs Object World)
pm.addPass(py::createNativeVerificationPass());
// Insert reference counting operations using Affine SSA (Linear Type) logic
pm.addPass(py::createRefCountInsertionPass());
// Early canonicalization and CSE for arith/func ops
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(py::createRuntimeLoweringPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createConvertControlFlowToLLVMPass());
pm.addPass(mlir::createConvertToLLVMPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createReconcileUnrealizedCastsPass());
pm.addPass(mlir::createCanonicalizerPass());
return pm.run(module);
bool dumpIR = static_cast<bool>(
llvm::sys::Process::GetEnv("LYTHON_DUMP_LOWERING_IR"));

if (dumpIR) {
llvm::errs() << "=== [Frontend Output (before any passes)] ===\n";
module.dump();
}

// Phase 1: Native verification
{
PassManager pm(&context);
pm.addPass(py::createNativeVerificationPass());
if (failed(pm.run(module)))
return failure();
}

if (dumpIR) {
llvm::errs() << "\n=== [After NativeVerificationPass] ===\n";
module.dump();
}

// Phase 2: Reference counting insertion
{
PassManager pm(&context);
pm.addPass(py::createRefCountInsertionPass());
if (failed(pm.run(module)))
return failure();
}

if (dumpIR) {
llvm::errs() << "\n=== [After RefCountInsertionPass] ===\n";
module.dump();
}

// Phase 3: Early canonicalization and CSE
{
PassManager pm(&context);
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
if (failed(pm.run(module)))
return failure();
}

if (dumpIR) {
llvm::errs() << "\n=== [After Canonicalizer + CSE] ===\n";
module.dump();
}

// Phase 4: Runtime lowering (Py dialect -> func/LLVM)
{
PassManager pm(&context);
pm.addPass(py::createRuntimeLoweringPass());
if (failed(pm.run(module)))
return failure();
}

// Phase 5: Final lowering to LLVM
{
PassManager pm(&context);
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createConvertControlFlowToLLVMPass());
pm.addPass(mlir::createConvertToLLVMPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createReconcileUnrealizedCastsPass());
pm.addPass(mlir::createCanonicalizerPass());
if (failed(pm.run(module)))
return failure();
}

if (dumpIR) {
llvm::errs() << "\n=== [Final LLVM IR] ===\n";
module.dump();
}

return success();
}

void registerRuntimeSymbols(ExecutionEngine& engine) {
Expand Down