package birdwhisperer import ( "encoding/binary" "encoding/hex" "fmt" "log/slog" "net" "time" ) type Packet struct { TargetAddr *net.UDPAddr Buffer []byte } type ChunkHeaderV1 struct { MessageID uint16 Length uint16 Version uint8 Index uint16 NumChunks uint16 HeaderLength uint8 } func (h *ChunkHeaderV1) toBytes() []byte { result := make([]byte, CHUNK_V1_HEADER_LENGTH) binary.BigEndian.PutUint16(result[0:2], h.MessageID) binary.BigEndian.PutUint16(result[2:4], h.Length) result[CHUNK_HEADER_VERSION_OFFSET] = 1 binary.BigEndian.PutUint16(result[5:7], h.Index) binary.BigEndian.PutUint16(result[7:9], h.NumChunks) result[8] = CHUNK_V1_HEADER_LENGTH return result } func (h *ChunkHeaderV1) toHexString() []byte { return []byte(hex.EncodeToString(h.toBytes())) } type ReconstructedMessage struct { Chunks []Packet ReceivedTime int64 CompletedTime int64 NumChunks uint16 ReceivedChunks uint16 } const CHUNK_HEADER_VERSION_OFFSET = 4 const CHUNK_V1_HEADER_LENGTH = 9 const RECONSTRUCTED_BUFFER_SIZE_ALERT = 1e8 // Based on an MTU of 1500 bytes, with some flexibility for overhead const CHUNK_MAX_LENGTH = 1400 type BirdWhisperer struct { ChunkRecord map[uint16]ReconstructedMessage } func NewBirdWhisperer() BirdWhisperer { return BirdWhisperer{ ChunkRecord: make(map[uint16]ReconstructedMessage), } } func ReconstructMessageFromChunks(chunkRecord ReconstructedMessage) ([]byte, error) { reconstructedLength := 0 for i := range len(chunkRecord.Chunks) { reconstructedLength += len(chunkRecord.Chunks[i].Buffer) } if reconstructedLength > RECONSTRUCTED_BUFFER_SIZE_ALERT { // We have some huge thing... return nil, fmt.Errorf("Attempted to reconstruct message bigger than allowed size: [%d] bytes", reconstructedLength) } reconstructedMessage := make([]byte, reconstructedLength) offset := 0 for i := range len(chunkRecord.Chunks) { chunkLength := len(chunkRecord.Chunks[i].Buffer) copy(reconstructedMessage[offset:offset+chunkLength], chunkRecord.Chunks[i].Buffer[:]) offset += chunkLength } return reconstructedMessage, nil } func (bw *BirdWhisperer) PruneChunkRecord(currentTimestamp int64) { const CHUNK_PRUNING_TIMEOUT_MS = 30_000 const CHUNK_RECEIVE_COMPLETE_TIMEOUT_MS = 60_000 * 3 for k, v := range bw.ChunkRecord { if v.CompletedTime > 0 && currentTimestamp-v.CompletedTime > CHUNK_PRUNING_TIMEOUT_MS { // Remove this from map - safe to do in loop delete(bw.ChunkRecord, k) } else if v.ReceivedTime > 0 && v.CompletedTime == 0 && currentTimestamp-v.ReceivedTime > CHUNK_RECEIVE_COMPLETE_TIMEOUT_MS { // Remove messages that were never completed after enough time delete(bw.ChunkRecord, k) } } } func (bw *BirdWhisperer) ReceiveChunk(chunk Packet, completedMessages chan []byte) error { if len(chunk.Buffer) <= CHUNK_HEADER_VERSION_OFFSET { return fmt.Errorf("Packet not long enough to be a chunk!") } switch chunk.Buffer[CHUNK_HEADER_VERSION_OFFSET] { case 1: header, err := ParseChunkHeaderV1(chunk) if err != nil { return err } _, exists := bw.ChunkRecord[header.MessageID] if !exists { bw.ChunkRecord[header.MessageID] = ReconstructedMessage{ make([]Packet, header.NumChunks), time.Now().Unix(), 0, header.NumChunks, 0, } } if chunkRecord, ok := bw.ChunkRecord[header.MessageID]; ok { if chunkRecord.CompletedTime > 0 { // Already received all of the chunks for this message return nil } if chunkRecord.Chunks[header.Index].TargetAddr == nil { // New chunk for existing message chunkRecord.Chunks[header.Index].Buffer = chunk.Buffer[header.Length : header.Length+uint16(header.HeaderLength)] chunkRecord.ReceivedChunks += 1 // Update map bw.ChunkRecord[header.MessageID] = chunkRecord } // If we have all chunks reconstruct the message and pass it along for processing if chunkRecord.ReceivedChunks == chunkRecord.NumChunks { // Have all chunks msg, err := ReconstructMessageFromChunks(chunkRecord) if err != nil { return err } chunkRecord.CompletedTime = time.Now().Unix() // Update in map bw.ChunkRecord[header.MessageID] = chunkRecord slog.Debug("Received all chunks of a message", "firstReceived", chunkRecord.ReceivedTime, "lastReceived", chunkRecord.CompletedTime) // Pass on to other things completedMessages <- msg } } else { return fmt.Errorf("Critical: Chunk record does not exist for message ID") } if int(header.Length+uint16(header.HeaderLength)) < len(chunk.Buffer) { // There is more data in our incoming chunk, perhaps another message appended bw.ReceiveChunk(Packet{TargetAddr: chunk.TargetAddr, Buffer: chunk.Buffer[header.Length+uint16(header.HeaderLength):]}, completedMessages) } default: return fmt.Errorf("Packet has unsupported chunk version") } return nil } func ParseChunkHeaderV1(chunk Packet) (header ChunkHeaderV1, err error) { if len(chunk.Buffer) < CHUNK_V1_HEADER_LENGTH { return header, fmt.Errorf("Malformed chunk") } // We chould check the version beforehand, but it doesn't cost // much to just have it parsed and then check anyway and this // makes the numeric ordering clear header.MessageID = binary.BigEndian.Uint16(chunk.Buffer[0:2]) header.Length = binary.BigEndian.Uint16(chunk.Buffer[2:4]) header.Version = chunk.Buffer[CHUNK_HEADER_VERSION_OFFSET] header.Index = binary.BigEndian.Uint16(chunk.Buffer[5:7]) header.NumChunks = binary.BigEndian.Uint16(chunk.Buffer[7:9]) header.HeaderLength = CHUNK_V1_HEADER_LENGTH if header.Version != 1 { return header, fmt.Errorf("Not a V1 header!") } return } /* * message: Array of *chars* */ func ChunkMessage(message []byte, messageId uint16) [][]byte { messageLen := len(message) numChunks := uint16(messageLen / CHUNK_MAX_LENGTH) leftoverBytes := uint16(messageLen % CHUNK_MAX_LENGTH) if leftoverBytes > 0 { numChunks += 1 } // +1 due to flooring messageChunks := make([][]byte, numChunks) consumedBytes := 0 // Loop through all full-length chunks for i := range numChunks - 1 { // TODO: When remaining bytes < CHUNK_MAX_LENGTH, does it set length to max anyway? Other edges thisHeader := ChunkHeaderV1{ MessageID: messageId, Index: i + 1, NumChunks: numChunks + 1, Length: CHUNK_MAX_LENGTH + CHUNK_V1_HEADER_LENGTH, Version: 1, HeaderLength: CHUNK_V1_HEADER_LENGTH, } // Header + portion of data it represents messageChunks[i] = append(thisHeader.toBytes(), message[(CHUNK_MAX_LENGTH*i):(CHUNK_MAX_LENGTH*(i+1))]...) consumedBytes += CHUNK_MAX_LENGTH } // This handles numChunks == 0 (i.e. single message) or any remaining bytes if leftoverBytes > 0 { thisHeader := ChunkHeaderV1{ MessageID: messageId, Index: numChunks + 1, NumChunks: numChunks + 1, Length: leftoverBytes + CHUNK_V1_HEADER_LENGTH, Version: 1, HeaderLength: CHUNK_V1_HEADER_LENGTH, } messageChunks[len(messageChunks)-1] = append(thisHeader.toBytes(), message[consumedBytes:]...) } return messageChunks }