introducing StructType.get_field_format

This commit is contained in:
Laura Klünder 2024-02-11 18:15:27 +01:00
parent 69c33690af
commit 54031df457

View file

@ -1,5 +1,6 @@
import re
import struct
import typing
from abc import ABC, abstractmethod
from dataclasses import Field, dataclass
from dataclasses import fields as dataclass_fields
@ -357,6 +358,24 @@ class StructType:
existing_c_struct = None
c_includes = set()
@classmethod
def get_field_format(cls, attr_name):
attr = getattr(cls, attr_name, None)
fields = [f for f in dataclass_fields(cls) if f.name == attr_name]
if not fields:
raise TypeError(f"{cls}.{attr_name} not a field")
field = fields[0]
type_ = typing.get_type_hints(cls)[attr_name]
if "format" in field.metadata:
field_format = field.metadata["format"]
field_format.set_field_type(type_)
return field_format
if issubclass(type_, StructType):
return type_
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, attr_name))
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
cls.union_type_field = union_type_field
@ -383,9 +402,6 @@ class StructType:
metadata = dict(attr.metadata)
if "defining_class" not in metadata:
metadata["defining_class"] = cls
if "format" in metadata:
metadata["format"].set_field_type(cls.__annotations__[attr_name])
attr.metadata = metadata
for key, values in cls._union_options.items():
@ -411,22 +427,24 @@ class StructType:
continue
fields.append((attr_name, type_, metadata))
for name, type_, metadata in fields:
if metadata.get("format", None):
try:
field_format = cls.get_field_format(name)
except TypeError:
# todo: in case of not a field, ignore it?
continue
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
cls._pydantic_fields[name] = (type_, ...)
elif issubclass(type_, StructType):
else:
if metadata.get("json_embed"):
cls._pydantic_fields.update(type_._pydantic_fields)
else:
cls._pydantic_fields[name] = (type_.schema, ...)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, name))
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
super().__init_subclass__(**kwargs)
@classmethod
def get_var_num(cls):
return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0)
return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0)
@classmethod
def get_types(cls):
@ -448,7 +466,7 @@ class StructType:
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in dataclass_fields(cls):
data += field_.metadata["format"].encode(getattr(instance, field_.name))
data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name))
# todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
@ -458,16 +476,14 @@ class StructType:
if field_.name in ignore_fields:
continue
value = getattr(instance, field_.name)
if "format" in field_.metadata:
data += field_.metadata["format"].encode(value)
elif issubclass(field_.type, StructType):
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
data += field_format.encode(value)
else:
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
data += value.encode(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return data
@classmethod
@ -476,13 +492,11 @@ class StructType:
kwargs = {}
no_init_data = {}
for field_ in dataclass_fields(cls):
if "format" in field_.metadata:
value, data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
value, data = field_.type.decode(data)
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
value, data = field_format.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
value, data = field_.type.decode(data)
if field_.init:
kwargs[field_.name] = value
else:
@ -512,7 +526,7 @@ class StructType:
for field_ in dataclass_fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
@ -522,9 +536,10 @@ class StructType:
for field_ in dataclass_fields(cls):
value = getattr(instance, field_.name)
if "format" in field_.metadata:
result[field_.name] = field_.metadata["format"].tojson(value)
elif issubclass(field_.type, StructType):
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
result[field_.name] = field_format.tojson(value)
else:
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
@ -534,9 +549,6 @@ class StructType:
result[k] = v
else:
result[field_.name] = value.tojson(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return result
@classmethod
@ -554,16 +566,14 @@ class StructType:
no_init_data = {}
for field_ in dataclass_fields(cls):
raw_value = data.get(field_.name, None)
if "format" in field_.metadata:
value = field_.metadata["format"].fromjson(raw_value)
elif issubclass(field_.type, StructType):
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
value = field_format.fromjson(raw_value)
else:
if field_.metadata.get("json_embed"):
value = field_.type.fromjson(data)
else:
value = field_.type.fromjson(raw_value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if field_.init:
kwargs[field_.name] = value
else:
@ -597,20 +607,21 @@ class StructType:
continue
name = field_.metadata.get("c_name", field_.name)
if "format" in field_.metadata:
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
items.append((
(
("%(typedef_name)s %(name)s;" % {
"typedef_name": field_.metadata["format"].get_typedef_name(),
"typedef_name": field_format.get_typedef_name(),
"name": name,
})
if field_.metadata.get("as_definition")
else field_.metadata["format"].get_c_code(name)
else field_format.get_c_code(name)
),
field_.metadata.get("doc", None),
)),
elif issubclass(field_.type, StructType):
else:
if field_.metadata.get("c_embed"):
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
items.extend(embedded_items)
@ -625,10 +636,7 @@ class StructType:
else field_.type.get_c_code(name, typedef=False)
),
field_.metadata.get("doc", None),
)),
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
))
if cls.union_type_field:
if not union_only:
@ -649,22 +657,20 @@ class StructType:
def get_c_definitions(cls) -> dict[str, str]:
definitions = {}
for field_ in dataclass_fields(cls):
if "format" in field_.metadata:
definitions.update(field_.metadata["format"].get_c_definitions())
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
definitions.update(field_format.get_c_definitions())
if field_.metadata.get("as_definition"):
typedef_name = field_.metadata["format"].get_typedef_name()
typedef_name = field_format.get_typedef_name()
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
"code": ''.join(field_.metadata["format"].get_c_parts()),
"code": ''.join(field_format.get_c_parts()),
"name": typedef_name,
}
elif issubclass(field_.type, StructType):
else:
definitions.update(field_.type.get_c_definitions())
if field_.metadata.get("as_definition"):
typedef_name = field_.type.get_typedef_name()
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if cls.union_type_field:
for key, option in cls._union_options[cls.union_type_field].items():
definitions.update(option.get_c_definitions())
@ -753,7 +759,7 @@ class StructType:
@classmethod
def get_min_size(cls, no_inherited_fields=False) -> int:
if cls.union_type_field:
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)])
own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)])
union_size = max(
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
return own_size + union_size
@ -761,13 +767,13 @@ class StructType:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)
return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0)
@classmethod
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
if cls.union_type_field:
own_size = sum(
[f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
[cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
union_size = max(
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
cls._union_options[cls.union_type_field].values()])
@ -776,7 +782,7 @@ class StructType:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields),
return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields),
start=0)