max message size fix by gwen

This commit is contained in:
Laura Klünder 2023-11-25 18:56:36 +01:00
parent 518219a7a4
commit 3a45af8837
4 changed files with 98 additions and 34 deletions

View file

@ -34,6 +34,16 @@ class BaseFormat(ABC):
def get_min_size(self):
pass
@abstractmethod
def get_max_size(self):
pass
def get_size(self, calculate_max=False):
if calculate_max:
return self.get_max_size()
else:
return self.get_min_size()
@abstractmethod
def get_c_parts(self) -> tuple[str, str]:
pass
@ -74,6 +84,9 @@ class SimpleFormat(BaseFormat):
def get_min_size(self):
return self.size
def get_max_size(self):
return self.size
c_types = {
"B": "uint8_t",
"H": "uint16_t",
@ -110,7 +123,7 @@ class EnumFormat(SimpleFormat):
def set_field_type(self, field_type):
super().set_field_type(field_type)
self.c_struct_name = normalize_name(field_type.__name__)+'_t'
self.c_struct_name = normalize_name(field_type.__name__) + '_t'
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
@ -134,7 +147,7 @@ class EnumFormat(SimpleFormat):
options = []
last_value = None
for item in self.field_type:
if last_value is not None and item.value != last_value+1:
if last_value is not None and item.value != last_value + 1:
options.append('')
last_value = item.value
options.append("%(prefix)s_%(name)s = %(value)s," % {
@ -158,13 +171,13 @@ class TwoNibblesEnumFormat(SimpleFormat):
def decode(self, data: bytes) -> tuple[bool, bytes]:
fields = dataclass_fields(self.field_type)
value, data = super().decode(data)
return self.field_type(fields[0].type(value//2**4), fields[1].type(value//2**4)), data
return self.field_type(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data
def encode(self, value):
fields = dataclass_fields(self.field_type)
return super().encode(
getattr(value, fields[0].name).value * 2**4 +
getattr(value, fields[1].name).value * 2**4
getattr(value, fields[0].name).value * 2 ** 4 +
getattr(value, fields[1].name).value * 2 ** 4
)
def fromjson(self, data):
@ -187,7 +200,7 @@ class ChipRevFormat(SimpleFormat):
return (value // 100, value % 100), data
def encode(self, value):
return value[0]*100 + value[1]
return value[0] * 100 + value[1]
class BoolFormat(SimpleFormat):
@ -231,20 +244,24 @@ class FixedHexFormat(SimpleFormat):
@abstractmethod
class BaseVarFormat(BaseFormat, ABC):
def __init__(self, num_fmt='B'):
def __init__(self, max_num, num_fmt='B'):
self.num_fmt = num_fmt
self.num_size = struct.calcsize(self.num_fmt)
self.max_num = max_num
def get_min_size(self):
return self.num_size
def get_max_size(self):
return self.num_size + self.max_num * self.get_var_num()
def get_num_c_code(self):
return SimpleFormat(self.num_fmt).get_c_code("num")
class VarArrayFormat(BaseVarFormat):
def __init__(self, child_type, num_fmt='B'):
super().__init__(num_fmt=num_fmt)
def __init__(self, child_type, max_num, num_fmt='B'):
super().__init__(num_fmt=num_fmt, max_num=max_num)
self.child_type = child_type
self.child_size = self.child_type.get_min_size()
@ -253,13 +270,18 @@ class VarArrayFormat(BaseVarFormat):
pass
def encode(self, values: Sequence) -> bytes:
data = struct.pack(self.num_fmt, len(values))
num = len(values)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
data = struct.pack(self.num_fmt, num)
for value in values:
data += self.child_type.encode(value)
return data
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
data = data[self.num_size:]
result = []
for i in range(num):
@ -283,14 +305,22 @@ class VarArrayFormat(BaseVarFormat):
class VarStrFormat(BaseVarFormat):
def __init__(self, max_len):
super().__init__(max_num=max_len)
def get_var_num(self):
return 1
def encode(self, value: str) -> bytes:
return struct.pack(self.num_fmt, len(value)) + value.encode()
num = len(value)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return struct.pack(self.num_fmt, num) + value.encode()
def decode(self, data: bytes) -> tuple[str, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
def get_c_parts(self):
@ -298,14 +328,22 @@ class VarStrFormat(BaseVarFormat):
class VarBytesFormat(BaseVarFormat):
def __init__(self, max_size):
super().__init__(max_num=max_size)
def get_var_num(self):
return 1
def encode(self, value: bytes) -> bytes:
return struct.pack(self.num_fmt, len(value)) + value
num = len(value)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return struct.pack(self.num_fmt, num) + value
def decode(self, data: bytes) -> tuple[bytes, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:]
def get_c_parts(self):
@ -380,7 +418,7 @@ class StructType:
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, name))
cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields)
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
super().__init_subclass__(**kwargs)
@classmethod
@ -722,6 +760,22 @@ class StructType:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)
@classmethod
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
if cls.union_type_field:
own_size = sum(
[f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
union_size = max(
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
cls._union_options[cls.union_type_field].values()])
return own_size + union_size
if no_inherited_fields:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields),
start=0)
def normalize_name(name):
if '_' in name:

View file

@ -14,8 +14,8 @@ class MacAddressFormat(FixedHexFormat):
class MacAddressesListFormat(VarArrayFormat):
def __init__(self):
super().__init__(child_type=MacAddressFormat())
def __init__(self, max_num):
super().__init__(child_type=MacAddressFormat(), max_num=max_num)
@unique

View file

@ -15,6 +15,7 @@ class Command(BaseCommand):
nodata = set()
struct_lines = {}
struct_sizes = []
struct_max_sizes = []
done_definitions = set()
ignore_names = set(field_.name for field_ in fields(MeshMessage))
@ -38,9 +39,11 @@ class Command(BaseCommand):
code = msg_class.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
if code:
size = msg_class.get_min_size(no_inherited_fields=True)
size = msg_class.get_size(no_inherited_fields=True, calculate_max=False)
max_size = msg_class.get_size(no_inherited_fields=True, calculate_max=True)
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
struct_sizes.append(size)
struct_max_sizes.append(max_size)
print(code)
print("static_assert(sizeof(%s) == %d, \"size of generated message structs is calculated wrong\");" %
(name, size))
@ -58,6 +61,9 @@ class Command(BaseCommand):
% max(struct_sizes)
)
print()
print('#define MESH_MSG_MAX_LENGTH (%d)' % max(struct_max_sizes))
print()
max_msg_type = max(MeshMessage.get_types().keys())
@ -69,15 +75,17 @@ class Command(BaseCommand):
getattr(msg_class.msg_type, 'name', msg_class.__name__)
))
macro_data.append((
msg_class.get_c_enum_name()+',',
("nodata" if msg_class in nodata else name)+',',
msg_class.get_c_enum_name(),
("nodata" if msg_class in nodata else name),
msg_class.get_var_num(),
msg_class.get_size(no_inherited_fields=True, calculate_max=True),
msg_class.__doc__.strip(),
))
else:
macro_data.append((
"RESERVED_%02X," % i,
"nodata,",
"RESERVED_%02X" % i,
"nodata",
0,
0,
"",
))
@ -85,13 +93,15 @@ class Command(BaseCommand):
max0 = max(len(d[0]) for d in macro_data)
max1 = max(len(d[1]) for d in macro_data)
max2 = max(len(str(d[2])) for d in macro_data)
max3 = max(len(str(d[3])) for d in macro_data)
lines = []
for i, (macro_name, struct_name, num_len, comment) in enumerate(macro_data):
for i, (macro_name, struct_name, num_len, max_len, comment) in enumerate(macro_data):
lines.append(indent_c(
"FN(%s %s %s) /** 0x%02X %s*/" % (
macro_name.ljust(max0),
struct_name.ljust(max1),
str(num_len).rjust(max2),
"FN(%s %s %s %s) /** 0x%02X %s*/" % (
f'{macro_name},'.ljust(max0+1),
f'{struct_name},'.ljust(max1+1),
f'{num_len},'.rjust(max2+1),
f'{max_len}'.rjust(max3),
i,
comment+(" " if comment else ""),
)

View file

@ -136,7 +136,7 @@ class NoopMessage(MeshMessage, msg_type=MeshMessageType.NOOP):
class BaseEchoMessage(MeshMessage, c_struct_name="echo"):
""" repeat back string """
content: str = field(default='', metadata={
"format": VarStrFormat(),
"format": VarStrFormat(max_len=255),
"doc": "string to echo",
"c_name": "str",
})
@ -173,7 +173,7 @@ class MeshLayerAnnounceMessage(MeshMessage, msg_type=MeshMessageType.MESH_LAYER_
class BaseDestinationsMessage(MeshMessage, c_struct_name="destinations"):
""" downstream node announces served/no longer served destination """
addresses: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(),
"format": MacAddressesListFormat(max_num=16),
"doc": "adresses of the destinations",
"c_name": "addresses",
})
@ -216,7 +216,7 @@ class MeshRouteTraceMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_TRA
""" special message, collects all hop adresses on its way """
request_id: int = field(metadata={"format": SimpleFormat('I')})
trace: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(),
"format": MacAddressesListFormat(max_num=16),
"doc": "addresses encountered by this message",
})
@ -311,7 +311,7 @@ class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL):
""" supply download URL for OTA update and who to distribute it to """
update_id: int = field(metadata={"format": SimpleFormat('I')})
distribute_to: str = field(metadata={"format": MacAddressFormat()})
url: str = field(metadata={"format": VarStrFormat()})
url: str = field(metadata={"format": VarStrFormat(max_len=255)})
@dataclass
@ -319,14 +319,14 @@ class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT):
""" supply OTA fragment """
update_id: int = field(metadata={"format": SimpleFormat('I')})
chunk: int = field(metadata={"format": SimpleFormat('H')})
data: str = field(metadata={"format": VarBytesFormat()})
data: str = field(metadata={"format": VarBytesFormat(max_size=512)})
@dataclass
class OTARequestFragmentsMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS):
""" request missing fragments """
update_id: int = field(metadata={"format": SimpleFormat('I')})
chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'))})
chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'), max_num=32)})
@dataclass
@ -359,14 +359,14 @@ class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQ
@dataclass
class LocateRangeResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RANGE_RESULTS):
""" reports distance to given nodes """
ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem)})
ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem, max_num=16)})
@dataclass
class LocateRawFTMResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS):
""" reports distance to given nodes """
peer: str = field(metadata={"format": MacAddressFormat()})
results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry)})
results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry, max_num=16)})
@dataclass
@ -378,4 +378,4 @@ class Reboot(MeshMessage, msg_type=MeshMessageType.REBOOT):
@dataclass
class ReportError(MeshMessage, msg_type=MeshMessageType.REPORT_ERROR):
""" report a critical error to upstream """
message: str = field(metadata={"format": VarStrFormat()})
message: str = field(metadata={"format": VarStrFormat(max_len=255)})