From 3a45af8837dea67e5de869a3cc54c11e55689fe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sat, 25 Nov 2023 18:56:36 +0100 Subject: [PATCH] max message size fix by gwen --- src/c3nav/mesh/baseformats.py | 80 ++++++++++++++++--- src/c3nav/mesh/dataformats.py | 4 +- .../management/commands/generate_c_types.py | 30 ++++--- src/c3nav/mesh/messages.py | 18 ++--- 4 files changed, 98 insertions(+), 34 deletions(-) diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py index 879ca138..2482b338 100644 --- a/src/c3nav/mesh/baseformats.py +++ b/src/c3nav/mesh/baseformats.py @@ -34,6 +34,16 @@ class BaseFormat(ABC): def get_min_size(self): pass + @abstractmethod + def get_max_size(self): + pass + + def get_size(self, calculate_max=False): + if calculate_max: + return self.get_max_size() + else: + return self.get_min_size() + @abstractmethod def get_c_parts(self) -> tuple[str, str]: pass @@ -74,6 +84,9 @@ class SimpleFormat(BaseFormat): def get_min_size(self): return self.size + def get_max_size(self): + return self.size + c_types = { "B": "uint8_t", "H": "uint16_t", @@ -110,7 +123,7 @@ class EnumFormat(SimpleFormat): def set_field_type(self, field_type): super().set_field_type(field_type) - self.c_struct_name = normalize_name(field_type.__name__)+'_t' + self.c_struct_name = normalize_name(field_type.__name__) + '_t' def decode(self, data: bytes) -> tuple[Any, bytes]: value, out_data = super().decode(data) @@ -134,7 +147,7 @@ class EnumFormat(SimpleFormat): options = [] last_value = None for item in self.field_type: - if last_value is not None and item.value != last_value+1: + if last_value is not None and item.value != last_value + 1: options.append('') last_value = item.value options.append("%(prefix)s_%(name)s = %(value)s," % { @@ -158,13 +171,13 @@ class TwoNibblesEnumFormat(SimpleFormat): def decode(self, data: bytes) -> tuple[bool, bytes]: fields = dataclass_fields(self.field_type) value, data = super().decode(data) - return self.field_type(fields[0].type(value//2**4), fields[1].type(value//2**4)), data + return self.field_type(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data def encode(self, value): fields = dataclass_fields(self.field_type) return super().encode( - getattr(value, fields[0].name).value * 2**4 + - getattr(value, fields[1].name).value * 2**4 + getattr(value, fields[0].name).value * 2 ** 4 + + getattr(value, fields[1].name).value * 2 ** 4 ) def fromjson(self, data): @@ -187,7 +200,7 @@ class ChipRevFormat(SimpleFormat): return (value // 100, value % 100), data def encode(self, value): - return value[0]*100 + value[1] + return value[0] * 100 + value[1] class BoolFormat(SimpleFormat): @@ -231,20 +244,24 @@ class FixedHexFormat(SimpleFormat): @abstractmethod class BaseVarFormat(BaseFormat, ABC): - def __init__(self, num_fmt='B'): + def __init__(self, max_num, num_fmt='B'): self.num_fmt = num_fmt self.num_size = struct.calcsize(self.num_fmt) + self.max_num = max_num def get_min_size(self): return self.num_size + def get_max_size(self): + return self.num_size + self.max_num * self.get_var_num() + def get_num_c_code(self): return SimpleFormat(self.num_fmt).get_c_code("num") class VarArrayFormat(BaseVarFormat): - def __init__(self, child_type, num_fmt='B'): - super().__init__(num_fmt=num_fmt) + def __init__(self, child_type, max_num, num_fmt='B'): + super().__init__(num_fmt=num_fmt, max_num=max_num) self.child_type = child_type self.child_size = self.child_type.get_min_size() @@ -253,13 +270,18 @@ class VarArrayFormat(BaseVarFormat): pass def encode(self, values: Sequence) -> bytes: - data = struct.pack(self.num_fmt, len(values)) + num = len(values) + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + data = struct.pack(self.num_fmt, num) for value in values: data += self.child_type.encode(value) return data def decode(self, data: bytes) -> tuple[list[Any], bytes]: num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') data = data[self.num_size:] result = [] for i in range(num): @@ -283,14 +305,22 @@ class VarArrayFormat(BaseVarFormat): class VarStrFormat(BaseVarFormat): + def __init__(self, max_len): + super().__init__(max_num=max_len) + def get_var_num(self): return 1 def encode(self, value: str) -> bytes: - return struct.pack(self.num_fmt, len(value)) + value.encode() + num = len(value) + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + return struct.pack(self.num_fmt, num) + value.encode() def decode(self, data: bytes) -> tuple[str, bytes]: num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:] def get_c_parts(self): @@ -298,14 +328,22 @@ class VarStrFormat(BaseVarFormat): class VarBytesFormat(BaseVarFormat): + def __init__(self, max_size): + super().__init__(max_num=max_size) + def get_var_num(self): return 1 def encode(self, value: bytes) -> bytes: - return struct.pack(self.num_fmt, len(value)) + value + num = len(value) + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + return struct.pack(self.num_fmt, num) + value def decode(self, data: bytes) -> tuple[bytes, bytes]: num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:] def get_c_parts(self): @@ -380,7 +418,7 @@ class StructType: else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__class__.__name__, name)) - cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields) + cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields) super().__init_subclass__(**kwargs) @classmethod @@ -722,6 +760,22 @@ class StructType: relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0) + @classmethod + def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int: + if cls.union_type_field: + own_size = sum( + [f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)]) + union_size = max( + [0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in + cls._union_options[cls.union_type_field].values()]) + return own_size + union_size + if no_inherited_fields: + relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] + else: + relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] + return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields), + start=0) + def normalize_name(name): if '_' in name: diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index f27b4ed1..7b525caa 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -14,8 +14,8 @@ class MacAddressFormat(FixedHexFormat): class MacAddressesListFormat(VarArrayFormat): - def __init__(self): - super().__init__(child_type=MacAddressFormat()) + def __init__(self, max_num): + super().__init__(child_type=MacAddressFormat(), max_num=max_num) @unique diff --git a/src/c3nav/mesh/management/commands/generate_c_types.py b/src/c3nav/mesh/management/commands/generate_c_types.py index 0b334b70..9d36b694 100644 --- a/src/c3nav/mesh/management/commands/generate_c_types.py +++ b/src/c3nav/mesh/management/commands/generate_c_types.py @@ -15,6 +15,7 @@ class Command(BaseCommand): nodata = set() struct_lines = {} struct_sizes = [] + struct_max_sizes = [] done_definitions = set() ignore_names = set(field_.name for field_ in fields(MeshMessage)) @@ -38,9 +39,11 @@ class Command(BaseCommand): code = msg_class.get_c_code(name, ignore_fields=ignore_names, no_empty=True) if code: - size = msg_class.get_min_size(no_inherited_fields=True) + size = msg_class.get_size(no_inherited_fields=True, calculate_max=False) + max_size = msg_class.get_size(no_inherited_fields=True, calculate_max=True) struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', '')) struct_sizes.append(size) + struct_max_sizes.append(max_size) print(code) print("static_assert(sizeof(%s) == %d, \"size of generated message structs is calculated wrong\");" % (name, size)) @@ -58,6 +61,9 @@ class Command(BaseCommand): % max(struct_sizes) ) + print() + print('#define MESH_MSG_MAX_LENGTH (%d)' % max(struct_max_sizes)) + print() max_msg_type = max(MeshMessage.get_types().keys()) @@ -69,15 +75,17 @@ class Command(BaseCommand): getattr(msg_class.msg_type, 'name', msg_class.__name__) )) macro_data.append(( - msg_class.get_c_enum_name()+',', - ("nodata" if msg_class in nodata else name)+',', + msg_class.get_c_enum_name(), + ("nodata" if msg_class in nodata else name), msg_class.get_var_num(), + msg_class.get_size(no_inherited_fields=True, calculate_max=True), msg_class.__doc__.strip(), )) else: macro_data.append(( - "RESERVED_%02X," % i, - "nodata,", + "RESERVED_%02X" % i, + "nodata", + 0, 0, "", )) @@ -85,13 +93,15 @@ class Command(BaseCommand): max0 = max(len(d[0]) for d in macro_data) max1 = max(len(d[1]) for d in macro_data) max2 = max(len(str(d[2])) for d in macro_data) + max3 = max(len(str(d[3])) for d in macro_data) lines = [] - for i, (macro_name, struct_name, num_len, comment) in enumerate(macro_data): + for i, (macro_name, struct_name, num_len, max_len, comment) in enumerate(macro_data): lines.append(indent_c( - "FN(%s %s %s) /** 0x%02X %s*/" % ( - macro_name.ljust(max0), - struct_name.ljust(max1), - str(num_len).rjust(max2), + "FN(%s %s %s %s) /** 0x%02X %s*/" % ( + f'{macro_name},'.ljust(max0+1), + f'{struct_name},'.ljust(max1+1), + f'{num_len},'.rjust(max2+1), + f'{max_len}'.rjust(max3), i, comment+(" " if comment else ""), ) diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index ab4a934c..4d1a73b0 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -136,7 +136,7 @@ class NoopMessage(MeshMessage, msg_type=MeshMessageType.NOOP): class BaseEchoMessage(MeshMessage, c_struct_name="echo"): """ repeat back string """ content: str = field(default='', metadata={ - "format": VarStrFormat(), + "format": VarStrFormat(max_len=255), "doc": "string to echo", "c_name": "str", }) @@ -173,7 +173,7 @@ class MeshLayerAnnounceMessage(MeshMessage, msg_type=MeshMessageType.MESH_LAYER_ class BaseDestinationsMessage(MeshMessage, c_struct_name="destinations"): """ downstream node announces served/no longer served destination """ addresses: list[str] = field(default_factory=list, metadata={ - "format": MacAddressesListFormat(), + "format": MacAddressesListFormat(max_num=16), "doc": "adresses of the destinations", "c_name": "addresses", }) @@ -216,7 +216,7 @@ class MeshRouteTraceMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_TRA """ special message, collects all hop adresses on its way """ request_id: int = field(metadata={"format": SimpleFormat('I')}) trace: list[str] = field(default_factory=list, metadata={ - "format": MacAddressesListFormat(), + "format": MacAddressesListFormat(max_num=16), "doc": "addresses encountered by this message", }) @@ -311,7 +311,7 @@ class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL): """ supply download URL for OTA update and who to distribute it to """ update_id: int = field(metadata={"format": SimpleFormat('I')}) distribute_to: str = field(metadata={"format": MacAddressFormat()}) - url: str = field(metadata={"format": VarStrFormat()}) + url: str = field(metadata={"format": VarStrFormat(max_len=255)}) @dataclass @@ -319,14 +319,14 @@ class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT): """ supply OTA fragment """ update_id: int = field(metadata={"format": SimpleFormat('I')}) chunk: int = field(metadata={"format": SimpleFormat('H')}) - data: str = field(metadata={"format": VarBytesFormat()}) + data: str = field(metadata={"format": VarBytesFormat(max_size=512)}) @dataclass class OTARequestFragmentsMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS): """ request missing fragments """ update_id: int = field(metadata={"format": SimpleFormat('I')}) - chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'))}) + chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'), max_num=32)}) @dataclass @@ -359,14 +359,14 @@ class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQ @dataclass class LocateRangeResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RANGE_RESULTS): """ reports distance to given nodes """ - ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem)}) + ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem, max_num=16)}) @dataclass class LocateRawFTMResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS): """ reports distance to given nodes """ peer: str = field(metadata={"format": MacAddressFormat()}) - results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry)}) + results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry, max_num=16)}) @dataclass @@ -378,4 +378,4 @@ class Reboot(MeshMessage, msg_type=MeshMessageType.REBOOT): @dataclass class ReportError(MeshMessage, msg_type=MeshMessageType.REPORT_ERROR): """ report a critical error to upstream """ - message: str = field(metadata={"format": VarStrFormat()}) + message: str = field(metadata={"format": VarStrFormat(max_len=255)})