introducing StructType.get_field_format
This commit is contained in:
parent
69c33690af
commit
54031df457
1 changed files with 61 additions and 55 deletions
|
@ -1,5 +1,6 @@
|
|||
import re
|
||||
import struct
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import Field, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
|
@ -357,6 +358,24 @@ class StructType:
|
|||
existing_c_struct = None
|
||||
c_includes = set()
|
||||
|
||||
@classmethod
|
||||
def get_field_format(cls, attr_name):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
fields = [f for f in dataclass_fields(cls) if f.name == attr_name]
|
||||
if not fields:
|
||||
raise TypeError(f"{cls}.{attr_name} not a field")
|
||||
field = fields[0]
|
||||
type_ = typing.get_type_hints(cls)[attr_name]
|
||||
if "format" in field.metadata:
|
||||
field_format = field.metadata["format"]
|
||||
field_format.set_field_type(type_)
|
||||
return field_format
|
||||
|
||||
if issubclass(type_, StructType):
|
||||
return type_
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__class__.__name__, attr_name))
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
|
||||
cls.union_type_field = union_type_field
|
||||
|
@ -383,9 +402,6 @@ class StructType:
|
|||
metadata = dict(attr.metadata)
|
||||
if "defining_class" not in metadata:
|
||||
metadata["defining_class"] = cls
|
||||
if "format" in metadata:
|
||||
metadata["format"].set_field_type(cls.__annotations__[attr_name])
|
||||
|
||||
attr.metadata = metadata
|
||||
|
||||
for key, values in cls._union_options.items():
|
||||
|
@ -411,22 +427,24 @@ class StructType:
|
|||
continue
|
||||
fields.append((attr_name, type_, metadata))
|
||||
for name, type_, metadata in fields:
|
||||
if metadata.get("format", None):
|
||||
try:
|
||||
field_format = cls.get_field_format(name)
|
||||
except TypeError:
|
||||
# todo: in case of not a field, ignore it?
|
||||
continue
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
cls._pydantic_fields[name] = (type_, ...)
|
||||
elif issubclass(type_, StructType):
|
||||
else:
|
||||
if metadata.get("json_embed"):
|
||||
cls._pydantic_fields.update(type_._pydantic_fields)
|
||||
else:
|
||||
cls._pydantic_fields[name] = (type_.schema, ...)
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__class__.__name__, name))
|
||||
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_var_num(cls):
|
||||
return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0)
|
||||
return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls):
|
||||
|
@ -448,7 +466,7 @@ class StructType:
|
|||
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
||||
|
||||
for field_ in dataclass_fields(cls):
|
||||
data += field_.metadata["format"].encode(getattr(instance, field_.name))
|
||||
data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name))
|
||||
|
||||
# todo: better
|
||||
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
|
||||
|
@ -458,16 +476,14 @@ class StructType:
|
|||
if field_.name in ignore_fields:
|
||||
continue
|
||||
value = getattr(instance, field_.name)
|
||||
if "format" in field_.metadata:
|
||||
data += field_.metadata["format"].encode(value)
|
||||
elif issubclass(field_.type, StructType):
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
data += field_format.encode(value)
|
||||
else:
|
||||
if not isinstance(value, field_.type):
|
||||
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||
(field_.type, cls.__name__, field_.name, value))
|
||||
data += value.encode(value)
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__class__.__name__, field_.name))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
|
@ -476,13 +492,11 @@ class StructType:
|
|||
kwargs = {}
|
||||
no_init_data = {}
|
||||
for field_ in dataclass_fields(cls):
|
||||
if "format" in field_.metadata:
|
||||
value, data = field_.metadata["format"].decode(data)
|
||||
elif issubclass(field_.type, StructType):
|
||||
value, data = field_.type.decode(data)
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
value, data = field_format.decode(data)
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__name__, field_.name))
|
||||
value, data = field_.type.decode(data)
|
||||
if field_.init:
|
||||
kwargs[field_.name] = value
|
||||
else:
|
||||
|
@ -512,7 +526,7 @@ class StructType:
|
|||
|
||||
for field_ in dataclass_fields(instance):
|
||||
if field_.name is cls.union_type_field:
|
||||
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
|
||||
result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name))
|
||||
break
|
||||
else:
|
||||
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
|
||||
|
@ -522,9 +536,10 @@ class StructType:
|
|||
|
||||
for field_ in dataclass_fields(cls):
|
||||
value = getattr(instance, field_.name)
|
||||
if "format" in field_.metadata:
|
||||
result[field_.name] = field_.metadata["format"].tojson(value)
|
||||
elif issubclass(field_.type, StructType):
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
result[field_.name] = field_format.tojson(value)
|
||||
else:
|
||||
if not isinstance(value, field_.type):
|
||||
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||
(field_.type, cls.__name__, field_.name, value))
|
||||
|
@ -534,9 +549,6 @@ class StructType:
|
|||
result[k] = v
|
||||
else:
|
||||
result[field_.name] = value.tojson(value)
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__class__.__name__, field_.name))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
@ -554,16 +566,14 @@ class StructType:
|
|||
no_init_data = {}
|
||||
for field_ in dataclass_fields(cls):
|
||||
raw_value = data.get(field_.name, None)
|
||||
if "format" in field_.metadata:
|
||||
value = field_.metadata["format"].fromjson(raw_value)
|
||||
elif issubclass(field_.type, StructType):
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
value = field_format.fromjson(raw_value)
|
||||
else:
|
||||
if field_.metadata.get("json_embed"):
|
||||
value = field_.type.fromjson(data)
|
||||
else:
|
||||
value = field_.type.fromjson(raw_value)
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__name__, field_.name))
|
||||
if field_.init:
|
||||
kwargs[field_.name] = value
|
||||
else:
|
||||
|
@ -597,20 +607,21 @@ class StructType:
|
|||
continue
|
||||
|
||||
name = field_.metadata.get("c_name", field_.name)
|
||||
if "format" in field_.metadata:
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
|
||||
items.append((
|
||||
(
|
||||
("%(typedef_name)s %(name)s;" % {
|
||||
"typedef_name": field_.metadata["format"].get_typedef_name(),
|
||||
"typedef_name": field_format.get_typedef_name(),
|
||||
"name": name,
|
||||
})
|
||||
if field_.metadata.get("as_definition")
|
||||
else field_.metadata["format"].get_c_code(name)
|
||||
else field_format.get_c_code(name)
|
||||
),
|
||||
field_.metadata.get("doc", None),
|
||||
)),
|
||||
elif issubclass(field_.type, StructType):
|
||||
else:
|
||||
if field_.metadata.get("c_embed"):
|
||||
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
|
||||
items.extend(embedded_items)
|
||||
|
@ -625,10 +636,7 @@ class StructType:
|
|||
else field_.type.get_c_code(name, typedef=False)
|
||||
),
|
||||
field_.metadata.get("doc", None),
|
||||
)),
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__name__, field_.name))
|
||||
))
|
||||
|
||||
if cls.union_type_field:
|
||||
if not union_only:
|
||||
|
@ -649,22 +657,20 @@ class StructType:
|
|||
def get_c_definitions(cls) -> dict[str, str]:
|
||||
definitions = {}
|
||||
for field_ in dataclass_fields(cls):
|
||||
if "format" in field_.metadata:
|
||||
definitions.update(field_.metadata["format"].get_c_definitions())
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
definitions.update(field_format.get_c_definitions())
|
||||
if field_.metadata.get("as_definition"):
|
||||
typedef_name = field_.metadata["format"].get_typedef_name()
|
||||
typedef_name = field_format.get_typedef_name()
|
||||
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
|
||||
"code": ''.join(field_.metadata["format"].get_c_parts()),
|
||||
"code": ''.join(field_format.get_c_parts()),
|
||||
"name": typedef_name,
|
||||
}
|
||||
elif issubclass(field_.type, StructType):
|
||||
else:
|
||||
definitions.update(field_.type.get_c_definitions())
|
||||
if field_.metadata.get("as_definition"):
|
||||
typedef_name = field_.type.get_typedef_name()
|
||||
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
|
||||
else:
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__name__, field_.name))
|
||||
if cls.union_type_field:
|
||||
for key, option in cls._union_options[cls.union_type_field].items():
|
||||
definitions.update(option.get_c_definitions())
|
||||
|
@ -753,7 +759,7 @@ class StructType:
|
|||
@classmethod
|
||||
def get_min_size(cls, no_inherited_fields=False) -> int:
|
||||
if cls.union_type_field:
|
||||
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)])
|
||||
own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)])
|
||||
union_size = max(
|
||||
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
|
||||
return own_size + union_size
|
||||
|
@ -761,13 +767,13 @@ class StructType:
|
|||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||
else:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)
|
||||
return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0)
|
||||
|
||||
@classmethod
|
||||
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
|
||||
if cls.union_type_field:
|
||||
own_size = sum(
|
||||
[f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
|
||||
[cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
|
||||
union_size = max(
|
||||
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
|
||||
cls._union_options[cls.union_type_field].values()])
|
||||
|
@ -776,7 +782,7 @@ class StructType:
|
|||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||
else:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||
return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields),
|
||||
return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields),
|
||||
start=0)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue