new parsing works

This commit is contained in:
Laura Klünder 2023-10-06 01:06:30 +02:00
parent 16f47168a2
commit da5ff59c96
8 changed files with 91 additions and 89 deletions

View file

@ -2,7 +2,7 @@ import re
import struct
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from enum import IntEnum, unique
from enum import IntEnum, unique, Enum
from typing import Self, Sequence, Any
from c3nav.mesh.utils import indent_c
@ -17,7 +17,7 @@ class BaseFormat(ABC):
@classmethod
@abstractmethod
def decode(cls, data) -> tuple[Any, bytes]:
def decode(cls, data: bytes) -> tuple[Any, bytes]:
pass
def fromjson(self, data):
@ -48,7 +48,9 @@ class SimpleFormat(BaseFormat):
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
def encode(self, value):
return struct.pack(self.fmt, (value, ) if self.num == 1 else tuple(value))
if self.num == 1:
return struct.pack(self.fmt, value)
return struct.pack(self.fmt, *value)
def decode(self, data: bytes) -> tuple[Any, bytes]:
value = struct.unpack(self.fmt, data[:self.size])
@ -104,7 +106,7 @@ class FixedHexFormat(SimpleFormat):
super().__init__('%dB' % self.num)
def encode(self, value: str):
return super().encode(tuple(bytes.fromhex(value)))
return super().encode(tuple(bytes.fromhex(value.replace(':', ''))))
def decode(self, data: bytes) -> tuple[str, bytes]:
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
@ -137,10 +139,12 @@ class VarArrayFormat(BaseVarFormat):
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
return [
self.child_type.decode(data[i:i+self.child_size])
for i in range(self.num_size, self.num_size+num*self.child_size, self.child_size)
], data[self.num_size+num*self.child_size:]
data = data[self.num_size:]
result = []
for i in range(num):
item, data = self.child_type.decode(data)
result.append(item)
return result, data
def get_c_parts(self):
pre, post = self.child_type.get_c_parts()
@ -192,23 +196,34 @@ class StructType:
super().__init_subclass__(**kwargs)
@classmethod
def encode(cls, instance) -> bytes:
def get_types(cls):
if not cls.union_type_field:
raise TypeError('Not a union class')
return cls._union_options[cls.union_type_field]
@classmethod
def get_type(cls, type_id) -> Self:
if not cls.union_type_field:
raise TypeError('Not a union class')
return cls.get_types()[type_id]
@classmethod
def encode(cls, instance, ignore_fields=()) -> bytes:
data = bytes()
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
if field_.name is cls.union_type_field:
data += field_.metadata["format"].encode(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
for field_ in fields(cls):
data += field_.metadata["format"].encode(getattr(instance, field_.name))
data += instance.encode(instance)
# todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in fields(cls)))
return data
for field_ in fields(cls):
if field_.name in ignore_fields:
continue
value = getattr(instance, field_.name)
if "format" in field_.metadata:
data += field_.metadata["format"].encode(value)
@ -224,30 +239,35 @@ class StructType:
@classmethod
def decode(cls, data: bytes) -> Self:
values = {}
orig_data = data
kwargs = {}
no_init_data = {}
for field_ in fields(cls):
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
value, data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
value, data = field_.type.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
values[field_.name] = field_.metadata["format"].decode(data)
if field_.init:
kwargs[field_.name] = value
else:
no_init_data[field_.name] = value
if cls.union_type_field:
try:
type_value = values[cls.union_type_field]
type_value = no_init_data[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
klass = cls.get_type(type_value)
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.decode(data)
return cls(**values)
return klass.decode(orig_data)
return cls(**kwargs), data
@classmethod
def tojson(cls, instance) -> dict:
@ -259,7 +279,7 @@ class StructType:
for field_ in fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name))
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
@ -282,32 +302,42 @@ class StructType:
return result
@classmethod
def fromjson(cls, data):
def upgrade_json(cls, data):
return data
@classmethod
def fromjson(cls, data: dict):
data = data.copy()
# todo: upgrade_json
cls.upgrade_json(data)
kwargs = {}
no_init_data = {}
for field_ in fields(cls):
raw_value = data.get(field_.name, None)
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
value = field_.metadata["format"].fromjson(raw_value)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
value = field_.type.fromjson(raw_value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
kwargs[field_.name], data = field_.metadata["format"].decode(data)
if field_.init:
kwargs[field_.name] = value
else:
no_init_data[field_.name] = value
if cls.union_type_field:
try:
type_value = kwargs[cls.union_type_field]
type_value = no_init_data.pop(cls.union_type_field)
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
klass = cls.get_type(type_value)
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
raise TypeError('union_type_field %s.%s value 0x%02x no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.fromjson(data)
@ -342,7 +372,7 @@ class StructType:
if cls.union_type_field:
parent_fields = set(field_.name for field_ in fields(cls))
union_items = []
for key, option in cls._union_options[cls.union_type_field].items():
for key, option in cls.get_types().items():
base_name = normalize_name(getattr(key, 'name', option.__name__))
if union_member_as_types:
struct_name = cls.get_struct_name(base_name)
@ -360,7 +390,7 @@ class StructType:
)
union_items.append(
"uint8_t bytes[%s];" % max(
(option.get_min_size() for option in cls._union_options[cls.union_type_field].values()),
(option.get_min_size() for option in cls.get_types().values()),
default=0,
)
)
@ -417,7 +447,7 @@ class StructType:
if cls.union_type_field:
return (
{f.name: field for f in fields()}[cls.union_type_field].metadata["format"].get_min_size() +
sum((option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), start=0)
sum((option.get_min_size() for option in cls.get_types().values()), start=0)
)
return sum((f.metadata.get("format", f.type).get_min_size() for f in fields(cls)), start=0)