roost/birdwhisperer/birdwhisperer.go

244 lines
7.0 KiB
Go

package birdwhisperer
import (
"encoding/binary"
"encoding/hex"
"fmt"
"log/slog"
"net"
"time"
)
type Packet struct {
TargetAddr *net.UDPAddr
Buffer []byte
}
type ChunkHeaderV1 struct {
MessageID uint16
Length uint16
Version uint8
Index uint16
NumChunks uint16
HeaderLength uint8
}
func (h *ChunkHeaderV1) toBytes() []byte {
result := make([]byte, CHUNK_V1_HEADER_LENGTH)
binary.BigEndian.PutUint16(result[0:2], h.MessageID)
binary.BigEndian.PutUint16(result[2:4], h.Length)
result[CHUNK_HEADER_VERSION_OFFSET] = 1
binary.BigEndian.PutUint16(result[5:7], h.Index)
binary.BigEndian.PutUint16(result[7:9], h.NumChunks)
result[8] = CHUNK_V1_HEADER_LENGTH
return result
}
func (h *ChunkHeaderV1) toHexString() []byte {
return []byte(hex.EncodeToString(h.toBytes()))
}
type ReconstructedMessage struct {
Chunks []Packet
ReceivedTime int64
CompletedTime int64
NumChunks uint16
ReceivedChunks uint16
}
const CHUNK_HEADER_VERSION_OFFSET = 4
const CHUNK_V1_HEADER_LENGTH = 9
const RECONSTRUCTED_BUFFER_SIZE_ALERT = 1e8
// Based on an MTU of 1500 bytes, with some flexibility for overhead
const CHUNK_MAX_LENGTH = 1400
type BirdWhisperer struct {
ChunkRecord map[uint16]ReconstructedMessage
}
func NewBirdWhisperer() BirdWhisperer {
return BirdWhisperer{
ChunkRecord: make(map[uint16]ReconstructedMessage),
}
}
func ReconstructMessageFromChunks(chunkRecord ReconstructedMessage) ([]byte, error) {
reconstructedLength := 0
for i := range len(chunkRecord.Chunks) {
reconstructedLength += len(chunkRecord.Chunks[i].Buffer)
}
if reconstructedLength > RECONSTRUCTED_BUFFER_SIZE_ALERT {
// We have some huge thing...
return nil, fmt.Errorf("Attempted to reconstruct message bigger than allowed size: [%d] bytes", reconstructedLength)
}
reconstructedMessage := make([]byte, reconstructedLength)
offset := 0
for i := range len(chunkRecord.Chunks) {
chunkLength := len(chunkRecord.Chunks[i].Buffer)
copy(reconstructedMessage[offset:offset+chunkLength], chunkRecord.Chunks[i].Buffer[:])
offset += chunkLength
}
return reconstructedMessage, nil
}
func (bw *BirdWhisperer) PruneChunkRecord(currentTimestamp int64) {
const CHUNK_PRUNING_TIMEOUT_MS = 30_000
const CHUNK_RECEIVE_COMPLETE_TIMEOUT_MS = 60_000 * 3
for k, v := range bw.ChunkRecord {
if v.CompletedTime > 0 && currentTimestamp-v.CompletedTime > CHUNK_PRUNING_TIMEOUT_MS {
// Remove this from map - safe to do in loop
delete(bw.ChunkRecord, k)
} else if v.ReceivedTime > 0 && v.CompletedTime == 0 && currentTimestamp-v.ReceivedTime > CHUNK_RECEIVE_COMPLETE_TIMEOUT_MS {
// Remove messages that were never completed after enough time
delete(bw.ChunkRecord, k)
}
}
}
func (bw *BirdWhisperer) ReceiveChunk(chunk Packet, completedMessages chan []byte) error {
if len(chunk.Buffer) <= CHUNK_HEADER_VERSION_OFFSET {
return fmt.Errorf("Packet not long enough to be a chunk!")
}
switch chunk.Buffer[CHUNK_HEADER_VERSION_OFFSET] {
case 1:
header, err := ParseChunkHeaderV1(chunk)
if err != nil {
return err
}
_, exists := bw.ChunkRecord[header.MessageID]
if !exists {
bw.ChunkRecord[header.MessageID] = ReconstructedMessage{
make([]Packet, header.NumChunks),
time.Now().Unix(),
0,
header.NumChunks,
0,
}
}
if chunkRecord, ok := bw.ChunkRecord[header.MessageID]; ok {
if chunkRecord.CompletedTime > 0 {
// Already received all of the chunks for this message
return nil
}
if chunkRecord.Chunks[header.Index].TargetAddr == nil {
// New chunk for existing message
chunkRecord.Chunks[header.Index].Buffer = chunk.Buffer[header.Length : header.Length+uint16(header.HeaderLength)]
chunkRecord.ReceivedChunks += 1
// Update map
bw.ChunkRecord[header.MessageID] = chunkRecord
}
// If we have all chunks reconstruct the message and pass it along for processing
if chunkRecord.ReceivedChunks == chunkRecord.NumChunks {
// Have all chunks
msg, err := ReconstructMessageFromChunks(chunkRecord)
if err != nil {
return err
}
chunkRecord.CompletedTime = time.Now().Unix()
// Update in map
bw.ChunkRecord[header.MessageID] = chunkRecord
slog.Debug("Received all chunks of a message", "firstReceived", chunkRecord.ReceivedTime, "lastReceived", chunkRecord.CompletedTime)
// Pass on to other things
completedMessages <- msg
}
} else {
return fmt.Errorf("Critical: Chunk record does not exist for message ID")
}
if int(header.Length+uint16(header.HeaderLength)) < len(chunk.Buffer) {
// There is more data in our incoming chunk, perhaps another message appended
bw.ReceiveChunk(Packet{TargetAddr: chunk.TargetAddr, Buffer: chunk.Buffer[header.Length+uint16(header.HeaderLength):]}, completedMessages)
}
default:
return fmt.Errorf("Packet has unsupported chunk version")
}
return nil
}
func ParseChunkHeaderV1(chunk Packet) (header ChunkHeaderV1, err error) {
if len(chunk.Buffer) < CHUNK_V1_HEADER_LENGTH {
return header, fmt.Errorf("Malformed chunk")
}
// We chould check the version beforehand, but it doesn't cost
// much to just have it parsed and then check anyway and this
// makes the numeric ordering clear
header.MessageID = binary.BigEndian.Uint16(chunk.Buffer[0:2])
header.Length = binary.BigEndian.Uint16(chunk.Buffer[2:4])
header.Version = chunk.Buffer[CHUNK_HEADER_VERSION_OFFSET]
header.Index = binary.BigEndian.Uint16(chunk.Buffer[5:7])
header.NumChunks = binary.BigEndian.Uint16(chunk.Buffer[7:9])
header.HeaderLength = CHUNK_V1_HEADER_LENGTH
if header.Version != 1 {
return header, fmt.Errorf("Not a V1 header!")
}
return
}
/*
* message: Array of *chars*
*/
func ChunkMessage(message []byte, messageId uint16) [][]byte {
messageLen := len(message)
numChunks := uint16(messageLen / CHUNK_MAX_LENGTH)
leftoverBytes := uint16(messageLen % CHUNK_MAX_LENGTH)
if leftoverBytes > 0 {
numChunks += 1
}
// +1 due to flooring
messageChunks := make([][]byte, numChunks)
consumedBytes := 0
// Loop through all full-length chunks
for i := range numChunks - 1 {
// TODO: When remaining bytes < CHUNK_MAX_LENGTH, does it set length to max anyway? Other edges
thisHeader := ChunkHeaderV1{
MessageID: messageId,
Index: i + 1,
NumChunks: numChunks + 1,
Length: CHUNK_MAX_LENGTH + CHUNK_V1_HEADER_LENGTH,
Version: 1,
HeaderLength: CHUNK_V1_HEADER_LENGTH,
}
// Header + portion of data it represents
messageChunks[i] = append(thisHeader.toBytes(), message[(CHUNK_MAX_LENGTH*i):(CHUNK_MAX_LENGTH*(i+1))]...)
consumedBytes += CHUNK_MAX_LENGTH
}
// This handles numChunks == 0 (i.e. single message) or any remaining bytes
if leftoverBytes > 0 {
thisHeader := ChunkHeaderV1{
MessageID: messageId,
Index: numChunks + 1,
NumChunks: numChunks + 1,
Length: leftoverBytes + CHUNK_V1_HEADER_LENGTH,
Version: 1,
HeaderLength: CHUNK_V1_HEADER_LENGTH,
}
messageChunks[len(messageChunks)-1] = append(thisHeader.toBytes(), message[consumedBytes:]...)
}
return messageChunks
}