diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py index 6742b8de..0b3f5de3 100644 --- a/src/c3nav/mesh/baseformats.py +++ b/src/c3nav/mesh/baseformats.py @@ -1,5 +1,6 @@ import re import struct +import typing from abc import ABC, abstractmethod from dataclasses import Field, dataclass from dataclasses import fields as dataclass_fields @@ -357,6 +358,24 @@ class StructType: existing_c_struct = None c_includes = set() + @classmethod + def get_field_format(cls, attr_name): + attr = getattr(cls, attr_name, None) + fields = [f for f in dataclass_fields(cls) if f.name == attr_name] + if not fields: + raise TypeError(f"{cls}.{attr_name} not a field") + field = fields[0] + type_ = typing.get_type_hints(cls)[attr_name] + if "format" in field.metadata: + field_format = field.metadata["format"] + field_format.set_field_type(type_) + return field_format + + if issubclass(type_, StructType): + return type_ + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__class__.__name__, attr_name)) + # noinspection PyMethodOverriding def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs): cls.union_type_field = union_type_field @@ -383,9 +402,6 @@ class StructType: metadata = dict(attr.metadata) if "defining_class" not in metadata: metadata["defining_class"] = cls - if "format" in metadata: - metadata["format"].set_field_type(cls.__annotations__[attr_name]) - attr.metadata = metadata for key, values in cls._union_options.items(): @@ -411,22 +427,24 @@ class StructType: continue fields.append((attr_name, type_, metadata)) for name, type_, metadata in fields: - if metadata.get("format", None): + try: + field_format = cls.get_field_format(name) + except TypeError: + # todo: in case of not a field, ignore it? + continue + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): cls._pydantic_fields[name] = (type_, ...) - elif issubclass(type_, StructType): + else: if metadata.get("json_embed"): cls._pydantic_fields.update(type_._pydantic_fields) else: cls._pydantic_fields[name] = (type_.schema, ...) - else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__class__.__name__, name)) cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields) super().__init_subclass__(**kwargs) @classmethod def get_var_num(cls): - return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0) + return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0) @classmethod def get_types(cls): @@ -448,7 +466,7 @@ class StructType: raise ValueError('expected value of type %r, got %r' % (cls, instance)) for field_ in dataclass_fields(cls): - data += field_.metadata["format"].encode(getattr(instance, field_.name)) + data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name)) # todo: better data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls))) @@ -458,16 +476,14 @@ class StructType: if field_.name in ignore_fields: continue value = getattr(instance, field_.name) - if "format" in field_.metadata: - data += field_.metadata["format"].encode(value) - elif issubclass(field_.type, StructType): + field_format = cls.get_field_format(field_.name) + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): + data += field_format.encode(value) + else: if not isinstance(value, field_.type): raise ValueError('expected value of type %r for %s.%s, got %r' % (field_.type, cls.__name__, field_.name, value)) data += value.encode(value) - else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__class__.__name__, field_.name)) return data @classmethod @@ -476,13 +492,11 @@ class StructType: kwargs = {} no_init_data = {} for field_ in dataclass_fields(cls): - if "format" in field_.metadata: - value, data = field_.metadata["format"].decode(data) - elif issubclass(field_.type, StructType): - value, data = field_.type.decode(data) + field_format = cls.get_field_format(field_.name) + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): + value, data = field_format.decode(data) else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__name__, field_.name)) + value, data = field_.type.decode(data) if field_.init: kwargs[field_.name] = value else: @@ -512,7 +526,7 @@ class StructType: for field_ in dataclass_fields(instance): if field_.name is cls.union_type_field: - result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name)) + result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name)) break else: raise TypeError('couldn\'t find %s value' % cls.union_type_field) @@ -522,9 +536,10 @@ class StructType: for field_ in dataclass_fields(cls): value = getattr(instance, field_.name) - if "format" in field_.metadata: - result[field_.name] = field_.metadata["format"].tojson(value) - elif issubclass(field_.type, StructType): + field_format = cls.get_field_format(field_.name) + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): + result[field_.name] = field_format.tojson(value) + else: if not isinstance(value, field_.type): raise ValueError('expected value of type %r for %s.%s, got %r' % (field_.type, cls.__name__, field_.name, value)) @@ -534,9 +549,6 @@ class StructType: result[k] = v else: result[field_.name] = value.tojson(value) - else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__class__.__name__, field_.name)) return result @classmethod @@ -554,16 +566,14 @@ class StructType: no_init_data = {} for field_ in dataclass_fields(cls): raw_value = data.get(field_.name, None) - if "format" in field_.metadata: - value = field_.metadata["format"].fromjson(raw_value) - elif issubclass(field_.type, StructType): + field_format = cls.get_field_format(field_.name) + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): + value = field_format.fromjson(raw_value) + else: if field_.metadata.get("json_embed"): value = field_.type.fromjson(data) else: value = field_.type.fromjson(raw_value) - else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__name__, field_.name)) if field_.init: kwargs[field_.name] = value else: @@ -597,20 +607,21 @@ class StructType: continue name = field_.metadata.get("c_name", field_.name) - if "format" in field_.metadata: + field_format = cls.get_field_format(field_.name) + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls: items.append(( ( ("%(typedef_name)s %(name)s;" % { - "typedef_name": field_.metadata["format"].get_typedef_name(), + "typedef_name": field_format.get_typedef_name(), "name": name, }) if field_.metadata.get("as_definition") - else field_.metadata["format"].get_c_code(name) + else field_format.get_c_code(name) ), field_.metadata.get("doc", None), )), - elif issubclass(field_.type, StructType): + else: if field_.metadata.get("c_embed"): embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only) items.extend(embedded_items) @@ -625,10 +636,7 @@ class StructType: else field_.type.get_c_code(name, typedef=False) ), field_.metadata.get("doc", None), - )), - else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__name__, field_.name)) + )) if cls.union_type_field: if not union_only: @@ -649,22 +657,20 @@ class StructType: def get_c_definitions(cls) -> dict[str, str]: definitions = {} for field_ in dataclass_fields(cls): - if "format" in field_.metadata: - definitions.update(field_.metadata["format"].get_c_definitions()) + field_format = cls.get_field_format(field_.name) + if not (isinstance(field_format, type) and issubclass(field_format, StructType)): + definitions.update(field_format.get_c_definitions()) if field_.metadata.get("as_definition"): - typedef_name = field_.metadata["format"].get_typedef_name() + typedef_name = field_format.get_typedef_name() definitions[typedef_name] = 'typedef %(code)s %(name)s;' % { - "code": ''.join(field_.metadata["format"].get_c_parts()), + "code": ''.join(field_format.get_c_parts()), "name": typedef_name, } - elif issubclass(field_.type, StructType): + else: definitions.update(field_.type.get_c_definitions()) if field_.metadata.get("as_definition"): typedef_name = field_.type.get_typedef_name() definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True) - else: - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__name__, field_.name)) if cls.union_type_field: for key, option in cls._union_options[cls.union_type_field].items(): definitions.update(option.get_c_definitions()) @@ -753,7 +759,7 @@ class StructType: @classmethod def get_min_size(cls, no_inherited_fields=False) -> int: if cls.union_type_field: - own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)]) + own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)]) union_size = max( [0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()]) return own_size + union_size @@ -761,13 +767,13 @@ class StructType: relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] else: relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] - return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0) + return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0) @classmethod def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int: if cls.union_type_field: own_size = sum( - [f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)]) + [cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)]) union_size = max( [0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in cls._union_options[cls.union_type_field].values()]) @@ -776,7 +782,7 @@ class StructType: relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] else: relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] - return sum((f.metadata.get("format", f.type).get_size(calculate_max=calculate_max) for f in relevant_fields), + return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields), start=0)