diff --git a/Makefile b/Makefile index 72a11ef..6336cc5 100644 --- a/Makefile +++ b/Makefile @@ -3,16 +3,16 @@ check: flake8 . test: - pytest -v --cov - + pytest -v --cov --doctest-modules cov: test codecov +format: + yapf --recursive -i py_hcl tests examples clean: rm -rf .eggs dist py_hcl.egg-info .coverage .pytest_cache */__pycache__ - upload: test python setup.py sdist && twine upload dist/* \ No newline at end of file diff --git a/README.md b/README.md index a2f23c4..de3efaa 100644 --- a/README.md +++ b/README.md @@ -3,25 +3,25 @@ [![codecov](https://codecov.io/gh/scutdig/py-hcl/branch/master/graph/badge.svg)](https://codecov.io/gh/scutdig/py-hcl) [![PyPI](https://img.shields.io/pypi/v/py-hcl.svg)](https://pypi.python.org/pypi) -PyHCL is a hardware construct language like [Chisel](https://github.com/freechipsproject/chisel3) but more lightweight and more relaxed to use. -As a novel hardware construction framework embedded in Python, PyHCL supports several useful features include object-oriented, functional programming, +PyHCL is a hardware construct language similar to [Chisel](https://github.com/freechipsproject/chisel3) but more lightweight and more relaxed to use. +As a novel hardware construction framework embedded in Python, PyHCL supports several useful features including object-oriented, functional programming, and dynamically typed objects. -The goal of PyHCL is providing a complete design and verification tool flow for heterogeneous computing systems flexibly using the same design methodology. +The goal of PyHCL is providing a complete design and verification toolchain for heterogeneous computing systems. -PyHCL is powered by [FIRRTL](https://github.com/freechipsproject/firrtl), an intermediate representation for digital circuit design. With the FIRRTL +PyHCL is powered by [FIRRTL](https://github.com/freechipsproject/firrtl), an intermediate representation for digital circuit design. Leveraging the FIRRTL compiler framework, PyHCL-generated circuits can be compiled to the widely-used HDL Verilog. ## Getting Started -#### Installing PyHCL +#### Install PyHCL ```shell script $ pip install py-hcl ``` -#### Writing A Full Adder +#### Write a Full Adder PyHCL defines modules using only simple Python syntax that looks like this: ```python from py_hcl import * @@ -42,14 +42,16 @@ class FullAdder(Module): io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin ``` -#### Compiling To FIRRTL +#### Compile To FIRRTL -Compiling module by calling `compile_to_firrtl`: +Compile module via `compile_to_firrtl`: ```python +from py_hcl import compile_to_firrtl + compile_to_firrtl(FullAdder, 'full_adder.fir') ``` -Will generate the following FIRRTL codes: +Will generate the following FIRRTL code: ``` circuit FullAdder : module FullAdder : @@ -72,7 +74,7 @@ circuit FullAdder : FullAdder_io_cout <= _T_6 ``` -#### Compiling To Verilog +#### Compile To Verilog While FIRRTL is generated, PyHCL's job is complete. To further compile to Verilog, the [FIRRTL compiler framework]( https://github.com/freechipsproject/firrtl) is required: diff --git a/py_hcl/core/module_factory/inherit_list/__init__.py b/examples/__init__.py similarity index 100% rename from py_hcl/core/module_factory/inherit_list/__init__.py rename to examples/__init__.py diff --git a/examples/full_adder.py b/examples/full_adder.py new file mode 100644 index 0000000..119394c --- /dev/null +++ b/examples/full_adder.py @@ -0,0 +1,37 @@ +import tempfile +import os + +from py_hcl import Module, IO, Input, Bool, Output + + +class FullAdder(Module): + io = IO( + a=Input(Bool), + b=Input(Bool), + cin=Input(Bool), + sum=Output(Bool), + cout=Output(Bool), + ) + + # Generate the sum + io.sum <<= io.a ^ io.b ^ io.cin + + # Generate the carry + io.cout <<= io.a & io.b | io.b & io.cin | io.a & io.cin + + +def main(): + tmp_dir = tempfile.mkdtemp() + path = os.path.join(tmp_dir, "full_adder.fir") + + FullAdder.compile_to_firrtl(path) + + with open(path) as f: + print(f.read()) + + os.remove(path) + os.removedirs(tmp_dir) + + +if __name__ == '__main__': + main() diff --git a/py_hcl/compile/__init__.py b/py_hcl/compile/__init__.py index d087766..f182bad 100644 --- a/py_hcl/compile/__init__.py +++ b/py_hcl/compile/__init__.py @@ -2,6 +2,47 @@ def compile_to_firrtl(module_class, path=None): + """ + Compiles PyHCL Module `module_class` to FIRRTL source code file. + + Examples + -------- + + Define a PyHCL module: + + >>> from py_hcl import * + >>> class N(Module): + ... io = IO( + ... i=Input(U.w(8)), + ... o=Output(U.w(8)), + ... ) + ... io.o <<= io.i + + Compile to FIRRTL: + + >>> from tempfile import mktemp + >>> tmp_file = mktemp() + >>> compile_to_firrtl(N, tmp_file) + + Read the content: + + >>> with open(tmp_file) as f: + ... print(f.read()) + circuit N : + module N : + input clock : Clock + input reset : UInt<1> + input N_io_i : UInt<8> + output N_io_o : UInt<8> + + N_io_o <= N_io_i + + + + >>> from os import remove + >>> remove(tmp_file) + """ + m = convert(module_class.packed_module) if path is None: @@ -9,3 +50,4 @@ def compile_to_firrtl(module_class, path=None): with open(path, 'wb') as f: m.serialize_stmt(f) + f.flush() diff --git a/py_hcl/core/__init__.py b/py_hcl/core/__init__.py index afcbf77..774e05f 100644 --- a/py_hcl/core/__init__.py +++ b/py_hcl/core/__init__.py @@ -1,8 +1,8 @@ def install_ops(): import py_hcl.core.expr.add # noqa: F401 - import py_hcl.core.expr.and_ # noqa: F401 + import py_hcl.core.expr.ands # noqa: F401 import py_hcl.core.expr.xor # noqa: F401 - import py_hcl.core.expr.or_ # noqa: F401 + import py_hcl.core.expr.ors # noqa: F401 import py_hcl.core.stmt.connect # noqa: F401 import py_hcl.core.expr.field # noqa: F401 import py_hcl.core.expr.slice # noqa: F401 diff --git a/py_hcl/core/error/__init__.py b/py_hcl/core/error/__init__.py index a9e288f..5e47efc 100644 --- a/py_hcl/core/error/__init__.py +++ b/py_hcl/core/error/__init__.py @@ -1,4 +1,4 @@ -from py_hcl.error import PyHclError +from py_hcl.utils.error import PyHclError class CoreError(PyHclError): diff --git a/py_hcl/core/expr/__init__.py b/py_hcl/core/expr/__init__.py index 813523c..ac2baf6 100644 --- a/py_hcl/core/expr/__init__.py +++ b/py_hcl/core/expr/__init__.py @@ -1,9 +1,9 @@ from multipledispatch.dispatcher import MethodDispatcher from py_hcl.core.hcl_ops import op_apply -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import UnknownType, HclType -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize class ExprIdGen: @@ -30,7 +30,7 @@ def get(cls, i): @json_serialize class HclExpr(object): hcl_type = UnknownType() - conn_side = ConnSide.UNKNOWN + variable_type = VariableType.UNKNOWN def __new__(cls, *args): obj = super().__new__(cls) @@ -94,10 +94,12 @@ def to_bool(self): return op_apply('to_bool')(self) -@json_serialize(json_fields=['id', 'type', 'hcl_type', 'conn_side', 'op_node']) +@json_serialize( + json_fields=['id', 'type', 'hcl_type', 'variable_type', 'op_node']) class ExprHolder(HclExpr): - def __init__(self, hcl_type: HclType, conn_side: ConnSide, op_node): + def __init__(self, hcl_type: HclType, variable_type: VariableType, + op_node): self.type = 'expr_holder' self.hcl_type = hcl_type - self.conn_side = conn_side + self.variable_type = variable_type self.op_node = op_node diff --git a/py_hcl/core/expr/add.py b/py_hcl/core/expr/add.py index 7f5974a..86262c5 100644 --- a/py_hcl/core/expr/add.py +++ b/py_hcl/core/expr/add.py @@ -1,16 +1,54 @@ +""" +Implement addition operation for pyhcl values. + +Examples +-------- + +>>> from py_hcl import U, S, Wire, Bundle + + +Add two literals of Uint type: + +>>> res = U(1) + U(2) + + +Add two literals of Sint type: + +>>> res = S(1) + S(2) + + +Add two wires of Uint type: + +>>> w1 = Wire(U.w(8)); w2 = Wire(U.w(9)) +>>> res = w1 + w2 + + +Add two wires of Vector type: + +>>> w1 = Wire(U.w(8)[8]); w2 = Wire(U.w(9)[8]) +>>> res = w1 + w2 + + +Add two wires of Bundle type: + +>>> w1 = Wire(Bundle(a=U.w(2), b=~S.w(3))) +>>> w2 = Wire(Bundle(a=U.w(3), b=~S.w(4))) +>>> res = w1 + w2 +""" + from py_hcl.core.expr import ExprHolder from py_hcl.core.expr.bundle_holder import BundleHolder from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side +from py_hcl.core.expr.utils import ensure_all_args_are_readable from py_hcl.core.expr.vec_holder import VecHolder from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir +from py_hcl.core.type.bundle import BundleT, BundleDirection from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize @@ -25,34 +63,34 @@ def __init__(self, left, right): @adder(UIntT, UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): w = max(lf.hcl_type.width, rt.hcl_type.width) + 1 t = UIntT(w) - return ExprHolder(t, ConnSide.RT, Add(lf, rt)) + return ExprHolder(t, VariableType.ReadOnly, Add(lf, rt)) @adder(SIntT, SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): w = max(lf.hcl_type.width, rt.hcl_type.width) + 1 t = SIntT(w) - return ExprHolder(t, ConnSide.RT, Add(lf, rt)) + return ExprHolder(t, VariableType.ReadOnly, Add(lf, rt)) @adder(VectorT, VectorT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): - # TODO: Accurate Error Message - assert lf.hcl_type.size == rt.hcl_type.size + if lf.hcl_type.size != rt.hcl_type.size: + raise ExprError.unmatched_vec_size(lf.hcl_type.size, rt.hcl_type.size) values = [lf[i] + rt[i] for i in range(lf.hcl_type.size)] v_type = VectorT(values[0].hcl_type, len(values)) - return VecHolder(v_type, ConnSide.RT, values) + return VecHolder(v_type, VariableType.ReadOnly, values) @adder(BundleT, BundleT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): # TODO: Accurate Error Message assert set(lf.hcl_type.fields.keys()) == set(rt.hcl_type.fields.keys()) @@ -61,10 +99,14 @@ def _(lf, rt): bd_values = {} for k in lf.hcl_type.fields.keys(): res = getattr(lf, k) + getattr(rt, k) - bd_type_fields[k] = {"dir": Dir.SRC, "hcl_type": res.hcl_type} + bd_type_fields[k] = { + "dir": BundleDirection.SOURCE, + "hcl_type": res.hcl_type + } bd_values[k] = res - return BundleHolder(BundleT(bd_type_fields), ConnSide.RT, bd_values) + return BundleHolder(BundleT(bd_type_fields), VariableType.ReadOnly, + bd_values) @adder(HclType, HclType) diff --git a/py_hcl/core/expr/and_.py b/py_hcl/core/expr/and_.py deleted file mode 100644 index a52d20a..0000000 --- a/py_hcl/core/expr/and_.py +++ /dev/null @@ -1,72 +0,0 @@ -from py_hcl.core.expr import ExprHolder -from py_hcl.core.expr.bundle_holder import BundleHolder -from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side -from py_hcl.core.expr.vec_holder import VecHolder -from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide -from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir -from py_hcl.core.type.sint import SIntT -from py_hcl.core.type.uint import UIntT -from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize - - -@json_serialize -class And(object): - def __init__(self, left, right): - self.operation = 'and' - self.left_expr_id = left.id - self.right_expr_id = right.id - - -ander = op_register('&') - - -@ander(UIntT, UIntT) -@assert_right_side -def _(lf, rt): - w = max(lf.hcl_type.width, rt.hcl_type.width) - t = UIntT(w) - return ExprHolder(t, ConnSide.RT, And(lf, rt)) - - -@ander(SIntT, SIntT) -@assert_right_side -def _(lf, rt): - w = max(lf.hcl_type.width, rt.hcl_type.width) - t = UIntT(w) - return ExprHolder(t, ConnSide.RT, And(lf, rt)) - - -@ander(VectorT, VectorT) -@assert_right_side -def _(lf, rt): - # TODO: Accurate Error Message - assert lf.hcl_type.size == rt.hcl_type.size - - values = [lf[i] & rt[i] for i in range(lf.hcl_type.size)] - v_type = VectorT(values[0].hcl_type, len(values)) - return VecHolder(v_type, ConnSide.RT, values) - - -@ander(BundleT, BundleT) -@assert_right_side -def _(lf, rt): - # TODO: Accurate Error Message - assert set(lf.hcl_type.fields.keys()) == set(rt.hcl_type.fields.keys()) - - bd_type_fields = {} - bd_values = {} - for k in lf.hcl_type.fields.keys(): - res = getattr(lf, k) & getattr(rt, k) - bd_type_fields[k] = {"dir": Dir.SRC, "hcl_type": res.hcl_type} - bd_values[k] = res - - return BundleHolder(BundleT(bd_type_fields), ConnSide.RT, bd_values) - - -@ander(HclType, HclType) -def _(_0, _1): - raise ExprError.op_type_err('and', _0, _1) diff --git a/py_hcl/core/expr/ands.py b/py_hcl/core/expr/ands.py new file mode 100644 index 0000000..35aca9b --- /dev/null +++ b/py_hcl/core/expr/ands.py @@ -0,0 +1,117 @@ +""" +Implement and operation for pyhcl values. + +*Note: `and` is a reserved keyword in Python. So change the mod name to `ands` +to get around parse error*. + +Examples +-------- + +>>> from py_hcl import U, S, Wire, Bundle + + +And two literals of Uint type: + +>>> res = U(1) & U(2) + + +And two literals of Sint type: + +>>> res = S(1) & S(2) + + +And two wires of Uint type: + +>>> w1 = Wire(U.w(8)); w2 = Wire(U.w(9)) +>>> res = w1 & w2 + + +And two wires of Vector type: + +>>> w1 = Wire(U.w(8)[8]); w2 = Wire(U.w(9)[8]) +>>> res = w1 & w2 + + +And two wires of Bundle type: + +>>> w1 = Wire(Bundle(a=U.w(2), b=~S.w(3))) +>>> w2 = Wire(Bundle(a=U.w(3), b=~S.w(4))) +>>> res = w1 & w2 +""" + +from py_hcl.core.expr import ExprHolder +from py_hcl.core.expr.bundle_holder import BundleHolder +from py_hcl.core.expr.error import ExprError +from py_hcl.core.expr.utils import ensure_all_args_are_readable +from py_hcl.core.expr.vec_holder import VecHolder +from py_hcl.core.hcl_ops import op_register +from py_hcl.core.stmt.connect import VariableType +from py_hcl.core.type import HclType +from py_hcl.core.type.bundle import BundleT, BundleDirection +from py_hcl.core.type.sint import SIntT +from py_hcl.core.type.uint import UIntT +from py_hcl.core.type.vector import VectorT +from py_hcl.utils.serialization import json_serialize + + +@json_serialize +class And(object): + def __init__(self, left, right): + self.operation = 'and' + self.left_expr_id = left.id + self.right_expr_id = right.id + + +ander = op_register('&') + + +@ander(UIntT, UIntT) +@ensure_all_args_are_readable +def _(lf, rt): + w = max(lf.hcl_type.width, rt.hcl_type.width) + t = UIntT(w) + return ExprHolder(t, VariableType.ReadOnly, And(lf, rt)) + + +@ander(SIntT, SIntT) +@ensure_all_args_are_readable +def _(lf, rt): + w = max(lf.hcl_type.width, rt.hcl_type.width) + t = UIntT(w) + return ExprHolder(t, VariableType.ReadOnly, And(lf, rt)) + + +@ander(VectorT, VectorT) +@ensure_all_args_are_readable +def _(lf, rt): + if lf.hcl_type.size != rt.hcl_type.size: + raise ExprError.unmatched_vec_size(lf.hcl_type.size, rt.hcl_type.size) + + values = [lf[i] & rt[i] for i in range(lf.hcl_type.size)] + v_type = VectorT(values[0].hcl_type, len(values)) + return VecHolder(v_type, VariableType.ReadOnly, values) + + +@ander(BundleT, BundleT) +@ensure_all_args_are_readable +def _(lf, rt): + # TODO: Accurate Error Message + assert set(lf.hcl_type.fields.keys()) == set(rt.hcl_type.fields.keys()) + + bd_type_fields = {} + bd_values = {} + for k in lf.hcl_type.fields.keys(): + res = getattr(lf, k) & getattr(rt, k) + bd_type_fields[k] = { + "dir": BundleDirection.SOURCE, + "hcl_type": res.hcl_type + } + bd_values[k] = res + + return BundleHolder(BundleT(bd_type_fields), VariableType.ReadOnly, + bd_values) + + +@ander(HclType, HclType) +def _(_0, _1): + raise ExprError.op_type_err('and', _0, _1) diff --git a/py_hcl/core/expr/bundle_holder.py b/py_hcl/core/expr/bundle_holder.py index e94043d..93a81fc 100644 --- a/py_hcl/core/expr/bundle_holder.py +++ b/py_hcl/core/expr/bundle_holder.py @@ -3,9 +3,9 @@ class BundleHolder(HclExpr): - def __init__(self, hcl_type, conn_side, assoc_value): + def __init__(self, hcl_type, variable_type, assoc_value): self.hcl_type = hcl_type - self.conn_side = conn_side + self.variable_type = variable_type assert isinstance(hcl_type, BundleT) assert set(hcl_type.fields.keys()) == set(assoc_value.keys()) diff --git a/py_hcl/core/expr/convert.py b/py_hcl/core/expr/convert.py index 7015df0..ba3ffba 100644 --- a/py_hcl/core/expr/convert.py +++ b/py_hcl/core/expr/convert.py @@ -1,12 +1,12 @@ from py_hcl.core.expr import ExprHolder from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side +from py_hcl.core.expr.utils import ensure_all_args_are_readable from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize to_bool = op_register('to_bool') to_uint = op_register('to_uint') @@ -28,13 +28,13 @@ def __init__(self, expr): @to_bool(UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(uint): return uint[0] @to_bool(SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(sint): return sint[0] @@ -45,16 +45,16 @@ def _(_0, *_): @to_uint(UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(uint): return uint @to_uint(SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(sint): t = UIntT(sint.hcl_type.width) - return ExprHolder(t, ConnSide.RT, ToUInt(sint)) + return ExprHolder(t, VariableType.ReadOnly, ToUInt(sint)) @to_uint(HclType) @@ -63,14 +63,14 @@ def _(_0, *_): @to_sint(UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(uint): t = SIntT(uint.hcl_type.width) - return ExprHolder(t, ConnSide.RT, ToSInt(uint)) + return ExprHolder(t, VariableType.ReadOnly, ToSInt(uint)) @to_sint(SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(sint): return sint diff --git a/py_hcl/core/expr/error/__init__.py b/py_hcl/core/expr/error/__init__.py index 4f59de8..6ea304a 100644 --- a/py_hcl/core/expr/error/__init__.py +++ b/py_hcl/core/expr/error/__init__.py @@ -5,24 +5,52 @@ def set_up(): ExprError.append({ 'IOValueError': { 'code': 200, - 'value': ExprError('io items should wrap with Input or Output')}, + 'value': + ExprError('IO items should be wrapped with Input or Output.') + }, 'OpTypeError': { 'code': 201, - 'value': ExprError('specified arguments contain unexpected types') + 'value': ExprError('Specified arguments contain unexpected type.') + }, + 'OutOfRangeError': { + 'code': 202, + 'value': + ExprError('Specified value out of range for the given type') + }, + 'VarTypeError': { + 'code': 203, + 'value': + ExprError('Specified expresion has an invalid variable type') + }, + 'UnmatchedVecSizeError': { + 'code': 204, + 'value': ExprError('Sizes of vectors are unmatched') } }) class ExprError(CoreError): @staticmethod - def io_value(msg): + def io_value_err(msg: str): return ExprError.err('IOValueError', msg) @staticmethod def op_type_err(op, *args): ts = ', '.join([type(a.hcl_type).__name__ for a in args]) - msg = '{}(): unsupported operand types: {}'.format(op, ts) + msg = '{}(): unsupported operand type: {}'.format(op, ts) return ExprError.err('OpTypeError', msg) + @staticmethod + def out_of_range_err(msg: str): + return ExprError.err('OutOfRangeError', msg) + + @staticmethod + def var_type_err(msg: str): + return ExprError.err('VarTypeError', msg) + + @staticmethod + def unmatched_vec_size(size0: int, size1: int): + return ExprError.err('UnmatchedVecSizeError', f'{size0} != {size1}') + set_up() diff --git a/py_hcl/core/expr/extend.py b/py_hcl/core/expr/extend.py index dee8eac..b7c33ca 100644 --- a/py_hcl/core/expr/extend.py +++ b/py_hcl/core/expr/extend.py @@ -1,10 +1,10 @@ from py_hcl.core.expr import ExprHolder -from py_hcl.core.expr.utils import assert_right_side +from py_hcl.core.expr.utils import ensure_all_args_are_readable from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize extend = op_register('extend') @@ -17,12 +17,12 @@ def __init__(self, expr): @extend(UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(uint, size): - return ExprHolder(UIntT(size), ConnSide.RT, Extend(uint)) + return ExprHolder(UIntT(size), VariableType.ReadOnly, Extend(uint)) @extend(SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(sint, size): - return ExprHolder(SIntT(size), ConnSide.RT, Extend(sint)) + return ExprHolder(SIntT(size), VariableType.ReadOnly, Extend(sint)) diff --git a/py_hcl/core/expr/field.py b/py_hcl/core/expr/field.py index 530a467..4737b51 100644 --- a/py_hcl/core/expr/field.py +++ b/py_hcl/core/expr/field.py @@ -2,10 +2,10 @@ from py_hcl.core.expr.bundle_holder import BundleHolder from py_hcl.core.expr.error import ExprError from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir -from py_hcl.utils import json_serialize +from py_hcl.core.type.bundle import BundleT, BundleDirection +from py_hcl.utils.serialization import json_serialize field_accessor = op_register('.') @@ -26,12 +26,12 @@ def _(bd, item): return bd.assoc_value[item] # build connect side - sd = bd.conn_side + var_type = bd.variable_type f = bd.hcl_type.fields[item] dr, tpe = f["dir"], f["hcl_type"] - new_sd = build_new_sd(sd, dr) + new_var_type = build_new_var_type(var_type, dr) - return ExprHolder(tpe, new_sd, FieldAccess(bd, item)) + return ExprHolder(tpe, new_var_type, FieldAccess(bd, item)) @field_accessor(HclType) @@ -39,11 +39,12 @@ def _(o, *_): raise ExprError.op_type_err('field_accessor', o) -def build_new_sd(sd: ConnSide, dr: Dir) -> ConnSide: - if sd == ConnSide.BOTH: - return ConnSide.BOTH - if sd == ConnSide.RT and dr == dr.SINK: - return ConnSide.LF - if sd == ConnSide.LF and dr == dr.SRC: - return ConnSide.LF - return ConnSide.RT +def build_new_var_type(var_type: VariableType, + dr: BundleDirection) -> VariableType: + if var_type == VariableType.ReadWrite: + return VariableType.ReadWrite + if var_type == VariableType.ReadOnly and dr == dr.SINK: + return VariableType.WriteOnly + if var_type == VariableType.WriteOnly and dr == dr.SOURCE: + return VariableType.WriteOnly + return VariableType.ReadOnly diff --git a/py_hcl/core/expr/index.py b/py_hcl/core/expr/index.py index a467f56..f5feadf 100644 --- a/py_hcl/core/expr/index.py +++ b/py_hcl/core/expr/index.py @@ -1,13 +1,13 @@ from py_hcl.core.expr import ExprHolder from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side +from py_hcl.core.expr.utils import ensure_all_args_are_readable from py_hcl.core.expr.vec_holder import VecHolder from py_hcl.core.hcl_ops import op_register from py_hcl.core.type import HclType from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize index = op_register('[i]') @@ -21,13 +21,13 @@ def __init__(self, expr, idx: int): @index(UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(uint, i: int): return uint[i:i] @index(SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(sint, i: int): return sint[i:i] @@ -38,7 +38,8 @@ def _(vec, i: int): assert i < vec.hcl_type.size if isinstance(vec, VecHolder): return vec.assoc_value[i] - return ExprHolder(vec.hcl_type.inner_type, vec.conn_side, VecIndex(vec, i)) + return ExprHolder(vec.hcl_type.inner_type, vec.variable_type, + VecIndex(vec, i)) @index(HclType) diff --git a/py_hcl/core/expr/io.py b/py_hcl/core/expr/io.py index 667538d..9c84eca 100644 --- a/py_hcl/core/expr/io.py +++ b/py_hcl/core/expr/io.py @@ -1,12 +1,9 @@ -from typing import Dict, Union, Optional, Tuple +from typing import List from py_hcl.core.expr import HclExpr -from py_hcl.core.expr.error import ExprError -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import Dir, BundleT -from py_hcl.core.utils import module_inherit_mro -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize @@ -23,63 +20,11 @@ def __init__(self, hcl_type: HclType): self.hcl_type = hcl_type -@json_serialize -class IOHolder(object): - def __init__(self, named_ports: Dict[str, Union[Input, Output]], - module_name: Optional[str] = None): - self.named_ports = named_ports - self.module_name = module_name - - -@json_serialize -class IONode(object): - def __init__(self, io_holder: IOHolder, - next_node: Optional["IOHolder"]): - self.io_holder = io_holder - if next_node is not None: - self.next_node = next_node - - -@json_serialize(json_fields=["id", "type", "hcl_type", - "conn_side", "io_chain_head"]) +@json_serialize( + json_fields=["id", "type", "hcl_type", "variable_type", "io_chain"]) class IO(HclExpr): - def __init__(self, hcl_type: HclType, io_chain_head: IONode): + def __init__(self, hcl_type: HclType, io_chain: List["IOHolder"]): self.type = 'io' self.hcl_type = hcl_type - self.conn_side = ConnSide.RT - self.io_chain_head = io_chain_head - - -def io_extend(modules: Tuple[type]): - modules = module_inherit_mro(modules) - - current_ports = {} - io_chain = None - for m in modules[::-1]: - h = m.io.io_chain_head.io_holder - current_ports.update(h.named_ports) - io_chain = IONode(h, io_chain) - - def _(named_ports: Dict[str, Union[Input, Output]]): - current_ports.update(named_ports) - io_chain_head = IONode(IOHolder(named_ports), io_chain) - return IO(calc_type_from_ports(current_ports), io_chain_head) - - return _ - - -def calc_type_from_ports(named_ports: Dict[str, Union[Input, Output]]): - types = {} - for k, v in named_ports.items(): - if isinstance(v, Input): - types[k] = {"dir": Dir.SRC, "hcl_type": v.hcl_type} - continue - - if isinstance(v, Output): - types[k] = {"dir": Dir.SINK, "hcl_type": v.hcl_type} - continue - - raise ExprError.io_value( - "type of '{}' is {}, not Input or Output".format(k, type(v))) - - return BundleT(types) + self.variable_type = VariableType.ReadOnly + self.io_chain = io_chain diff --git a/py_hcl/core/expr/lit_sint.py b/py_hcl/core/expr/lit_sint.py index 8cc3ed4..2488ad4 100644 --- a/py_hcl/core/expr/lit_sint.py +++ b/py_hcl/core/expr/lit_sint.py @@ -1,13 +1,13 @@ from py_hcl.core.expr import HclExpr -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type.sint import SIntT -from py_hcl.utils import signed_num_bin_len +from py_hcl.utils import signed_num_bin_width class SLiteral(HclExpr): def __init__(self, value: int): self.value = value - w = signed_num_bin_len(value) + w = signed_num_bin_width(value) self.hcl_type = SIntT(w) - self.conn_side = ConnSide.RT + self.variable_type = VariableType.ReadOnly diff --git a/py_hcl/core/expr/lit_uint.py b/py_hcl/core/expr/lit_uint.py index bd0d010..6023f44 100644 --- a/py_hcl/core/expr/lit_uint.py +++ b/py_hcl/core/expr/lit_uint.py @@ -1,13 +1,13 @@ from py_hcl.core.expr import HclExpr -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type.uint import UIntT -from py_hcl.utils import unsigned_num_bin_len +from py_hcl.utils import unsigned_num_bin_width class ULiteral(HclExpr): def __init__(self, value: int): self.value = value - w = unsigned_num_bin_len(value) + w = unsigned_num_bin_width(value) self.hcl_type = UIntT(w) - self.conn_side = ConnSide.RT + self.variable_type = VariableType.ReadOnly diff --git a/py_hcl/core/expr/mod_inst.py b/py_hcl/core/expr/mod_inst.py index 395a1b2..d0b01f0 100644 --- a/py_hcl/core/expr/mod_inst.py +++ b/py_hcl/core/expr/mod_inst.py @@ -1,15 +1,15 @@ from py_hcl.core.expr import HclExpr -from py_hcl.core.stmt.connect import ConnSide -from py_hcl.utils import json_serialize +from py_hcl.core.stmt.connect import VariableType +from py_hcl.utils.serialization import json_serialize @json_serialize( - json_fields=['id', 'type', 'hcl_type', "conn_side", 'module_name']) + json_fields=['id', 'type', 'hcl_type', "variable_type", 'module_name']) class ModuleInst(HclExpr): def __init__(self, module_cls): self.type = 'module_inst' self.hcl_type = module_cls.io.hcl_type - self.conn_side = ConnSide.LF + self.variable_type = VariableType.WriteOnly self.packed_module = module_cls.packed_module self.module_name = module_cls.packed_module.name diff --git a/py_hcl/core/expr/or_.py b/py_hcl/core/expr/or_.py deleted file mode 100644 index 2f57d00..0000000 --- a/py_hcl/core/expr/or_.py +++ /dev/null @@ -1,72 +0,0 @@ -from py_hcl.core.expr import ExprHolder -from py_hcl.core.expr.bundle_holder import BundleHolder -from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side -from py_hcl.core.expr.vec_holder import VecHolder -from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide -from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir -from py_hcl.core.type.sint import SIntT -from py_hcl.core.type.uint import UIntT -from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize - - -@json_serialize -class Or(object): - def __init__(self, left, right): - self.operation = 'or' - self.left_expr_id = left.id - self.right_expr_id = right.id - - -orer = op_register('|') - - -@orer(UIntT, UIntT) -@assert_right_side -def _(lf, rt): - w = max(lf.hcl_type.width, rt.hcl_type.width) - t = UIntT(w) - return ExprHolder(t, ConnSide.RT, Or(lf, rt)) - - -@orer(SIntT, SIntT) -@assert_right_side -def _(lf, rt): - w = max(lf.hcl_type.width, rt.hcl_type.width) - t = UIntT(w) - return ExprHolder(t, ConnSide.RT, Or(lf, rt)) - - -@orer(VectorT, VectorT) -@assert_right_side -def _(lf, rt): - # TODO: Accurate Error Message - assert lf.hcl_type.size == rt.hcl_type.size - - values = [lf[i] | rt[i] for i in range(lf.hcl_type.size)] - v_type = VectorT(values[0].hcl_type, len(values)) - return VecHolder(v_type, ConnSide.RT, values) - - -@orer(BundleT, BundleT) -@assert_right_side -def _(lf, rt): - # TODO: Accurate Error Message - assert set(lf.hcl_type.fields.keys()) == set(rt.hcl_type.fields.keys()) - - bd_type_fields = {} - bd_values = {} - for k in lf.hcl_type.fields.keys(): - res = getattr(lf, k) | getattr(rt, k) - bd_type_fields[k] = {"dir": Dir.SRC, "hcl_type": res.hcl_type} - bd_values[k] = res - - return BundleHolder(BundleT(bd_type_fields), ConnSide.RT, bd_values) - - -@orer(HclType, HclType) -def _(_0, _1): - raise ExprError.op_type_err('or', _0, _1) diff --git a/py_hcl/core/expr/ors.py b/py_hcl/core/expr/ors.py new file mode 100644 index 0000000..2781fc5 --- /dev/null +++ b/py_hcl/core/expr/ors.py @@ -0,0 +1,117 @@ +""" +Implement or operation for pyhcl values. + +*Note: `or` is a reserved keyword in Python. So change the mod name to `ors` +to get around parse error*. + +Examples +-------- + +>>> from py_hcl import U, S, Wire, Bundle + + +Or two literals of Uint type: + +>>> res = U(1) | U(2) + + +Or two literals of Sint type: + +>>> res = S(1) | S(2) + + +Or two wires of Uint type: + +>>> w1 = Wire(U.w(8)); w2 = Wire(U.w(9)) +>>> res = w1 | w2 + + +Or two wires of Vector type: + +>>> w1 = Wire(U.w(8)[8]); w2 = Wire(U.w(9)[8]) +>>> res = w1 | w2 + + +Or two wires of Bundle type: + +>>> w1 = Wire(Bundle(a=U.w(2), b=~S.w(3))) +>>> w2 = Wire(Bundle(a=U.w(3), b=~S.w(4))) +>>> res = w1 | w2 +""" + +from py_hcl.core.expr import ExprHolder +from py_hcl.core.expr.bundle_holder import BundleHolder +from py_hcl.core.expr.error import ExprError +from py_hcl.core.expr.utils import ensure_all_args_are_readable +from py_hcl.core.expr.vec_holder import VecHolder +from py_hcl.core.hcl_ops import op_register +from py_hcl.core.stmt.connect import VariableType +from py_hcl.core.type import HclType +from py_hcl.core.type.bundle import BundleT, BundleDirection +from py_hcl.core.type.sint import SIntT +from py_hcl.core.type.uint import UIntT +from py_hcl.core.type.vector import VectorT +from py_hcl.utils.serialization import json_serialize + + +@json_serialize +class Or(object): + def __init__(self, left, right): + self.operation = 'or' + self.left_expr_id = left.id + self.right_expr_id = right.id + + +orer = op_register('|') + + +@orer(UIntT, UIntT) +@ensure_all_args_are_readable +def _(lf, rt): + w = max(lf.hcl_type.width, rt.hcl_type.width) + t = UIntT(w) + return ExprHolder(t, VariableType.ReadOnly, Or(lf, rt)) + + +@orer(SIntT, SIntT) +@ensure_all_args_are_readable +def _(lf, rt): + w = max(lf.hcl_type.width, rt.hcl_type.width) + t = UIntT(w) + return ExprHolder(t, VariableType.ReadOnly, Or(lf, rt)) + + +@orer(VectorT, VectorT) +@ensure_all_args_are_readable +def _(lf, rt): + if lf.hcl_type.size != rt.hcl_type.size: + raise ExprError.unmatched_vec_size(lf.hcl_type.size, rt.hcl_type.size) + + values = [lf[i] | rt[i] for i in range(lf.hcl_type.size)] + v_type = VectorT(values[0].hcl_type, len(values)) + return VecHolder(v_type, VariableType.ReadOnly, values) + + +@orer(BundleT, BundleT) +@ensure_all_args_are_readable +def _(lf, rt): + # TODO: Accurate Error Message + assert set(lf.hcl_type.fields.keys()) == set(rt.hcl_type.fields.keys()) + + bd_type_fields = {} + bd_values = {} + for k in lf.hcl_type.fields.keys(): + res = getattr(lf, k) | getattr(rt, k) + bd_type_fields[k] = { + "dir": BundleDirection.SOURCE, + "hcl_type": res.hcl_type + } + bd_values[k] = res + + return BundleHolder(BundleT(bd_type_fields), VariableType.ReadOnly, + bd_values) + + +@orer(HclType, HclType) +def _(_0, _1): + raise ExprError.op_type_err('or', _0, _1) diff --git a/py_hcl/core/expr/slice.py b/py_hcl/core/expr/slice.py index ed9b49e..61b52d7 100644 --- a/py_hcl/core/expr/slice.py +++ b/py_hcl/core/expr/slice.py @@ -1,14 +1,14 @@ from py_hcl.core.expr import ExprHolder from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side +from py_hcl.core.expr.utils import ensure_all_args_are_readable from py_hcl.core.expr.vec_holder import VecHolder from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize slice_ = op_register('[i:j]') @@ -23,19 +23,19 @@ def __init__(self, expr, high, low): @slice_(UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(uint, high: int, low: int): check_bit_width(uint, high, low) t = UIntT(high - low + 1) - return ExprHolder(t, ConnSide.RT, Bits(uint, high, low)) + return ExprHolder(t, VariableType.ReadOnly, Bits(uint, high, low)) @slice_(SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(sint, high: int, low: int): check_bit_width(sint, high, low) t = UIntT(high - low + 1) - return ExprHolder(t, ConnSide.RT, Bits(sint, high, low)) + return ExprHolder(t, VariableType.ReadOnly, Bits(sint, high, low)) @slice_(VectorT) @@ -43,12 +43,12 @@ def _(vec, low: int, high: int): check_vec_size(vec, low, high) if isinstance(vec, VecHolder): - values = vec.assoc_value[low: high] + values = vec.assoc_value[low:high] else: values = [vec[i] for i in range(low, high, 1)] v_type = VectorT(vec.hcl_type.inner_type, high - low) - return VecHolder(v_type, vec.conn_side, values) + return VecHolder(v_type, vec.variable_type, values) @slice_(HclType) diff --git a/py_hcl/core/expr/utils/__init__.py b/py_hcl/core/expr/utils/__init__.py index ae8ec36..6eca3f1 100644 --- a/py_hcl/core/expr/utils/__init__.py +++ b/py_hcl/core/expr/utils/__init__.py @@ -1,13 +1,58 @@ -from py_hcl.core.stmt.connect import ConnSide +import functools +from py_hcl.core.expr.error import ExprError +from py_hcl.core.stmt.connect import VariableType -def assert_right_side(f): + +def ensure_all_args_are_readable(f): + """ + A helper decorator to ensure that the variable type of all arguments within + a function call should be `ReadOnly` or `ReadWrite` if they are PyHCL + expressions. + + It's useful to check validity when constructing nodes of binary operation + like `add`, and unary operation like `invert`. + + Examples + -------- + + >>> from py_hcl import * + >>> @ensure_all_args_are_readable + ... def func(*args): + ... pass + + + Literals are `ReadOnly` so they will pass the check: + + >>> func(U(10), S(30)) + + + Also for wires as they're `ReadWrite`: + + >>> func(Wire(U.w(10)), Wire(S.w(10))) + + + But not for output as they're `WriteOnly`: + + >>> class TempModule(Module): + ... io = IO(o=Output(U.w(10))) + ... func(io.o) + Traceback (most recent call last): + ... + py_hcl.core.expr.error.ExprError: Specified expresion has an invalid + variable type + """ + @functools.wraps(f) def _(*args): - check_lists = [a for a in args if hasattr(a, 'conn_side')] - sides = [ConnSide.RT, ConnSide.BOTH] + check_lists = [a for a in args if hasattr(a, 'variable_type')] + sides = [VariableType.ReadOnly, VariableType.ReadWrite] + + for a in check_lists: + if a.variable_type not in sides: + msg = f'{a}\'s variable_type neither `ReadOnly` ' \ + f'nor `ReadWrite`' + raise ExprError.var_type_err(msg) - # TODO: Accurate Error Message - assert all(a.conn_side in sides for a in check_lists) return f(*args) return _ diff --git a/py_hcl/core/expr/vec_holder.py b/py_hcl/core/expr/vec_holder.py index 3f1f65e..915665c 100644 --- a/py_hcl/core/expr/vec_holder.py +++ b/py_hcl/core/expr/vec_holder.py @@ -3,9 +3,9 @@ class VecHolder(HclExpr): - def __init__(self, hcl_type, conn_side, assoc_value): + def __init__(self, hcl_type, variable_type, assoc_value): self.hcl_type = hcl_type - self.conn_side = conn_side + self.variable_type = variable_type assert isinstance(hcl_type, VectorT) assert hcl_type.size == len(assoc_value) diff --git a/py_hcl/core/expr/wire.py b/py_hcl/core/expr/wire.py index a90eff8..da4dad6 100644 --- a/py_hcl/core/expr/wire.py +++ b/py_hcl/core/expr/wire.py @@ -1,11 +1,11 @@ from py_hcl.core.expr import HclExpr -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize -@json_serialize(json_fields=['hcl_type', 'conn_side']) +@json_serialize(json_fields=['hcl_type', 'variable_type']) class Wire(HclExpr): def __init__(self, hcl_type: HclType): self.hcl_type = hcl_type - self.conn_side = ConnSide.BOTH + self.variable_type = VariableType.ReadWrite diff --git a/py_hcl/core/expr/xor.py b/py_hcl/core/expr/xor.py index 7225cf0..43f0ba2 100644 --- a/py_hcl/core/expr/xor.py +++ b/py_hcl/core/expr/xor.py @@ -1,16 +1,54 @@ +""" +Implement xor operation for pyhcl values. + +Examples +-------- + +>>> from py_hcl import U, S, Wire, Bundle + + +Xor two literals of Uint type: + +>>> res = U(1) ^ U(2) + + +Xor two literals of Sint type: + +>>> res = S(1) ^ S(2) + + +Xor two wires of Uint type: + +>>> w1 = Wire(U.w(8)); w2 = Wire(U.w(9)) +>>> res = w1 ^ w2 + + +Xor two wires of Vector type: + +>>> w1 = Wire(U.w(8)[8]); w2 = Wire(U.w(9)[8]) +>>> res = w1 ^ w2 + + +Xor two wires of Bundle type: + +>>> w1 = Wire(Bundle(a=U.w(2), b=~S.w(3))) +>>> w2 = Wire(Bundle(a=U.w(3), b=~S.w(4))) +>>> res = w1 ^ w2 +""" + from py_hcl.core.expr import ExprHolder from py_hcl.core.expr.bundle_holder import BundleHolder from py_hcl.core.expr.error import ExprError -from py_hcl.core.expr.utils import assert_right_side +from py_hcl.core.expr.utils import ensure_all_args_are_readable from py_hcl.core.expr.vec_holder import VecHolder from py_hcl.core.hcl_ops import op_register -from py_hcl.core.stmt.connect import ConnSide +from py_hcl.core.stmt.connect import VariableType from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir +from py_hcl.core.type.bundle import BundleT, BundleDirection from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize @@ -25,34 +63,34 @@ def __init__(self, left, right): @xorer(UIntT, UIntT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): w = max(lf.hcl_type.width, rt.hcl_type.width) t = UIntT(w) - return ExprHolder(t, ConnSide.RT, Xor(lf, rt)) + return ExprHolder(t, VariableType.ReadOnly, Xor(lf, rt)) @xorer(SIntT, SIntT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): w = max(lf.hcl_type.width, rt.hcl_type.width) t = UIntT(w) - return ExprHolder(t, ConnSide.RT, Xor(lf, rt)) + return ExprHolder(t, VariableType.ReadOnly, Xor(lf, rt)) @xorer(VectorT, VectorT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): - # TODO: Accurate Error Message - assert lf.hcl_type.size == rt.hcl_type.size + if lf.hcl_type.size != rt.hcl_type.size: + raise ExprError.unmatched_vec_size(lf.hcl_type.size, rt.hcl_type.size) values = [lf[i] ^ rt[i] for i in range(lf.hcl_type.size)] v_type = VectorT(values[0].hcl_type, len(values)) - return VecHolder(v_type, ConnSide.RT, values) + return VecHolder(v_type, VariableType.ReadOnly, values) @xorer(BundleT, BundleT) -@assert_right_side +@ensure_all_args_are_readable def _(lf, rt): # TODO: Accurate Error Message assert set(lf.hcl_type.fields.keys()) == set(rt.hcl_type.fields.keys()) @@ -61,10 +99,14 @@ def _(lf, rt): bd_values = {} for k in lf.hcl_type.fields.keys(): res = getattr(lf, k) ^ getattr(rt, k) - bd_type_fields[k] = {"dir": Dir.SRC, "hcl_type": res.hcl_type} + bd_type_fields[k] = { + "dir": BundleDirection.SOURCE, + "hcl_type": res.hcl_type + } bd_values[k] = res - return BundleHolder(BundleT(bd_type_fields), ConnSide.RT, bd_values) + return BundleHolder(BundleT(bd_type_fields), VariableType.ReadOnly, + bd_values) @xorer(HclType, HclType) diff --git a/py_hcl/core/hcl_ops/__init__.py b/py_hcl/core/hcl_ops/__init__.py index fe5be0e..7594467 100644 --- a/py_hcl/core/hcl_ops/__init__.py +++ b/py_hcl/core/hcl_ops/__init__.py @@ -31,7 +31,7 @@ def _(*objects): if func is not None: return func(*objects) - msg = 'No matched functions for types {} while calling operation ' \ + msg = 'No matched functions for type {} while calling operation ' \ '"{}"'.format([type(o).__name__ for o in objects], operation) raise NotImplementedError(msg) diff --git a/py_hcl/core/module/base_module.py b/py_hcl/core/module/base_module.py index 0113ad4..d61e661 100644 --- a/py_hcl/core/module/base_module.py +++ b/py_hcl/core/module/base_module.py @@ -1,5 +1,6 @@ +from py_hcl.compile import compile_to_firrtl from py_hcl.core import install_ops -from py_hcl.core.expr.io import io_extend +from py_hcl.core.module_factory.inherit_chain.io import io_extend from py_hcl.core.module.meta_module import MetaModule install_ops() @@ -7,3 +8,48 @@ class BaseModule(metaclass=MetaModule): io = io_extend(tuple())({}) + + @classmethod + def compile_to_firrtl(cls, path=None): + """ + Compiles current PyHCL Module to FIRRTL source code file. + + Examples + -------- + + Define a PyHCL module: + + >>> from py_hcl import * + >>> class M(Module): + ... io = IO( + ... i=Input(U.w(8)), + ... o=Output(U.w(8)), + ... ) + ... io.o <<= io.i + + Compile to FIRRTL: + + >>> from tempfile import mktemp + >>> tmp_file = mktemp() + >>> M.compile_to_firrtl(tmp_file) + + Read the content: + + >>> with open(tmp_file) as f: + ... print(f.read()) + circuit M : + module M : + input clock : Clock + input reset : UInt<1> + input M_io_i : UInt<8> + output M_io_o : UInt<8> + + M_io_o <= M_io_i + + + + >>> from os import remove + >>> remove(tmp_file) + """ + + compile_to_firrtl(cls, path) diff --git a/py_hcl/core/module/meta_module.py b/py_hcl/core/module/meta_module.py index 0c5a227..16dd3d4 100644 --- a/py_hcl/core/module/meta_module.py +++ b/py_hcl/core/module/meta_module.py @@ -14,7 +14,7 @@ def __init__(cls, name, bases, dct): name = fetch_module_name(name) check_io_exist(dct, name) - dct["io"].io_chain_head.io_holder.module_name = name + dct["io"].io_chain[0].module_name = name packed = packer.pack(bases, dct, name) cls.packed_module = packed diff --git a/py_hcl/core/module/packed_module.py b/py_hcl/core/module/packed_module.py index ed06ac3..32c4bf6 100644 --- a/py_hcl/core/module/packed_module.py +++ b/py_hcl/core/module/packed_module.py @@ -1,4 +1,4 @@ -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize diff --git a/py_hcl/core/module_factory/error/__init__.py b/py_hcl/core/module_factory/error/__init__.py index 7ef30b8..f1befa5 100644 --- a/py_hcl/core/module_factory/error/__init__.py +++ b/py_hcl/core/module_factory/error/__init__.py @@ -5,7 +5,8 @@ def set_up(): ModuleError.append({ 'NotContainsIO': { 'code': 100, - 'value': ModuleError('the module lack of io attribute')}, + 'value': ModuleError('the module lack of io attribute') + }, }) diff --git a/py_hcl/dsl/tpe/__init__.py b/py_hcl/core/module_factory/inherit_chain/__init__.py similarity index 100% rename from py_hcl/dsl/tpe/__init__.py rename to py_hcl/core/module_factory/inherit_chain/__init__.py diff --git a/py_hcl/core/module_factory/inherit_chain/io.py b/py_hcl/core/module_factory/inherit_chain/io.py new file mode 100644 index 0000000..bd47b7d --- /dev/null +++ b/py_hcl/core/module_factory/inherit_chain/io.py @@ -0,0 +1,52 @@ +from typing import Dict, Union, Optional, Tuple + +from py_hcl.core.expr.error import ExprError +from py_hcl.core.expr.io import Input, Output, IO +from py_hcl.core.type.bundle import BundleT, BundleDirection +from py_hcl.core.utils import module_inherit_mro +from py_hcl.utils.serialization import json_serialize + + +@json_serialize +class IOHolder(object): + def __init__(self, + named_ports: Dict[str, Union[Input, Output]], + module_name: Optional[str] = None): + self.named_ports = named_ports + self.module_name = module_name + + +def io_extend(modules: Tuple[type]): + modules = module_inherit_mro(modules) + + current_ports = {} + io_chain = [] + for m in modules[::-1]: + h = m.io.io_chain[0] + current_ports.update(h.named_ports) + io_chain.insert(0, h) + + def _(named_ports: Dict[str, Union[Input, Output]]): + current_ports.update(named_ports) + io_chain.insert(0, IOHolder(named_ports)) + return IO(__build_bundle_type_from_ports(current_ports), io_chain) + + return _ + + +def __build_bundle_type_from_ports( + named_ports: Dict[str, Union[Input, Output]]) -> BundleT: + fields = {} + for k, v in named_ports.items(): + if isinstance(v, Input): + fields[k] = {"dir": BundleDirection.SOURCE, "hcl_type": v.hcl_type} + continue + + if isinstance(v, Output): + fields[k] = {"dir": BundleDirection.SINK, "hcl_type": v.hcl_type} + continue + + raise ExprError.io_value_err( + "type of '{}' is {}, not Input or Output".format(k, type(v))) + + return BundleT(fields) diff --git a/py_hcl/core/module_factory/inherit_chain/named_expr.py b/py_hcl/core/module_factory/inherit_chain/named_expr.py new file mode 100644 index 0000000..c5d3860 --- /dev/null +++ b/py_hcl/core/module_factory/inherit_chain/named_expr.py @@ -0,0 +1,11 @@ +from typing import Dict + +from py_hcl.utils.serialization import json_serialize + + +@json_serialize +class NamedExprHolder(object): + def __init__(self, module_name: str, named_expression_table: Dict[int, + str]): + self.module_name = module_name + self.named_expression_table = named_expression_table diff --git a/py_hcl/core/module_factory/inherit_chain/stmt_holder.py b/py_hcl/core/module_factory/inherit_chain/stmt_holder.py new file mode 100644 index 0000000..ef9cd54 --- /dev/null +++ b/py_hcl/core/module_factory/inherit_chain/stmt_holder.py @@ -0,0 +1,9 @@ +from py_hcl.core.stmt import ClusterStatement +from py_hcl.utils.serialization import json_serialize + + +@json_serialize +class StmtHolder(object): + def __init__(self, module_name: str, top_statement: ClusterStatement): + self.module_name = module_name + self.top_statement = top_statement diff --git a/py_hcl/core/module_factory/inherit_list/named_expr.py b/py_hcl/core/module_factory/inherit_list/named_expr.py deleted file mode 100644 index 7e47458..0000000 --- a/py_hcl/core/module_factory/inherit_list/named_expr.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Optional, Dict - -from py_hcl.utils import json_serialize - - -@json_serialize -class NamedExprHolder(object): - def __init__(self, module_name: str, - named_expression_table: Dict[int, str]): - self.module_name = module_name - self.named_expression_table = named_expression_table - - -@json_serialize -class NamedExprNode(object): - def __init__(self, - named_expr_holder: NamedExprHolder, - next_node: Optional["NamedExprNode"]): - self.named_expr_holder = named_expr_holder - if next_node: - self.next_node = next_node - - -@json_serialize -class NamedExprChain(object): - def __init__(self, named_expr_chain_head: NamedExprNode): - self.named_expr_chain_head = named_expr_chain_head diff --git a/py_hcl/core/module_factory/inherit_list/stmt_holder.py b/py_hcl/core/module_factory/inherit_list/stmt_holder.py deleted file mode 100644 index 3d248bc..0000000 --- a/py_hcl/core/module_factory/inherit_list/stmt_holder.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Optional - -from py_hcl.core.stmt import ClusterStatement -from py_hcl.utils import json_serialize - - -@json_serialize -class StmtHolder(object): - def __init__(self, module_name: str, top_statement: ClusterStatement): - self.module_name = module_name - self.top_statement = top_statement - - -@json_serialize -class StmtNode(object): - def __init__(self, stmt_holder: StmtHolder, - next_node: Optional["StmtNode"]): - self.stmt_holder = stmt_holder - if next_node: - self.next_node = next_node - - -@json_serialize -class StmtChain(object): - def __init__(self, stmt_chain_head: StmtNode): - self.stmt_chain_head = stmt_chain_head diff --git a/py_hcl/core/module_factory/merger.py b/py_hcl/core/module_factory/merger.py deleted file mode 100644 index 4701239..0000000 --- a/py_hcl/core/module_factory/merger.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import List - -from .inherit_list.named_expr import NamedExprChain, \ - NamedExprNode, NamedExprHolder -from .inherit_list.stmt_holder import StmtChain, \ - StmtHolder, StmtNode - - -def merge_expr(modules: List[type], - expr_holder: NamedExprHolder) -> NamedExprChain: - expr_list = None - for m in modules[::-1]: - h = m.packed_module.named_expr_chain \ - .named_expr_chain_head.named_expr_holder - expr_list = NamedExprNode(h, expr_list) - - expr_list = NamedExprNode(expr_holder, expr_list) - return NamedExprChain(expr_list) - - -def merge_statement(modules: List[type], - stmt_holder: StmtHolder) -> StmtChain: - stmt_list = None - for m in modules[::-1]: - h = m.packed_module.statement_chain \ - .stmt_chain_head.stmt_holder - stmt_list = StmtNode(h, stmt_list) - - stmt_list = StmtNode(stmt_holder, stmt_list) - return StmtChain(stmt_list) diff --git a/py_hcl/core/module_factory/packer.py b/py_hcl/core/module_factory/packer.py index daaf6fe..fcfc578 100644 --- a/py_hcl/core/module_factory/packer.py +++ b/py_hcl/core/module_factory/packer.py @@ -1,10 +1,9 @@ from py_hcl.core.module.packed_module import PackedModule -from py_hcl.core.module_factory.inherit_list.named_expr import NamedExprHolder -from py_hcl.core.module_factory.inherit_list.stmt_holder import StmtHolder +from py_hcl.core.module_factory.inherit_chain.named_expr import NamedExprHolder +from py_hcl.core.module_factory.inherit_chain.stmt_holder import StmtHolder from py_hcl.core.stmt_factory.trapper import StatementTrapper from py_hcl.core.utils import module_inherit_mro from . import extractor -from . import merger from ..stmt import ClusterStatement, ConditionStatement from ..stmt_factory.scope import ScopeType @@ -23,11 +22,11 @@ def pack(bases, dct, name) -> PackedModule: def handle_inherit(bases, named_expression, top_statement, name): modules = module_inherit_mro(bases) - named_expr_chain = \ - merger.merge_expr(modules, NamedExprHolder(name, named_expression)) + named_expr_chain = [NamedExprHolder(name, named_expression)] + \ + [m.packed_module.named_expr_chain[0] for m in modules] - statement_chain = \ - merger.merge_statement(modules, StmtHolder(name, top_statement)) + statement_chain = [StmtHolder(name, top_statement)] + \ + [m.packed_module.statement_chain[0] for m in modules] return named_expr_chain, statement_chain diff --git a/py_hcl/core/stmt/__init__.py b/py_hcl/core/stmt/__init__.py index 966638e..effd786 100644 --- a/py_hcl/core/stmt/__init__.py +++ b/py_hcl/core/stmt/__init__.py @@ -1,4 +1,4 @@ -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize(json_fields=['stmt_class', 'statement']) diff --git a/py_hcl/core/stmt/branch.py b/py_hcl/core/stmt/branch.py index 4d8b447..99df51a 100644 --- a/py_hcl/core/stmt/branch.py +++ b/py_hcl/core/stmt/branch.py @@ -4,7 +4,7 @@ from py_hcl.core.stmt_factory.scope import ScopeManager, ScopeType from py_hcl.core.stmt_factory.trapper import StatementTrapper from py_hcl.core.type.uint import UIntT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize @@ -55,9 +55,8 @@ def do_otherwise_exit(): def check_bool_expr(cond_expr: HclExpr): if isinstance(cond_expr.hcl_type, UIntT) and cond_expr.hcl_type.width == 1: return - raise StatementError.wrong_branch_syntax( - 'check_bool_expr(): ' - 'expected bool-type expression') + raise StatementError.wrong_branch_syntax('check_bool_expr(): ' + 'expected bool-type expression') def check_branch_syntax(): @@ -68,9 +67,8 @@ def check_branch_syntax(): def check_exists_pre_stmts(): if len(StatementTrapper.trapped_stmts[-1]) == 0: - raise StatementError.wrong_branch_syntax( - 'check_exists_pre_stmts(): ' - 'expected when block') + raise StatementError.wrong_branch_syntax('check_exists_pre_stmts(): ' + 'expected when block') def check_exists_pre_when_block(): @@ -102,6 +100,5 @@ def check_correct_block_level(): if last_scope_level == current_scope_level + 1: return - raise StatementError.wrong_branch_syntax( - 'check_correct_block_level(): ' - 'branch block not matched') + raise StatementError.wrong_branch_syntax('check_correct_block_level(): ' + 'branch block not matched') diff --git a/py_hcl/core/stmt/connect.py b/py_hcl/core/stmt/connect.py index 1153717..b44577c 100644 --- a/py_hcl/core/stmt/connect.py +++ b/py_hcl/core/stmt/connect.py @@ -1,3 +1,46 @@ +""" +Implement connection between two PyHCL expressions. + +Examples +-------- + +>>> from py_hcl import * + +Connect literal to output: + +>>> class _(Module): +... io = IO(o=Output(U.w(5))) +... io.o <<= U(10) + + +Connect input to output: + +>>> class _(Module): +... io = IO(i=Input(U.w(8)), o=Output(U.w(5))) +... io.o <<= io.i + + +Connect wire to output and connect input to wire: + +>>> class _(Module): +... io = IO(i=Input(U.w(8)), o=Output(U.w(5))) +... w = Wire(U.w(6)) +... io.o <<= w +... w <<= io.i + + +Connection with wrong direction + +>>> class _(Module): +... io = IO(i=Input(U.w(8))) +... lit = U(8) +... lit <<= io.i +Traceback (most recent call last): +... +py_hcl.core.stmt.error.StatementError: Connection statement with unexpected +direction. +""" + import logging from enum import Enum @@ -5,18 +48,18 @@ from py_hcl.core.stmt.error import StatementError from py_hcl.core.stmt_factory.trapper import StatementTrapper from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir +from py_hcl.core.type.bundle import BundleT, BundleDirection from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT from py_hcl.core.type.vector import VectorT -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize -class ConnSide(Enum): +class VariableType(Enum): UNKNOWN = 0 - LF = 1 - RT = 2 - BOTH = 3 + WriteOnly = 1 + ReadOnly = 2 + ReadWrite = 3 @json_serialize @@ -30,14 +73,31 @@ def __init__(self, left, right): connector = op_register('<<=') +def check_connect_direction(f): + def _(left: HclType, right: HclType): + if left.variable_type not in (VariableType.WriteOnly, + VariableType.ReadWrite): + direction = left.variable_type + raise StatementError.connect_direction_error( + f'The lhs of connection statement can not be a {direction}') + if right.variable_type not in (VariableType.ReadOnly, + VariableType.ReadWrite): + direction = right.variable_type + raise StatementError.connect_direction_error( + f'The rhs of connection statement can not be a {direction}') + + return f(left, right) + + return _ + + @connector(UIntT, UIntT) +@check_connect_direction def _(left, right): - check_connect_dir(left, right) - if left.hcl_type.width < right.hcl_type.width: - msg = 'connect(): connecting {} to {} will truncate the bits'.format( - right.hcl_type, left.hcl_type) - logging.warning(msg) + logging.warning( + f'connect(): connecting {right.hcl_type} to {left.hcl_type} ' + f'will truncate the bits') right = right[left.hcl_type.width - 1:0] if left.hcl_type.width > right.hcl_type.width: @@ -49,14 +109,12 @@ def _(left, right): @connector(SIntT, SIntT) +@check_connect_direction def _(left, right): - check_connect_dir(left, right) - if left.hcl_type.width < right.hcl_type.width: logging.warning( - 'connect(): connecting {} to {} will truncate the bits'.format( - right.hcl_type, left.hcl_type - )) + f'connect(): connecting {right.hcl_type} to {left.hcl_type} ' + f'will truncate the bits') right = right[left.hcl_type.width - 1:0].to_sint() if left.hcl_type.width > right.hcl_type.width: @@ -68,39 +126,38 @@ def _(left, right): @connector(UIntT, SIntT) +@check_connect_direction def _(left, right): - msg = 'connect(): connecting SInt to UInt, an auto-conversion will occur' - logging.warning(msg) + logging.warning( + 'connect(): connecting SInt to UInt will cause auto-conversion') if left.hcl_type.width < right.hcl_type.width: logging.warning( - 'connect(): connecting {} to {} will truncate the bits'.format( - right.hcl_type, left.hcl_type - )) + f'connect(): connect {right.hcl_type} to {left.hcl_type} ' + f'will truncate the bits') return op_apply('<<=')(left, right[left.hcl_type.width - 1:0]) return op_apply('<<=')(left, right.to_uint()) @connector(SIntT, UIntT) +@check_connect_direction def _(left, right): - msg = 'connect(): connecting UInt to SInt, an auto-conversion will occur' - logging.warning(msg) + logging.warning( + 'connect(): connecting UInt to SInt will cause auto-conversion') if left.hcl_type.width < right.hcl_type.width: logging.warning( - 'connect(): connecting {} to {} will truncate the bits'.format( - right.hcl_type, left.hcl_type - )) + f'connect(): connecting {right.hcl_type} to {left.hcl_type} ' + f'will truncate the bits') right = right[left.hcl_type.width - 1:0] return op_apply('<<=')(left, right.to_sint()) @connector(BundleT, BundleT) +@check_connect_direction def _(left, right): - check_connect_dir(left, right) - # TODO: Accurate Error Message dir_and_types = right.hcl_type.fields keys = dir_and_types.keys() @@ -109,7 +166,7 @@ def _(left, right): for k in keys: lf = op_apply('.')(left, k) rt = op_apply('.')(right, k) - if dir_and_types[k]['dir'] == Dir.SRC: + if dir_and_types[k]['dir'] == BundleDirection.SOURCE: op_apply('<<=')(lf, rt) else: op_apply('<<=')(rt, lf) @@ -118,9 +175,8 @@ def _(left, right): @connector(VectorT, VectorT) +@check_connect_direction def _(left, right): - check_connect_dir(left, right) - # TODO: Accurate Error Message assert left.hcl_type.size == right.hcl_type.size @@ -133,9 +189,3 @@ def _(left, right): @connector(HclType, HclType) def _(_0, _1): raise StatementError.connect_type_error(_0, _1) - - -def check_connect_dir(left, right): - # TODO: Accurate Error Message - assert left.conn_side in (ConnSide.LF, ConnSide.BOTH) - assert right.conn_side in (ConnSide.RT, ConnSide.BOTH) diff --git a/py_hcl/core/stmt/error/__init__.py b/py_hcl/core/stmt/error/__init__.py index 5625f57..4b91a76 100644 --- a/py_hcl/core/stmt/error/__init__.py +++ b/py_hcl/core/stmt/error/__init__.py @@ -4,13 +4,22 @@ def set_up(): StatementError.append({ 'WrongBranchSyntax': { - 'code': 300, - 'value': StatementError( - 'expected a well-defined when-else_when-otherwise block')}, + 'code': + 300, + 'value': + StatementError( + 'expected a well-defined when-else_when-otherwise block') + }, 'ConnectTypeError': { 'code': 301, - 'value': StatementError( - 'connect statement contains unexpected types') + 'value': + StatementError('Connect statement contains unexpected type.') + }, + 'ConnectDirectionError': { + 'code': + 302, + 'value': + StatementError('Connection statement with unexpected direction.') } }) @@ -23,8 +32,12 @@ def wrong_branch_syntax(msg): @staticmethod def connect_type_error(*args): ts = ', '.join([type(a.hcl_type).__name__ for a in args]) - msg = 'connect(): unsupported connect types: {}'.format(ts) + msg = 'connect(): unsupported connect type: {}'.format(ts) return StatementError.err('ConnectTypeError', msg) + @staticmethod + def connect_direction_error(msg): + return StatementError.err('ConnectDirectionError', msg) + set_up() diff --git a/py_hcl/core/stmt_factory/scope.py b/py_hcl/core/stmt_factory/scope.py index 889eba2..3932825 100644 --- a/py_hcl/core/stmt_factory/scope.py +++ b/py_hcl/core/stmt_factory/scope.py @@ -1,6 +1,6 @@ from enum import Enum -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize class ScopeType(Enum): @@ -56,9 +56,7 @@ class ScopeManager(object): scope_type=ScopeType.TOP, ) ] - scope_id_map = { - scope_list[0].scope_id: scope_list[0] - } + scope_id_map = {scope_list[0].scope_id: scope_list[0]} scope_expanding_hooks = [] scope_shrinking_hooks = [] diff --git a/py_hcl/core/stmt_factory/trapper.py b/py_hcl/core/stmt_factory/trapper.py index 1de18fd..d82b06c 100644 --- a/py_hcl/core/stmt_factory/trapper.py +++ b/py_hcl/core/stmt_factory/trapper.py @@ -3,10 +3,8 @@ def set_up(): - ScopeManager.register_scope_expanding( - StatementTrapper.on_scope_expanding) - ScopeManager.register_scope_shrinking( - StatementTrapper.on_scope_shrinking) + ScopeManager.register_scope_expanding(StatementTrapper.on_scope_expanding) + ScopeManager.register_scope_shrinking(StatementTrapper.on_scope_shrinking) ScopeManager.expand_scope(ScopeType.GROUND) @@ -28,10 +26,8 @@ def trap(cls): @classmethod def track(cls, statement): - statement = LineStatement( - ScopeManager.current_scope().scope_id, - statement - ) + statement = LineStatement(ScopeManager.current_scope().scope_id, + statement) cls.trapped_stmts[-1].append(statement) @classmethod diff --git a/py_hcl/core/type/__init__.py b/py_hcl/core/type/__init__.py index 029f12b..24baa50 100644 --- a/py_hcl/core/type/__init__.py +++ b/py_hcl/core/type/__init__.py @@ -1,4 +1,4 @@ -from py_hcl.utils import json_serialize +from py_hcl.utils.serialization import json_serialize @json_serialize diff --git a/py_hcl/core/type/bundle.py b/py_hcl/core/type/bundle.py index 54ac664..af292a1 100644 --- a/py_hcl/core/type/bundle.py +++ b/py_hcl/core/type/bundle.py @@ -5,8 +5,8 @@ from py_hcl.core.type.wrapper import vec_wrap, bd_fld_wrap -class Dir(Enum): - SRC = 1 +class BundleDirection(Enum): + SOURCE = 1 SINK = 2 diff --git a/py_hcl/core/type/error.py b/py_hcl/core/type/error.py new file mode 100644 index 0000000..a610753 --- /dev/null +++ b/py_hcl/core/type/error.py @@ -0,0 +1,19 @@ +from py_hcl.core.error import CoreError + + +def set_up(): + TypeError.append({ + 'SizeError': { + 'code': 500, + 'value': TypeError('Specified size is invalid.') + } + }) + + +class TypeError(CoreError): + @staticmethod + def size_err(msg): + return TypeError.err('SizeError', msg) + + +set_up() diff --git a/py_hcl/core/type/sint.py b/py_hcl/core/type/sint.py index 2fb0ea3..5adcfce 100644 --- a/py_hcl/core/type/sint.py +++ b/py_hcl/core/type/sint.py @@ -1,20 +1,29 @@ from py_hcl.core.type import HclType from py_hcl.core.type.wrapper import bd_fld_wrap, vec_wrap -from py_hcl.utils import signed_num_bin_len +from py_hcl.utils import signed_num_bin_width +from py_hcl.core.type.error import TypeError +from py_hcl.core.expr.error import ExprError @bd_fld_wrap @vec_wrap class SIntT(HclType): def __init__(self, width): + if width <= 1: + raise TypeError.size_err( + f'SInt width can not equal or less than 1, got {width}') + self.type = "sint" self.width = width def __call__(self, value: int): from py_hcl.core.expr.lit_sint import SLiteral - # TODO: Accurate Error Message - assert signed_num_bin_len(value) <= self.width + least_len = signed_num_bin_width(value) + if least_len > self.width: + raise ExprError.out_of_range_err( + f'Literal {value} out of range for sint[{self.width}]') + u = SLiteral(value) u.hcl_type = SIntT(self.width) return u diff --git a/py_hcl/core/type/uint.py b/py_hcl/core/type/uint.py index 1d691fb..42823a2 100644 --- a/py_hcl/core/type/uint.py +++ b/py_hcl/core/type/uint.py @@ -1,20 +1,29 @@ +from py_hcl.core.expr.error import ExprError from py_hcl.core.type import HclType from py_hcl.core.type.wrapper import bd_fld_wrap, vec_wrap -from py_hcl.utils import unsigned_num_bin_len +from py_hcl.utils import unsigned_num_bin_width +from py_hcl.core.type.error import TypeError @bd_fld_wrap @vec_wrap class UIntT(HclType): def __init__(self, width): + if width <= 0: + raise TypeError.size_err( + f'UInt width can not equal or less than 0, got {width}') + self.type = "uint" self.width = width def __call__(self, value: int): from py_hcl.core.expr.lit_uint import ULiteral - # TODO: Accurate Error Message - assert unsigned_num_bin_len(value) <= self.width + least_len = unsigned_num_bin_width(value) + if least_len > self.width: + raise ExprError.out_of_range_err( + f'Literal {value} out of range for uint[{self.width}]') + u = ULiteral(value) u.hcl_type = UIntT(self.width) return u diff --git a/py_hcl/core/type/vector.py b/py_hcl/core/type/vector.py index a840537..737c210 100644 --- a/py_hcl/core/type/vector.py +++ b/py_hcl/core/type/vector.py @@ -1,12 +1,15 @@ from py_hcl.core.type import HclType +from py_hcl.core.type.error import TypeError from py_hcl.core.type.wrapper import bd_fld_wrap @bd_fld_wrap class VectorT(HclType): def __init__(self, inner_type: HclType, size: int): - # TODO: Accurate Error Message - assert size > 0 + if size <= 0: + raise TypeError.size_err( + f'Vector size can not equal or less than 0, got {size}') + self.type = "vector" self.size = size self.inner_type = inner_type diff --git a/py_hcl/core/type/wrapper.py b/py_hcl/core/type/wrapper.py index 2f2ee2c..8173921 100644 --- a/py_hcl/core/type/wrapper.py +++ b/py_hcl/core/type/wrapper.py @@ -5,8 +5,8 @@ @dispatch() def invert_exp(self: HclType): - from py_hcl.core.type.bundle import Dir - return {'dir': Dir.SINK, 'hcl_type': self} + from py_hcl.core.type.bundle import BundleDirection + return {'dir': BundleDirection.SINK, 'hcl_type': self} def bd_fld_wrap(cls): @@ -18,13 +18,13 @@ def bd_fld_wrap(cls): @dispatch() -def vec_exp(self: HclType, i: int): +def vec_ext(self: HclType, i: int): from py_hcl.core.type.vector import VectorT return VectorT(self, i) @dispatch() -def vec_exp(self: HclType, t: tuple): +def vec_ext(self: HclType, t: tuple): from py_hcl.core.type.vector import VectorT # TODO: Accurate Error Message @@ -39,5 +39,5 @@ def vec_wrap(cls): if hasattr(cls, '__getitem__'): return cls - cls.__getitem__ = vec_exp + cls.__getitem__ = vec_ext return cls diff --git a/py_hcl/core/utils/__init__.py b/py_hcl/core/utils/__init__.py index 0a13545..9098975 100644 --- a/py_hcl/core/utils/__init__.py +++ b/py_hcl/core/utils/__init__.py @@ -1,9 +1,79 @@ from typing import Tuple, List -def module_inherit_mro(modules: Tuple[type]) -> List[type]: +def module_inherit_mro(bases: Tuple[type]) -> List[type]: + """ + Returns modules in method resolution order. + + As we want to handle inherit stuff, the mro is a very basic entry to travel + the inherit graph. + + + Examples + -------- + + >>> from py_hcl import * + + A normal PyHCL module `V`: + + >>> class V(Module): + ... io = IO() + + + A PyHCL module `W` inheriting `V`: + + >>> class W(V): + ... io = io_extend(V)() + + + Let's see the bases of `W`: + + >>> W.__bases__ + (,) + + + So we can get the mro of its bases via `module_inherit_mro`: + + >>> module_inherit_mro(W.__bases__) + [, + ] + + + A more complicated case: + + >>> class X(W): + ... io = io_extend(W)() + + >>> class Y(V): + ... io = io_extend(V)() + + >>> class Z(X, Y): + ... io = io_extend(X, Y)() + + + To handle expression inheritance or statement inheritance of the module + `Z`, we can first get the mro of `Z`'s bases: + + >>> module_inherit_mro(Z.__bases__) + [, + , + , + , + ] + + """ + from py_hcl.core.module.meta_module import MetaModule - modules = type("_hcl_fake_module", modules, {}).mro() - modules = [m for m in modules[1:] if isinstance(m, MetaModule)] + # Step 1: Build a fake Python class extending bases. + # + # Since `Module`s inherit from `MetaModule`, class construction here will + # trigger `MetaModule.__init__`. We get around the side effect by adding a + # conditional early return at the beginning of `MetaModule.__init__`. + fake_type = type("_hcl_fake_module", bases, {}) + + # Step 2: Get the method resolution order of the fake class. + # Step 3: Filter useless types in the mro. + modules = [m for m in fake_type.mro()[1:] if isinstance(m, MetaModule)] + return modules diff --git a/py_hcl/dsl/__init__.py b/py_hcl/dsl/__init__.py index e204fb4..6f27ce8 100644 --- a/py_hcl/dsl/__init__.py +++ b/py_hcl/dsl/__init__.py @@ -1,8 +1,8 @@ from .branch import when, else_when, otherwise # noqa: F401 from .module import Module # noqa: F401 -from .tpe.clock import Clock # noqa: F401 -from .tpe.uint import U, Bool # noqa: F401 -from .tpe.sint import S # noqa: F401 -from .tpe.bundle import Bundle # noqa: F401 +from .type.clock import Clock # noqa: F401 +from .type.uint import U, Bool # noqa: F401 +from .type.sint import S # noqa: F401 +from .type.bundle import Bundle # noqa: F401 from .expr.io import IO, Input, Output, io_extend # noqa: F401 from .expr.wire import Wire # noqa: F401 diff --git a/py_hcl/dsl/error/__init__.py b/py_hcl/dsl/error/__init__.py index 60b5854..5c2860c 100644 --- a/py_hcl/dsl/error/__init__.py +++ b/py_hcl/dsl/error/__init__.py @@ -1,4 +1,4 @@ -from py_hcl.error import PyHclError +from py_hcl.utils.error import PyHclError class DslError(PyHclError): diff --git a/py_hcl/dsl/expr/io.py b/py_hcl/dsl/expr/io.py index d862919..7ee77d2 100644 --- a/py_hcl/dsl/expr/io.py +++ b/py_hcl/dsl/expr/io.py @@ -1,12 +1,13 @@ from typing import Union import py_hcl.core.expr.io as cio +import py_hcl.core.module_factory.inherit_chain.io as inherit_io from py_hcl.core.module.base_module import BaseModule from py_hcl.core.type import HclType def IO(**named_ports: Union[cio.Input, cio.Output]) -> cio.IO: - return cio.io_extend((BaseModule,))(named_ports) + return inherit_io.io_extend((BaseModule, ))(named_ports) def Input(hcl_type: HclType) -> cio.Input: @@ -19,6 +20,6 @@ def Output(hcl_type: HclType) -> cio.Output: def io_extend(*modules: type): def _(**named_ports: Union[cio.Input, cio.Output]): - return cio.io_extend(modules)(named_ports) + return inherit_io.io_extend(modules)(named_ports) return _ diff --git a/py_hcl/dsl/tpe/bundle.py b/py_hcl/dsl/tpe/bundle.py deleted file mode 100644 index 53420cb..0000000 --- a/py_hcl/dsl/tpe/bundle.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Union - -from py_hcl.core.type import HclType -from py_hcl.core.type.bundle import BundleT, Dir - - -def Bundle(**named_ports: Union[HclType, dict]) -> BundleT: - t = {k: ({'dir': Dir.SRC, 'hcl_type': v} if isinstance(v, HclType) else v) - for k, v in named_ports.items()} - - return BundleT(t) diff --git a/py_hcl/dsl/type/__init__.py b/py_hcl/dsl/type/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py_hcl/dsl/type/bundle.py b/py_hcl/dsl/type/bundle.py new file mode 100644 index 0000000..7557708 --- /dev/null +++ b/py_hcl/dsl/type/bundle.py @@ -0,0 +1,16 @@ +from typing import Union + +from py_hcl.core.type import HclType +from py_hcl.core.type.bundle import BundleT, BundleDirection + + +def Bundle(**named_ports: Union[HclType, dict]) -> BundleT: + t = { + k: ({ + 'dir': BundleDirection.SOURCE, + 'hcl_type': v + } if isinstance(v, HclType) else v) + for k, v in named_ports.items() + } + + return BundleT(t) diff --git a/py_hcl/dsl/tpe/clock.py b/py_hcl/dsl/type/clock.py similarity index 100% rename from py_hcl/dsl/tpe/clock.py rename to py_hcl/dsl/type/clock.py diff --git a/py_hcl/dsl/tpe/sint.py b/py_hcl/dsl/type/sint.py similarity index 100% rename from py_hcl/dsl/tpe/sint.py rename to py_hcl/dsl/type/sint.py diff --git a/py_hcl/dsl/tpe/uint.py b/py_hcl/dsl/type/uint.py similarity index 100% rename from py_hcl/dsl/tpe/uint.py rename to py_hcl/dsl/type/uint.py diff --git a/py_hcl/firrtl_ir/type_checker/expr/accessor.py b/py_hcl/firrtl_ir/type_checker/expr/accessor.py index 3b0d815..09cc60c 100644 --- a/py_hcl/firrtl_ir/type_checker/expr/accessor.py +++ b/py_hcl/firrtl_ir/type_checker/expr/accessor.py @@ -9,24 +9,22 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(SubField) def check(sub_field): from .. import check_all_expr if not check_all_expr(sub_field.bundle_ref): logging.error("sub_field: reference check failed - {}".format( - sub_field.bundle_ref - )) + sub_field.bundle_ref)) return False if not type_in(sub_field.bundle_ref.tpe, BundleType): logging.error("sub_field: reference type check failed - {}".format( - sub_field.bundle_ref - )) + sub_field.bundle_ref)) return False field = None @@ -36,8 +34,7 @@ def check(sub_field): break if field is None: logging.error("sub_field: field not exist - {} not in {}".format( - sub_field.name, sub_field.bundle_ref.tpe.fields - )) + sub_field.name, sub_field.bundle_ref.tpe.fields)) return False if not equal(field.tpe, sub_field.tpe): diff --git a/py_hcl/firrtl_ir/type_checker/expr/literal.py b/py_hcl/firrtl_ir/type_checker/expr/literal.py index f19bd37..4f21ad9 100644 --- a/py_hcl/firrtl_ir/type_checker/expr/literal.py +++ b/py_hcl/firrtl_ir/type_checker/expr/literal.py @@ -4,15 +4,15 @@ from ..utils import type_in from ...expr.literal import SIntLiteral, SIntType, UIntLiteral, UIntType -from py_hcl.utils import signed_num_bin_len, unsigned_num_bin_len +from py_hcl.utils import signed_num_bin_width, unsigned_num_bin_width checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(UIntLiteral) def check(uint: UIntLiteral): if not type_in(uint.tpe, UIntType): @@ -23,10 +23,9 @@ def check(uint: UIntLiteral): logging.error("uint: value check failed - {}".format(uint.value)) return False - if unsigned_num_bin_len(uint.value) > uint.tpe.width.width: + if unsigned_num_bin_width(uint.value) > uint.tpe.width.width: logging.error("uint: width check failed - {} > {}".format( - unsigned_num_bin_len(uint.value), uint.tpe.width.width) - ) + unsigned_num_bin_width(uint.value), uint.tpe.width.width)) return False return True @@ -38,10 +37,9 @@ def check(sint: SIntLiteral): logging.error("sint: type check failed - {}".format(sint.tpe)) return False - if signed_num_bin_len(sint.value) > sint.tpe.width.width: + if signed_num_bin_width(sint.value) > sint.tpe.width.width: logging.error("sint: width check failed - {} > {}".format( - signed_num_bin_len(sint.value), sint.tpe.width.width) - ) + signed_num_bin_width(sint.value), sint.tpe.width.width)) return False return True diff --git a/py_hcl/firrtl_ir/type_checker/expr/mux.py b/py_hcl/firrtl_ir/type_checker/expr/mux.py index fd8ddd9..6767871 100644 --- a/py_hcl/firrtl_ir/type_checker/expr/mux.py +++ b/py_hcl/firrtl_ir/type_checker/expr/mux.py @@ -6,11 +6,11 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(Mux) def check(mux: Mux): from .. import check_all_expr diff --git a/py_hcl/firrtl_ir/type_checker/expr/prim_ops.py b/py_hcl/firrtl_ir/type_checker/expr/prim_ops.py index 3a517c9..d26b227 100644 --- a/py_hcl/firrtl_ir/type_checker/expr/prim_ops.py +++ b/py_hcl/firrtl_ir/type_checker/expr/prim_ops.py @@ -10,20 +10,18 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(Add) def check(add: Add): from .. import check_all_expr if not check_all_expr(*add.args): return False - if not check_all_same_uint_sint(add.args[0].tpe, - add.args[1].tpe, - add.tpe): + if not check_all_same_uint_sint(add.args[0].tpe, add.args[1].tpe, add.tpe): return False expected_type_width = max(add.args[0].tpe.width.width, @@ -40,8 +38,7 @@ def check(sub: Sub): if not check_all_expr(*sub.args): return False - if not check_all_same_uint_sint(sub.args[0].tpe, - sub.args[1].tpe): + if not check_all_same_uint_sint(sub.args[0].tpe, sub.args[1].tpe): return False if not type_in(sub.tpe, SIntType): @@ -61,9 +58,7 @@ def check(mul: Mul): if not check_all_expr(*mul.args): return False - if not check_all_same_uint_sint(mul.args[0].tpe, - mul.args[1].tpe, - mul.tpe): + if not check_all_same_uint_sint(mul.args[0].tpe, mul.args[1].tpe, mul.tpe): return False expected_type_width = \ @@ -80,9 +75,7 @@ def check(div: Div): if not check_all_expr(*div.args): return False - if not check_all_same_uint_sint(div.args[0].tpe, - div.args[1].tpe, - div.tpe): + if not check_all_same_uint_sint(div.args[0].tpe, div.args[1].tpe, div.tpe): return False expected_type_width = div.args[0].tpe.width.width @@ -100,9 +93,7 @@ def check(rem: Rem): if not check_all_expr(*rem.args): return False - if not check_all_same_uint_sint(rem.args[0].tpe, - rem.args[1].tpe, - rem.tpe): + if not check_all_same_uint_sint(rem.args[0].tpe, rem.args[1].tpe, rem.tpe): return False expected_type_width = min(rem.args[0].tpe.width.width, @@ -197,8 +188,7 @@ def check(cat: Cat): if not check_all_expr(*cat.args): return False - if not check_all_same_uint_sint(cat.args[0].tpe, - cat.args[1].tpe): + if not check_all_same_uint_sint(cat.args[0].tpe, cat.args[1].tpe): return False if not type_in(cat.tpe, UIntType): @@ -322,8 +312,7 @@ def check(dshl: Dshl): if not check_all_expr(*dshl.args): return False - if not check_all_same_uint_sint(dshl.args[0].tpe, - dshl.tpe): + if not check_all_same_uint_sint(dshl.args[0].tpe, dshl.tpe): return False if not type_in(dshl.args[1].tpe, UIntType): @@ -343,8 +332,7 @@ def check(dshr: Dshr): if not check_all_expr(*dshr.args): return False - if not check_all_same_uint_sint(dshr.args[0].tpe, - dshr.tpe): + if not check_all_same_uint_sint(dshr.args[0].tpe, dshr.tpe): return False if not type_in(dshr.args[1].tpe, UIntType): diff --git a/py_hcl/firrtl_ir/type_checker/stmt/block.py b/py_hcl/firrtl_ir/type_checker/stmt/block.py index c296972..9111e2d 100644 --- a/py_hcl/firrtl_ir/type_checker/stmt/block.py +++ b/py_hcl/firrtl_ir/type_checker/stmt/block.py @@ -5,11 +5,11 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(Block) def check(block: Block): from .. import check_all_stmt diff --git a/py_hcl/firrtl_ir/type_checker/stmt/conditionally.py b/py_hcl/firrtl_ir/type_checker/stmt/conditionally.py index 9cc08de..bce9259 100644 --- a/py_hcl/firrtl_ir/type_checker/stmt/conditionally.py +++ b/py_hcl/firrtl_ir/type_checker/stmt/conditionally.py @@ -8,11 +8,11 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(Conditionally) def check(cond: Conditionally): from .. import check_all_expr, check_all_stmt diff --git a/py_hcl/firrtl_ir/type_checker/stmt/connect.py b/py_hcl/firrtl_ir/type_checker/stmt/connect.py index 23e3ce8..2974ae3 100644 --- a/py_hcl/firrtl_ir/type_checker/stmt/connect.py +++ b/py_hcl/firrtl_ir/type_checker/stmt/connect.py @@ -7,11 +7,11 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(Connect) def check(connect: Connect): from .. import check_all_expr @@ -21,8 +21,7 @@ def check(connect: Connect): if not equal(connect.loc_ref.tpe, connect.expr_ref.tpe): logging.error("connect: type unmatched - {} & {}".format( - connect.loc_ref.tpe, connect.expr_ref.tpe - )) + connect.loc_ref.tpe, connect.expr_ref.tpe)) return False return True diff --git a/py_hcl/firrtl_ir/type_checker/stmt/definition.py b/py_hcl/firrtl_ir/type_checker/stmt/definition.py index 14d6f20..d5a2cdf 100644 --- a/py_hcl/firrtl_ir/type_checker/stmt/definition.py +++ b/py_hcl/firrtl_ir/type_checker/stmt/definition.py @@ -14,11 +14,11 @@ checker = dispatch - ############################################################### # TYPE CHECKERS # ############################################################### + @checker(DefWire) def check(_: DefWire): return True diff --git a/py_hcl/transformer/pyhcl_to_firrtl/context.py b/py_hcl/transformer/pyhcl_to_firrtl/context.py deleted file mode 100644 index 170afd2..0000000 --- a/py_hcl/transformer/pyhcl_to_firrtl/context.py +++ /dev/null @@ -1,8 +0,0 @@ -from py_hcl.core.expr import ExprTable - - -class Context(object): - modules = {} - expr_id_to_name = {} - expr_obj_id_to_ref = {} - expr_table = ExprTable.table diff --git a/py_hcl/transformer/pyhcl_to_firrtl/conv_expr.py b/py_hcl/transformer/pyhcl_to_firrtl/conv_expr.py index 04b62ca..1265a9b 100644 --- a/py_hcl/transformer/pyhcl_to_firrtl/conv_expr.py +++ b/py_hcl/transformer/pyhcl_to_firrtl/conv_expr.py @@ -2,15 +2,15 @@ from multipledispatch import dispatch -from py_hcl.transformer.pyhcl_to_firrtl.context import Context +from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext from py_hcl.transformer.pyhcl_to_firrtl.conv_port import ports_to_bundle_type from py_hcl.transformer.pyhcl_to_firrtl.conv_type import convert_type from py_hcl.transformer.pyhcl_to_firrtl.utils import build_io_name, get_io_obj from py_hcl.core.expr import ExprHolder from py_hcl.core.expr.add import Add as CAdd -from py_hcl.core.expr.and_ import And as CAnd +from py_hcl.core.expr.ands import And as CAnd from py_hcl.core.expr.xor import Xor as CXor -from py_hcl.core.expr.or_ import Or as COr +from py_hcl.core.expr.ors import Or as COr from py_hcl.core.expr.convert import ToSInt, ToUInt from py_hcl.core.expr.extend import Extend from py_hcl.core.expr.field import FieldAccess @@ -35,9 +35,9 @@ def convert_expr_by_id(expr_id: int): - obj = Context.expr_table[expr_id] - if id(obj) in Context.expr_obj_id_to_ref: - return [], Context.expr_obj_id_to_ref[id(obj)] + obj = GlobalContext.expr_table[expr_id] + if id(obj) in GlobalContext.expr_obj_id_to_ref: + return [], GlobalContext.expr_obj_id_to_ref[id(obj)] return convert_expr(obj) @@ -48,8 +48,8 @@ def convert_expr_op(expr_holder: ExprHolder, add: CAdd): r_stmts, r_ref = convert_expr_by_id(add.right_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(Add([l_ref, r_ref], typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(Add([l_ref, r_ref], typ), name, typ, + id(expr_holder)) return [*l_stmts, *r_stmts, stmt], ref @@ -59,8 +59,8 @@ def convert_expr_op(expr_holder: ExprHolder, and_: CAnd): r_stmts, r_ref = convert_expr_by_id(and_.right_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(And([l_ref, r_ref], typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(And([l_ref, r_ref], typ), name, typ, + id(expr_holder)) return [*l_stmts, *r_stmts, stmt], ref @@ -70,8 +70,8 @@ def convert_expr_op(expr_holder: ExprHolder, xor: CXor): r_stmts, r_ref = convert_expr_by_id(xor.right_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(Xor([l_ref, r_ref], typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(Xor([l_ref, r_ref], typ), name, typ, + id(expr_holder)) return [*l_stmts, *r_stmts, stmt], ref @@ -81,8 +81,8 @@ def convert_expr_op(expr_holder: ExprHolder, or_: COr): r_stmts, r_ref = convert_expr_by_id(or_.right_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(Or([l_ref, r_ref], typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(Or([l_ref, r_ref], typ), name, typ, + id(expr_holder)) return [*l_stmts, *r_stmts, stmt], ref @@ -91,8 +91,8 @@ def convert_expr_op(expr_holder: ExprHolder, bs: CBits): stmts, v_ref = convert_expr_by_id(bs.ref_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(Bits(v_ref, [bs.high, bs.low], typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(Bits(v_ref, [bs.high, bs.low], typ), name, typ, + id(expr_holder)) return [*stmts, stmt], ref @@ -101,8 +101,7 @@ def convert_expr_op(expr_holder: ExprHolder, ts: ToSInt): stmts, v_ref = convert_expr_by_id(ts.ref_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(AsSInt(v_ref, typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(AsSInt(v_ref, typ), name, typ, id(expr_holder)) return [*stmts, stmt], ref @@ -111,8 +110,7 @@ def convert_expr_op(expr_holder: ExprHolder, tu: ToUInt): stmts, v_ref = convert_expr_by_id(tu.ref_expr_id) name = NameGetter.get(expr_holder.id) typ = convert_type(expr_holder.hcl_type) - stmt, ref = save_node_ref(AsUInt(v_ref, typ), - name, typ, id(expr_holder)) + stmt, ref = save_node_ref(AsUInt(v_ref, typ), name, typ, id(expr_holder)) return [*stmts, stmt], ref @@ -122,7 +120,7 @@ def convert_expr_op(expr_holder: ExprHolder, et: Extend): typ = convert_type(expr_holder.hcl_type) ref = copy.copy(v_ref) ref.tpe = typ - Context.expr_obj_id_to_ref[id(expr_holder)] = ref + GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref return stmts, ref @@ -130,27 +128,25 @@ def convert_expr_op(expr_holder: ExprHolder, et: Extend): def convert_expr_op(expr_holder: ExprHolder, vi: VecIndex): stmts, v_ref = convert_expr_by_id(vi.ref_expr_id) ref = SubIndex(v_ref, vi.index, convert_type(expr_holder.hcl_type)) - Context.expr_obj_id_to_ref[id(expr_holder)] = ref + GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref return stmts, ref @dispatch() def convert_expr_op(expr_holder: ExprHolder, fa: FieldAccess): typ = convert_type(expr_holder.hcl_type) - obj = Context.expr_table[fa.ref_expr_id] + obj = GlobalContext.expr_table[fa.ref_expr_id] def fetch_current_io_holder(obj): - current_node = obj.io_chain_head - while True: - if fa.item in current_node.io_holder.named_ports: - return current_node.io_holder - current_node = current_node.next_node + for io_holder in obj.io_chain: + if fa.item in io_holder.named_ports: + return io_holder if isinstance(obj, IO): io_holder = fetch_current_io_holder(obj) name = build_io_name(io_holder.module_name, fa.item) ref = Reference(name, typ) - Context.expr_obj_id_to_ref[id(expr_holder)] = ref + GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref return [], ref elif isinstance(obj, ModuleInst): @@ -158,13 +154,13 @@ def fetch_current_io_holder(obj): io_holder = fetch_current_io_holder(get_io_obj(obj.packed_module)) name = build_io_name(io_holder.module_name, fa.item) ref = SubField(b_ref, name, typ) - Context.expr_obj_id_to_ref[id(expr_holder)] = ref + GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref return stmts, ref else: stmts, src_ref = convert_expr_by_id(fa.ref_expr_id) ref = SubField(src_ref, fa.item, typ) - Context.expr_obj_id_to_ref[id(expr_holder)] = ref + GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref return stmts, ref @@ -176,14 +172,14 @@ def convert_expr(expr_holder: ExprHolder): @dispatch() def convert_expr(slit: SLiteral): ft = SIntLiteral(slit.value, convert_type(slit.hcl_type).width) - Context.expr_obj_id_to_ref[id(slit)] = ft + GlobalContext.expr_obj_id_to_ref[id(slit)] = ft return [], ft @dispatch() def convert_expr(ulit: ULiteral): ft = UIntLiteral(ulit.value, convert_type(ulit.hcl_type).width) - Context.expr_obj_id_to_ref[id(ulit)] = ft + GlobalContext.expr_obj_id_to_ref[id(ulit)] = ft return [], ft @@ -194,24 +190,26 @@ def convert_expr(wire: Wire): stmt = DefWire(name, typ) ref = Reference(name, typ) - Context.expr_obj_id_to_ref[id(wire)] = ref + GlobalContext.expr_obj_id_to_ref[id(wire)] = ref return [stmt], ref @dispatch() def convert_expr(mi: ModuleInst): - if mi.module_name not in Context.modules: + if mi.module_name not in GlobalContext.modules: from .conv_module import convert_module convert_module(mi.packed_module) - module = Context.modules[mi.module_name] + module = GlobalContext.modules[mi.module_name] name = NameGetter.get(mi.id) ref = Reference(name, ports_to_bundle_type(module.ports)) - stmts = [DefInstance(name, mi.module_name), - Connect(SubField(ref, 'clock', ClockType()), - Reference('clock', ClockType())), - Connect(SubField(ref, 'reset', UIntType(Width(1))), - Reference('reset', UIntType(Width(1))))] - Context.expr_obj_id_to_ref[id(mi)] = ref + stmts = [ + DefInstance(name, mi.module_name), + Connect(SubField(ref, 'clock', ClockType()), + Reference('clock', ClockType())), + Connect(SubField(ref, 'reset', UIntType(Width(1))), + Reference('reset', UIntType(Width(1)))) + ] + GlobalContext.expr_obj_id_to_ref[id(mi)] = ref return stmts, ref @@ -221,7 +219,7 @@ class NameGetter(object): @classmethod def get(cls, expr_id: int): try: - return Context.expr_id_to_name[expr_id] + return GlobalContext.expr_id_to_name[expr_id] except KeyError: cls.cnt += 1 return "_T_" + str(cls.cnt) @@ -230,5 +228,5 @@ def get(cls, expr_id: int): def save_node_ref(op_ir, name, tpe, obj_id): stmt = DefNode(name, op_ir) ref = Reference(name, tpe) - Context.expr_obj_id_to_ref[obj_id] = ref + GlobalContext.expr_obj_id_to_ref[obj_id] = ref return stmt, ref diff --git a/py_hcl/transformer/pyhcl_to_firrtl/conv_module.py b/py_hcl/transformer/pyhcl_to_firrtl/conv_module.py index b7145de..e599350 100644 --- a/py_hcl/transformer/pyhcl_to_firrtl/conv_module.py +++ b/py_hcl/transformer/pyhcl_to_firrtl/conv_module.py @@ -1,18 +1,20 @@ -from py_hcl.transformer.pyhcl_to_firrtl.context import Context +from typing import List + +from py_hcl.core.module_factory.inherit_chain.named_expr import NamedExprHolder +from py_hcl.core.module_factory.inherit_chain.stmt_holder import StmtHolder +from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext from py_hcl.transformer.pyhcl_to_firrtl.conv_port import convert_ports from py_hcl.transformer.pyhcl_to_firrtl.conv_stmt import convert_stmt from py_hcl.transformer.pyhcl_to_firrtl.utils import build_reserve_name, \ build_io_name, get_io_obj from py_hcl.core.expr.io import IO from py_hcl.core.module.packed_module import PackedModule -from py_hcl.core.module_factory.inherit_list.named_expr import NamedExprChain -from py_hcl.core.module_factory.inherit_list.stmt_holder import StmtChain from py_hcl.firrtl_ir.stmt.block import Block from py_hcl.firrtl_ir.stmt.defn.module import DefModule def convert_module(packed_module: PackedModule): - Context.expr_id_to_name.update( + GlobalContext.expr_id_to_name.update( flatten_named_expr_chain(packed_module.named_expr_chain)) name = packed_module.name @@ -23,55 +25,35 @@ def convert_module(packed_module: PackedModule): final_stmts = [ss for s in stmts for ss in convert_stmt(s)] module = DefModule(name, ports, Block(final_stmts)) - Context.modules[name] = module + GlobalContext.modules[name] = module return module -def flatten_named_expr_chain(named_expr_chain: NamedExprChain): +def flatten_named_expr_chain(named_expr_chain: List[NamedExprHolder]): named_exprs = {} - node = named_expr_chain.named_expr_chain_head - while True: - holder = node.named_expr_holder + for holder in named_expr_chain: for k, v in holder.named_expression_table.items(): named_exprs[k] = build_reserve_name(holder.module_name, v) - if not hasattr(node, "next_node"): - break - node = node.next_node - return named_exprs -def flatten_statement_chain(statement_chain: StmtChain): +def flatten_statement_chain(statement_chain: List[StmtHolder]): stmts = [] - node = statement_chain.stmt_chain_head - while True: - holder = node.stmt_holder + for holder in statement_chain: for stmt in reversed(holder.top_statement.statements): stmts.append(stmt) - if not hasattr(node, "next_node"): - break - node = node.next_node - - return list(reversed(stmts)) + return stmts[::-1] def flatten_io_chain(io: IO): ports = {} - node = io.io_chain_head - - while True: - holder = node.io_holder - # reverse the dict order - for k in list(holder.named_ports.keys())[::-1]: - v = holder.named_ports[k] - ports[build_io_name(holder.module_name, k)] = v - if not hasattr(node, "next_node"): - break - node = node.next_node + for io_holder in io.io_chain: + for k, v in io_holder.named_ports.items(): + ports[build_io_name(io_holder.module_name, k)] = v - return {k: ports[k] for k in list(ports.keys())[::-1]} + return ports diff --git a/py_hcl/transformer/pyhcl_to_firrtl/conv_port.py b/py_hcl/transformer/pyhcl_to_firrtl/conv_port.py index a67d25c..dd3cad9 100644 --- a/py_hcl/transformer/pyhcl_to_firrtl/conv_port.py +++ b/py_hcl/transformer/pyhcl_to_firrtl/conv_port.py @@ -9,8 +9,10 @@ def convert_ports(raw_ports: Dict[str, Union[Input, Output]]): - ports = [InputPort('clock', ClockType()), - InputPort('reset', UIntType(Width(1)))] + ports = [ + InputPort('clock', ClockType()), + InputPort('reset', UIntType(Width(1))) + ] for k, v in raw_ports.items(): p = InputPort if v.port_dir == 'input' else OutputPort ports.append(p(k, convert_type(v.hcl_type))) diff --git a/py_hcl/transformer/pyhcl_to_firrtl/conv_type.py b/py_hcl/transformer/pyhcl_to_firrtl/conv_type.py index 93361e5..a9a1f12 100644 --- a/py_hcl/transformer/pyhcl_to_firrtl/conv_type.py +++ b/py_hcl/transformer/pyhcl_to_firrtl/conv_type.py @@ -1,7 +1,7 @@ from multipledispatch import dispatch from py_hcl.core.type import UnknownType as HclUnknownType -from py_hcl.core.type.bundle import BundleT, Dir +from py_hcl.core.type.bundle import BundleT, BundleDirection from py_hcl.core.type.clock import ClockT from py_hcl.core.type.sint import SIntT from py_hcl.core.type.uint import UIntT @@ -41,6 +41,7 @@ def convert_type(vec: VectorT): def convert_type(bundle: BundleT): fields = [] for k, v in bundle.fields.items(): - f = Field(k, convert_type(v['hcl_type']), v['dir'] == Dir.SINK) + f = Field(k, convert_type(v['hcl_type']), + v['dir'] == BundleDirection.SINK) fields.append(f) return BundleType(fields) diff --git a/py_hcl/transformer/pyhcl_to_firrtl/convertor.py b/py_hcl/transformer/pyhcl_to_firrtl/convertor.py index 24914ba..9e7e311 100644 --- a/py_hcl/transformer/pyhcl_to_firrtl/convertor.py +++ b/py_hcl/transformer/pyhcl_to_firrtl/convertor.py @@ -1,4 +1,4 @@ -from py_hcl.transformer.pyhcl_to_firrtl.context import Context +from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext from py_hcl.transformer.pyhcl_to_firrtl.conv_module import convert_module from py_hcl.core.module.packed_module import PackedModule from py_hcl.firrtl_ir.stmt.defn.circuit import DefCircuit @@ -7,10 +7,10 @@ def convert(packed_module: PackedModule): convert_module(packed_module) - modules = list(Context.modules.values()) - Context.modules.clear() - Context.expr_obj_id_to_ref.clear() - Context.expr_id_to_name.clear() + modules = list(GlobalContext.modules.values()) + + GlobalContext.clear() + cir = DefCircuit(packed_module.name, modules) assert check(cir) return cir diff --git a/py_hcl/transformer/pyhcl_to_firrtl/global_context.py b/py_hcl/transformer/pyhcl_to_firrtl/global_context.py new file mode 100644 index 0000000..75a6b22 --- /dev/null +++ b/py_hcl/transformer/pyhcl_to_firrtl/global_context.py @@ -0,0 +1,14 @@ +from py_hcl.core.expr import ExprTable + + +class GlobalContext(object): + modules = {} + expr_id_to_name = {} + expr_obj_id_to_ref = {} + expr_table = ExprTable.table + + @staticmethod + def clear(): + GlobalContext.modules.clear() + GlobalContext.expr_id_to_name.clear() + GlobalContext.expr_obj_id_to_ref.clear() diff --git a/py_hcl/transformer/pyhcl_to_firrtl/utils.py b/py_hcl/transformer/pyhcl_to_firrtl/utils.py index 2aa4edf..71d2adb 100644 --- a/py_hcl/transformer/pyhcl_to_firrtl/utils.py +++ b/py_hcl/transformer/pyhcl_to_firrtl/utils.py @@ -1,4 +1,5 @@ -from py_hcl.transformer.pyhcl_to_firrtl.context import Context +from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext +from py_hcl.utils import get_key_by_value def build_io_name(module_name: str, field_name: str): @@ -10,7 +11,6 @@ def build_reserve_name(module_name: str, expr_name: str): def get_io_obj(packed_module): - table = packed_module.named_expr_chain.named_expr_chain_head \ - .named_expr_holder.named_expression_table - io_id = list(table.keys())[list(table.values()).index('io')] - return Context.expr_table[io_id] + table = packed_module.named_expr_chain[0].named_expression_table + io_id = get_key_by_value(table, 'io') + return GlobalContext.expr_table[io_id] diff --git a/py_hcl/utils/__init__.py b/py_hcl/utils/__init__.py new file mode 100644 index 0000000..01f43e1 --- /dev/null +++ b/py_hcl/utils/__init__.py @@ -0,0 +1,78 @@ +def signed_num_bin_width(num: int): + """ + Returns least binary width to hold the specified signed `num`. + + Examples + -------- + + >>> signed_num_bin_width(10) + 5 + + >>> signed_num_bin_width(-1) + 2 + + >>> signed_num_bin_width(-2) + 3 + + >>> signed_num_bin_width(0) + 2 + """ + + return len("{:+b}".format(num)) + + +def unsigned_num_bin_width(num: int): + """ + Returns least binary width to hold the specified unsigned `num`. + + Examples + -------- + + >>> unsigned_num_bin_width(10) + 4 + + >>> unsigned_num_bin_width(1) + 1 + + >>> unsigned_num_bin_width(0) + 1 + + >>> unsigned_num_bin_width(-1) + Traceback (most recent call last): + ... + ValueError: Unexpected negative number: -1 + """ + + if num < 0: + raise ValueError(f"Unexpected negative number: {num}") + + return len("{:b}".format(num)) + + +def get_key_by_value(kvs: dict, value): + """ + Returns key associated to specified value from dictionary. + + Examples + -------- + + >>> get_key_by_value({1: 'a', 2: 'b'}, 'b') + 2 + + >>> get_key_by_value({'a': 1, 'b': 2}, 2) + 'b' + + >>> get_key_by_value({'a': 1, 'b': 1}, 1) + 'a' + + >>> get_key_by_value({1: 'a'}, 'b') + Traceback (most recent call last): + ... + ValueError: b is not in dict values + """ + + vs = list(kvs.values()) + if value not in vs: + raise ValueError(f"{value} is not in dict values") + + return list(kvs.keys())[vs.index(value)] diff --git a/py_hcl/error/__init__.py b/py_hcl/utils/error/__init__.py similarity index 100% rename from py_hcl/error/__init__.py rename to py_hcl/utils/error/__init__.py diff --git a/py_hcl/utils.py b/py_hcl/utils/serialization/__init__.py similarity index 52% rename from py_hcl/utils.py rename to py_hcl/utils/serialization/__init__.py index 8dc5697..a81e97e 100644 --- a/py_hcl/utils.py +++ b/py_hcl/utils/serialization/__init__.py @@ -2,16 +2,6 @@ from enum import Enum from functools import partial -from multipledispatch import dispatch - - -def signed_num_bin_len(num): - return len("{:+b}".format(num)) - - -def unsigned_num_bin_len(num): - return len("{:b}".format(num)) - def json_serialize(cls=None, json_fields=()): def rec(v): @@ -45,42 +35,3 @@ def js(self): return _(cls, json_fields) return partial(_, _json_fields=json_fields) - - -@dispatch() -def _fm(vd: dict): - ls = ['{}: {}'.format(k, _fm(v)) for k, v in vd.items()] - fs = _iter_repr(ls) - return '{%s}' % (''.join(fs)) - - -@dispatch() -def _fm(v: list): - ls = [_fm(a) for a in v] - fs = _iter_repr(ls) - return '[%s]' % (''.join(fs)) - - -@dispatch() -def _fm(v: tuple): - ls = [_fm(a) for a in v] - fs = _iter_repr(ls) - return '(%s)' % (''.join(fs)) - - -@dispatch() -def _fm(v: object): - return str(v) - - -def _iter_repr(ls): - if len(ls) <= 1: - fs = ''.join(ls) - else: - fs = ''.join(['\n {},'.format(_indent(l)) for l in ls]) + '\n' - return fs - - -def _indent(s: str) -> str: - s = s.split('\n') - return '\n '.join(s) diff --git a/requirements.txt b/requirements.txt index aedc790..72b8836 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pytest-cov codecov multipledispatch twine +yapf diff --git a/setup.cfg b/setup.cfg index 23d70d9..5ac6df0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,8 +4,8 @@ description-file = README.md [aliases] test = pytest -[tool:pytest] -addopts = --capture=no --cov - [flake8] -ignore = F811 \ No newline at end of file +ignore = F811 + +[tool:pytest] +doctest_optionflags = NORMALIZE_WHITESPACE \ No newline at end of file diff --git a/tests/test_dsl/test_branch.py b/tests/test_dsl/test_branch.py index ddf310f..833a851 100644 --- a/tests/test_dsl/test_branch.py +++ b/tests/test_dsl/test_branch.py @@ -8,7 +8,7 @@ from py_hcl.dsl.expr.io import IO from py_hcl.dsl.expr.wire import Wire from py_hcl.dsl.module import Module -from py_hcl.dsl.tpe.uint import U +from py_hcl.dsl.type.uint import U def test_branch(): @@ -34,8 +34,7 @@ class A(Module): c <<= a + b c <<= a + b - s = A.packed_module.statement_chain \ - .stmt_chain_head.stmt_holder.top_statement.statements + s = A.packed_module.statement_chain[0].top_statement.statements assert len(s) == 2 si = ScopeManager.get_scope_info(s[0].scope_id) @@ -57,6 +56,7 @@ class A(Module): def test_branch_syntax_error1(): with pytest.raises(StatementError): + class A(Module): io = IO() a = Wire(U.w(8)) @@ -72,6 +72,7 @@ class A(Module): def test_branch_syntax_error2(): with pytest.raises(StatementError): + class A(Module): io = IO() a = Wire(U.w(8)) @@ -84,6 +85,7 @@ class A(Module): def test_branch_syntax_error3(): with pytest.raises(StatementError): + class A(Module): io = IO() a = Wire(U.w(8)) @@ -98,6 +100,7 @@ class A(Module): def test_branch_syntax_error4(): with pytest.raises(StatementError): + class A(Module): io = IO() a = Wire(U.w(8)) @@ -114,6 +117,7 @@ class A(Module): def test_branch_syntax_error5(): with pytest.raises(StatementError): + class A(Module): io = IO() a = Wire(U.w(8)) diff --git a/tests/test_dsl/test_int.py b/tests/test_dsl/test_int.py new file mode 100644 index 0000000..6f1f091 --- /dev/null +++ b/tests/test_dsl/test_int.py @@ -0,0 +1,54 @@ +import pytest + +from py_hcl import Module, IO, U, S, Input, Output +from py_hcl.core.expr.error import ExprError +from py_hcl.core.type.error import TypeError + + +def test_uint_width_too_small(): + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(U.w(0))) + + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(U.w(-42))) + + +def test_sint_width_too_small(): + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(S.w(0))) + + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(S.w(1))) + + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(S.w(-42))) + + +def test_uint_lit_out_of_range(): + with pytest.raises( + ExprError, + match='Specified value out of range for the given type'): + + class A(Module): + io = IO(o=Output(U.w(1))) + io.o <<= U.w(1)(42) + + +def test_sint_lit_out_of_range(): + with pytest.raises( + ExprError, + match='Specified value out of range for the given type'): + + class A(Module): + io = IO(o=Output(S.w(2))) + io.o <<= S.w(2)(42) diff --git a/tests/test_dsl/test_io.py b/tests/test_dsl/test_io.py index 4cf5832..a037f54 100644 --- a/tests/test_dsl/test_io.py +++ b/tests/test_dsl/test_io.py @@ -6,21 +6,22 @@ from py_hcl.dsl.expr.io import IO, Input, Output, io_extend from py_hcl.dsl.module import Module from py_hcl.core.type import HclType -from py_hcl.dsl.tpe.uint import U +from py_hcl.dsl.type.uint import U +from py_hcl.utils import get_key_by_value class A(Module): io = IO( i=Input(U.w(8)), - o=Output(U.w(8))) + o=Output(U.w(8)), + ) io.o <<= io.i def test_io(): - table = A.packed_module.named_expr_chain.named_expr_chain_head \ - .named_expr_holder.named_expression_table - id = list(table.keys())[list(table.values()).index('io')] + table = A.packed_module.named_expr_chain[0].named_expression_table + id = get_key_by_value(table, 'io') t = ExprTable.table[id].hcl_type assert isinstance(t, BundleT) assert len(t.fields) == 2 @@ -28,14 +29,11 @@ def test_io(): def test_io_inherit_basis(): class B(A): - io = io_extend(A)( - i1=Input(U.w(9)), - ) + io = io_extend(A)(i1=Input(U.w(9)), ) io.o <<= io.i1 - table = B.packed_module.named_expr_chain.named_expr_chain_head \ - .named_expr_holder.named_expression_table - id = list(table.keys())[list(table.values()).index('io')] + table = B.packed_module.named_expr_chain[0].named_expression_table + id = get_key_by_value(table, 'io') t = ExprTable.table[id].hcl_type assert isinstance(t, BundleT) assert len(t.fields) == 3 @@ -43,14 +41,11 @@ class B(A): def test_io_inherit_override(): class B(A): - io = io_extend(A)( - i=Input(U.w(9)), - ) + io = io_extend(A)(i=Input(U.w(9)), ) io.o <<= io.i - table = B.packed_module.named_expr_chain.named_expr_chain_head \ - .named_expr_holder.named_expression_table - id = list(table.keys())[list(table.values()).index('io')] + table = B.packed_module.named_expr_chain[0].named_expression_table + id = get_key_by_value(table, 'io') t = ExprTable.table[id].hcl_type assert isinstance(t, BundleT) assert len(t.fields) == 2 @@ -58,17 +53,16 @@ class B(A): def test_io_no_wrap_io(): with pytest.raises(ExprError, match='^.*Input.*Output.*$'): + class A(Module): io = IO(i=HclType()) with pytest.raises(ExprError, match='^.*Input.*Output.*$'): + class A(Module): - io = IO( - i=HclType(), - o=Output(HclType())) + io = IO(i=HclType(), o=Output(HclType())) with pytest.raises(ExprError, match='^.*Input.*Output.*$'): + class A(Module): - io = IO( - i=Input(HclType()), - o=HclType()) + io = IO(i=Input(HclType()), o=HclType()) diff --git a/tests/test_dsl/test_module.py b/tests/test_dsl/test_module.py index 9886eed..94ae3dd 100644 --- a/tests/test_dsl/test_module.py +++ b/tests/test_dsl/test_module.py @@ -12,11 +12,11 @@ class A(Module): a = HclExpr() assert hasattr(A, "packed_module") - assert len(A.packed_module.named_expr_chain.named_expr_chain_head - .named_expr_holder.named_expression_table) == 2 + assert len(A.packed_module.named_expr_chain[0].named_expression_table) == 2 def test_module_not_contains_io(): with pytest.raises(ModuleError, match='^.*lack of io.*$'): + class A(Module): b = HclExpr() diff --git a/tests/test_dsl/test_statement.py b/tests/test_dsl/test_statement.py index d3d7b6d..63093a3 100644 --- a/tests/test_dsl/test_statement.py +++ b/tests/test_dsl/test_statement.py @@ -2,7 +2,7 @@ from py_hcl.dsl.expr.io import IO from py_hcl.dsl.expr.wire import Wire from py_hcl.dsl.module import Module -from py_hcl.dsl.tpe.uint import U +from py_hcl.dsl.type.uint import U def test_statement(): @@ -15,7 +15,6 @@ class A(Module): c <<= a + b - s = A.packed_module.statement_chain \ - .stmt_chain_head.stmt_holder.top_statement.statements + s = A.packed_module.statement_chain[0].top_statement.statements assert len(s) == 1 assert isinstance(s[0].statement, Connect) diff --git a/tests/test_dsl/test_vector.py b/tests/test_dsl/test_vector.py new file mode 100644 index 0000000..dad7995 --- /dev/null +++ b/tests/test_dsl/test_vector.py @@ -0,0 +1,16 @@ +import pytest + +from py_hcl import Module, IO, U, Input +from py_hcl.core.type.error import TypeError + + +def test_vec_size_too_small(): + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(U.w(8)[0])) + + with pytest.raises(TypeError, match='Specified size is invalid'): + + class A(Module): + io = IO(i=Input(U.w(8)[-42])) diff --git a/tests/test_firrtl_ir/test_expr/test_accessor.py b/tests/test_firrtl_ir/test_expr/test_accessor.py index 6e44238..4b340ba 100644 --- a/tests/test_firrtl_ir/test_expr/test_accessor.py +++ b/tests/test_firrtl_ir/test_expr/test_accessor.py @@ -100,8 +100,8 @@ def test_sub_access_non_vector(): sa = SubAccess(n("vc", uw(8)), u(2, w(3)), uw(8)) assert not check(sa) - sa = SubAccess(n("vc", bdl(a=(vec(uw(8), 10), True))), - u(2, w(3)), vec(uw(8), 10)) + sa = SubAccess(n("vc", bdl(a=(vec(uw(8), 10), True))), u(2, w(3)), + vec(uw(8), 10)) assert not check(sa) diff --git a/tests/test_firrtl_ir/test_expr/test_mux.py b/tests/test_firrtl_ir/test_expr/test_mux.py index 256255c..5b06cb7 100644 --- a/tests/test_firrtl_ir/test_expr/test_mux.py +++ b/tests/test_firrtl_ir/test_expr/test_mux.py @@ -9,9 +9,7 @@ def test_mux_basis(): assert check(mux) serialize_equal(mux, "mux(c, a, b)") - mux = Mux(u(1, w(1)), - n("b", vec(sw(8), 10)), - n("c", vec(sw(8), 10)), + mux = Mux(u(1, w(1)), n("b", vec(sw(8), 10)), n("c", vec(sw(8), 10)), vec(sw(8), 10)) assert check(mux) serialize_equal(mux, 'mux(UInt<1>("h1"), b, c)') @@ -28,7 +26,7 @@ def test_mux_cond_type_wrong(): assert not check(mux) -def test_mux_tf_value_type_wrong(): +def test_mux_tf_variable_type_wrong(): mux = Mux(n("c", uw(1)), n("a", uw(7)), n("b", uw(8)), uw(8)) assert not check(mux) diff --git a/tests/test_firrtl_ir/test_prim_ops/helper.py b/tests/test_firrtl_ir/test_prim_ops/helper.py index c314d8a..3504bda 100644 --- a/tests/test_firrtl_ir/test_prim_ops/helper.py +++ b/tests/test_firrtl_ir/test_prim_ops/helper.py @@ -6,7 +6,7 @@ UnknownType, BundleType, VectorType, ClockType from py_hcl.firrtl_ir.type.field import Field from py_hcl.firrtl_ir.type_checker import check -from py_hcl.utils import signed_num_bin_len, unsigned_num_bin_len +from py_hcl.utils import signed_num_bin_width, unsigned_num_bin_width class OpCase(object): @@ -45,7 +45,7 @@ def name_gen(): def u_gen(): if random.randint(0, 1): rand_u_value = random.randint(0, 1024) - rand_u_value_width = unsigned_num_bin_len(rand_u_value) + rand_u_value_width = unsigned_num_bin_width(rand_u_value) rand_u_width = random.randint(rand_u_value_width, 2 * rand_u_value_width) return u(rand_u_value, w(rand_u_width)) @@ -58,7 +58,7 @@ def u_gen(): def s_gen(): if random.randint(0, 1): rand_s_value = random.randint(-1024, 1024) - rand_s_value_width = signed_num_bin_len(rand_s_value) + rand_s_value_width = signed_num_bin_width(rand_s_value) rand_s_width = random.randint(rand_s_value_width, 2 * rand_s_value_width) return s(rand_s_value, w(rand_s_width)) diff --git a/tests/test_firrtl_ir/test_prim_ops/test_assint.py b/tests/test_firrtl_ir/test_prim_ops/test_assint.py index a1e5a87..b557ba3 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_assint.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_assint.py @@ -40,9 +40,6 @@ def test_assint(): basis_tester(assint_basis_cases) encounter_error_tester(assint_type_wrong_cases) encounter_error_tester(assint_width_wrong_cases) - serialize_equal(AsSInt(u(20, w(5)), sw(5)), - 'asSInt(UInt<5>("h14"))') - serialize_equal(AsSInt(s(-20, w(6)), sw(5)), - 'asSInt(SInt<6>("h-14"))') - serialize_equal(AsSInt(n("clock", ClockType()), sw(1)), - 'asSInt(clock)') + serialize_equal(AsSInt(u(20, w(5)), sw(5)), 'asSInt(UInt<5>("h14"))') + serialize_equal(AsSInt(s(-20, w(6)), sw(5)), 'asSInt(SInt<6>("h-14"))') + serialize_equal(AsSInt(n("clock", ClockType()), sw(1)), 'asSInt(clock)') diff --git a/tests/test_firrtl_ir/test_prim_ops/test_asuint.py b/tests/test_firrtl_ir/test_prim_ops/test_asuint.py index 8119a95..55fdd16 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_asuint.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_asuint.py @@ -40,9 +40,6 @@ def test_asuint(): basis_tester(asuint_basis_cases) encounter_error_tester(asuint_type_wrong_cases) encounter_error_tester(asuint_width_wrong_cases) - serialize_equal(AsUInt(u(20, w(5)), uw(5)), - 'asUInt(UInt<5>("h14"))') - serialize_equal(AsUInt(s(-20, w(6)), uw(5)), - 'asUInt(SInt<6>("h-14"))') - serialize_equal(AsUInt(n("clock", ClockType()), uw(1)), - 'asUInt(clock)') + serialize_equal(AsUInt(u(20, w(5)), uw(5)), 'asUInt(UInt<5>("h14"))') + serialize_equal(AsUInt(s(-20, w(6)), uw(5)), 'asUInt(SInt<6>("h-14"))') + serialize_equal(AsUInt(n("clock", ClockType()), uw(1)), 'asUInt(clock)') diff --git a/tests/test_firrtl_ir/test_prim_ops/test_bits.py b/tests/test_firrtl_ir/test_prim_ops/test_bits.py index efbb0cd..b6382fb 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_bits.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_bits.py @@ -31,54 +31,49 @@ def tpe(res_type): bits_basis_cases = [ - args(UIntType).const(int, int).filter( - lambda u, a, b: width(u) > a >= b >= 0).tpe( - lambda u, a, b: uw(a - b + 1)), - args(SIntType).const(int, int).filter( - lambda u, a, b: width(u) > a >= b >= 0).tpe( - lambda u, a, b: uw(a - b + 1)), + args(UIntType).const( + int, int).filter(lambda u, a, b: width(u) > a >= b >= 0).tpe( + lambda u, a, b: uw(a - b + 1)), + args(SIntType).const( + int, int).filter(lambda u, a, b: width(u) > a >= b >= 0).tpe( + lambda u, a, b: uw(a - b + 1)), ] bits_type_wrong_cases = [ - args(UnknownType).const(int, int).filter( - lambda u, a, b: a >= b >= 0).tpe( + args(UnknownType).const(int, int).filter(lambda u, a, b: a >= b >= 0).tpe( lambda u, a, b: uw(a - b + 1)), - args(VectorType).const(int, int).filter( - lambda u, a, b: a >= b >= 0).tpe( + args(VectorType).const(int, int).filter(lambda u, a, b: a >= b >= 0).tpe( lambda u, a, b: uw(a - b + 1)), - args(BundleType).const(int, int).filter( - lambda u, a, b: a >= b >= 0).tpe( + args(BundleType).const(int, int).filter(lambda u, a, b: a >= b >= 0).tpe( lambda u, a, b: uw(a - b + 1)), ] bits_width_wrong_cases = [ + args(UIntType).const( + int, int).filter(lambda u, a, b: width(u) > a >= b >= 0).tpe( + lambda u, a, b: uw(a - b + 2)), + args(SIntType).const( + int, int).filter(lambda u, a, b: width(u) > a >= b >= 0).tpe( + lambda u, a, b: uw(a - b + 2)), args(UIntType).const(int, int).filter( - lambda u, a, b: width(u) > a >= b >= 0).tpe( - lambda u, a, b: uw(a - b + 2)), - args(SIntType).const(int, int).filter( - lambda u, a, b: width(u) > a >= b >= 0).tpe( - lambda u, a, b: uw(a - b + 2)), - args(UIntType).const(int, int).filter( - lambda u, a, b: width(u) > a >= b >= 0).tpe( - lambda u, a, b: uw(a - b)), + lambda u, a, b: width(u) > a >= b >= 0).tpe(lambda u, a, b: uw(a - b)), args(SIntType).const(int, int).filter( - lambda u, a, b: width(u) > a >= b >= 0).tpe( - lambda u, a, b: uw(a - b)), + lambda u, a, b: width(u) > a >= b >= 0).tpe(lambda u, a, b: uw(a - b)), ] bits_invalid_cases = [ - args(UIntType).const(int, int).filter( - lambda u, a, b: width(u) <= a and a >= b >= 0).tpe( - lambda u, a, b: uw(a - b + 1)), - args(SIntType).const(int, int).filter( - lambda u, a, b: width(u) <= a and a >= b >= 0).tpe( - lambda u, a, b: uw(a - b + 1)), - args(UIntType).const(int, int).filter( - lambda u, a, b: width(u) > a and b > a >= 0).tpe( - lambda u, a, b: uw(b - a + 1)), - args(SIntType).const(int, int).filter( - lambda u, a, b: width(u) > a and b > a >= 0).tpe( - lambda u, a, b: uw(b - a + 1)), + args(UIntType).const( + int, int).filter(lambda u, a, b: width(u) <= a and a >= b >= 0).tpe( + lambda u, a, b: uw(a - b + 1)), + args(SIntType).const( + int, int).filter(lambda u, a, b: width(u) <= a and a >= b >= 0).tpe( + lambda u, a, b: uw(a - b + 1)), + args(UIntType).const( + int, int).filter(lambda u, a, b: width(u) > a and b > a >= 0).tpe( + lambda u, a, b: uw(b - a + 1)), + args(SIntType).const( + int, int).filter(lambda u, a, b: width(u) > a and b > a >= 0).tpe( + lambda u, a, b: uw(b - a + 1)), ] diff --git a/tests/test_firrtl_ir/test_prim_ops/test_dshl.py b/tests/test_firrtl_ir/test_prim_ops/test_dshl.py index 34dd68d..10bbfd6 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_dshl.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_dshl.py @@ -16,23 +16,21 @@ def tpe(res_type): dshl_basis_cases = [ - args(UIntType, UIntType).tpe( - lambda x, y: uw(width(x) + pow(2, width(y)) - 1)), - args(SIntType, UIntType).tpe( - lambda x, y: sw(width(x) + pow(2, width(y)) - 1)), + args(UIntType, + UIntType).tpe(lambda x, y: uw(width(x) + pow(2, width(y)) - 1)), + args(SIntType, + UIntType).tpe(lambda x, y: sw(width(x) + pow(2, width(y)) - 1)), ] dshl_type_wrong_cases = type_wrong_cases_2_args_gen(Dshl) dshl_width_wrong_cases = [ - args(UIntType, UIntType).tpe( - lambda x, y: uw(width(x) + pow(2, width(y)))), - args(SIntType, UIntType).tpe( - lambda x, y: sw(width(x) + pow(2, width(y)))), - args(UIntType, UIntType).tpe( - lambda x, y: uw(width(x) + pow(2, width(y)) - 2)), - args(SIntType, UIntType).tpe( - lambda x, y: sw(width(x) + pow(2, width(y)) - 2)), + args(UIntType, UIntType).tpe(lambda x, y: uw(width(x) + pow(2, width(y)))), + args(SIntType, UIntType).tpe(lambda x, y: sw(width(x) + pow(2, width(y)))), + args(UIntType, + UIntType).tpe(lambda x, y: uw(width(x) + pow(2, width(y)) - 2)), + args(SIntType, + UIntType).tpe(lambda x, y: sw(width(x) + pow(2, width(y)) - 2)), ] diff --git a/tests/test_firrtl_ir/test_prim_ops/test_dshr.py b/tests/test_firrtl_ir/test_prim_ops/test_dshr.py index 203b95d..65fb29f 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_dshr.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_dshr.py @@ -16,23 +16,17 @@ def tpe(res_type): dshr_basis_cases = [ - args(UIntType, UIntType).tpe( - lambda x, y: uw(width(x))), - args(SIntType, UIntType).tpe( - lambda x, y: sw(width(x))), + args(UIntType, UIntType).tpe(lambda x, y: uw(width(x))), + args(SIntType, UIntType).tpe(lambda x, y: sw(width(x))), ] dshr_type_wrong_cases = type_wrong_cases_2_args_gen(Dshr) dshr_width_wrong_cases = [ - args(UIntType, UIntType).tpe( - lambda x, y: uw(width(x) + 1)), - args(SIntType, UIntType).tpe( - lambda x, y: sw(width(x) + 1)), - args(UIntType, UIntType).tpe( - lambda x, y: uw(width(x) - 1)), - args(SIntType, UIntType).tpe( - lambda x, y: sw(width(x) - 1)), + args(UIntType, UIntType).tpe(lambda x, y: uw(width(x) + 1)), + args(SIntType, UIntType).tpe(lambda x, y: sw(width(x) + 1)), + args(UIntType, UIntType).tpe(lambda x, y: uw(width(x) - 1)), + args(SIntType, UIntType).tpe(lambda x, y: sw(width(x) - 1)), ] diff --git a/tests/test_firrtl_ir/test_prim_ops/test_neg.py b/tests/test_firrtl_ir/test_prim_ops/test_neg.py index 8154891..f3a334a 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_neg.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_neg.py @@ -39,7 +39,5 @@ def test_neg(): basis_tester(neg_basis_cases) encounter_error_tester(neg_type_wrong_cases) encounter_error_tester(neg_width_wrong_cases) - serialize_equal(Neg(u(20, w(5)), sw(6)), - 'neg(UInt<5>("h14"))') - serialize_equal(Neg(s(-20, w(6)), sw(7)), - 'neg(SInt<6>("h-14"))') + serialize_equal(Neg(u(20, w(5)), sw(6)), 'neg(UInt<5>("h14"))') + serialize_equal(Neg(s(-20, w(6)), sw(7)), 'neg(SInt<6>("h-14"))') diff --git a/tests/test_firrtl_ir/test_prim_ops/test_not.py b/tests/test_firrtl_ir/test_prim_ops/test_not.py index 44307f2..3701f69 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_not.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_not.py @@ -37,7 +37,5 @@ def test_not(): basis_tester(not_basis_cases) encounter_error_tester(not_type_wrong_cases) encounter_error_tester(not_width_wrong_cases) - serialize_equal(Not(u(20, w(5)), uw(5)), - 'not(UInt<5>("h14"))') - serialize_equal(Not(s(-20, w(6)), uw(6)), - 'not(SInt<6>("h-14"))') + serialize_equal(Not(u(20, w(5)), uw(5)), 'not(UInt<5>("h14"))') + serialize_equal(Not(s(-20, w(6)), uw(6)), 'not(SInt<6>("h-14"))') diff --git a/tests/test_firrtl_ir/test_prim_ops/test_shl.py b/tests/test_firrtl_ir/test_prim_ops/test_shl.py index 9c48104..6c20a06 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_shl.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_shl.py @@ -46,7 +46,5 @@ def test_shl(): basis_tester(shl_basis_cases) encounter_error_tester(shl_type_wrong_cases) encounter_error_tester(shl_width_wrong_cases) - serialize_equal(Shl(u(20, w(5)), 6, uw(11)), - 'shl(UInt<5>("h14"), 6)') - serialize_equal(Shl(s(-20, w(6)), 6, sw(12)), - 'shl(SInt<6>("h-14"), 6)') + serialize_equal(Shl(u(20, w(5)), 6, uw(11)), 'shl(UInt<5>("h14"), 6)') + serialize_equal(Shl(s(-20, w(6)), 6, sw(12)), 'shl(SInt<6>("h-14"), 6)') diff --git a/tests/test_firrtl_ir/test_prim_ops/test_shr.py b/tests/test_firrtl_ir/test_prim_ops/test_shr.py index ea8a72d..a9d339f 100644 --- a/tests/test_firrtl_ir/test_prim_ops/test_shr.py +++ b/tests/test_firrtl_ir/test_prim_ops/test_shr.py @@ -24,8 +24,10 @@ def tpe(res_type): shr_basis_cases = [ - args(UIntType).const(int).tpe(lambda x, y: uw(max(1, width(x) - y))), - args(SIntType).const(int).tpe(lambda x, y: sw(max(1, width(x) - y))), + args(UIntType).const(int).tpe(lambda x, y: uw(max(1, + width(x) - y))), + args(SIntType).const(int).tpe(lambda x, y: sw(max(1, + width(x) - y))), ] shr_type_wrong_cases = [ @@ -35,10 +37,14 @@ def tpe(res_type): ] shr_width_wrong_cases = [ - args(UIntType).const(int).tpe(lambda x, y: uw(max(1, width(x) - y) + 1)), - args(SIntType).const(int).tpe(lambda x, y: sw(max(1, width(x) - y) + 1)), - args(UIntType).const(int).tpe(lambda x, y: uw(max(1, width(x) - y) - 1)), - args(SIntType).const(int).tpe(lambda x, y: sw(max(1, width(x) - y) - 1)), + args(UIntType).const(int).tpe(lambda x, y: uw(max(1, + width(x) - y) + 1)), + args(SIntType).const(int).tpe(lambda x, y: sw(max(1, + width(x) - y) + 1)), + args(UIntType).const(int).tpe(lambda x, y: uw(max(1, + width(x) - y) - 1)), + args(SIntType).const(int).tpe(lambda x, y: sw(max(1, + width(x) - y) - 1)), ] @@ -46,7 +52,5 @@ def test_shr(): basis_tester(shr_basis_cases) encounter_error_tester(shr_type_wrong_cases) encounter_error_tester(shr_width_wrong_cases) - serialize_equal(Shr(u(20, w(5)), 3, uw(2)), - 'shr(UInt<5>("h14"), 3)') - serialize_equal(Shr(s(-20, w(6)), 3, uw(3)), - 'shr(SInt<6>("h-14"), 3)') + serialize_equal(Shr(u(20, w(5)), 3, uw(2)), 'shr(UInt<5>("h14"), 3)') + serialize_equal(Shr(s(-20, w(6)), 3, uw(3)), 'shr(SInt<6>("h-14"), 3)') diff --git a/tests/test_firrtl_ir/test_stmt/test_block.py b/tests/test_firrtl_ir/test_stmt/test_block.py index 652a347..fab6a30 100644 --- a/tests/test_firrtl_ir/test_stmt/test_block.py +++ b/tests/test_firrtl_ir/test_stmt/test_block.py @@ -13,17 +13,18 @@ def test_block_basis(): assert check(blk) serialize_stmt_equal(blk, "skip") - blk = Block([DefNode("n", u(1, w(1))), - Conditionally(n("n", uw(1)), - EmptyStmt(), - Connect(n("a", uw(8)), n("b", uw(8)))) - ]) + blk = Block([ + DefNode("n", u(1, w(1))), + Conditionally(n("n", uw(1)), EmptyStmt(), + Connect(n("a", uw(8)), n("b", uw(8)))) + ]) assert check(blk) - serialize_stmt_equal(blk, 'node n = UInt<1>("h1")\n' - 'when n :\n' - ' skip\n' - 'else :\n' - ' a <= b') + serialize_stmt_equal( + blk, 'node n = UInt<1>("h1")\n' + 'when n :\n' + ' skip\n' + 'else :\n' + ' a <= b') def test_block_empty(): diff --git a/tests/test_firrtl_ir/test_stmt/test_conditionally.py b/tests/test_firrtl_ir/test_stmt/test_conditionally.py index 552af9d..ec12972 100644 --- a/tests/test_firrtl_ir/test_stmt/test_conditionally.py +++ b/tests/test_firrtl_ir/test_stmt/test_conditionally.py @@ -12,10 +12,7 @@ def test_conditionally_basis(): s2 = Connect(n("a", uw(8)), n("b", uw(8))) cn = Conditionally(n("a", uw(1)), s1, s2) assert check(cn) - serialize_stmt_equal(cn, "when a :\n" - " skip\n" - "else :\n" - " a <= b") + serialize_stmt_equal(cn, "when a :\n" " skip\n" "else :\n" " a <= b") s1 = Block([ Connect(n("a", uw(8)), n("b", uw(8))), @@ -26,10 +23,10 @@ def test_conditionally_basis(): assert check(cn) serialize_stmt_equal( cn, 'when UInt<1>("h1") :\n' - ' a <= b\n' - ' c <= d\n' - 'else :\n' - ' skip') + ' a <= b\n' + ' c <= d\n' + 'else :\n' + ' skip') def test_conditionally_type_wrong(): diff --git a/tests/test_firrtl_ir/test_stmt/test_def/test_circuit.py b/tests/test_firrtl_ir/test_stmt/test_def/test_circuit.py index 8012d50..a50d26f 100644 --- a/tests/test_firrtl_ir/test_stmt/test_def/test_circuit.py +++ b/tests/test_firrtl_ir/test_stmt/test_def/test_circuit.py @@ -14,44 +14,47 @@ def test_circuit_basis(): m1 = DefModule("m1", [OutputPort("p", uw(8))], Connect(n("p", uw(8)), u(2, w(8)))) - m2 = DefModule("m2", [InputPort("b", uw(8)), - OutputPort("a", uw(8))], - Block([DefNode("n", u(1, w(1))), - Conditionally(n("n", uw(1)), - EmptyStmt(), - Connect(n("a", uw(8)), n("b", uw(8))))] - )) + m2 = DefModule( + "m2", + [InputPort("b", uw(8)), OutputPort("a", uw(8))], + Block([ + DefNode("n", u(1, w(1))), + Conditionally(n("n", uw(1)), EmptyStmt(), + Connect(n("a", uw(8)), n("b", uw(8)))) + ])) ct = DefCircuit("m1", [m1, m2]) assert check(ct) - serialize_stmt_equal(ct, 'circuit m1 :\n' - ' module m1 :\n' - ' output p : UInt<8>\n' - '\n' - ' p <= UInt<8>("h2")\n' - '\n' - ' module m2 :\n' - ' input b : UInt<8>\n' - ' output a : UInt<8>\n' - '\n' - ' node n = UInt<1>("h1")\n' - ' when n :\n' - ' skip\n' - ' else :\n' - ' a <= b\n' - '\n') + serialize_stmt_equal( + ct, 'circuit m1 :\n' + ' module m1 :\n' + ' output p : UInt<8>\n' + '\n' + ' p <= UInt<8>("h2")\n' + '\n' + ' module m2 :\n' + ' input b : UInt<8>\n' + ' output a : UInt<8>\n' + '\n' + ' node n = UInt<1>("h1")\n' + ' when n :\n' + ' skip\n' + ' else :\n' + ' a <= b\n' + '\n') def test_circuit_module_not_exist(): m1 = DefModule("m1", [OutputPort("p", uw(8))], Connect(n("p", uw(8)), u(2, w(8)))) - m2 = DefModule("m2", [InputPort("b", uw(8)), - OutputPort("a", uw(8))], - Block([DefNode("n", u(1, w(1))), - Conditionally(n("n", uw(1)), - EmptyStmt(), - Connect(n("a", uw(8)), n("b", uw(8))))] - )) + m2 = DefModule( + "m2", + [InputPort("b", uw(8)), OutputPort("a", uw(8))], + Block([ + DefNode("n", u(1, w(1))), + Conditionally(n("n", uw(1)), EmptyStmt(), + Connect(n("a", uw(8)), n("b", uw(8)))) + ])) ct = DefCircuit("m3", [m1, m2]) assert not check(ct) diff --git a/tests/test_firrtl_ir/test_stmt/test_def/test_memory.py b/tests/test_firrtl_ir/test_stmt/test_def/test_memory.py index 076be4b..2b62982 100644 --- a/tests/test_firrtl_ir/test_stmt/test_def/test_memory.py +++ b/tests/test_firrtl_ir/test_stmt/test_def/test_memory.py @@ -52,8 +52,8 @@ def test_read_port_index_wrong(): assert not check(mr) mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10)) - mr = DefMemReadPort("mr", mem_ref, - n("a", vec(uw(1), 10)), n("clock", ClockType())) + mr = DefMemReadPort("mr", mem_ref, n("a", vec(uw(1), 10)), + n("clock", ClockType())) assert not check(mr) @@ -95,8 +95,8 @@ def test_write_port_index_wrong(): assert not check(mw) mem_ref = n("m", vec(bdl(a=(uw(8), False)), 10)) - mw = DefMemWritePort("mw", mem_ref, - n("a", vec(uw(1), 10)), n("clock", ClockType())) + mw = DefMemWritePort("mw", mem_ref, n("a", vec(uw(1), 10)), + n("clock", ClockType())) assert not check(mw) @@ -106,6 +106,5 @@ def test_write_port_mem_wrong(): assert not check(mw) mem_ref = n("m", uw(9)) - mw = DefMemWritePort("mw", mem_ref, - n("a", uw(2)), n("clock", ClockType())) + mw = DefMemWritePort("mw", mem_ref, n("a", uw(2)), n("clock", ClockType())) assert not check(mw) diff --git a/tests/test_firrtl_ir/test_stmt/test_def/test_module.py b/tests/test_firrtl_ir/test_stmt/test_def/test_module.py index 51e829f..6e3871e 100644 --- a/tests/test_firrtl_ir/test_stmt/test_def/test_module.py +++ b/tests/test_firrtl_ir/test_stmt/test_def/test_module.py @@ -14,28 +14,31 @@ def test_module_basis(): mod = DefModule("m", [OutputPort("p", uw(8))], Connect(n("p", uw(8)), u(2, w(8)))) assert check(mod) - serialize_stmt_equal(mod, 'module m :\n' - ' output p : UInt<8>\n' - '\n' - ' p <= UInt<8>("h2")') + serialize_stmt_equal( + mod, 'module m :\n' + ' output p : UInt<8>\n' + '\n' + ' p <= UInt<8>("h2")') - mod = DefModule("m", [InputPort("b", uw(8)), - OutputPort("a", uw(8))], - Block([DefNode("n", u(1, w(1))), - Conditionally(n("n", uw(1)), - EmptyStmt(), - Connect(n("a", uw(8)), n("b", uw(8)))) - ])) + mod = DefModule( + "m", + [InputPort("b", uw(8)), OutputPort("a", uw(8))], + Block([ + DefNode("n", u(1, w(1))), + Conditionally(n("n", uw(1)), EmptyStmt(), + Connect(n("a", uw(8)), n("b", uw(8)))) + ])) assert check(mod) - serialize_stmt_equal(mod, 'module m :\n' - ' input b : UInt<8>\n' - ' output a : UInt<8>\n' - '\n' - ' node n = UInt<1>("h1")\n' - ' when n :\n' - ' skip\n' - ' else :\n' - ' a <= b') + serialize_stmt_equal( + mod, 'module m :\n' + ' input b : UInt<8>\n' + ' output a : UInt<8>\n' + '\n' + ' node n = UInt<1>("h1")\n' + ' when n :\n' + ' skip\n' + ' else :\n' + ' a <= b') def test_module_empty_ports(): @@ -50,14 +53,16 @@ def test_module_body_wrong(): def test_ext_module_basis(): - mod = DefExtModule("em", [InputPort("b", uw(8)), - OutputPort("a", uw(8))], "em") + mod = DefExtModule( + "em", + [InputPort("b", uw(8)), OutputPort("a", uw(8))], "em") assert check(mod) - serialize_stmt_equal(mod, 'extmodule em :\n' - ' input b : UInt<8>\n' - ' output a : UInt<8>\n' - '\n' - ' defname = em') + serialize_stmt_equal( + mod, 'extmodule em :\n' + ' input b : UInt<8>\n' + ' output a : UInt<8>\n' + '\n' + ' defname = em') def test_ext_module_empty_ports(): diff --git a/tests/test_firrtl_ir/test_stmt/test_def/test_node.py b/tests/test_firrtl_ir/test_stmt/test_def/test_node.py index 22ab384..0d3eb16 100644 --- a/tests/test_firrtl_ir/test_stmt/test_def/test_node.py +++ b/tests/test_firrtl_ir/test_stmt/test_def/test_node.py @@ -16,7 +16,7 @@ def test_node_basis(): def test_node_expr_wrong(): - node = DefNode("n1", s(20, w(5))) + node = DefNode("n1", s(20, w(4))) assert not check(node) node = DefNode("n2", SubIndex(n("v", vec(uw(8), 10)), 10, uw(8))) diff --git a/tests/test_firrtl_ir/test_stmt/test_def/test_register.py b/tests/test_firrtl_ir/test_stmt/test_def/test_register.py index 28c3617..40ebaef 100644 --- a/tests/test_firrtl_ir/test_stmt/test_def/test_register.py +++ b/tests/test_firrtl_ir/test_stmt/test_def/test_register.py @@ -24,44 +24,46 @@ def test_register_clock_wrong(): def test_init_register_basis(): - r1 = DefInitRegister("r1", uw(8), - n("clock", ClockType()), n("r", uw(1)), u(5, w(8))) + r1 = DefInitRegister("r1", uw(8), n("clock", ClockType()), n("r", uw(1)), + u(5, w(8))) assert check(r1) - serialize_stmt_equal(r1, 'reg r1 : UInt<8>, clock with :\n' - ' reset => (r, UInt<8>("h5"))') + serialize_stmt_equal( + r1, 'reg r1 : UInt<8>, clock with :\n' + ' reset => (r, UInt<8>("h5"))') - r2 = DefInitRegister("r2", sw(8), - n("clock", ClockType()), u(0, w(1)), s(5, w(8))) + r2 = DefInitRegister("r2", sw(8), n("clock", ClockType()), u(0, w(1)), + s(5, w(8))) assert check(r2) - serialize_stmt_equal(r2, 'reg r2 : SInt<8>, clock with :\n' - ' reset => (UInt<1>("h0"), SInt<8>("h5"))') + serialize_stmt_equal( + r2, 'reg r2 : SInt<8>, clock with :\n' + ' reset => (UInt<1>("h0"), SInt<8>("h5"))') def test_init_register_clock_wrong(): - r1 = DefInitRegister("r1", uw(8), - n("clock", uw(1)), n("r", uw(1)), u(5, w(8))) + r1 = DefInitRegister("r1", uw(8), n("clock", uw(1)), n("r", uw(1)), + u(5, w(8))) assert not check(r1) - r2 = DefInitRegister("r2", sw(8), - n("clock", sw(1)), u(0, w(1)), s(5, w(8))) + r2 = DefInitRegister("r2", sw(8), n("clock", sw(1)), u(0, w(1)), + s(5, w(8))) assert not check(r2) def test_init_register_reset_wrong(): - r1 = DefInitRegister("r1", uw(8), - n("clock", ClockType()), n("r", sw(1)), u(5, w(8))) + r1 = DefInitRegister("r1", uw(8), n("clock", ClockType()), n("r", sw(1)), + u(5, w(8))) assert not check(r1) - r2 = DefInitRegister("r2", sw(8), - n("clock", ClockType()), s(0, w(1)), s(5, w(8))) + r2 = DefInitRegister("r2", sw(8), n("clock", ClockType()), s(0, w(1)), + s(5, w(8))) assert not check(r2) def test_init_register_type_not_match(): - r1 = DefInitRegister("r1", uw(8), - n("clock", ClockType()), n("r", uw(1)), s(5, w(8))) + r1 = DefInitRegister("r1", uw(8), n("clock", ClockType()), n("r", uw(1)), + s(5, w(8))) assert not check(r1) - r2 = DefInitRegister("r2", uw(8), - n("clock", ClockType()), u(0, w(1)), s(5, w(8))) + r2 = DefInitRegister("r2", uw(8), n("clock", ClockType()), u(0, w(1)), + s(5, w(8))) assert not check(r2) diff --git a/tests/test_firrtl_ir/test_type.py b/tests/test_firrtl_ir/test_type.py index dc5d276..0b70b01 100644 --- a/tests/test_firrtl_ir/test_type.py +++ b/tests/test_firrtl_ir/test_type.py @@ -52,16 +52,16 @@ def test_bundle_type(): Field("c", VectorType(vt, 32)), ]) serialize_equal( - bd, "{a : UInt<8>[16], flip b : UInt<8>, c : UInt<8>[16][32]}" - ) + bd, "{a : UInt<8>[16], flip b : UInt<8>, c : UInt<8>[16][32]}") # TODO: Is it valid? bd = BundleType([ - Field("l1", BundleType([ - Field("l2", BundleType([ - Field("l3", UIntType(Width(8)), True) - ])), - Field("vt", vt), - ])) + Field( + "l1", + BundleType([ + Field("l2", BundleType([Field("l3", UIntType(Width(8)), + True)])), + Field("vt", vt), + ])) ]) serialize_equal(bd, "{l1 : {l2 : {flip l3 : UInt<8>}, vt : UInt<8>[16]}}") diff --git a/tests/test_firrtl_ir/test_type_equal.py b/tests/test_firrtl_ir/test_type_equal.py index f8407d3..776537b 100644 --- a/tests/test_firrtl_ir/test_type_equal.py +++ b/tests/test_firrtl_ir/test_type_equal.py @@ -37,8 +37,7 @@ def test_type_neq(): assert not equal(uw(10), sw(10)) assert not equal(uw(10), vec(uw(10), 8)) assert not equal(uw(10), vec(sw(10), 8)) - assert not equal(uw(10), - bdl(a=(vec(uw(10), 8), False), b=(uw(10), False))) + assert not equal(uw(10), bdl(a=(vec(uw(10), 8), False), b=(uw(10), False))) assert not equal(sw(10), UnknownType()) assert not equal(sw(10), ClockType()) @@ -47,8 +46,7 @@ def test_type_neq(): assert not equal(sw(10), uw(10)) assert not equal(sw(10), vec(uw(10), 8)) assert not equal(sw(10), vec(sw(10), 8)) - assert not equal(sw(10), - bdl(a=(vec(uw(10), 8), False), b=(uw(10), False))) + assert not equal(sw(10), bdl(a=(vec(uw(10), 8), False), b=(uw(10), False))) assert not equal(vec(uw(10), 8), UnknownType()) assert not equal(vec(uw(10), 8), ClockType()) @@ -65,12 +63,9 @@ def test_type_neq(): UnknownType()) assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), ClockType()) - assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), - sw(8)) - assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), - uw(10)) - assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), - uw(10)) + assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), sw(8)) + assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), uw(10)) + assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), uw(10)) assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)), vec(sw(10), 8)) assert not equal(bdl(a=(vec(uw(10), 8), False), b=(uw(10), False)),