introducing StructType.get_field_format

This commit is contained in:
Laura Klünder 2024-02-11 18:15:27 +01:00
parent 69c33690af
commit 54031df457

View file

@ -1,5 +1,6 @@
import re import re
import struct import struct
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import Field, dataclass from dataclasses import Field, dataclass
from dataclasses import fields as dataclass_fields from dataclasses import fields as dataclass_fields
@ -357,6 +358,24 @@ class StructType:
existing_c_struct = None existing_c_struct = None
c_includes = set() c_includes = set()
@classmethod
def get_field_format(cls, attr_name):
attr = getattr(cls, attr_name, None)
fields = [f for f in dataclass_fields(cls) if f.name == attr_name]
if not fields:
raise TypeError(f"{cls}.{attr_name} not a field")
field = fields[0]
type_ = typing.get_type_hints(cls)[attr_name]
if "format" in field.metadata:
field_format = field.metadata["format"]
field_format.set_field_type(type_)
return field_format
if issubclass(type_, StructType):
return type_
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, attr_name))
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs): def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
cls.union_type_field = union_type_field cls.union_type_field = union_type_field
@ -383,9 +402,6 @@ class StructType:
metadata = dict(attr.metadata) metadata = dict(attr.metadata)
if "defining_class" not in metadata: if "defining_class" not in metadata:
metadata["defining_class"] = cls metadata["defining_class"] = cls
if "format" in metadata:
metadata["format"].set_field_type(cls.__annotations__[attr_name])
attr.metadata = metadata attr.metadata = metadata
for key, values in cls._union_options.items(): for key, values in cls._union_options.items():
@ -411,22 +427,24 @@ class StructType:
continue continue
fields.append((attr_name, type_, metadata)) fields.append((attr_name, type_, metadata))
for name, type_, metadata in fields: for name, type_, metadata in fields:
if metadata.get("format", None): try:
field_format = cls.get_field_format(name)
except TypeError:
# todo: in case of not a field, ignore it?
continue
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
cls._pydantic_fields[name] = (type_, ...) cls._pydantic_fields[name] = (type_, ...)
elif issubclass(type_, StructType): else:
if metadata.get("json_embed"): if metadata.get("json_embed"):
cls._pydantic_fields.update(type_._pydantic_fields) cls._pydantic_fields.update(type_._pydantic_fields)
else: else:
cls._pydantic_fields[name] = (type_.schema, ...) cls._pydantic_fields[name] = (type_.schema, ...)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, name))
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields) cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
@classmethod @classmethod
def get_var_num(cls): def get_var_num(cls):
return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0) return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0)
@classmethod @classmethod
def get_types(cls): def get_types(cls):
@ -448,7 +466,7 @@ class StructType:
raise ValueError('expected value of type %r, got %r' % (cls, instance)) raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in dataclass_fields(cls): for field_ in dataclass_fields(cls):
data += field_.metadata["format"].encode(getattr(instance, field_.name)) data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name))
# todo: better # todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls))) data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
@ -458,16 +476,14 @@ class StructType:
if field_.name in ignore_fields: if field_.name in ignore_fields:
continue continue
value = getattr(instance, field_.name) value = getattr(instance, field_.name)
if "format" in field_.metadata: field_format = cls.get_field_format(field_.name)
data += field_.metadata["format"].encode(value) if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
elif issubclass(field_.type, StructType): data += field_format.encode(value)
else:
if not isinstance(value, field_.type): if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' % raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value)) (field_.type, cls.__name__, field_.name, value))
data += value.encode(value) data += value.encode(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return data return data
@classmethod @classmethod
@ -476,13 +492,11 @@ class StructType:
kwargs = {} kwargs = {}
no_init_data = {} no_init_data = {}
for field_ in dataclass_fields(cls): for field_ in dataclass_fields(cls):
if "format" in field_.metadata: field_format = cls.get_field_format(field_.name)
value, data = field_.metadata["format"].decode(data) if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
elif issubclass(field_.type, StructType): value, data = field_format.decode(data)
value, data = field_.type.decode(data)
else: else:
raise TypeError('field %s.%s has no format and is no StructType' % value, data = field_.type.decode(data)
(cls.__name__, field_.name))
if field_.init: if field_.init:
kwargs[field_.name] = value kwargs[field_.name] = value
else: else:
@ -512,7 +526,7 @@ class StructType:
for field_ in dataclass_fields(instance): for field_ in dataclass_fields(instance):
if field_.name is cls.union_type_field: if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name)) result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name))
break break
else: else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field) raise TypeError('couldn\'t find %s value' % cls.union_type_field)
@ -522,9 +536,10 @@ class StructType:
for field_ in dataclass_fields(cls): for field_ in dataclass_fields(cls):
value = getattr(instance, field_.name) value = getattr(instance, field_.name)
if "format" in field_.metadata: field_format = cls.get_field_format(field_.name)
result[field_.name] = field_.metadata["format"].tojson(value) if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
elif issubclass(field_.type, StructType): result[field_.name] = field_format.tojson(value)
else:
if not isinstance(value, field_.type): if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' % raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value)) (field_.type, cls.__name__, field_.name, value))
@ -534,9 +549,6 @@ class StructType:
result[k] = v result[k] = v
else: else:
result[field_.name] = value.tojson(value) result[field_.name] = value.tojson(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return result return result
@classmethod @classmethod
@ -554,16 +566,14 @@ class StructType:
no_init_data = {} no_init_data = {}
for field_ in dataclass_fields(cls): for field_ in dataclass_fields(cls):
raw_value = data.get(field_.name, None) raw_value = data.get(field_.name, None)
if "format" in field_.metadata: field_format = cls.get_field_format(field_.name)
value = field_.metadata["format"].fromjson(raw_value) if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
elif issubclass(field_.type, StructType): value = field_format.fromjson(raw_value)
else:
if field_.metadata.get("json_embed"): if field_.metadata.get("json_embed"):
value = field_.type.fromjson(data) value = field_.type.fromjson(data)
else: else:
value = field_.type.fromjson(raw_value) value = field_.type.fromjson(raw_value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if field_.init: if field_.init:
kwargs[field_.name] = value kwargs[field_.name] = value
else: else:
@ -597,20 +607,21 @@ class StructType:
continue continue
name = field_.metadata.get("c_name", field_.name) name = field_.metadata.get("c_name", field_.name)
if "format" in field_.metadata: field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls: if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
items.append(( items.append((
( (
("%(typedef_name)s %(name)s;" % { ("%(typedef_name)s %(name)s;" % {
"typedef_name": field_.metadata["format"].get_typedef_name(), "typedef_name": field_format.get_typedef_name(),
"name": name, "name": name,
}) })
if field_.metadata.get("as_definition") if field_.metadata.get("as_definition")
else field_.metadata["format"].get_c_code(name) else field_format.get_c_code(name)
), ),
field_.metadata.get("doc", None), field_.metadata.get("doc", None),
)), )),
elif issubclass(field_.type, StructType): else:
if field_.metadata.get("c_embed"): if field_.metadata.get("c_embed"):
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only) embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
items.extend(embedded_items) items.extend(embedded_items)
@ -625,10 +636,7 @@ class StructType:
else field_.type.get_c_code(name, typedef=False) else field_.type.get_c_code(name, typedef=False)
), ),
field_.metadata.get("doc", None), field_.metadata.get("doc", None),
)), ))
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if cls.union_type_field: if cls.union_type_field:
if not union_only: if not union_only:
@ -649,22 +657,20 @@ class StructType:
def get_c_definitions(cls) -> dict[str, str]: def get_c_definitions(cls) -> dict[str, str]:
definitions = {} definitions = {}
for field_ in dataclass_fields(cls): for field_ in dataclass_fields(cls):
if "format" in field_.metadata: field_format = cls.get_field_format(field_.name)
definitions.update(field_.metadata["format"].get_c_definitions()) if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
definitions.update(field_format.get_c_definitions())
if field_.metadata.get("as_definition"): if field_.metadata.get("as_definition"):
typedef_name = field_.metadata["format"].get_typedef_name() typedef_name = field_format.get_typedef_name()
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % { definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
"code": ''.join(field_.metadata["format"].get_c_parts()), "code": ''.join(field_format.get_c_parts()),
"name": typedef_name, "name": typedef_name,
} }
elif issubclass(field_.type, StructType): else:
definitions.update(field_.type.get_c_definitions()) definitions.update(field_.type.get_c_definitions())
if field_.metadata.get("as_definition"): if field_.metadata.get("as_definition"):
typedef_name = field_.type.get_typedef_name() typedef_name = field_.type.get_typedef_name()
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True) definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if cls.union_type_field: if cls.union_type_field:
for key, option in cls._union_options[cls.union_type_field].items(): for key, option in cls._union_options[cls.union_type_field].items():
definitions.update(option.get_c_definitions()) definitions.update(option.get_c_definitions())
@ -753,7 +759,7 @@ class StructType:
@classmethod @classmethod
def get_min_size(cls, no_inherited_fields=False) -> int: def get_min_size(cls, no_inherited_fields=False) -> int:
if cls.union_type_field: if cls.union_type_field:
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)]) own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)])
union_size = max( union_size = max(
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()]) [0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
return own_size + union_size return own_size + union_size
@ -761,13 +767,13 @@ class StructType:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else: else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0) return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0)
@classmethod @classmethod
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int: def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
if cls.union_type_field: if cls.union_type_field:
own_size = sum( own_size = sum(
[f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)]) [cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
union_size = max( union_size = max(
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in [0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
cls._union_options[cls.union_type_field].values()]) cls._union_options[cls.union_type_field].values()])
@ -776,7 +782,7 @@ class StructType:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else: else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields), return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields),
start=0) start=0)