Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions pysquared/hardware/radio/packetizer/packet_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

from ....logger import Logger
from ....nvm.counter import Counter
from ....protos.radio import RadioProto

try:
Expand All @@ -17,17 +18,18 @@ def __init__(
logger: Logger,
radio: RadioProto,
license: str,
message_counter: Counter,
send_delay: float = 0.2,
) -> None:
"""Initialize the packet manager with maximum packet size"""
self._logger: Logger = logger
self._radio: RadioProto = radio
self._send_delay: float = send_delay
self._license: str = license
self._header_size: int = (
5 # 2 bytes for sequence number, 2 for total packets, 1 for rssi
)
# 1 byte for packet identifier, 2 bytes for sequence number, 2 for total packets, 1 for rssi
self._header_size: int = 6
self._payload_size: int = radio.get_max_packet_size() - self._header_size
self._message_counter: Counter = message_counter

def send(self, data: bytes) -> bool:
"""Send data"""
Expand Down Expand Up @@ -55,8 +57,10 @@ def _pack_data(self, data: bytes) -> list[bytes]:
"""
Takes input data and returns a list of packets ready for transmission
Each packet includes:
- 1 byte: packet identifier
- 2 bytes: sequence number (0-based)
- 2 bytes: total number of packets
- 1 byte: rssi
- remaining bytes: payload
"""
# Calculate number of packets needed
Expand All @@ -67,11 +71,14 @@ def _pack_data(self, data: bytes) -> list[bytes]:
data_length=len(data),
)

packet_identifier: int = self._get_packet_identifier()

packets: list[bytes] = []
for sequence_number in range(total_packets):
# Create header
header: bytes = (
sequence_number.to_bytes(2, "big")
packet_identifier.to_bytes(1, "big")
+ sequence_number.to_bytes(2, "big")
+ total_packets.to_bytes(2, "big")
+ abs(self._radio.get_rssi()).to_bytes(1, "big")
)
Expand Down Expand Up @@ -117,6 +124,8 @@ def listen(self, timeout: Optional[int] = None) -> bytes | None:
if packet is None:
continue

packet_identifier, _, total_packets, _ = self._get_header(packet)

# Log received packets
self._logger.debug(
"Received packet",
Expand All @@ -125,11 +134,19 @@ def listen(self, timeout: Optional[int] = None) -> bytes | None:
payload=self._get_payload(packet),
)

# Process received packet
if received_packets:
(
first_packet_identifier,
_,
_,
_,
) = self._get_header(received_packets[0])
if packet_identifier != first_packet_identifier:
continue

received_packets.append(packet)

# Check if we have all packets
_, total_packets, _ = self._get_header(packet)
if total_packets == len(received_packets):
self._logger.debug(
"Received all expected packets", received=total_packets
Expand All @@ -150,19 +167,25 @@ def _unpack_data(self, packets: list[bytes]) -> bytes:
Returns None if packets are missing or corrupted
"""
sorted_packets: list = sorted(
packets, key=lambda p: int.from_bytes(p[:2], "big")
packets, key=lambda p: int.from_bytes(p[1:3], "big")
)

return b"".join(self._get_payload(packet) for packet in sorted_packets)

def _get_header(self, packet: bytes) -> tuple[int, int, int]:
def _get_header(self, packet: bytes) -> tuple[int, int, int, int]:
"""Returns the sequence number and total packets stored in the header."""
return (
int.from_bytes(packet[0:2], "big"), # sequence number
int.from_bytes(packet[2:4], "big"), # total packets
-int.from_bytes(packet[4:5], "big"), # RSSI
int.from_bytes(packet[0:1], "big"), # packet identifier
int.from_bytes(packet[1:3], "big"), # sequence number
int.from_bytes(packet[3:5], "big"), # total packets
-int.from_bytes(packet[5:6], "big"), # RSSI
)

def _get_payload(self, packet: bytes) -> bytes:
"""Returns the payload of the packet."""
return packet[self._header_size :]

def _get_packet_identifier(self) -> int:
"""Increments message_counter and returns the current identifier for a packet"""
self._message_counter.increment()
return self._message_counter.get()
Loading
Loading