max message size fix by gwen
This commit is contained in:
parent
518219a7a4
commit
3a45af8837
4 changed files with 98 additions and 34 deletions
|
@ -34,6 +34,16 @@ class BaseFormat(ABC):
|
|||
def get_min_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_size(self):
|
||||
pass
|
||||
|
||||
def get_size(self, calculate_max=False):
|
||||
if calculate_max:
|
||||
return self.get_max_size()
|
||||
else:
|
||||
return self.get_min_size()
|
||||
|
||||
@abstractmethod
|
||||
def get_c_parts(self) -> tuple[str, str]:
|
||||
pass
|
||||
|
@ -74,6 +84,9 @@ class SimpleFormat(BaseFormat):
|
|||
def get_min_size(self):
|
||||
return self.size
|
||||
|
||||
def get_max_size(self):
|
||||
return self.size
|
||||
|
||||
c_types = {
|
||||
"B": "uint8_t",
|
||||
"H": "uint16_t",
|
||||
|
@ -110,7 +123,7 @@ class EnumFormat(SimpleFormat):
|
|||
|
||||
def set_field_type(self, field_type):
|
||||
super().set_field_type(field_type)
|
||||
self.c_struct_name = normalize_name(field_type.__name__)+'_t'
|
||||
self.c_struct_name = normalize_name(field_type.__name__) + '_t'
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value, out_data = super().decode(data)
|
||||
|
@ -134,7 +147,7 @@ class EnumFormat(SimpleFormat):
|
|||
options = []
|
||||
last_value = None
|
||||
for item in self.field_type:
|
||||
if last_value is not None and item.value != last_value+1:
|
||||
if last_value is not None and item.value != last_value + 1:
|
||||
options.append('')
|
||||
last_value = item.value
|
||||
options.append("%(prefix)s_%(name)s = %(value)s," % {
|
||||
|
@ -158,13 +171,13 @@ class TwoNibblesEnumFormat(SimpleFormat):
|
|||
def decode(self, data: bytes) -> tuple[bool, bytes]:
|
||||
fields = dataclass_fields(self.field_type)
|
||||
value, data = super().decode(data)
|
||||
return self.field_type(fields[0].type(value//2**4), fields[1].type(value//2**4)), data
|
||||
return self.field_type(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data
|
||||
|
||||
def encode(self, value):
|
||||
fields = dataclass_fields(self.field_type)
|
||||
return super().encode(
|
||||
getattr(value, fields[0].name).value * 2**4 +
|
||||
getattr(value, fields[1].name).value * 2**4
|
||||
getattr(value, fields[0].name).value * 2 ** 4 +
|
||||
getattr(value, fields[1].name).value * 2 ** 4
|
||||
)
|
||||
|
||||
def fromjson(self, data):
|
||||
|
@ -187,7 +200,7 @@ class ChipRevFormat(SimpleFormat):
|
|||
return (value // 100, value % 100), data
|
||||
|
||||
def encode(self, value):
|
||||
return value[0]*100 + value[1]
|
||||
return value[0] * 100 + value[1]
|
||||
|
||||
|
||||
class BoolFormat(SimpleFormat):
|
||||
|
@ -231,20 +244,24 @@ class FixedHexFormat(SimpleFormat):
|
|||
|
||||
@abstractmethod
|
||||
class BaseVarFormat(BaseFormat, ABC):
|
||||
def __init__(self, num_fmt='B'):
|
||||
def __init__(self, max_num, num_fmt='B'):
|
||||
self.num_fmt = num_fmt
|
||||
self.num_size = struct.calcsize(self.num_fmt)
|
||||
self.max_num = max_num
|
||||
|
||||
def get_min_size(self):
|
||||
return self.num_size
|
||||
|
||||
def get_max_size(self):
|
||||
return self.num_size + self.max_num * self.get_var_num()
|
||||
|
||||
def get_num_c_code(self):
|
||||
return SimpleFormat(self.num_fmt).get_c_code("num")
|
||||
|
||||
|
||||
class VarArrayFormat(BaseVarFormat):
|
||||
def __init__(self, child_type, num_fmt='B'):
|
||||
super().__init__(num_fmt=num_fmt)
|
||||
def __init__(self, child_type, max_num, num_fmt='B'):
|
||||
super().__init__(num_fmt=num_fmt, max_num=max_num)
|
||||
self.child_type = child_type
|
||||
self.child_size = self.child_type.get_min_size()
|
||||
|
||||
|
@ -253,13 +270,18 @@ class VarArrayFormat(BaseVarFormat):
|
|||
pass
|
||||
|
||||
def encode(self, values: Sequence) -> bytes:
|
||||
data = struct.pack(self.num_fmt, len(values))
|
||||
num = len(values)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
data = struct.pack(self.num_fmt, num)
|
||||
for value in values:
|
||||
data += self.child_type.encode(value)
|
||||
return data
|
||||
|
||||
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
data = data[self.num_size:]
|
||||
result = []
|
||||
for i in range(num):
|
||||
|
@ -283,14 +305,22 @@ class VarArrayFormat(BaseVarFormat):
|
|||
|
||||
|
||||
class VarStrFormat(BaseVarFormat):
|
||||
def __init__(self, max_len):
|
||||
super().__init__(max_num=max_len)
|
||||
|
||||
def get_var_num(self):
|
||||
return 1
|
||||
|
||||
def encode(self, value: str) -> bytes:
|
||||
return struct.pack(self.num_fmt, len(value)) + value.encode()
|
||||
num = len(value)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return struct.pack(self.num_fmt, num) + value.encode()
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
|
||||
|
||||
def get_c_parts(self):
|
||||
|
@ -298,14 +328,22 @@ class VarStrFormat(BaseVarFormat):
|
|||
|
||||
|
||||
class VarBytesFormat(BaseVarFormat):
|
||||
def __init__(self, max_size):
|
||||
super().__init__(max_num=max_size)
|
||||
|
||||
def get_var_num(self):
|
||||
return 1
|
||||
|
||||
def encode(self, value: bytes) -> bytes:
|
||||
return struct.pack(self.num_fmt, len(value)) + value
|
||||
num = len(value)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return struct.pack(self.num_fmt, num) + value
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bytes, bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:]
|
||||
|
||||
def get_c_parts(self):
|
||||
|
@ -380,7 +418,7 @@ class StructType:
|
|||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__class__.__name__, name))
|
||||
cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields)
|
||||
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
|
@ -722,6 +760,22 @@ class StructType:
|
|||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)
|
||||
|
||||
@classmethod
|
||||
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
|
||||
if cls.union_type_field:
|
||||
own_size = sum(
|
||||
[f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
|
||||
union_size = max(
|
||||
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
|
||||
cls._union_options[cls.union_type_field].values()])
|
||||
return own_size + union_size
|
||||
if no_inherited_fields:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||
else:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||
return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields),
|
||||
start=0)
|
||||
|
||||
|
||||
def normalize_name(name):
|
||||
if '_' in name:
|
||||
|
|
|
@ -14,8 +14,8 @@ class MacAddressFormat(FixedHexFormat):
|
|||
|
||||
|
||||
class MacAddressesListFormat(VarArrayFormat):
|
||||
def __init__(self):
|
||||
super().__init__(child_type=MacAddressFormat())
|
||||
def __init__(self, max_num):
|
||||
super().__init__(child_type=MacAddressFormat(), max_num=max_num)
|
||||
|
||||
|
||||
@unique
|
||||
|
|
|
@ -15,6 +15,7 @@ class Command(BaseCommand):
|
|||
nodata = set()
|
||||
struct_lines = {}
|
||||
struct_sizes = []
|
||||
struct_max_sizes = []
|
||||
done_definitions = set()
|
||||
|
||||
ignore_names = set(field_.name for field_ in fields(MeshMessage))
|
||||
|
@ -38,9 +39,11 @@ class Command(BaseCommand):
|
|||
|
||||
code = msg_class.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
|
||||
if code:
|
||||
size = msg_class.get_min_size(no_inherited_fields=True)
|
||||
size = msg_class.get_size(no_inherited_fields=True, calculate_max=False)
|
||||
max_size = msg_class.get_size(no_inherited_fields=True, calculate_max=True)
|
||||
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
|
||||
struct_sizes.append(size)
|
||||
struct_max_sizes.append(max_size)
|
||||
print(code)
|
||||
print("static_assert(sizeof(%s) == %d, \"size of generated message structs is calculated wrong\");" %
|
||||
(name, size))
|
||||
|
@ -58,6 +61,9 @@ class Command(BaseCommand):
|
|||
% max(struct_sizes)
|
||||
)
|
||||
|
||||
print()
|
||||
print('#define MESH_MSG_MAX_LENGTH (%d)' % max(struct_max_sizes))
|
||||
|
||||
print()
|
||||
|
||||
max_msg_type = max(MeshMessage.get_types().keys())
|
||||
|
@ -69,15 +75,17 @@ class Command(BaseCommand):
|
|||
getattr(msg_class.msg_type, 'name', msg_class.__name__)
|
||||
))
|
||||
macro_data.append((
|
||||
msg_class.get_c_enum_name()+',',
|
||||
("nodata" if msg_class in nodata else name)+',',
|
||||
msg_class.get_c_enum_name(),
|
||||
("nodata" if msg_class in nodata else name),
|
||||
msg_class.get_var_num(),
|
||||
msg_class.get_size(no_inherited_fields=True, calculate_max=True),
|
||||
msg_class.__doc__.strip(),
|
||||
))
|
||||
else:
|
||||
macro_data.append((
|
||||
"RESERVED_%02X," % i,
|
||||
"nodata,",
|
||||
"RESERVED_%02X" % i,
|
||||
"nodata",
|
||||
0,
|
||||
0,
|
||||
"",
|
||||
))
|
||||
|
@ -85,13 +93,15 @@ class Command(BaseCommand):
|
|||
max0 = max(len(d[0]) for d in macro_data)
|
||||
max1 = max(len(d[1]) for d in macro_data)
|
||||
max2 = max(len(str(d[2])) for d in macro_data)
|
||||
max3 = max(len(str(d[3])) for d in macro_data)
|
||||
lines = []
|
||||
for i, (macro_name, struct_name, num_len, comment) in enumerate(macro_data):
|
||||
for i, (macro_name, struct_name, num_len, max_len, comment) in enumerate(macro_data):
|
||||
lines.append(indent_c(
|
||||
"FN(%s %s %s) /** 0x%02X %s*/" % (
|
||||
macro_name.ljust(max0),
|
||||
struct_name.ljust(max1),
|
||||
str(num_len).rjust(max2),
|
||||
"FN(%s %s %s %s) /** 0x%02X %s*/" % (
|
||||
f'{macro_name},'.ljust(max0+1),
|
||||
f'{struct_name},'.ljust(max1+1),
|
||||
f'{num_len},'.rjust(max2+1),
|
||||
f'{max_len}'.rjust(max3),
|
||||
i,
|
||||
comment+(" " if comment else ""),
|
||||
)
|
||||
|
|
|
@ -136,7 +136,7 @@ class NoopMessage(MeshMessage, msg_type=MeshMessageType.NOOP):
|
|||
class BaseEchoMessage(MeshMessage, c_struct_name="echo"):
|
||||
""" repeat back string """
|
||||
content: str = field(default='', metadata={
|
||||
"format": VarStrFormat(),
|
||||
"format": VarStrFormat(max_len=255),
|
||||
"doc": "string to echo",
|
||||
"c_name": "str",
|
||||
})
|
||||
|
@ -173,7 +173,7 @@ class MeshLayerAnnounceMessage(MeshMessage, msg_type=MeshMessageType.MESH_LAYER_
|
|||
class BaseDestinationsMessage(MeshMessage, c_struct_name="destinations"):
|
||||
""" downstream node announces served/no longer served destination """
|
||||
addresses: list[str] = field(default_factory=list, metadata={
|
||||
"format": MacAddressesListFormat(),
|
||||
"format": MacAddressesListFormat(max_num=16),
|
||||
"doc": "adresses of the destinations",
|
||||
"c_name": "addresses",
|
||||
})
|
||||
|
@ -216,7 +216,7 @@ class MeshRouteTraceMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_TRA
|
|||
""" special message, collects all hop adresses on its way """
|
||||
request_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
trace: list[str] = field(default_factory=list, metadata={
|
||||
"format": MacAddressesListFormat(),
|
||||
"format": MacAddressesListFormat(max_num=16),
|
||||
"doc": "addresses encountered by this message",
|
||||
})
|
||||
|
||||
|
@ -311,7 +311,7 @@ class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL):
|
|||
""" supply download URL for OTA update and who to distribute it to """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
distribute_to: str = field(metadata={"format": MacAddressFormat()})
|
||||
url: str = field(metadata={"format": VarStrFormat()})
|
||||
url: str = field(metadata={"format": VarStrFormat(max_len=255)})
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -319,14 +319,14 @@ class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT):
|
|||
""" supply OTA fragment """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
chunk: int = field(metadata={"format": SimpleFormat('H')})
|
||||
data: str = field(metadata={"format": VarBytesFormat()})
|
||||
data: str = field(metadata={"format": VarBytesFormat(max_size=512)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTARequestFragmentsMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS):
|
||||
""" request missing fragments """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'))})
|
||||
chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'), max_num=32)})
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -359,14 +359,14 @@ class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQ
|
|||
@dataclass
|
||||
class LocateRangeResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RANGE_RESULTS):
|
||||
""" reports distance to given nodes """
|
||||
ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem)})
|
||||
ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem, max_num=16)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocateRawFTMResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS):
|
||||
""" reports distance to given nodes """
|
||||
peer: str = field(metadata={"format": MacAddressFormat()})
|
||||
results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry)})
|
||||
results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry, max_num=16)})
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -378,4 +378,4 @@ class Reboot(MeshMessage, msg_type=MeshMessageType.REBOOT):
|
|||
@dataclass
|
||||
class ReportError(MeshMessage, msg_type=MeshMessageType.REPORT_ERROR):
|
||||
""" report a critical error to upstream """
|
||||
message: str = field(metadata={"format": VarStrFormat()})
|
||||
message: str = field(metadata={"format": VarStrFormat(max_len=255)})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue