From d6b86425eea2c9da4b04965e934553783bc91d20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sun, 26 Nov 2023 17:39:10 +0100 Subject: [PATCH] lots of adjustments in consumer code, more state handling and stuff --- src/c3nav/mesh/consumers.py | 208 ++++++++++++++++++++++-------------- 1 file changed, 127 insertions(+), 81 deletions(-) diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 9464074b..c6e86ea5 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -27,29 +27,31 @@ class Unknown: @unique -class OTAWaitingFor(IntEnum): +class NodeWaitingFor(IntEnum): NOTHING = auto() - START_OR_CANCEL_CONFIRM = auto() + CONFIG = auto() + OTA_CHECK = auto() @dataclass -class OTADeviceState: - waiting_for: OTAWaitingFor = OTAWaitingFor.NOTHING - reported_ota: Optional[int] = None # None = unknown, 0 = no update +class NodeState: + waiting_for: NodeWaitingFor = NodeWaitingFor.CONFIG + attempt: int = 0 + last_msg: dict[MeshMessageType: MeshMessage] = field(default_factory=dict) last_sent: datetime = field(default_factory=timezone.now) - recipient: Optional[OTAUpdateRecipient] = None + reported_ota_update: Optional[int] = None # None = unknown, 0 = no update + ota_recipient: Optional[OTAUpdateRecipient] = None class MeshConsumer(AsyncWebsocketConsumer): def __init__(self): super().__init__() self.uplink = None - self.dst_nodes = set() + self.dst_nodes: dict[str, NodeState] = {} # keys are addresses self.open_requests = set() self.ping_task = None - self.check_ota_states_task = None + self.check_node_state_task = None self.ota_send_task = None - self.ota_states: dict[str, OTADeviceState] = {} # keys are addresses self.ota_chunks: dict[int, set[int]] = {} # keys are update IDs, values are a list of chunk IDs self.ota_chunks_available_condition = asyncio.Condition() @@ -59,12 +61,12 @@ class MeshConsumer(AsyncWebsocketConsumer): # await self.log_text(None, "new mesh websocket connection") await self.accept() self.ping_task = get_event_loop().create_task(self.ping_regularly()) - self.check_ota_states_task = get_event_loop().create_task(self.check_node_ota_states()) + self.check_node_state_task = get_event_loop().create_task(self.check_node_states()) self.ota_send_task = get_event_loop().create_task(self.ota_send()) async def disconnect(self, close_code): self.ping_task.cancel() - self.check_ota_states_task.cancel() + self.check_node_state_task.cancel() self.ota_send_task.cancel() await self.log_text(self.uplink.node, "mesh websocket disconnected") if self.uplink is not None: @@ -171,6 +173,7 @@ class MeshConsumer(AsyncWebsocketConsumer): # add this node as a destination that this uplink handles (duh) await self.add_dst_nodes(nodes=(src_node, )) + self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg return @@ -181,6 +184,9 @@ class MeshConsumer(AsyncWebsocketConsumer): await self.log_received_message(src_node, msg) + node_status = self.dst_nodes[msg.src] + node_status.last_msg[msg.msg_type] = msg + if isinstance(msg, messages.MeshAddDestinationsMessage): await self.add_dst_nodes(addresses=msg.addresses) @@ -207,22 +213,27 @@ class MeshConsumer(AsyncWebsocketConsumer): route=uplink.node_id if uplink else MESH_NONE_ADDRESS, )) + if isinstance(msg, (messages.ConfigHardwareMessage, + messages.ConfigFirmwareMessage, + messages.ConfigBoardMessage)): + if (node_status.waiting_for == NodeWaitingFor.CONFIG and + not {MeshMessageType.CONFIG_HARDWARE, + MeshMessageType.CONFIG_FIRMWARE, + MeshMessageType.CONFIG_BOARD} - set(node_status.last_msg.keys())): + print('got all config, checking ota') + await self.check_ota([msg.src], first_time=True) + if isinstance(msg, messages.OTAStatusMessage): print('got OTA status', msg) - try: - ota_status = self.ota_states[msg.src] - except KeyError: - print('ota status from node where we didn\'t expect it') - await self.check_ota(msg.src) - else: - if ota_status.waiting_for == OTAWaitingFor.START_OR_CANCEL_CONFIRM: - update_id = ota_status.recipient.update_id if ota_status.recipient else 0 - if update_id == msg.update_id: - print('start/cancel confirmed!') - ota_status.waiting_for = OTAWaitingFor.NOTHING - if update_id: - print('queue chunk sending') - await self.ota_set_chunks(ota_status.recipient.update) + node_status.reported_ota_update = msg.update_id + if node_status.waiting_for == NodeWaitingFor.OTA_CHECK: + update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0 + if update_id == msg.update_id: + print('start/cancel confirmed!') + node_status.waiting_for = NodeWaitingFor.NOTHING # todo: probably good to continue from here + if update_id: + print('queue chunk sending') + await self.ota_set_chunks(node_status.ota_recipient.update, min_chunk=msg.highest_chunk+1) @database_sync_to_async def create_uplink_in_database(self, address): @@ -273,7 +284,7 @@ class MeshConsumer(AsyncWebsocketConsumer): if data["uplink"] != self.channel_name: await self.log_text(data["node"], "node now served by new consumer") # going the short way cause the other consumer will already have done database stuff - self.dst_nodes.discard(data["node"]) + self.dst_nodes.pop(data["node"], None) async def mesh_send(self, data): if self.uplink.node.address == data["exclude_uplink_address"]: @@ -320,66 +331,109 @@ class MeshConsumer(AsyncWebsocketConsumer): }) print("MESH %s: [%s] %s" % (self.uplink.node, address, text)) - async def check_ota(self, addresses): + """ connection state machine """ + + async def check_ota(self, addresses, first_time: bool = False): + """ + this method will check the latest OTA for these nodes in the database + + it will ignore nodes that are still waiting for their config, unless first_time is set + """ recipients = await self.get_nodes_with_ota(addresses) for address, recipient in recipients.items(): - ota_state = self.ota_states.setdefault(address, OTADeviceState()) - update_id = recipient.update_id if recipient else 0 - if update_id != ota_state.reported_ota: - ota_state.waiting_for = OTAWaitingFor.START_OR_CANCEL_CONFIRM - ota_state.recipient = recipient - await self.ota_resend_ask(address) + node_state = self.dst_nodes[address] + if not first_time and node_state.waiting_for == NodeWaitingFor.CONFIG: + # too soon + continue + + # check if the installed firmware is the one we want to install + if recipient: + target_app_desc = recipient.update.build.firmware_image.app_desc + fw_msg = node_state.last_msg.get(MeshMessageType.CONFIG_FIRMWARE, None) + current_app_desc = fw_msg.app_desc if fw_msg else None + print('target app desc:', target_app_desc) + print('current app desc:', current_app_desc) + if target_app_desc == current_app_desc: + print('the node already has this firmware, awesome') + # todo: do something with this information + recipient = False + else: + print('app desc does not match') + + desired_update_id = recipient.update_id if recipient else 0 + if desired_update_id != node_state.reported_ota_update: + print('changing OTA state on node') + node_state.waiting_for = NodeWaitingFor.OTA_CHECK + node_state.attempt = 0 + node_state.ota_recipient = recipient + await self.node_resend_ask(address) + else: + node_state.ota_recipient = None + @database_sync_to_async - def get_nodes_with_ota(self, addresses) -> dict: + def get_nodes_with_ota(self, addresses) -> dict[str, Optional[OTAUpdateRecipient]]: return {node.address: node.current_ota for node in MeshNode.objects.prefetch_ota().filter(address__in=addresses)} - async def ota_resend_ask(self, address): - ota_state = self.ota_states[address] - if ota_state.waiting_for == OTAWaitingFor.START_OR_CANCEL_CONFIRM: - ota_state.last_sent = timezone.now() - if ota_state.recipient: - print('starting ota') + async def node_resend_ask(self, address): + node_state = self.dst_nodes[address] - await self.send_msg(messages.OTAStartMessage( + match(node_state.waiting_for): + case NodeWaitingFor.NOTHING: + return + + case NodeWaitingFor.CONFIG: + node_state.last_sent = timezone.now() + print('request config dump, attempt #%d' % node_state.attempt) + node_state.attempt += 1 + await self.send_msg(messages.ConfigDumpMessage( src=MESH_ROOT_ADDRESS, dst=address, - update_id=ota_state.recipient.update_id, # noqa - total_bytes=ota_state.recipient.update.build.binary.size, - auto_apply=False, - auto_reboot=False, - )) - else: - print('canceling ota') - await self.send_msg(messages.OTAAbortMessage( - src=MESH_ROOT_ADDRESS, - dst=address, - update_id=0, )) - async def check_node_ota_states(self): + case NodeWaitingFor.OTA_CHECK: + node_state.last_sent = timezone.now() + if node_state.ota_recipient: + print('starting ota, attempt #%d' % node_state.attempt) + await self.send_msg(messages.OTAStartMessage( + src=MESH_ROOT_ADDRESS, + dst=address, + update_id=node_state.ota_recipient.update_id, # noqa + total_bytes=node_state.ota_recipient.update.build.binary.size, + auto_apply=False, + auto_reboot=False, + )) + else: + print('canceling ota, attempt #%d' % node_state.attempt) + await self.send_msg(messages.OTAAbortMessage( + src=MESH_ROOT_ADDRESS, + dst=address, + update_id=0, + )) + + async def check_node_states(self): while True: - for address in tuple(self.ota_states.keys()): + for address in tuple(self.dst_nodes.keys()): try: if address not in self.dst_nodes: - self.ota_states.pop(address, None) + self.dst_nodes.pop(address, None) continue - ota_state = self.ota_states.get(address, None) - if ota_state: - if (ota_state.waiting_for != OTAWaitingFor.NOTHING and - ota_state.last_sent+timedelta(seconds=10) < timezone.now()): - await self.ota_resend_ask(address) + node_state = self.dst_nodes.get(address, None) + if node_state: + if (node_state.waiting_for != NodeWaitingFor.NOTHING and + node_state.last_sent+timedelta(seconds=10) < timezone.now()): + await self.node_resend_ask(address) except Exception: # noqa - print('failure in check_node_ota_states') + print('failure in check_node_states') traceback.print_exc() await asyncio.sleep(1) - async def ota_set_chunks(self, update: OTAUpdate, chunks: Optional[set[int]] = None): + async def ota_set_chunks(self, update: OTAUpdate, chunks: Optional[set[int]] = None, min_chunk: int=0): async with self.ota_chunks_available_condition: num_chunks = (update.build.binary.size-1)//OTA_CHUNK_SIZE+1 print('queueing chunks for update', update.id, 'num_chunks=%d' % num_chunks, "chunks:", chunks) - chunks = set(range(num_chunks)) if chunks is None else {chunk for chunk in chunks if chunk < num_chunks} + chunks = set(range(min_chunk, num_chunks*0+10)) if chunks is None else {chunk for chunk in chunks if chunk < num_chunks} self.ota_chunks.setdefault(update.id, set()).update(chunks) self.ota_chunks_available_condition.notify_all() @@ -395,8 +449,8 @@ class MeshConsumer(AsyncWebsocketConsumer): continue # find recipients, so we know if broadcast or not - recipients = [address for address, state in self.ota_states.items() - if state.recipient and state.recipient.update_id == update_id] + recipients = [address for address, state in self.dst_nodes.items() + if state.ota_recipient and state.ota_recipient.update_id == update_id] if not recipients: # no recipients? then lets stop print('no more recipients for', update_id, 'stopping sending…') @@ -404,7 +458,7 @@ class MeshConsumer(AsyncWebsocketConsumer): continue # send the message - with self.ota_states[recipients[0]].recipient.update.build.binary.open('rb') as f: + with self.dst_nodes[recipients[0]].ota_recipient.update.build.binary.open('rb') as f: f.seek(chunk * OTA_CHUNK_SIZE) data = f.read(OTA_CHUNK_SIZE) await self.send_msg(messages.OTAFragmentMessage( @@ -422,6 +476,8 @@ class MeshConsumer(AsyncWebsocketConsumer): if not self.ota_chunks: await self.ota_chunks_available_condition.wait() + """ routing """ + async def add_dst_nodes(self, nodes=None, addresses=None): nodes = list(nodes) if nodes else [] addresses = set(addresses) if addresses else set() @@ -444,16 +500,11 @@ class MeshConsumer(AsyncWebsocketConsumer): # add ourselves as uplink await self._add_destination(address) - # tell the node to dump its current information - await self.send_msg( - messages.ConfigDumpMessage( - src=messages.MESH_ROOT_ADDRESS, - dst=address, - ) - ) + # if we aren't handling this address yet, write it down + if address not in self.dst_nodes: + self.dst_nodes[address] = NodeState() - self.ota_states.pop(address, None) - await self.check_ota([address]) + await self.node_resend_ask(address) @database_sync_to_async def _add_destination(self, address): @@ -471,15 +522,10 @@ class MeshConsumer(AsyncWebsocketConsumer): "uplink": self.channel_name }) - # if we aren't handling this address yet, write it down - if address not in self.dst_nodes: - self.dst_nodes.add(address) async def remove_dst_nodes(self, addresses): for address in tuple(addresses): await self.log_text(address, "destination removed") - - self.ota_states.pop(address, None) await self._remove_destination(address) @database_sync_to_async @@ -495,7 +541,7 @@ class MeshConsumer(AsyncWebsocketConsumer): # no longer serving this node if address in self.dst_nodes: - self.dst_nodes.discard(address) + self.dst_nodes.pop(address, None) class MeshUIConsumer(AsyncJsonWebsocketConsumer):