introducing StructType.get_field_format
This commit is contained in:
parent
69c33690af
commit
54031df457
1 changed files with 61 additions and 55 deletions
|
@ -1,5 +1,6 @@
|
||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import Field, dataclass
|
from dataclasses import Field, dataclass
|
||||||
from dataclasses import fields as dataclass_fields
|
from dataclasses import fields as dataclass_fields
|
||||||
|
@ -357,6 +358,24 @@ class StructType:
|
||||||
existing_c_struct = None
|
existing_c_struct = None
|
||||||
c_includes = set()
|
c_includes = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_format(cls, attr_name):
|
||||||
|
attr = getattr(cls, attr_name, None)
|
||||||
|
fields = [f for f in dataclass_fields(cls) if f.name == attr_name]
|
||||||
|
if not fields:
|
||||||
|
raise TypeError(f"{cls}.{attr_name} not a field")
|
||||||
|
field = fields[0]
|
||||||
|
type_ = typing.get_type_hints(cls)[attr_name]
|
||||||
|
if "format" in field.metadata:
|
||||||
|
field_format = field.metadata["format"]
|
||||||
|
field_format.set_field_type(type_)
|
||||||
|
return field_format
|
||||||
|
|
||||||
|
if issubclass(type_, StructType):
|
||||||
|
return type_
|
||||||
|
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||||
|
(cls.__class__.__name__, attr_name))
|
||||||
|
|
||||||
# noinspection PyMethodOverriding
|
# noinspection PyMethodOverriding
|
||||||
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
|
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
|
||||||
cls.union_type_field = union_type_field
|
cls.union_type_field = union_type_field
|
||||||
|
@ -383,9 +402,6 @@ class StructType:
|
||||||
metadata = dict(attr.metadata)
|
metadata = dict(attr.metadata)
|
||||||
if "defining_class" not in metadata:
|
if "defining_class" not in metadata:
|
||||||
metadata["defining_class"] = cls
|
metadata["defining_class"] = cls
|
||||||
if "format" in metadata:
|
|
||||||
metadata["format"].set_field_type(cls.__annotations__[attr_name])
|
|
||||||
|
|
||||||
attr.metadata = metadata
|
attr.metadata = metadata
|
||||||
|
|
||||||
for key, values in cls._union_options.items():
|
for key, values in cls._union_options.items():
|
||||||
|
@ -411,22 +427,24 @@ class StructType:
|
||||||
continue
|
continue
|
||||||
fields.append((attr_name, type_, metadata))
|
fields.append((attr_name, type_, metadata))
|
||||||
for name, type_, metadata in fields:
|
for name, type_, metadata in fields:
|
||||||
if metadata.get("format", None):
|
try:
|
||||||
|
field_format = cls.get_field_format(name)
|
||||||
|
except TypeError:
|
||||||
|
# todo: in case of not a field, ignore it?
|
||||||
|
continue
|
||||||
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
cls._pydantic_fields[name] = (type_, ...)
|
cls._pydantic_fields[name] = (type_, ...)
|
||||||
elif issubclass(type_, StructType):
|
else:
|
||||||
if metadata.get("json_embed"):
|
if metadata.get("json_embed"):
|
||||||
cls._pydantic_fields.update(type_._pydantic_fields)
|
cls._pydantic_fields.update(type_._pydantic_fields)
|
||||||
else:
|
else:
|
||||||
cls._pydantic_fields[name] = (type_.schema, ...)
|
cls._pydantic_fields[name] = (type_.schema, ...)
|
||||||
else:
|
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
|
||||||
(cls.__class__.__name__, name))
|
|
||||||
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
|
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_var_num(cls):
|
def get_var_num(cls):
|
||||||
return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0)
|
return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_types(cls):
|
def get_types(cls):
|
||||||
|
@ -448,7 +466,7 @@ class StructType:
|
||||||
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
||||||
|
|
||||||
for field_ in dataclass_fields(cls):
|
for field_ in dataclass_fields(cls):
|
||||||
data += field_.metadata["format"].encode(getattr(instance, field_.name))
|
data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name))
|
||||||
|
|
||||||
# todo: better
|
# todo: better
|
||||||
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
|
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
|
||||||
|
@ -458,16 +476,14 @@ class StructType:
|
||||||
if field_.name in ignore_fields:
|
if field_.name in ignore_fields:
|
||||||
continue
|
continue
|
||||||
value = getattr(instance, field_.name)
|
value = getattr(instance, field_.name)
|
||||||
if "format" in field_.metadata:
|
field_format = cls.get_field_format(field_.name)
|
||||||
data += field_.metadata["format"].encode(value)
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
elif issubclass(field_.type, StructType):
|
data += field_format.encode(value)
|
||||||
|
else:
|
||||||
if not isinstance(value, field_.type):
|
if not isinstance(value, field_.type):
|
||||||
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||||
(field_.type, cls.__name__, field_.name, value))
|
(field_.type, cls.__name__, field_.name, value))
|
||||||
data += value.encode(value)
|
data += value.encode(value)
|
||||||
else:
|
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
|
||||||
(cls.__class__.__name__, field_.name))
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -476,13 +492,11 @@ class StructType:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
no_init_data = {}
|
no_init_data = {}
|
||||||
for field_ in dataclass_fields(cls):
|
for field_ in dataclass_fields(cls):
|
||||||
if "format" in field_.metadata:
|
field_format = cls.get_field_format(field_.name)
|
||||||
value, data = field_.metadata["format"].decode(data)
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
elif issubclass(field_.type, StructType):
|
value, data = field_format.decode(data)
|
||||||
value, data = field_.type.decode(data)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
value, data = field_.type.decode(data)
|
||||||
(cls.__name__, field_.name))
|
|
||||||
if field_.init:
|
if field_.init:
|
||||||
kwargs[field_.name] = value
|
kwargs[field_.name] = value
|
||||||
else:
|
else:
|
||||||
|
@ -512,7 +526,7 @@ class StructType:
|
||||||
|
|
||||||
for field_ in dataclass_fields(instance):
|
for field_ in dataclass_fields(instance):
|
||||||
if field_.name is cls.union_type_field:
|
if field_.name is cls.union_type_field:
|
||||||
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
|
result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
|
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
|
||||||
|
@ -522,9 +536,10 @@ class StructType:
|
||||||
|
|
||||||
for field_ in dataclass_fields(cls):
|
for field_ in dataclass_fields(cls):
|
||||||
value = getattr(instance, field_.name)
|
value = getattr(instance, field_.name)
|
||||||
if "format" in field_.metadata:
|
field_format = cls.get_field_format(field_.name)
|
||||||
result[field_.name] = field_.metadata["format"].tojson(value)
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
elif issubclass(field_.type, StructType):
|
result[field_.name] = field_format.tojson(value)
|
||||||
|
else:
|
||||||
if not isinstance(value, field_.type):
|
if not isinstance(value, field_.type):
|
||||||
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||||
(field_.type, cls.__name__, field_.name, value))
|
(field_.type, cls.__name__, field_.name, value))
|
||||||
|
@ -534,9 +549,6 @@ class StructType:
|
||||||
result[k] = v
|
result[k] = v
|
||||||
else:
|
else:
|
||||||
result[field_.name] = value.tojson(value)
|
result[field_.name] = value.tojson(value)
|
||||||
else:
|
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
|
||||||
(cls.__class__.__name__, field_.name))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -554,16 +566,14 @@ class StructType:
|
||||||
no_init_data = {}
|
no_init_data = {}
|
||||||
for field_ in dataclass_fields(cls):
|
for field_ in dataclass_fields(cls):
|
||||||
raw_value = data.get(field_.name, None)
|
raw_value = data.get(field_.name, None)
|
||||||
if "format" in field_.metadata:
|
field_format = cls.get_field_format(field_.name)
|
||||||
value = field_.metadata["format"].fromjson(raw_value)
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
elif issubclass(field_.type, StructType):
|
value = field_format.fromjson(raw_value)
|
||||||
|
else:
|
||||||
if field_.metadata.get("json_embed"):
|
if field_.metadata.get("json_embed"):
|
||||||
value = field_.type.fromjson(data)
|
value = field_.type.fromjson(data)
|
||||||
else:
|
else:
|
||||||
value = field_.type.fromjson(raw_value)
|
value = field_.type.fromjson(raw_value)
|
||||||
else:
|
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
|
||||||
(cls.__name__, field_.name))
|
|
||||||
if field_.init:
|
if field_.init:
|
||||||
kwargs[field_.name] = value
|
kwargs[field_.name] = value
|
||||||
else:
|
else:
|
||||||
|
@ -597,20 +607,21 @@ class StructType:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = field_.metadata.get("c_name", field_.name)
|
name = field_.metadata.get("c_name", field_.name)
|
||||||
if "format" in field_.metadata:
|
field_format = cls.get_field_format(field_.name)
|
||||||
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
|
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
|
||||||
items.append((
|
items.append((
|
||||||
(
|
(
|
||||||
("%(typedef_name)s %(name)s;" % {
|
("%(typedef_name)s %(name)s;" % {
|
||||||
"typedef_name": field_.metadata["format"].get_typedef_name(),
|
"typedef_name": field_format.get_typedef_name(),
|
||||||
"name": name,
|
"name": name,
|
||||||
})
|
})
|
||||||
if field_.metadata.get("as_definition")
|
if field_.metadata.get("as_definition")
|
||||||
else field_.metadata["format"].get_c_code(name)
|
else field_format.get_c_code(name)
|
||||||
),
|
),
|
||||||
field_.metadata.get("doc", None),
|
field_.metadata.get("doc", None),
|
||||||
)),
|
)),
|
||||||
elif issubclass(field_.type, StructType):
|
else:
|
||||||
if field_.metadata.get("c_embed"):
|
if field_.metadata.get("c_embed"):
|
||||||
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
|
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
|
||||||
items.extend(embedded_items)
|
items.extend(embedded_items)
|
||||||
|
@ -625,10 +636,7 @@ class StructType:
|
||||||
else field_.type.get_c_code(name, typedef=False)
|
else field_.type.get_c_code(name, typedef=False)
|
||||||
),
|
),
|
||||||
field_.metadata.get("doc", None),
|
field_.metadata.get("doc", None),
|
||||||
)),
|
))
|
||||||
else:
|
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
|
||||||
(cls.__name__, field_.name))
|
|
||||||
|
|
||||||
if cls.union_type_field:
|
if cls.union_type_field:
|
||||||
if not union_only:
|
if not union_only:
|
||||||
|
@ -649,22 +657,20 @@ class StructType:
|
||||||
def get_c_definitions(cls) -> dict[str, str]:
|
def get_c_definitions(cls) -> dict[str, str]:
|
||||||
definitions = {}
|
definitions = {}
|
||||||
for field_ in dataclass_fields(cls):
|
for field_ in dataclass_fields(cls):
|
||||||
if "format" in field_.metadata:
|
field_format = cls.get_field_format(field_.name)
|
||||||
definitions.update(field_.metadata["format"].get_c_definitions())
|
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||||
|
definitions.update(field_format.get_c_definitions())
|
||||||
if field_.metadata.get("as_definition"):
|
if field_.metadata.get("as_definition"):
|
||||||
typedef_name = field_.metadata["format"].get_typedef_name()
|
typedef_name = field_format.get_typedef_name()
|
||||||
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
|
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
|
||||||
"code": ''.join(field_.metadata["format"].get_c_parts()),
|
"code": ''.join(field_format.get_c_parts()),
|
||||||
"name": typedef_name,
|
"name": typedef_name,
|
||||||
}
|
}
|
||||||
elif issubclass(field_.type, StructType):
|
else:
|
||||||
definitions.update(field_.type.get_c_definitions())
|
definitions.update(field_.type.get_c_definitions())
|
||||||
if field_.metadata.get("as_definition"):
|
if field_.metadata.get("as_definition"):
|
||||||
typedef_name = field_.type.get_typedef_name()
|
typedef_name = field_.type.get_typedef_name()
|
||||||
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
|
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
|
||||||
else:
|
|
||||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
|
||||||
(cls.__name__, field_.name))
|
|
||||||
if cls.union_type_field:
|
if cls.union_type_field:
|
||||||
for key, option in cls._union_options[cls.union_type_field].items():
|
for key, option in cls._union_options[cls.union_type_field].items():
|
||||||
definitions.update(option.get_c_definitions())
|
definitions.update(option.get_c_definitions())
|
||||||
|
@ -753,7 +759,7 @@ class StructType:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_size(cls, no_inherited_fields=False) -> int:
|
def get_min_size(cls, no_inherited_fields=False) -> int:
|
||||||
if cls.union_type_field:
|
if cls.union_type_field:
|
||||||
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)])
|
own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)])
|
||||||
union_size = max(
|
union_size = max(
|
||||||
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
|
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
|
||||||
return own_size + union_size
|
return own_size + union_size
|
||||||
|
@ -761,13 +767,13 @@ class StructType:
|
||||||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||||
else:
|
else:
|
||||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||||
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)
|
return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
|
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
|
||||||
if cls.union_type_field:
|
if cls.union_type_field:
|
||||||
own_size = sum(
|
own_size = sum(
|
||||||
[f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
|
[cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
|
||||||
union_size = max(
|
union_size = max(
|
||||||
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
|
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
|
||||||
cls._union_options[cls.union_type_field].values()])
|
cls._union_options[cls.union_type_field].values()])
|
||||||
|
@ -776,7 +782,7 @@ class StructType:
|
||||||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||||
else:
|
else:
|
||||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||||
return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields),
|
return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields),
|
||||||
start=0)
|
start=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue