diff --git a/net/packet/packet.go b/net/packet/packet.go index 8dee86ea7..0d2dfa260 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -556,14 +556,18 @@ func updateV4PacketChecksums(p *Parsed, old, new netip.Addr) { // TODO(maisem): more protocols (sctp, gre, dccp) } -// updateV4Checksum calculates and updates the checksum in the packet buffer -// for a change between old and new. The checksum is updated in place. +// updateV4Checksum calculates and updates the checksum in the packet buffer for +// a change between old and new. The oldSum must point to the 16-bit checksum +// field in the packet buffer that holds the old checksum value, it will be +// updated in place. +// +// The old and new must be the same length, and must be an even number of bytes. func updateV4Checksum(oldSum, old, new []byte) { if len(old) != len(new) { panic("old and new must be the same length") } if len(old)%2 != 0 { - panic("old and new must be even length") + panic("old and new must be of even length") } /* RFC 1624 diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index 9d6254f09..acf878296 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -5,6 +5,7 @@ package packet import ( "bytes" + "encoding/binary" "encoding/hex" "net/netip" "reflect" @@ -29,6 +30,72 @@ const ( Fragment = ipproto.Fragment ) +func fullHeaderChecksumV4(b []byte) uint16 { + s := uint32(0) + for i := 0; i < len(b); i += 2 { + if i == 10 { + // Skip checksum field. + continue + } + s += uint32(binary.BigEndian.Uint16(b[i : i+2])) + } + for s>>16 > 0 { + s = s&0xFFFF + s>>16 + } + return ^uint16(s) +} + +func TestHeaderChecksums(t *testing.T) { + // This is not a good enough test, because it doesn't + // check the various packet types or the many edge cases + // of the checksum algorithm. But it's a start. + + tests := []struct { + name string + packet []byte + }{ + { + name: "ICMPv4", + packet: []byte{ + 0x45, 0x00, 0x00, 0x54, 0xb7, 0x96, 0x40, 0x00, 0x40, 0x01, 0x7a, 0x06, 0x64, 0x7f, 0x3f, 0x4c, 0x64, 0x40, 0x01, 0x01, 0x08, 0x00, 0x47, 0x1a, 0x00, 0x11, 0x01, 0xac, 0xcc, 0xf5, 0x95, 0x63, 0x00, 0x00, 0x00, 0x00, 0x8d, 0xfc, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + }, + }, + { + name: "TLS", + packet: []byte{ + 0x45, 0x00, 0x00, 0x3c, 0x54, 0x29, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xac, 0x64, 0x42, 0xd4, 0x33, 0x64, 0x61, 0x98, 0x0f, 0xb1, 0x94, 0x01, 0xbb, 0x0a, 0x51, 0xce, 0x7c, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfb, 0xe0, 0x38, 0xf6, 0x00, 0x00, 0x02, 0x04, 0x04, 0xd8, 0x04, 0x02, 0x08, 0x0a, 0x86, 0x2b, 0xcc, 0xd5, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07, + }, + }, + { + name: "DNS", + packet: []byte{ + 0x45, 0x00, 0x00, 0x74, 0xe2, 0x85, 0x00, 0x00, 0x40, 0x11, 0x96, 0xb5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x42, 0xd4, 0x33, 0x00, 0x35, 0xec, 0x55, 0x00, 0x60, 0xd9, 0x19, 0xed, 0xfd, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x34, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x0c, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x01, 0x6c, 0xc0, 0x15, 0xc0, 0x31, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x04, 0x8e, 0xfa, 0xbd, 0xce, 0x00, 0x00, 0x29, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + } + var p Parsed + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p.Decode(tt.packet) + t.Log(p.String()) + p.UpdateSrcAddr(netip.MustParseAddr("100.64.0.1")) + + got := binary.BigEndian.Uint16(tt.packet[10:12]) + want := fullHeaderChecksumV4(tt.packet[:20]) + if got != want { + t.Fatalf("got %x want %x", got, want) + } + + p.UpdateDstAddr(netip.MustParseAddr("100.64.0.2")) + got = binary.BigEndian.Uint16(tt.packet[10:12]) + want = fullHeaderChecksumV4(tt.packet[:20]) + if got != want { + t.Fatalf("got %x want %x", got, want) + } + }) + } +} + func mustIPPort(s string) netip.AddrPort { ipp, err := netip.ParseAddrPort(s) if err != nil {