team-3/src/c3nav/mesh/models.py

439 lines
17 KiB
Python
Raw Normal View History

from collections import UserDict, namedtuple
from contextlib import suppress
from dataclasses import dataclass
2023-11-07 16:35:46 +01:00
from datetime import datetime, timedelta
from functools import cached_property
2023-10-04 23:24:49 +02:00
from operator import attrgetter
from typing import Any, Mapping, Optional, Self
from django.contrib.auth import get_user_model
2023-11-10 19:00:09 +01:00
from django.core.validators import RegexValidator
2023-10-06 02:46:43 +02:00
from django.db import NotSupportedError, models
from django.db.models import Q, UniqueConstraint
2023-11-07 16:35:46 +01:00
from django.utils import timezone
from django.utils.text import slugify
2022-04-15 20:02:42 +02:00
from django.utils.translation import gettext_lazy as _
from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage
2023-10-06 02:46:43 +02:00
from c3nav.mesh.messages import MeshMessage as MeshMessage
from c3nav.mesh.messages import MeshMessageType
2023-11-07 16:35:46 +01:00
from c3nav.mesh.utils import UPLINK_TIMEOUT
2023-11-10 20:11:50 +01:00
from c3nav.routing.rangelocator import RangeLocator
FirmwareLookup = namedtuple('FirmwareLookup', ('sha256_hash', 'chip', 'project_name', 'version', 'idf_version'))
@dataclass
class FirmwareDescription:
chip: ChipType
project_name: str
version: str
idf_version: str
sha256_hash: str
build: Optional["FirmwareBuild"] = None
created: datetime | None = None
def get_lookup(self) -> FirmwareLookup:
return FirmwareLookup(
chip=self.chip,
project_name=self.project_name,
version=self.version,
idf_version=self.idf_version,
sha256_hash=self.sha256_hash,
)
2023-11-10 16:08:55 +01:00
@dataclass(frozen=True)
class HardwareDescription:
chip: ChipType
board: BoardType
class MeshNodeQuerySet(models.QuerySet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._prefetch_last_messages = set()
self._prefetch_last_messages_done = False
self._prefetch_firmwares = False
2023-11-10 16:08:55 +01:00
self._prefetch_ota = False
self._prefetch_ota_done = False
def _clone(self):
clone = super()._clone()
clone._prefetch_last_messages = self._prefetch_last_messages
clone._prefetch_firmwares = self._prefetch_firmwares
2023-11-10 16:08:55 +01:00
clone._prefetch_ota = self._prefetch_ota
return clone
2023-10-04 15:44:54 +02:00
def prefetch_last_messages(self, *types: MeshMessageType):
clone = self._chain()
clone._prefetch_last_messages |= (
2023-10-20 15:23:45 +02:00
set(types) if types else set(msgtype for msgtype in MeshMessageType)
)
return clone
def prefetch_firmwares(self, *types: MeshMessageType):
clone = self.prefetch_last_messages(MeshMessageType.CONFIG_FIRMWARE,
MeshMessageType.CONFIG_HARDWARE)
clone._prefetch_firmwares = True
return clone
2023-11-10 16:08:55 +01:00
def prefetch_ota(self):
clone = self._chain()
clone._prefetch_pta = True
return clone
def _fetch_all(self):
super()._fetch_all()
2023-11-10 16:08:55 +01:00
nodes = None
if self._prefetch_last_messages and not self._prefetch_last_messages_done:
nodes: dict[str, MeshNode] = {node.pk: node for node in self._result_cache}
try:
2023-10-20 14:43:00 +02:00
for message in NodeMessage.objects.order_by('message_type', 'src_node', '-datetime', '-pk').filter(
2023-10-20 15:23:45 +02:00
message_type__in=(t.name for t in self._prefetch_last_messages),
2023-10-03 17:51:49 +02:00
src_node__in=nodes.keys(),
).prefetch_related("uplink").distinct('message_type', 'src_node'):
2023-10-20 14:43:00 +02:00
nodes[message.src_node_id].last_messages[message.message_type] = message
2023-10-04 23:24:49 +02:00
for node in nodes.values():
node.last_messages["any"] = max(node.last_messages.values(), key=attrgetter("datetime"))
self._prefetch_last_messages_done = True
except NotSupportedError:
pass
2022-04-15 20:02:42 +02:00
if self._prefetch_firmwares:
# fetch matching firmware builds
firmwares = {
fw_desc.get_lookup(): fw_desc for fw_desc in
(build.firmware_description for build in FirmwareBuild.objects.filter(
sha256_hash__in=set(
node.last_messages[MeshMessageType.CONFIG_FIRMWARE].parsed.app_desc.app_elf_sha256
for node in self._result_cache
)
))
}
# assign firmware descriptions
for node in nodes.values():
firmware_desc = node.firmware_description
node._firmware_description = firmwares.get(firmware_desc.get_lookup(), firmware_desc)
# get date of first appearance
nodes_to_complete = tuple(
node for node in nodes.values()
if node._firmware_description.build is None
)
try:
created_lookup = {
msg.parsed.app_desc.app_elf_sha256: msg.datetime
for msg in NodeMessage.objects.filter(
message_type=MeshMessageType.CONFIG_FIRMWARE.name,
data__app_elf_sha256__in=(node._firmware_description.sha256_hash
for node in nodes_to_complete)
).order_by('data__app_elf_sha256', 'datetime').distinct('data__app_elf_sha256')
}
print(created_lookup)
except NotSupportedError:
created_lookup = {
app_elf_sha256: NodeMessage.objects.filter(
message_type=MeshMessageType.CONFIG_FIRMWARE.name,
data__app_elf_sha256=app_elf_sha256
).order_by('datetime').first()
for app_elf_sha256 in {node.f_firmware_description.sha256_hash for node in nodes_to_complete}
}
for node in nodes_to_complete:
node._firmware_description.created = created_lookup[node._firmware_description.sha256_hash]
2023-11-10 16:08:55 +01:00
if self._prefetch_ota and not self._prefetch_ota_done:
if nodes is None:
nodes: dict[str, MeshNode] = {node.pk: node for node in self._result_cache}
try:
for ota in OTAUpdateRecipient.objects.order_by('node', '-update__created').filter(
src_node__in=nodes.keys(),
).select_related("update", "update__build").distinct('node'):
# noinspection PyUnresolvedReferences
nodes[ota.node_id]._current_ota = ota
for node in nodes.values():
if not hasattr(node, "_current_ota"):
node._current_ota = None
self._prefetch_ota_done = True
except NotSupportedError:
pass
2023-10-06 02:46:43 +02:00
class LastMessagesByTypeLookup(UserDict):
def __init__(self, node):
super().__init__()
self.node = node
def _get_key(self, item):
2023-10-04 15:44:54 +02:00
if isinstance(item, MeshMessageType):
return item
if isinstance(item, str):
try:
2023-10-04 15:44:54 +02:00
return getattr(MeshMessageType, item)
except AttributeError:
pass
2023-10-04 15:44:54 +02:00
return MeshMessageType(item)
def __getitem__(self, key):
2023-10-04 23:24:49 +02:00
if key == "any":
msg = self.node.received_messages.order_by('-datetime', '-pk').first()
self.data["any"] = msg
return msg
key = self._get_key(key)
try:
return self.data[key]
except KeyError:
pass
msg = self.node.received_messages.filter(message_type=key.name).order_by('-datetime', '-pk').first()
self.data[key] = msg
return msg
def __setitem__(self, key, item):
2023-10-20 14:43:00 +02:00
if key == "any":
self.data["any"] = item
return
self.data[self._get_key(key)] = item
2022-04-15 20:02:42 +02:00
class MeshNode(models.Model):
"""
A nesh node. Any node.
"""
2023-11-10 19:00:09 +01:00
address = models.CharField(_('mac address'), max_length=17, primary_key=True,
validators=[RegexValidator(
regex='^([a-f0-9]{2}:){5}[a-f0-9]{2}$',
message='Must be a lower-case mac address',
code='invalid_macaddress'
)])
2023-10-02 22:02:25 +02:00
name = models.CharField(_('name'), max_length=32, null=True, blank=True)
2022-04-15 20:02:42 +02:00
first_seen = models.DateTimeField(_('first seen'), auto_now_add=True)
uplink = models.ForeignKey('MeshUplink', models.PROTECT, null=True,
2023-10-03 17:51:49 +02:00
related_name='routed_nodes', verbose_name=_('uplink'))
last_signin = models.DateTimeField(_('last signin'), null=True)
objects = models.Manager.from_queryset(MeshNodeQuerySet)()
2022-04-15 20:02:42 +02:00
def __str__(self):
2023-10-02 22:02:25 +02:00
if self.name:
return '%s (%s)' % (self.address, self.name)
2022-04-15 20:02:42 +02:00
return self.address
@cached_property
def last_messages(self) -> Mapping[Any, "NodeMessage"]:
return LastMessagesByTypeLookup(self)
2023-11-10 16:08:55 +01:00
@cached_property
def current_ota(self) -> Optional["OTAUpdateRecipient"]:
try:
# noinspection PyUnresolvedReferences
return self._current_ota
except AttributeError:
return self.ota_updates.order_by('-update__created').first()
# noinspection PyUnresolvedReferences
@cached_property
def firmware_description(self) -> FirmwareDescription:
with suppress(AttributeError):
return self._firmware_description
2023-11-10 16:08:55 +01:00
# noinspection PyTypeChecker
firmware_msg: ConfigFirmwareMessage = self.last_messages[MeshMessageType.CONFIG_FIRMWARE].parsed
2023-11-10 16:08:55 +01:00
# noinspection PyTypeChecker
hardware_msg: ConfigHardwareMessage = self.last_messages[MeshMessageType.CONFIG_HARDWARE].parsed
return FirmwareDescription(
chip=hardware_msg.chip,
project_name=firmware_msg.app_desc.project_name,
version=firmware_msg.app_desc.version,
idf_version=firmware_msg.app_desc.idf_version,
sha256_hash=firmware_msg.app_desc.app_elf_sha256,
)
@cached_property
def hardware_description(self) -> HardwareDescription:
2023-11-10 16:08:55 +01:00
# noinspection PyUnresolvedReferences
return HardwareDescription(
chip=self.last_messages[MeshMessageType.CONFIG_HARDWARE].parsed.chip,
board=self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board,
)
# overriden by prefetch_firmwares()
firmware_desc = None
2023-11-06 19:22:23 +01:00
@cached_property
def chip(self) -> ChipType:
return self.last_messages[MeshMessageType.CONFIG_HARDWARE].parsed.chip
@cached_property
def board(self) -> ChipType:
2023-11-10 16:08:55 +01:00
# noinspection PyUnresolvedReferences
2023-11-06 19:22:23 +01:00
return self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board
2023-11-07 16:35:46 +01:00
def get_uplink(self) -> Optional["MeshUplink"]:
if self.uplink_id is None:
return None
if self.uplink.last_ping + timedelta(seconds=UPLINK_TIMEOUT) < timezone.now():
return None
return self.uplink
@classmethod
def get_node_and_uplink(self, address) -> Optional["MeshUplink"]:
try:
dst_node = MeshNode.objects.select_related('uplink').get(address=address)
except MeshNode.DoesNotExist:
return False
return dst_node.get_uplink()
2023-11-10 20:11:50 +01:00
def get_locator_beacon(self):
locator = RangeLocator.load()
return locator.beacons.get(self.address, None)
2022-04-15 20:02:42 +02:00
class MeshUplink(models.Model):
"""
An uplink session, a direct connection to a node
"""
class EndReason(models.TextChoices):
CLOSED = "closed", _("closed")
REPLACED = "replaced", _("replaced")
NEW_TIMEOUT = "new_timeout", _("new (timeout)")
name = models.CharField(_('channel name'), max_length=128)
started = models.DateTimeField(_('started'), auto_now_add=True)
2023-11-10 16:08:55 +01:00
node = models.ForeignKey(MeshNode, models.PROTECT, related_name='uplink_sessions',
verbose_name=_('node'))
last_ping = models.DateTimeField(_('last ping from consumer'))
end_reason = models.CharField(_('end reason'), choices=EndReason.choices, null=True, max_length=16)
class Meta:
constraints = (
UniqueConstraint(fields=["node"], condition=Q(end_reason__isnull=True), name='only_one_active_uplink'),
)
2022-04-15 20:02:42 +02:00
class NodeMessage(models.Model):
2023-10-20 15:23:45 +02:00
MESSAGE_TYPES = [(msgtype.name, msgtype.pretty_name) for msgtype in MeshMessageType]
2023-11-10 16:08:55 +01:00
src_node = models.ForeignKey(MeshNode, models.PROTECT, related_name='received_messages',
verbose_name=_('node'))
uplink = models.ForeignKey(MeshUplink, models.PROTECT, related_name='relayed_messages',
verbose_name=_('uplink'))
2022-04-15 20:02:42 +02:00
datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True)
2023-10-20 15:23:45 +02:00
message_type = models.CharField(_('message type'), max_length=24, db_index=True, choices=MESSAGE_TYPES)
2022-04-15 20:02:42 +02:00
data = models.JSONField(_('message data'))
def __str__(self):
return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime)
@cached_property
def parsed(self) -> Self:
return MeshMessage.fromjson(self.data)
2022-04-15 20:02:42 +02:00
class FirmwareVersion(models.Model):
2022-04-15 20:02:42 +02:00
project_name = models.CharField(_('project name'), max_length=32)
version = models.CharField(_('firmware version'), max_length=32, unique=True)
2022-04-15 20:02:42 +02:00
idf_version = models.CharField(_('IDF version'), max_length=32)
uploader = models.ForeignKey(get_user_model(), null=True, on_delete=models.SET_NULL)
created = models.DateTimeField(_('creation/upload date'), auto_now_add=True)
def serialize(self):
return {
'project_name': self.project_name,
'version': self.version,
'idf_version': self.idf_version,
'created': self.created.isoformat(),
'builds': {
build.variant: build.serialize()
for build in self.builds.all().prefetch_related("firmwarebuildboard_set")
}
}
def firmware_upload_path(instance, filename):
# file will be uploaded to MEDIA_ROOT/user_<id>/<filename>
version = slugify(instance.version.version)
variant = slugify(instance.variant)
return f"firmware/{version}/{variant}/{filename}"
class FirmwareBuild(models.Model):
CHIPS = [(chiptype.value, chiptype.pretty_name) for chiptype in ChipType]
version = models.ForeignKey(FirmwareVersion, related_name='builds', on_delete=models.CASCADE)
variant = models.CharField(_('variant name'), max_length=64)
chip = models.SmallIntegerField(_('chip'), db_index=True, choices=CHIPS)
2022-04-15 20:02:42 +02:00
sha256_hash = models.CharField(_('SHA256 hash'), unique=True, max_length=64)
project_description = models.JSONField(verbose_name=_('project_description.json'))
binary = models.FileField(_('firmware file'), null=True, upload_to=firmware_upload_path)
class Meta:
unique_together = [
('version', 'variant'),
]
@property
def boards(self):
2023-11-06 19:22:23 +01:00
return {BoardType[board.board] for board in self.firmwarebuildboard_set.all()}
2023-11-10 16:08:55 +01:00
@property
def chip_type(self) -> ChipType:
return ChipType(self.chip)
def serialize(self):
return {
'chip': ChipType(self.chip).name,
'sha256_hash': self.sha256_hash,
'url': self.binary.url,
'boards': [board.name for board in self.boards],
}
@cached_property
def firmware_description(self) -> FirmwareDescription:
return FirmwareDescription(
2023-11-10 16:08:55 +01:00
chip=self.chip_type,
project_name=self.version.project_name,
version=self.version.version,
idf_version=self.version.idf_version,
sha256_hash=self.sha256_hash,
created=self.version.created,
build=self,
)
@cached_property
def hardware_descriptions(self) -> list[HardwareDescription]:
2023-11-10 16:08:55 +01:00
return [
HardwareDescription(
chip=self.chip_type,
board=board,
)
for board in self.boards
]
@cached_property
def firmware_image(self) -> FirmwareImage:
firmware_image, remaining = FirmwareImage.decode(self.binary.open('rb').read()[:FirmwareImage.get_min_size()])
return firmware_image
class FirmwareBuildBoard(models.Model):
BOARDS = [(boardtype.name, boardtype.pretty_name) for boardtype in BoardType]
build = models.ForeignKey(FirmwareBuild, on_delete=models.CASCADE)
board = models.CharField(_('board'), max_length=32, db_index=True, choices=BOARDS)
2022-04-15 20:02:42 +02:00
class Meta:
unique_together = [
('build', 'board'),
2022-04-15 20:02:42 +02:00
]
2023-11-10 16:08:55 +01:00
class OTAUpdate(models.Model):
build = models.ForeignKey(FirmwareBuild, on_delete=models.CASCADE)
created = models.DateTimeField(_('creation'), auto_now_add=True)
class OTAUpdateRecipient(models.Model):
update = models.ForeignKey(OTAUpdate, on_delete=models.CASCADE, related_name='recipients')
node = models.ForeignKey(MeshNode, models.PROTECT, related_name='ota_updates',
verbose_name=_('node'))