stun: check high bits in Is, add tests
Also use new stun.TxID type in stunner. Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>pull/106/head
parent
2489ea4268
commit
14abc82033
|
@ -218,9 +218,7 @@ func mappedAddress(b []byte) (addr []byte, port uint16, err error) {
|
|||
|
||||
// Is reports whether b is a STUN message.
|
||||
func Is(b []byte) bool {
|
||||
if len(b) < headerLen {
|
||||
return false // every STUN message must have a 20-byte header
|
||||
}
|
||||
// TODO RFC5389 suggests checking the first 2 bits of the header are zero.
|
||||
return string(b[4:8]) == magicCookie
|
||||
return len(b) >= headerLen &&
|
||||
b[0]&0b11000000 == 0 && // top two bits must be zero
|
||||
string(b[4:8]) == magicCookie
|
||||
}
|
||||
|
|
|
@ -166,3 +166,29 @@ func TestParseResponse(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIs(t *testing.T) {
|
||||
const magicCookie = "\x21\x12\xa4\x42"
|
||||
tests := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"", false},
|
||||
{"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
|
||||
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
|
||||
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true},
|
||||
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00foo", true},
|
||||
// high bits set:
|
||||
{"\xf0\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
|
||||
{"\x40\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
|
||||
// first byte non-zero, but not high bits:
|
||||
{"\x20\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
pkt := []byte(tt.in)
|
||||
got := stun.Is(pkt)
|
||||
if got != tt.want {
|
||||
t.Errorf("%d. In(%q (%v)) = %v; want %v", i, pkt, pkt, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ type Stunner struct {
|
|||
|
||||
type session struct {
|
||||
replied chan struct{} // closed when server responds
|
||||
tIDs [][12]byte // transaction IDs sent to a server
|
||||
tIDs []stun.TxID // transaction IDs sent to a server
|
||||
}
|
||||
|
||||
// Receive delivers a STUN packet to the stunner.
|
||||
|
@ -90,7 +90,7 @@ func (s *Stunner) Run(ctx context.Context) error {
|
|||
}
|
||||
for _, server := range s.Servers {
|
||||
// Generate the transaction IDs for this session.
|
||||
tIDs := make([][12]byte, len(retryDurations))
|
||||
tIDs := make([]stun.TxID, len(retryDurations))
|
||||
for i := range tIDs {
|
||||
if _, err := rand.Read(tIDs[i][:]); err != nil {
|
||||
return fmt.Errorf("stunner: rand failed: %v", err)
|
||||
|
@ -147,7 +147,7 @@ func (s *Stunner) runServer(ctx context.Context, server string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error {
|
||||
func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error {
|
||||
host, port, err := net.SplitHostPort(server)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue