merge with Laura's commits

This commit is contained in:
Gwendolyn 2023-10-06 02:19:57 +02:00
parent 6ed57d99d2
commit 85fb5d2a7e
2 changed files with 66 additions and 486 deletions

View file

@ -9,7 +9,6 @@ from c3nav.mesh.utils import indent_c
class BaseFormat(ABC):
def get_var_num(self):
return 0
@ -50,7 +49,9 @@ class SimpleFormat(BaseFormat):
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
def encode(self, value):
return struct.pack(self.fmt, (value,) if self.num == 1 else tuple(value))
if self.num == 1:
return struct.pack(self.fmt, value)
return struct.pack(self.fmt, *value)
def decode(self, data: bytes) -> tuple[Any, bytes]:
value = struct.unpack(self.fmt, data[:self.size])
@ -106,7 +107,7 @@ class FixedHexFormat(SimpleFormat):
super().__init__('%dB' % self.num)
def encode(self, value: str):
return super().encode(tuple(bytes.fromhex(value)))
return super().encode(tuple(bytes.fromhex(value.replace(':', ''))))
def decode(self, data: bytes) -> tuple[str, bytes]:
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
@ -136,17 +137,19 @@ class VarArrayFormat(BaseVarFormat):
pass
def encode(self, values: Sequence) -> bytes:
data = struct.pack(self.num_fmt, (len(values),))
data = struct.pack(self.num_fmt, len(values))
for value in values:
data += self.child_type.encode(value)
return data
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
return [
self.child_type.decode(data[i:i + self.child_size])
for i in range(self.num_size, self.num_size + num * self.child_size, self.child_size)
], data[self.num_size + num * self.child_size:]
data = data[self.num_size:]
result = []
for i in range(num):
item, data = self.child_type.decode(data)
result.append(item)
return result, data
def get_c_parts(self):
pre, post = self.child_type.get_c_parts()
@ -159,7 +162,7 @@ class VarStrFormat(BaseVarFormat):
return 1
def encode(self, value: str) -> bytes:
return struct.pack(self.num_fmt, (len(str),)) + value.encode()
return struct.pack(self.num_fmt, len(str)) + value.encode()
def decode(self, data: bytes) -> tuple[str, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
@ -210,23 +213,34 @@ class StructType:
return sum([f.metadata.get("format", f.type).get_var_num() for f in fields(cls)], start=0)
@classmethod
def encode(cls, instance) -> bytes:
def get_types(cls):
if not cls.union_type_field:
raise TypeError('Not a union class')
return cls._union_options[cls.union_type_field]
@classmethod
def get_type(cls, type_id) -> Self:
if not cls.union_type_field:
raise TypeError('Not a union class')
return cls.get_types()[type_id]
@classmethod
def encode(cls, instance, ignore_fields=()) -> bytes:
data = bytes()
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
if field_.name is cls.union_type_field:
data += field_.metadata["format"].encode(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
for field_ in fields(cls):
data += field_.metadata["format"].encode(getattr(instance, field_.name))
data += instance.encode(instance)
# todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in fields(cls)))
return data
for field_ in fields(cls):
if field_.name in ignore_fields:
continue
value = getattr(instance, field_.name)
if "format" in field_.metadata:
data += field_.metadata["format"].encode(value)
@ -241,31 +255,36 @@ class StructType:
return data
@classmethod
def decode(cls, data: bytes) -> Self:
values = {}
def decode(cls, data: bytes) -> tuple[Self, bytes]:
orig_data = data
kwargs = {}
no_init_data = {}
for field_ in fields(cls):
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
value, data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
value, data = field_.type.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
values[field_.name] = field_.metadata["format"].decode(data)
if field_.init:
kwargs[field_.name] = value
else:
no_init_data[field_.name] = value
if cls.union_type_field:
try:
type_value = values[cls.union_type_field]
type_value = no_init_data[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
klass = cls.get_type(type_value)
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.decode(data)
return cls(**values)
return klass.decode(orig_data)
return cls(**kwargs), data
@classmethod
def tojson(cls, instance) -> dict:
@ -277,7 +296,7 @@ class StructType:
for field_ in fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name))
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
@ -300,32 +319,42 @@ class StructType:
return result
@classmethod
def fromjson(cls, data):
def upgrade_json(cls, data):
return data
@classmethod
def fromjson(cls, data: dict):
data = data.copy()
# todo: upgrade_json
cls.upgrade_json(data)
kwargs = {}
no_init_data = {}
for field_ in fields(cls):
raw_value = data.get(field_.name, None)
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
value = field_.metadata["format"].fromjson(raw_value)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
value = field_.type.fromjson(raw_value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
kwargs[field_.name], data = field_.metadata["format"].decode(data)
if field_.init:
kwargs[field_.name] = value
else:
no_init_data[field_.name] = value
if cls.union_type_field:
try:
type_value = kwargs[cls.union_type_field]
type_value = no_init_data.pop(cls.union_type_field)
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
klass = cls.get_type(type_value)
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
raise TypeError('union_type_field %s.%s value 0x%02x no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.fromjson(data)