diff --git a/.github/workflows/js.yml b/.github/workflows/js.yml index 16ed872..a4c94e8 100644 --- a/.github/workflows/js.yml +++ b/.github/workflows/js.yml @@ -8,6 +8,7 @@ on: jobs: build: + if: false # disable the entire workflow runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5e17664 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +ROOT_DIR ?= $(realpath $(PWD)) +BUILD_DIR ?= $(ROOT_DIR)/build +BUILD_TYPE ?= Debug + +all: check + +configure: + cmake -B${BUILD_DIR} -S. -G Ninja -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_BUILD_TYPE=${BUILD_TYPE} + +build: configure + cmake --build ${BUILD_DIR} + +test: build + ctest --test-dir ${BUILD_DIR}/tests/ --output-on-failure + python3 -m pytest . + +check: test + python3 -m mypy . + +clean: + rm -rf ${BUILD_DIR} + diff --git a/README.md b/README.md index 2413ec0..48b51db 100644 --- a/README.md +++ b/README.md @@ -17,27 +17,37 @@ Features: - Supported output formats: C++, JSON, TypeScript - Supported output formats TODO: Go, Markdown (documentation) -## Dependencies +## Runtime Dependencies - Python 3.X On Linux: -``` +```bash sudo apt install python3 ``` -On Windows 10: +On Windows: 1. Download https://bootstrap.pypa.io/get-pip.py 2. Execute `python3 get_pip.py` 3. Execute `pip3 install pyyaml` -### Build dependencies +## Build & Test Dependencies - libgtest-dev (for testing) +- pytest (for testing) +- cmake +- ninja +- mypy + +## Testing & Verification + +```bash +make check +``` -## Generate messages +## Generating messages All data types should be placed in one directory. Each protocol can be placed in any arbitrary directory. @@ -177,18 +187,19 @@ Type ids can be assigned to structs in `weather_station.yaml` file (see below). #### Protocol -**Protocol** defines the protocol ID and type IDs for structs that will be used as messages. -Type ID used during serialization/deserialization to identify the message type. -Multiple protocols may be used in one system, e.g. `my_namespace/bootloader` and `my_namespace/application`. -Parser can check the protocol by protocol ID, that can be serialized in message header. +**Protocol** defines the protocol ID and message IDs for structs that will be used +as messages. Message ID used during serialization/deserialization to identify the +message type. Multiple protocols may be used in one system, e.g. +`my_namespace/bootloader` and `my_namespace/application`. Parser can check the +protocol by protocol ID, that can be serialized in message header. Example protocol definition (`weather_station.yaml`): ```yaml comment: "Weather station application protocol" -types_map: - 0: "heartbeat" - 1: "system_status" - 2: "system_command" - 3: "baro_report" +messages: + 0: { name: "heartbeat", type: "application/heartbeat", comment: "Heartbeat message" } + 1: { name: "system_status", type: "system/status", comment: "System status message" } + 2: { name: "system_command", type: "system/command", comment: "System command message" } + 3: { name: "baro_report", type: "measurement/baro_report", comment: "Barometer report message" } ``` diff --git a/messgen-generate.py b/messgen-generate.py index 4216117..e1915fd 100644 --- a/messgen-generate.py +++ b/messgen-generate.py @@ -1,6 +1,7 @@ import argparse import os +from messgen.validation import validate_protocol, validate_types from messgen import generator, yaml_parser from pathlib import Path @@ -26,10 +27,14 @@ def generate(args: argparse.Namespace): if (gen := generator.get_generator(args.lang, opts)) is not None: if parsed_protocols and parsed_types: - gen.generate(Path(args.outdir), parsed_types, parsed_protocols) - elif parsed_types: + for proto_def in parsed_protocols.values(): + validate_protocol(proto_def, parsed_types) + + if parsed_types: + validate_types(parsed_types) gen.generate_types(Path(args.outdir), parsed_types) - elif parsed_protocols: + + if parsed_protocols: gen.generate_protocols(Path(args.outdir), parsed_protocols) else: diff --git a/messgen/cpp_generator.py b/messgen/cpp_generator.py index ed8d92f..3d31bd0 100644 --- a/messgen/cpp_generator.py +++ b/messgen/cpp_generator.py @@ -9,8 +9,6 @@ Path, ) -from .validation import validate_protocol - from .common import ( SEPARATOR, SIZE_TYPE, @@ -58,6 +56,7 @@ def _namespace(name: str, code:list[str]): if ns_name: code.append("") code.append(f"}} // namespace {ns_name}") + code.append("") @contextmanager @@ -67,6 +66,7 @@ def _struct(name: str, code: list[str]): yield finally: code.append("};") + code.append("") def _inline_comment(type_def: FieldType | EnumValue): @@ -125,15 +125,6 @@ def __init__(self, options: dict): self._ctx: dict = {} self._types: dict[str, MessgenType] = {} - def generate(self, out_dir: Path, types: dict[str, MessgenType], protocols: dict[str, Protocol]) -> None: - self.validate(types, protocols) - self.generate_types(out_dir, types) - self.generate_protocols(out_dir, protocols) - - def validate(self, types: dict[str, MessgenType], protocols: dict[str, Protocol]): - for proto_def in protocols.values(): - validate_protocol(proto_def, types) - def generate_types(self, out_dir: Path, types: dict[str, MessgenType]) -> None: self._types = types for type_name, type_def in types.items(): @@ -169,7 +160,7 @@ def _generate_type_file(self, type_name: str, type_def: MessgenType) -> list: elif isinstance(type_def, StructType): code.extend(self._generate_type_struct(type_name, type_def)) - code.extend(self._generate_members_of(type_name, type_def)) + code.extend(self._generate_type_members_of(type_name, type_def)) code = self._PREAMBLE_HEADER + self._generate_includes() + code @@ -185,82 +176,92 @@ def _generate_proto_file(self, proto_name: str, proto_def: Protocol) -> list[str self._add_include("messgen/messgen.h") namespace_name, class_name = _split_last_name(proto_name) - print(f"Namespace: {namespace_name}, Class: {class_name}") with _namespace(namespace_name, code): with _struct(class_name, code): - for type_name in proto_def.types.values(): - self._add_include(type_name + self._EXT_HEADER) + for message in proto_def.messages.values(): + self._add_include(message.type + self._EXT_HEADER) proto_id = proto_def.proto_id if proto_id is not None: - code.append(f" constexpr static int PROTO_ID = {proto_id};") + code.append(f" constexpr static inline int PROTO_ID = {proto_id};") + code.append(f" constexpr static inline uint32_t HASH = {hash(proto_def)};") - code.extend(self._generate_type_id_decl(proto_def)) - code.extend(self._generate_reflect_type_decl()) + code.extend(self._generate_messages(class_name, proto_def)) + code.extend(self._generate_reflect_message_decl()) code.extend(self._generate_dispatcher_decl()) - code.extend(self._generate_type_ids(class_name, proto_def)) - code.extend(self._generate_reflect_type(class_name, proto_def)) + code.extend(self._generate_protocol_members_of(class_name, proto_def)) + code.extend(self._generate_reflect_message(class_name, proto_def)) code.extend(self._generate_dispatcher(class_name)) code.append("") return self._PREAMBLE_HEADER + self._generate_includes() + code - @staticmethod - def _generate_type_id_decl(proto: Protocol) -> list[str]: - return textwrap.indent(textwrap.dedent(""" - template - constexpr static inline int TYPE_ID = []{ - static_assert(sizeof(Msg) == 0, \"Provided type is not part of the protocol.\"); - return 0; - }();"""), " ").splitlines() - @staticmethod - def _generate_type_ids(class_name: str, proto: Protocol) -> list[str]: + def _generate_messages(self, class_name: str, proto_def: Protocol): + self._add_include("tuple") code: list[str] = [] - for type_id, type_name in proto.types.items(): - code.append(f" template <> constexpr inline int {class_name}::TYPE_ID<{_qual_name(type_name)}> = {type_id};") + for message in proto_def.messages.values(): + code.extend(textwrap.indent(textwrap.dedent(f""" + struct {message.name} : {_qual_name(message.type)} {{ + using data_type = {_qual_name(message.type)}; + using protocol_type = {class_name}; + constexpr inline static int PROTO_ID = protocol_type::PROTO_ID; + constexpr inline static int MESSAGE_ID = {message.message_id}; + }};"""), " ").splitlines()) + return code + + def _generate_protocol_members_of(self, class_name: str, proto_def: Protocol): + self._add_include("tuple") + code: list[str] = [] + code.append(f"[[nodiscard]] consteval auto members_of(::messgen::reflect_t<{class_name}>) noexcept {{") + code.append(" return std::tuple{") + for message in proto_def.messages.values(): + code.append(f" ::messgen::member<{class_name}, {class_name}::{message.name}>{{\"{message.name}\"}},") + code.append(" };") + code.append("}") + code.append("") return code @staticmethod - def _generate_reflect_type_decl() -> list[str]: + def _generate_reflect_message_decl() -> list[str]: return textwrap.indent(textwrap.dedent(""" template - constexpr static auto reflect_message(int type_id, Fn&& fn); + constexpr static auto reflect_message(int msg_id, Fn &&fn); """), " ").splitlines() @staticmethod - def _generate_reflect_type(class_name: str, proto: Protocol) -> list[str]: + def _generate_reflect_message(class_name: str, proto: Protocol) -> list[str]: code: list[str] = [] - code.append(" template ") - code.append(f" constexpr auto {class_name}::reflect_message(int type_id, Fn&& fn) {{") - code.append(" switch (type_id) {") - for type_name in proto.types.values(): - qual_name = _qual_name(type_name) - code.append(f" case TYPE_ID<{qual_name}>:") - code.append(f" std::forward(fn)(::messgen::reflect_type<{qual_name}>);") - code.append(f" return;") - code.append(" }") + code.append("template ") + code.append(f"constexpr auto {class_name}::reflect_message(int msg_id, Fn &&fn) {{") + code.append(" switch (msg_id) {") + for message in proto.messages.values(): + msg_type = f"{class_name}::{_unqual_name(message.name)}" + code.append(f" case {msg_type}::MESSAGE_ID:") + code.append(f" std::forward(fn)(::messgen::reflect_type<{msg_type}>);") + code.append(f" return;") code.append(" }") + code.append("}") return code @staticmethod def _generate_dispatcher_decl() -> list[str]: return textwrap.indent(textwrap.dedent(""" template - static bool dispatch_message(int msg_id, const uint8_t *payload, T handler); + constexpr static bool dispatch_message(int msg_id, const uint8_t *payload, T handler); """), " ").splitlines() @staticmethod def _generate_dispatcher(class_name: str) -> list[str]: return textwrap.dedent(f""" template - bool {class_name}::dispatch_message(int msg_id, const uint8_t *payload, T handler) {{ + constexpr bool {class_name}::dispatch_message(int msg_id, const uint8_t *payload, T handler) {{ auto result = false; reflect_message(msg_id, [&](R) {{ using message_type = messgen::splice_t; if constexpr (requires(message_type msg) {{ handler(msg); }}) {{ - message_type msg; + auto msg = message_type{{}}; msg.deserialize(payload); handler(std::move(msg)); result = true; @@ -269,17 +270,6 @@ def _generate_dispatcher(class_name: str) -> list[str]: return result; }}""").splitlines() - @staticmethod - def _generate_traits() -> list[str]: - return textwrap.dedent(""" - namespace messgen { - template - struct reflect_t {}; - - template - struct splice_t {}; - }""").splitlines() - @staticmethod def _generate_comment_type(type_def): if not type_def.comment: @@ -379,11 +369,12 @@ def _generate_type_struct(self, type_name: str, type_def: StructType): is_empty = len(groups) == 0 is_flat = is_empty or (len(groups) == 1 and groups[0].size is not None) if is_flat: - code.append(_indent("static constexpr size_t FLAT_SIZE = %d;" % (0 if is_empty else groups[0].size))) + code.append(_indent("constexpr static inline size_t FLAT_SIZE = %d;" % (0 if is_empty else groups[0].size))) is_flat_str = "true" - code.append(_indent(f"static constexpr bool IS_FLAT = {is_flat_str};")) - code.append(_indent(f"static constexpr const char* NAME = \"{_qual_name(type_name)}\";")) - code.append(_indent(f"static constexpr const char* SCHEMA = R\"_({self._generate_schema(type_def)})_\";")) + code.append(_indent(f"constexpr static inline bool IS_FLAT = {is_flat_str};")) + code.append(_indent(f"constexpr static inline uint32_t HASH = {hash(type_def)};")) + code.append(_indent(f"constexpr static inline const char* NAME = \"{_qual_name(type_name)}\";")) + code.append(_indent(f"constexpr static inline const char* SCHEMA = R\"_({self._generate_schema(type_def)})_\";")) code.append("") for field in type_def.fields: @@ -483,7 +474,7 @@ def _generate_type_struct(self, type_name: str, type_def: StructType): if self._get_cpp_standard() >= 20: # Operator <=> code.append("") - code.append(_indent("auto operator<=>(const %s&) const = default;" % unqual_name)) + code.append(_indent("auto operator<=>(const %s &) const = default;" % unqual_name)) code.append("};") @@ -527,17 +518,17 @@ def _generate_includes(self): code.append("") return code - def _generate_members_of(self, type_name: str, type_def: StructType): + def _generate_type_members_of(self, type_name: str, type_def: StructType): self._add_include("tuple") unqual_name = _unqual_name(type_name) code: list[str] = [] code.append("") - code.append(f"[[nodiscard]] inline constexpr auto members_of(::messgen::reflect_t<{unqual_name}>) noexcept {{") + code.append(f"[[nodiscard]] consteval auto members_of(::messgen::reflect_t<{unqual_name}>) noexcept {{") code.append(" return std::tuple{") for field in type_def.fields: - code.append(f" ::messgen::member{{\"{field.name}\", &{unqual_name}::{field.name}}},") + code.append(f" ::messgen::member_variable{{{{\"{field.name}\"}}, &{unqual_name}::{field.name}}},") code.append(" };") code.append("}") diff --git a/messgen/dynamic.py b/messgen/dynamic.py index f024a40..2b97790 100644 --- a/messgen/dynamic.py +++ b/messgen/dynamic.py @@ -1,5 +1,7 @@ import struct +from functools import singledispatchmethod +from abc import ABC, abstractmethod from pathlib import Path from .model import ( @@ -35,88 +37,113 @@ class MessgenError(Exception): pass -class Converter: +class TypeConverter(ABC): + def __init__(self, types: dict[str, MessgenType], type_name: str): - self.type_name = type_name - self.type_def = types[type_name] - self.type_class = self.type_def.type_class + self._type_name = type_name + self._type_def = types[type_name] + self._type_class = self._type_def.type_class + def type_name(self) -> str: + return self._type_name -class ScalarConverter(Converter): - def __init__(self, types: dict[str, MessgenType], type_name:str): + def type_hash(self) -> int: + return hash(self._type_def) + + def serialize(self, data: dict) -> bytes: + return self._serialize(data) + + def deserialize(self, data: bytes) -> dict: + msg, sz = self._deserialize(data) + if sz != len(data): + raise MessgenError( + f"Invalid message size: expected={sz} actual={len(data)} type_name={self._type_name}") + return msg + + @abstractmethod + def _serialize(self, data) -> bytes: + pass + + @abstractmethod + def _deserialize(self, data) -> tuple[dict, int]: + pass + + +class ScalarConverter(TypeConverter): + def __init__(self, types: dict[str, MessgenType], type_name: str): super().__init__(types, type_name) - assert self.type_class == TypeClass.scalar + assert self._type_class == TypeClass.scalar self.struct_fmt = STRUCT_TYPES_MAP.get(type_name) if self.struct_fmt is None: - raise RuntimeError("Unsupported scalar type \"%s\"" % self.type_name) + raise RuntimeError("Unsupported scalar type \"%s\"" % type_name) self.struct_fmt = "<" + self.struct_fmt self.size = struct.calcsize(self.struct_fmt) self.def_value: bool | float | int = 0 - if self.type_name == "bool": + if type_name == "bool": self.def_value = False - elif self.type_name == "float32" or self.type_name == "float64": + elif type_name == "float32" or type_name == "float64": self.def_value = 0.0 - def serialize(self, data): + def _serialize(self, data): return struct.pack(self.struct_fmt, data) - def deserialize(self, data): + def _deserialize(self, data): return struct.unpack(self.struct_fmt, data[:self.size])[0], self.size def default_value(self): return self.def_value -class EnumConverter(Converter): +class EnumConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name:str): super().__init__(types, type_name) - assert self.type_class == TypeClass.enum - assert isinstance(self.type_def, EnumType) - self.base_type = self.type_def.base_type + assert self._type_class == TypeClass.enum + assert isinstance(self._type_def, EnumType) + self.base_type = self._type_def.base_type self.struct_fmt = STRUCT_TYPES_MAP.get(self.base_type, None) if self.struct_fmt is None: - raise RuntimeError("Unsupported base type \"%s\" in %s" % (self.base_type, self.type_name)) + raise RuntimeError("Unsupported base type \"%s\" in %s" % (self.base_type, type_name)) self.struct_fmt = "<" + self.struct_fmt self.size = struct.calcsize(self.struct_fmt) self.mapping = {} - for item in self.type_def.values: + for item in self._type_def.values: self.mapping[item.value] = item.name self.rev_mapping = {v: k for k, v in self.mapping.items()} - def serialize(self, data): + def _serialize(self, data): v = self.rev_mapping[data] return struct.pack(self.struct_fmt, v) - def deserialize(self, data): + def _deserialize(self, data): v, = struct.unpack(self.struct_fmt, data[:self.size]) return self.mapping[v], self.size def default_value(self): - return self.type_def.values[0].name + return self._type_def.values[0].name -class StructConverter(Converter): +class StructConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name:str): super().__init__(types, type_name) - assert self.type_class == TypeClass.struct - assert isinstance(self.type_def, StructType) - self.fields = [(field.name, get_type(types, field.type)) - for field in self.type_def.fields] + assert self._type_class == TypeClass.struct + assert isinstance(self._type_def, StructType) + self.fields = [(field.name, create_type_converter(types, field.type)) + for field in self._type_def.fields] - def serialize(self, data): + def _serialize(self, data): out = [] for field_name, field_type in self.fields: v = data.get(field_name, None) if v is None: v = field_type.default_value() - out.append(field_type.serialize(v)) + out.append(field_type._serialize(v)) return b"".join(out) - def deserialize(self, data): + def _deserialize(self, data): out = {} offset = 0 for field_name, field_type in self.fields: - value, size = field_type.deserialize(data[offset:]) + value, size = field_type._deserialize(data[offset:]) out[field_name] = value offset += size return out, offset @@ -126,26 +153,26 @@ def default_value(self): for field_name, field_type in self.fields} -class ArrayConverter(Converter): +class ArrayConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name:str): super().__init__(types, type_name) - assert self.type_class == TypeClass.array - assert isinstance(self.type_def, ArrayType) - self.element_type = get_type(types, self.type_def.element_type) - self.array_size = self.type_def.array_size + assert self._type_class == TypeClass.array + assert isinstance(self._type_def, ArrayType) + self.element_type = create_type_converter(types, self._type_def.element_type) + self.array_size = self._type_def.array_size - def serialize(self, data): + def _serialize(self, data): out = [] assert len(data) == self.array_size for item in data: - out.append(self.element_type.serialize(item)) + out.append(self.element_type._serialize(item)) return b"".join(out) - def deserialize(self, data): + def _deserialize(self, data): out = [] offset = 0 for i in range(self.array_size): - value, size = self.element_type.deserialize(data[offset:]) + value, size = self.element_type._deserialize(data[offset:]) out.append(value) offset += size return out, offset @@ -157,29 +184,29 @@ def default_value(self): return out -class VectorConverter(Converter): +class VectorConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name: str): super().__init__(types, type_name) - assert self.type_class == TypeClass.vector - assert isinstance(self.type_def, VectorType) - self.size_type = get_type(types, "uint32") - self.element_type = get_type(types, self.type_def.element_type) + assert self._type_class == TypeClass.vector + assert isinstance(self._type_def, VectorType) + self.size_type = create_type_converter(types, "uint32") + self.element_type = create_type_converter(types, self._type_def.element_type) - def serialize(self, data): + def _serialize(self, data): out = [] - out.append(self.size_type.serialize(len(data))) + out.append(self.size_type._serialize(len(data))) for item in data: - out.append(self.element_type.serialize(item)) + out.append(self.element_type._serialize(item)) return b"".join(out) - def deserialize(self, data): + def _deserialize(self, data): out = [] offset = 0 - n, n_size = self.size_type.deserialize(data[offset:]) + n, n_size = self.size_type._deserialize(data[offset:]) offset += n_size for i in range(n): - value, n = self.element_type.deserialize(data[offset:]) + value, n = self.element_type._deserialize(data[offset:]) out.append(value) offset += n return out, offset @@ -188,32 +215,32 @@ def default_value(self): return [] -class MapConverter(Converter): +class MapConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name:str): super().__init__(types, type_name) - assert self.type_class == TypeClass.map - assert isinstance(self.type_def, MapType) - self.size_type = get_type(types, "uint32") - self.key_type = get_type(types, self.type_def.key_type) - self.value_type = get_type(types, self.type_def.value_type) + assert self._type_class == TypeClass.map + assert isinstance(self._type_def, MapType) + self.size_type = create_type_converter(types, "uint32") + self.key_type = create_type_converter(types, self._type_def.key_type) + self.value_type = create_type_converter(types, self._type_def.value_type) - def serialize(self, data): + def _serialize(self, data): out = [] - out.append(self.size_type.serialize(len(data))) + out.append(self.size_type._serialize(len(data))) for k, v in data.items(): - out.append(self.key_type.serialize(k)) - out.append(self.value_type.serialize(v)) + out.append(self.key_type._serialize(k)) + out.append(self.value_type._serialize(v)) return b"".join(out) - def deserialize(self, data): + def _deserialize(self, data): out = {} offset = 0 - n, n_size = self.size_type.deserialize(data[offset:]) + n, n_size = self.size_type._deserialize(data[offset:]) offset += n_size for i in range(n): - key, n = self.key_type.deserialize(data[offset:]) + key, n = self.key_type._deserialize(data[offset:]) offset += n - value, n = self.value_type.deserialize(data[offset:]) + value, n = self.value_type._deserialize(data[offset:]) offset += n out[key] = value return out, offset @@ -222,18 +249,18 @@ def default_value(self): return {} -class StringConverter(Converter): +class StringConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name:str): super().__init__(types, type_name) - assert self.type_class == TypeClass.string - self.size_type = get_type(types, "uint32") + assert self._type_class == TypeClass.string + self.size_type = create_type_converter(types, "uint32") self.struct_fmt = "<%is" - def serialize(self, data): - return self.size_type.serialize(len(data)) + struct.pack(self.struct_fmt % len(data), data.encode("utf-8")) + def _serialize(self, data): + return self.size_type._serialize(len(data)) + struct.pack(self.struct_fmt % len(data), data.encode("utf-8")) - def deserialize(self, data): - n, n_size = self.size_type.deserialize(data) + def _deserialize(self, data): + n, n_size = self.size_type._deserialize(data) offset = n_size value = struct.unpack(self.struct_fmt % n, data[offset:offset + n])[0] offset += n @@ -243,18 +270,18 @@ def default_value(self): return "" -class BytesConverter(Converter): +class BytesConverter(TypeConverter): def __init__(self, types: dict[str, MessgenType], type_name:str): super().__init__(types, type_name) - assert self.type_class == TypeClass.bytes - self.size_type = get_type(types, "uint32") + assert self._type_class == TypeClass.bytes + self.size_type = create_type_converter(types, "uint32") self.struct_fmt = "<%is" - def serialize(self, data): - return self.size_type.serialize(len(data)) + struct.pack(self.struct_fmt % len(data), data) + def _serialize(self, data): + return self.size_type._serialize(len(data)) + struct.pack(self.struct_fmt % len(data), data) - def deserialize(self, data): - n, n_size = self.size_type.deserialize(data) + def _deserialize(self, data): + n, n_size = self.size_type._deserialize(data) offset = n_size value = struct.unpack(self.struct_fmt % n, data[offset:offset + n])[0] offset += n @@ -264,7 +291,7 @@ def default_value(self): return b"" -def get_type(types: dict[str, MessgenType], type_name:str) -> Converter: +def create_type_converter(types: dict[str, MessgenType], type_name: str) -> TypeConverter: type_def = types[type_name] type_class = type_def.type_class if type_class == TypeClass.scalar: @@ -286,50 +313,75 @@ def get_type(types: dict[str, MessgenType], type_name:str) -> Converter: raise RuntimeError("Unsupported field type class \"%s\" in %s" % (type_class, type_def.type)) +class MessageInfo: + + def __init__(self, proto_id: int, message_id: int, proto_name: str, message_name: str, type_converter: TypeConverter): + self._proto_id = proto_id + self._message_id = message_id + self._proto_name = proto_name + self._message_name = message_name + self._type_converter = type_converter + + def proto_name(self) -> str: + return self._proto_name + + def message_name(self) -> str: + return self._message_name + + def proto_id(self) -> int: + return self._proto_id + + def message_id(self) -> int: + return self._message_id + + def type_name(self) -> str: + return self._type_converter.type_name() + + def type_hash(self) -> int: + return self._type_converter.type_hash() + + def type_converter(self) -> TypeConverter: + return self._type_converter + + class Codec: - def __init__(self): - self.types_by_name = {} - self.types_by_id = {} + + def __init__(self) -> None: + self._converters_by_name: dict[str, TypeConverter] = {} + self._id_by_name: dict[tuple[str, str], tuple[int, int, str]] = {} + self._name_by_id: dict[tuple[int, int], tuple[str, str, str]] = {} def load(self, type_dirs: list[str | Path], protocols: list[str] | None = None): parsed_types = parse_types(type_dirs) - if protocols: - parsed_protocols = parse_protocols(protocols) + if not protocols: + return + + for type_name in parsed_types: + self._converters_by_name[type_name] = create_type_converter(parsed_types, type_name) + parsed_protocols = parse_protocols(protocols) for proto_name, proto_def in parsed_protocols.items(): - by_name: tuple[int, dict] = (proto_def.proto_id, {}) - by_id: tuple[str, dict] = (proto_name, {}) - for type_id, type_name in proto_def.types.items(): - t = get_type(parsed_types, type_name) - by_name[1][type_name] = t - if type_id is not None: - by_id[1][type_id] = t - self.types_by_name[proto_name] = by_name - self.types_by_id[proto_def.proto_id] = by_id - - def get_type_by_name(self, proto_name: str, type_name: str): - return self.types_by_name[proto_name][1][type_name] - - def serialize(self, proto_name: str, msg_name: str, msg: dict) -> tuple[int, int, bytes]: - p = self.types_by_name.get(proto_name) - if p is None: - raise MessgenError("Unsupported proto_name in serialization: proto_name=%s" % proto_name) - t = p[1].get(msg_name) - if t is None: - raise MessgenError( - "Unsupported msg_name in serialization: proto_name=%s msg_name=%s" % (proto_name, msg_name)) - payload = t.serialize(msg) - return p[0], t.id, payload - - def deserialize(self, proto_id: int, msg_id: int, data: bytes) -> tuple[int, str, dict]: - p = self.types_by_id.get(proto_id) - if p is None: - raise MessgenError("Unsupported proto_id in deserialization: proto_id=%s" % proto_id) - t = p[1].get(msg_id) - if t is None: - raise MessgenError("Unsupported msg_id in deserialization: proto_id=%s msg_id=%s" % (proto_id, msg_id)) - msg, sz = t.deserialize(data) - if sz != len(data): - raise MessgenError( - "Invalid message size: expected=%s actual=%s proto_id=%s msg_id=%s" % (sz, len(data), proto_id, msg_id)) - return p[0], t.type_name, msg + for msg_id, message in proto_def.messages.items(): + self._id_by_name[(proto_name, message.name)] = (proto_def.proto_id, msg_id, message.type) + self._name_by_id[(proto_def.proto_id, msg_id)] = (proto_name, message.name, message.type) + + def type_converter(self, type_name: str) -> TypeConverter: + if converter := self._converters_by_name.get(type_name): + return converter + raise MessgenError(f"Unsupported type_name={type_name}") + + def message_info_by_id(self, proto_id: int, message_id: int) -> MessageInfo: + key = (proto_id, message_id) + if not key in self._name_by_id: + raise MessgenError(f"Unsupported proto_id={proto_id} message_id={message_id}") + + proto_name, message_name, type_name = self._name_by_id[key] + return MessageInfo(proto_id, message_id, proto_name, message_name, self._converters_by_name[type_name]) + + def message_info_by_name(self, proto_name: str, message_name: str) -> MessageInfo: + key = (proto_name, message_name) + if not key in self._id_by_name: + raise MessgenError(f"Unsupported proto_name={proto_name} message_name={message_name}") + + proto_id, message_id, type_name = self._id_by_name[key] + return MessageInfo(proto_id, message_id, proto_name, message_name, self._converters_by_name[type_name]) diff --git a/messgen/json_generator.py b/messgen/json_generator.py index c2e67e6..3f34919 100644 --- a/messgen/json_generator.py +++ b/messgen/json_generator.py @@ -39,12 +39,12 @@ def generate_types(self, out_dir: Path, types: dict[str, MessgenType]) -> None: def generate_protocols(self, out_dir: Path, protocols: dict[str, Protocol]) -> None: combined: list = [] - + for proto_def in protocols.values(): proto_dict = asdict(proto_def) proto_dict["version"] = version_hash(proto_dict) combined.append(proto_dict) - + self._write_file(out_dir, "protocols", combined) def _write_file(self, out_dir: Path, name: str, data: list) -> None: diff --git a/messgen/model.py b/messgen/model.py index 75f3544..d2776c1 100644 --- a/messgen/model.py +++ b/messgen/model.py @@ -1,8 +1,19 @@ -from dataclasses import dataclass +import hashlib +import json + +from dataclasses import dataclass, asdict from enum import Enum, auto from typing import Union +def _hash_model_type(dt) -> int: + input_string = json.dumps(asdict(dt)).replace(" ", "") + hash_object = hashlib.md5(input_string.encode()) + hex_digest = hash_object.hexdigest() + hash_32_bits = int(hex_digest[:8], 16) + return hash_32_bits + + class TypeClass(str, Enum): scalar = auto() string = auto() @@ -20,6 +31,9 @@ class BasicType: type_class: TypeClass size: int | None + def __hash__(self): + return _hash_model_type(self) + @dataclass class ArrayType: @@ -29,6 +43,9 @@ class ArrayType: array_size: int size: int | None + def __hash__(self): + return _hash_model_type(self) + @dataclass class VectorType: @@ -37,6 +54,9 @@ class VectorType: element_type: str size: None + def __hash__(self): + return _hash_model_type(self) + @dataclass class MapType: @@ -46,6 +66,9 @@ class MapType: value_type: str size: None + def __hash__(self): + return _hash_model_type(self) + @dataclass class EnumValue: @@ -53,6 +76,9 @@ class EnumValue: value: int comment: str + def __hash__(self): + return _hash_model_type(self) + @dataclass class EnumType: @@ -63,6 +89,9 @@ class EnumType: values: list[EnumValue] size: int + def __hash__(self): + return _hash_model_type(self) + @dataclass class FieldType: @@ -70,6 +99,9 @@ class FieldType: type: str comment: str | None + def __hash__(self): + return _hash_model_type(self) + @dataclass class StructType: @@ -79,6 +111,9 @@ class StructType: fields: list[FieldType] size: int | None + def __hash__(self): + return _hash_model_type(self) + MessgenType = Union[ ArrayType, @@ -90,8 +125,22 @@ class StructType: ] +@dataclass +class Message: + message_id: int + name: str + type: str + comment: str | None + + def __hash__(self): + return _hash_model_type(self) + + @dataclass class Protocol: name: str proto_id: int - types: dict[int, str] + messages: dict[int, Message] + + def __hash__(self): + return _hash_model_type(self) diff --git a/messgen/ts_generator.py b/messgen/ts_generator.py index 3452bb7..c9cd2b4 100644 --- a/messgen/ts_generator.py +++ b/messgen/ts_generator.py @@ -70,23 +70,23 @@ def validate(self, types: dict[str, MessgenType], protocols: dict[str, Protocol] def generate_protocols(self, out_dir: Path, protocols: dict[str, Protocol]) -> None: types = set() code = [] - + for proto_name, proto_def in protocols.items(): code.append(f"export interface {self._to_camel_case(proto_name)} {{") code.append(f" '{proto_name}': {{") - for struct_name in proto_def.types.values(): - ts_struct_name = self._to_camel_case(struct_name) - code.append(f" '{struct_name}': {ts_struct_name};") + for message in proto_def.messages.values(): + ts_struct_name = self._to_camel_case(message.type) + code.append(f" '{message.name}': {message.type};") types.add(ts_struct_name) code.append(' }') code.append('}') code.append('') - + import_statements = self._generate_protocol_imports(types) protocol_types = ' | '.join(self._to_camel_case(proto_name) for proto_name in protocols.keys()) final_code = '\n'.join(import_statements + code) + f'export type Protocol = {protocol_types};' - - + + self._write_output_file(out_dir, self._PROTOCOLS_FILE, final_code) def _generate_protocol_imports(self, types: set[str]) -> list[str]: @@ -105,9 +105,9 @@ def generate_types(self, out_dir: Path, types: dict[str, MessgenType]) -> None: self._generate_struct(type_name, type_def) elif type_def.type_class == TypeClass.enum: self._generate_enum(type_name, type_def) - + code = '\n'.join(self._types) - + self._write_output_file(out_dir, self._TYPES_FILE, code) def _generate_enum(self, enum_name, type_def): @@ -121,7 +121,7 @@ def _generate_enum(self, enum_name, type_def): self._types.append("}") self._types.append("") - + def _generate_struct(self, name: str, type_def: MessgenType): self._types.append(f"export interface {self._to_camel_case(name)} {{") fields = getattr(type_def, 'fields', []) or [] @@ -175,7 +175,7 @@ def _is_typed_array(self, field_type): if typed_array: return typed_array return None - + def _write_output_file(self, output_path, file, content): output_file = os.path.join(output_path, f"{file}") @@ -186,5 +186,5 @@ def _write_output_file(self, output_path, file, content): def _to_camel_case(s: str): name = '_'.join(s.split(SEPARATOR)) return ''.join(word.capitalize() for word in name.split('_')) - + diff --git a/messgen/validation.py b/messgen/validation.py index cb99f30..6701707 100644 --- a/messgen/validation.py +++ b/messgen/validation.py @@ -122,9 +122,24 @@ def validate_protocol(protocol: Protocol, types: dict[str, MessgenType]): - for type_name in protocol.types.values(): - if type_name not in types: - raise RuntimeError(f"Type {type_name} required by {protocol.name} protocol not found") + seen_names = set() + for msg_id, msg in protocol.messages.items(): + if msg.name in seen_names: + raise RuntimeError(f"Message with name={msg.name} appears multiple times in protocol={protocol.name}") + if msg.type not in types: + raise RuntimeError(f"Type {msg.type} required by message={msg.name} protocol={protocol.name} not found") + if msg.message_id != msg_id: + raise RuntimeError(f"Message {msg.name} has different message_id={msg.message_id} than key={msg_id} in protocol={protocol.name}") + seen_names.add(msg.name) + + +def validate_types(types: dict[str, MessgenType]): + seen_hashes: dict[int, Any] = {} + for type_name, type_def in types.items(): + type_hash = hash(type_def) + if hash_conflict := seen_hashes.get(type_hash): + raise RuntimeError(f"Type {type_name} has the same hash as {hash_conflict.type}") + seen_hashes[type_hash] = type_def # Checks if `s` is a valid name for a field or a message type diff --git a/messgen/yaml_parser.py b/messgen/yaml_parser.py index 9cc9d83..f6eb2ed 100644 --- a/messgen/yaml_parser.py +++ b/messgen/yaml_parser.py @@ -12,6 +12,7 @@ EnumValue, FieldType, MapType, + Message, MessgenType, Protocol, StructType, @@ -76,7 +77,14 @@ def _parse_protocol(protocol_file: Path) -> Protocol: def _get_protocol(proto_name, protocol_desc: dict[str, Any]) -> Protocol: return Protocol(name=proto_name, proto_id=int(protocol_desc["proto_id"]), - types=protocol_desc.get("types_map", {})) + messages={msg_id: _get_message_type(msg_id, msg) for msg_id, msg in protocol_desc.get("messages", {}).items()}) + + +def _get_message_type(msg_id: int, message_desc: dict[str, Any]) -> Message: + return Message(message_id=msg_id, + name=message_desc["name"], + type=message_desc["type"], + comment=message_desc.get("comment")) def parse_types(base_dirs: list[str | Path]) -> dict[str, MessgenType]: @@ -85,7 +93,7 @@ def parse_types(base_dirs: list[str | Path]) -> dict[str, MessgenType]: type_descriptors = {} for directory in base_dirs: - base_dir = Path.cwd() / directory + base_dir = Path.cwd() / directory if not isinstance(directory, Path) else directory type_files = base_dir.rglob(f"*{_CONFIG_EXT}") for type_file in type_files: with open(type_file, "r") as f: diff --git a/port/cpp_nostl/messgen/concepts.h b/port/cpp_nostl/messgen/concepts.h index 079f423..ce536c2 100644 --- a/port/cpp_nostl/messgen/concepts.h +++ b/port/cpp_nostl/messgen/concepts.h @@ -9,21 +9,27 @@ namespace messgen { -template -concept serializable = requires(std::remove_cvref_t msg, uint8_t *buf, Allocator &allocator) { +template +concept serializable = requires(std::remove_cvref_t msg, uint8_t *buf, Allocator &allocator) { { msg.serialized_size() } -> std::same_as; { msg.serialize(buf) } -> std::same_as; { msg.deserialize(buf, allocator) } -> std::same_as; }; -template -concept type = serializable && requires(std::remove_cvref_t msg) { +template +concept type = serializable && requires(std::remove_cvref_t msg) { { msg.NAME } -> std::convertible_to; { msg.SCHEMA } -> std::convertible_to; { msg.IS_FLAT } -> std::convertible_to; }; -template -concept flat_type = type && std::remove_cvref_t::IS_FLAT; +template +concept flat_type = type && std::remove_cvref_t::IS_FLAT; + +template +concept message = type::data_type> && requires(std::remove_cvref_t msg) { + { msg.PROTO_ID } -> std::convertible_to; + { msg.MESSAGE_ID } -> std::convertible_to; +}; } // namespace messgen \ No newline at end of file diff --git a/port/cpp_nostl/messgen/messgen.h b/port/cpp_nostl/messgen/messgen.h index 5933327..4aa08e3 100644 --- a/port/cpp_nostl/messgen/messgen.h +++ b/port/cpp_nostl/messgen/messgen.h @@ -29,12 +29,20 @@ struct member { using member_type = std::remove_cvref_t; const char *name; +}; + +template +struct member_variable : member { + using member::name; M C::*ptr; }; +template +member_variable(const char *, M C::*) -> member_variable; + template requires std::same_as, std::remove_cvref_t> -[[nodiscard]] constexpr decltype(auto) value_of(S &&obj, const member &m) noexcept { +[[nodiscard]] constexpr decltype(auto) value_of(S &&obj, const member_variable &m) noexcept { return std::forward(obj).*m.ptr; } diff --git a/port/cpp_stl/messgen/concepts.h b/port/cpp_stl/messgen/concepts.h index 277c279..aa0c134 100644 --- a/port/cpp_stl/messgen/concepts.h +++ b/port/cpp_stl/messgen/concepts.h @@ -7,21 +7,27 @@ namespace messgen { -template -concept serializable = requires(std::remove_cvref_t msg, uint8_t *buf) { +template +concept serializable = requires(std::remove_cvref_t msg, uint8_t *buf) { { msg.serialized_size() } -> std::same_as; { msg.serialize(buf) } -> std::same_as; { msg.deserialize(buf) } -> std::same_as; }; -template -concept type = serializable && requires(std::remove_cvref_t msg) { +template +concept type = serializable && requires(std::remove_cvref_t msg) { { msg.NAME } -> std::convertible_to; { msg.SCHEMA } -> std::convertible_to; { msg.IS_FLAT } -> std::convertible_to; }; -template -concept flat_type = type && std::remove_cvref_t::IS_FLAT; +template +concept flat_type = type && std::remove_cvref_t::IS_FLAT; + +template +concept message = type::data_type> && requires(std::remove_cvref_t msg) { + { msg.PROTO_ID } -> std::convertible_to; + { msg.MESSAGE_ID } -> std::convertible_to; +}; } // namespace messgen \ No newline at end of file diff --git a/port/cpp_stl/messgen/messgen.h b/port/cpp_stl/messgen/messgen.h index d5382b1..7064225 100644 --- a/port/cpp_stl/messgen/messgen.h +++ b/port/cpp_stl/messgen/messgen.h @@ -32,13 +32,20 @@ struct member { using class_type = C; using member_type = std::remove_cvref_t; - std::string_view name; + const char *name; +}; + +template +struct member_variable : member { M C::*ptr; }; +template +member_variable(const char *, M C::*) -> member_variable; + template requires std::same_as, std::remove_cvref_t> -[[nodiscard]] constexpr decltype(auto) value_of(S &&obj, const member &m) noexcept { +[[nodiscard]] constexpr decltype(auto) value_of(S &&obj, const member_variable &m) noexcept { return std::forward(obj).*m.ptr; } diff --git a/tests/cpp/CppNostlTest.cpp b/tests/cpp/CppNostlTest.cpp index d8c4714..2009057 100644 --- a/tests/cpp/CppNostlTest.cpp +++ b/tests/cpp/CppNostlTest.cpp @@ -93,7 +93,7 @@ TEST_F(CppNostlTest, EmptyStruct) { test_serialization(e); } -TEST_F(CppNostlTest, MessageConcept) { +TEST_F(CppNostlTest, TypeConcept) { using namespace messgen; struct not_a_message {}; @@ -103,10 +103,10 @@ TEST_F(CppNostlTest, MessageConcept) { EXPECT_FALSE(type); } -TEST_F(CppNostlTest, FlatMessageConcept) { +TEST_F(CppNostlTest, FlatTypeConcept) { using namespace messgen; EXPECT_TRUE(flat_type); EXPECT_FALSE(flat_type); EXPECT_FALSE(flat_type); -} \ No newline at end of file +} diff --git a/tests/cpp/CppTest.cpp b/tests/cpp/CppTest.cpp index 8982924..7f99c13 100644 --- a/tests/cpp/CppTest.cpp +++ b/tests/cpp/CppTest.cpp @@ -49,15 +49,16 @@ class CppTest : public ::testing::Test { }; TEST_F(CppTest, SimpleStruct) { - messgen::test::simple_struct msg{}; - msg.f0 = 1; - msg.f1 = 2; - msg.f2 = 3; - msg.f3 = 4; - msg.f4 = 5; - msg.f5 = 6; - msg.f6 = 7; - msg.f8 = 9; + test_proto::simple_struct_msg msg{{ + .f0 = 1, + .f1 = 2, + .f2 = 3, + .f3 = 4, + .f4 = 5, + .f5 = 6, + .f6 = 7, + .f8 = 9, + }}; test_serialization(msg); } @@ -263,7 +264,7 @@ TEST_F(CppTest, DispatchMessage) { auto handler = [&](auto &&actual) { using ActualType = std::decay_t; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { EXPECT_EQ(expected.f0, actual.f0); EXPECT_EQ(expected.f1, actual.f1); invoked = true; @@ -271,25 +272,37 @@ TEST_F(CppTest, DispatchMessage) { FAIL() << "Unexpected message type handled."; } }; - test_proto::dispatch_message(test_proto::TYPE_ID, _buf.data(), handler); + + test_proto::dispatch_message(test_proto::simple_struct_msg::MESSAGE_ID, _buf.data(), handler); EXPECT_TRUE(invoked); } -TEST_F(CppTest, MessageConcept) { +TEST_F(CppTest, TypeConcept) { using namespace messgen; struct not_a_message {}; EXPECT_TRUE(type); + EXPECT_TRUE(type); EXPECT_FALSE(type); EXPECT_FALSE(type); } -TEST_F(CppTest, FlatMessageConcept) { +TEST_F(CppTest, FlatTypeConcept) { using namespace messgen; EXPECT_TRUE(flat_type); EXPECT_FALSE(flat_type); EXPECT_FALSE(flat_type); +} + +TEST_F(CppTest, MessageConcept) { + using namespace messgen; + + struct not_a_message {}; + + EXPECT_FALSE(message); + EXPECT_FALSE(message); + EXPECT_TRUE(message); } \ No newline at end of file diff --git a/tests/data/protocols/nested/another_proto.yaml b/tests/data/protocols/nested/another_proto.yaml index fb2f079..87a82b6 100644 --- a/tests/data/protocols/nested/another_proto.yaml +++ b/tests/data/protocols/nested/another_proto.yaml @@ -1,3 +1,4 @@ proto_id: 2 -types_map: - 0: "cross_proto" +messages: + 0: {name: "cross_proto_msg", type: "cross_proto", comment: "Cross proto message"} + diff --git a/tests/data/protocols/test_proto.yaml b/tests/data/protocols/test_proto.yaml index 49394ea..ac17160 100644 --- a/tests/data/protocols/test_proto.yaml +++ b/tests/data/protocols/test_proto.yaml @@ -1,10 +1,10 @@ proto_id: 1 -types_map: - 0: "messgen/test/simple_struct" - 1: "messgen/test/complex_struct" - 2: "messgen/test/var_size_struct" - 3: "messgen/test/struct_with_enum" - 4: "messgen/test/empty_struct" - 5: "messgen/test/complex_struct_with_empty" - 6: "messgen/test/complex_struct_nostl" - 7: "messgen/test/flat_struct" +messages: + 0: {name: "simple_struct_msg", type: "messgen/test/simple_struct", comment: "Simple struct message"} + 1: {name: "complex_struct_msg", type: "messgen/test/complex_struct", comment: "Complex struct message"} + 2: {name: "var_size_struct_msg", type: "messgen/test/var_size_struct", comment: "Variable size struct message"} + 3: {name: "struct_with_enum_msg", type: "messgen/test/struct_with_enum", comment: "Struct with enum message"} + 4: {name: "empty_struct_msg", type: "messgen/test/empty_struct", comment: "Empty struct message"} + 5: {name: "complex_struct_with_empty_msg", type: "messgen/test/complex_struct_with_empty", comment: "Complex struct with empty message"} + 6: {name: "complex_struct_nostl_msg", type: "messgen/test/complex_struct_nostl", comment: "Complex struct without STL message"} + 7: {name: "flat_struct_msg", type: "messgen/test/flat_struct", comment: "Flat struct message"} \ No newline at end of file diff --git a/tests/python/generate_serialized_data.py b/tests/python/generate_serialized_data.py index a7bd62c..1cd75f5 100644 --- a/tests/python/generate_serialized_data.py +++ b/tests/python/generate_serialized_data.py @@ -13,7 +13,7 @@ codec.load(type_dirs=['tests/data/types'], protocols=["tests/data/protocols:test_proto", "tests/data/protocols:nested/another_proto"]) # simple_struct - t = codec.get_type_by_name("test_proto", "messgen/test/simple_struct") + t = codec.type_converter("messgen/test/simple_struct") msg1: dict[str, Any] = { "f0": 0x1234567890abcdef, "f1": 0x1234567890abcdef, @@ -33,7 +33,7 @@ print("Successfully generated serialized data to tests/data/serialized/bin/simple_struct.bin") # var_size_struct - t = codec.get_type_by_name("test_proto", "messgen/test/var_size_struct") + t = codec.type_converter("messgen/test/var_size_struct") msg1 = { "f0": 0x1234567890abcdef, "f1_vec": [-0x1234567890abcdef, 5, 1], @@ -48,7 +48,7 @@ print("Successfully generated serialized data to tests/data/serialized/bin/var_size_struct.bin") # struct_with_enum - t = codec.get_type_by_name("test_proto", "messgen/test/struct_with_enum") + t = codec.type_converter("messgen/test/struct_with_enum") msg1 = { "f0": 0x1234567890abcdef, "f1": 0x1234567890abcdef, @@ -61,7 +61,7 @@ print("Successfully generated serialized data to tests/data/serialized/bin/struct_with_enum.bin") # empty_struct - t = codec.get_type_by_name("test_proto", "messgen/test/empty_struct") + t = codec.type_converter("messgen/test/empty_struct") msg1 = {} b = t.serialize(msg1) with open('tests/data/serialized/bin/empty_struct.bin', 'wb') as f: @@ -69,7 +69,7 @@ print("Successfully generated serialized data to tests/data/serialized/bin/empty_struct.bin") # complex_struct_with_empty - t = codec.get_type_by_name("test_proto", "messgen/test/complex_struct_with_empty") + t = codec.type_converter("messgen/test/complex_struct_with_empty") msg1 = { "e": {}, # empty_struct "dynamic_array": [{} for _ in range(3)], # list of empty_struct, replace 3 with desired length @@ -89,7 +89,7 @@ # complex_struct_nostl - t = codec.get_type_by_name("test_proto", "messgen/test/complex_struct_nostl") + t = codec.type_converter("messgen/test/complex_struct_nostl") simple_struct = { "f0": 0x1234567890abcdef, "f1": 0x1234567890abcdef, @@ -129,7 +129,7 @@ # complex_struct - t = codec.get_type_by_name("test_proto", "messgen/test/complex_struct") + t = codec.type_converter("messgen/test/complex_struct") simple_struct = { "f0": 0x1234567890abcdef, @@ -174,7 +174,7 @@ # flat_struct - t = codec.get_type_by_name("test_proto", "messgen/test/flat_struct") + t = codec.type_converter("messgen/test/flat_struct") msg1 = { "f0": 0x1234567890abcdef, "f1": 0x1234567890abcdef, diff --git a/tests/python/test_invalid_types.py b/tests/python/test_invalid_types.py deleted file mode 100644 index 5ccca06..0000000 --- a/tests/python/test_invalid_types.py +++ /dev/null @@ -1,22 +0,0 @@ -import getpass -import os -import pytest - -from itertools import product -from pathlib import Path - -from messgen import ( - generator, - yaml_parser, -) - -_GENERATOR_LANGS = ["cpp"] -_OUTPUT_DIR = Path("/") / "tmp" / getpass.getuser() / "messgen_tests" -_TYPE_DIRS = list((Path() / "tests" / "data" / "types_invalid").glob("*")) - - -@pytest.mark.parametrize("lang, types_dir", product(["cpp"], _TYPE_DIRS)) -def test_yaml_parser_validation(lang, types_dir): - with pytest.raises(Exception): - types = yaml_parser.parse_types([types_dir]) - generator.get_generator(lang, {}).generate_types(types, _OUTPUT_DIR) diff --git a/tests/python/test_serialization.py b/tests/python/test_serialization.py index afcc395..4f29220 100644 --- a/tests/python/test_serialization.py +++ b/tests/python/test_serialization.py @@ -10,9 +10,9 @@ def codec(): yield codec_ -def test_serialization1(codec): - type_def = codec.get_type_by_name("test_proto", "messgen/test/simple_struct") - expected_msg = { +@pytest.fixture +def simple_struct(): + return { "f0": 0x1234567890abcdef, "f2": 1.2345678901234567890, "f3": 0x12345678, @@ -22,17 +22,21 @@ def test_serialization1(codec): "f8": -0x12, "f9": True, } + + +def test_serialization1(codec, simple_struct): + type_def = codec.type_converter("messgen/test/simple_struct") + expected_msg = simple_struct expected_bytes = type_def.serialize(expected_msg) assert expected_bytes - actual_msg, actual_size = type_def.deserialize(expected_bytes) - assert actual_size == len(expected_bytes) + actual_msg = type_def.deserialize(expected_bytes) for key in expected_msg: assert actual_msg[key] == pytest.approx(expected_msg[key]) def test_serialization2(codec): - type_def = codec.get_type_by_name("test_proto", "messgen/test/var_size_struct") + type_def = codec.type_converter("messgen/test/var_size_struct") expected_msg = { "f0": 0x1234567890abcdef, "f1_vec": [-0x1234567890abcdef, 5, 1], @@ -42,6 +46,29 @@ def test_serialization2(codec): expected_bytes = type_def.serialize(expected_msg) assert expected_bytes - actual_msg, actual_size = type_def.deserialize(expected_bytes) - assert actual_size == len(expected_bytes) + actual_msg = type_def.deserialize(expected_bytes) assert actual_msg == expected_msg + + +def test_protocol_deserialization(codec, simple_struct): + message_info_by_name = codec.message_info_by_name(proto_name="test_proto", message_name="simple_struct_msg") + expected_bytes = message_info_by_name.type_converter().serialize(simple_struct) + assert expected_bytes + + message_info_by_id = codec.message_info_by_id(proto_id=message_info_by_name.proto_id(), message_id=message_info_by_name.message_id()) + actual_msg = message_info_by_id.type_converter().deserialize(expected_bytes) + + assert message_info_by_name.proto_id() == 1 + assert message_info_by_name.message_id() == 0 + assert message_info_by_name.proto_name() == "test_proto" + assert message_info_by_name.message_name() == "simple_struct_msg" + assert message_info_by_name.type_name() == "messgen/test/simple_struct" + + assert message_info_by_name.proto_id() == message_info_by_id.proto_id() + assert message_info_by_name.message_id() == message_info_by_id.message_id() + assert message_info_by_name.proto_name() == message_info_by_id.proto_name() + assert message_info_by_name.message_name() == message_info_by_id.message_name() + assert message_info_by_name.type_name() == message_info_by_id.type_name() + + for key in simple_struct: + assert actual_msg[key] == pytest.approx(simple_struct[key]) diff --git a/tests/python/test_validation.py b/tests/python/test_validation.py new file mode 100644 index 0000000..9a8d654 --- /dev/null +++ b/tests/python/test_validation.py @@ -0,0 +1,92 @@ +import getpass +import os +import pytest + +from itertools import product +from pathlib import Path + +from messgen import ( + generator, + yaml_parser, + model, + validation, +) + + +_OUTPUT_DIR = Path("/") / "tmp" / getpass.getuser() / "messgen_tests" +_TYPE_DIRS = list((Path() / "tests" / "data" / "types_invalid").glob("*")) + + +@pytest.mark.parametrize("lang, types_dir", product(["cpp"], _TYPE_DIRS)) +def test_yaml_parser_validation(lang, types_dir): + with pytest.raises(Exception): + types = yaml_parser.parse_types([types_dir]) + generator.get_generator(lang, {}).generate_types(types, _OUTPUT_DIR) + + +def test_validate_protocol_correct(): + test_type = model.StructType(type="types/test", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + proto = model.Protocol( + name="test", + proto_id=1, + messages={0: model.Message(message_id=0, name="some_msg", type=test_type.type, comment=""), + 1: model.Message(message_id=1, name="other_msg", type=test_type.type, comment="")}, + ) + + validation.validate_protocol(protocol=proto, types={test_type.type: test_type.type}) + + +def test_validate_protocol_id_mismatch(): + test_type = model.StructType(type="types/test", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + proto = model.Protocol( + name="test", + proto_id=1, + messages={0: model.Message(message_id=0, name="some_msg", type=test_type.type, comment=""), + 1: model.Message(message_id=0, name="other_msg", type=test_type.type, comment="")}, + ) + + with pytest.raises(RuntimeError, match="Message other_msg has different message_id=0 than key=1 in protocol=test"): + validation.validate_protocol(protocol=proto, types={test_type.type: test_type.type}) + + +def test_validate_protocol_missing_type(): + test_type = model.StructType(type="types/test", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + proto = model.Protocol( + name="test", + proto_id=1, + messages={0: model.Message(message_id=0, name="some_msg", type="types/missing", comment="")}, + ) + + with pytest.raises(RuntimeError, match="Type types/missing required by message=some_msg protocol=test not found"): + validation.validate_protocol(protocol=proto, types={test_type.type: test_type}) + + +def test_validate_protocol_duplicated_msg_name(): + test_type = model.StructType(type="types/test", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + proto = model.Protocol( + name="test", + proto_id=1, + messages={0: model.Message(message_id=0, name="some_msg", type=test_type.type, comment=""), + 1: model.Message(message_id=1, name="some_msg", type=test_type.type, comment="")}, + ) + + with pytest.raises(RuntimeError, match="Message with name=some_msg appears multiple times in protocol=test"): + validation.validate_protocol(protocol=proto, types={test_type.type: test_type.type}) + + +def test_validate_types_no_conflict(): + type1 = model.StructType(type="types/test1", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + type2 = model.StructType(type="types/test2", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + + try: + validation.validate_types({type1.type: type1, type2.type: type2}) + except RuntimeError: + pytest.fail("validate_types raised RuntimeError unexpectedly!") + + +def test_validate_types_with_conflict(): + type1 = model.StructType(type="types/test1", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + type2 = model.StructType(type="types/test2", type_class=model.TypeClass.struct, comment="", fields=[], size=None) + + with pytest.raises(RuntimeError, match="Type types/test2 has the same hash as types/test1"): + validation.validate_types({type1.type: type1, type2.type: type1})