diff --git a/protocols/bgp/packet/path_attributes.go b/protocols/bgp/packet/path_attributes.go index 1d51e4a48f97e7a16ced7d8dd1e59410ad078742..ee0c5d89329d571ed298e780f4824027c7876d1a 100644 --- a/protocols/bgp/packet/path_attributes.go +++ b/protocols/bgp/packet/path_attributes.go @@ -281,23 +281,23 @@ func (pa *PathAttribute) decodeASPath(buf *bytes.Buffer, asnLength uint8) error func (pa *PathAttribute) decodeASN(buf *bytes.Buffer, asnSize uint8) (asn uint32, err error) { if asnSize == 4 { - return pa.decode4ByteASN(buf) + return decode4ByteASN(buf) } - return pa.decode2ByteASN(buf) + return decode2ByteASN(buf) } -func (pa *PathAttribute) decode4ByteASN(buf *bytes.Buffer) (asn uint32, err error) { +func decode4ByteASN(buf *bytes.Buffer) (asn uint32, err error) { asn4 := uint32(0) err = decode.DecodeUint32(buf, &asn4) if err != nil { return 0, err } - return uint32(asn4), nil + return asn4, nil } -func (pa *PathAttribute) decode2ByteASN(buf *bytes.Buffer) (asn uint32, err error) { +func decode2ByteASN(buf *bytes.Buffer) (asn uint32, err error) { asn2 := uint16(0) err = decode.DecodeUint16(buf, &asn2) if err != nil { diff --git a/protocols/bgp/packet/path_attributes_test.go b/protocols/bgp/packet/path_attributes_test.go index dcbcb3bbb6baa03b5f87eb71e12b1400c48033a6..fef287f8fe9657ca2ef0919e67ce2114c14f2663 100644 --- a/protocols/bgp/packet/path_attributes_test.go +++ b/protocols/bgp/packet/path_attributes_test.go @@ -2336,3 +2336,27 @@ func TestFourBytesToUint32(t *testing.T) { } } } + +func TestDecode4ByteASN(t *testing.T) { + tests := []struct { + name string + input *bytes.Buffer + expected uint32 + }{ + { + name: "Test #1", + input: bytes.NewBuffer([]byte{0b00000000, 0b00000011, 0b00010011, 0b11100101}), + expected: 201701, + }, + } + + for _, test := range tests { + res, err := decode4ByteASN(test.input) + if err != nil { + t.Errorf("error in test %q: %v", test.name, err) + continue + } + + assert.Equal(t, test.expected, res, test.name) + } +} diff --git a/protocols/bgp/server/bmp_router.go b/protocols/bgp/server/bmp_router.go index 1f4b522a97ed346239bfba7dadcadbcbc967b75e..5e7245aae64871ace7864b6a8207f0c3bbfd3a89 100644 --- a/protocols/bgp/server/bmp_router.go +++ b/protocols/bgp/server/bmp_router.go @@ -158,7 +158,9 @@ func (r *Router) processRouteMonitoringMsg(msg *bmppkt.RouteMonitoringMsg) { } s := n.fsm.state.(*establishedState) - s.msgReceived(msg.BGPUpdate, s.fsm.decodeOptions()) + opt := s.fsm.decodeOptions() + opt.Use32BitASN = !msg.PerPeerHeader.GetAFlag() + s.msgReceived(msg.BGPUpdate, opt) } func (r *Router) processInitiationMsg(msg *bmppkt.InitiationMessage) { @@ -385,9 +387,6 @@ func (p *peer) configureBySentOpen(msg *packet.BGPOpen) { MaxPaths: 10, } } - case packet.ASN4CapabilityCode: - asn4Cap := cap.Value.(packet.ASN4Capability) - p.localASN = asn4Cap.ASN4 } } } diff --git a/protocols/bgp/server/bmp_server_test.go b/protocols/bgp/server/bmp_server_test.go index 7f8c605eb2f7d3e54974d2ac3181b00dc44805e3..fef31b5c9822d4e5fdad754283adfb9da0d87ee7 100644 --- a/protocols/bgp/server/bmp_server_test.go +++ b/protocols/bgp/server/bmp_server_test.go @@ -224,7 +224,7 @@ func TestBMPServer(t *testing.T) { 0, // Msg Type (route monitoring) 0, // Peer Type (global instance peer) - 0, // Peer Flags + 0b00100000, // Peer Flags 0, 0, 0, 0, 0, 0, 0, 123, // Peer Distinguisher 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 1, 1, 1, // Peer Address (10.1.1.1) 0, 0, 0, 200, // Peer AS = 200 @@ -315,7 +315,7 @@ func TestBMPServer(t *testing.T) { 0, // Msg Type (route monitoring) 0, // Peer Type (global instance peer) - 0, // Peer Flags + 0b00100000, // Peer Flags 0, 0, 0, 0, 0, 0, 0, 123, // Peer Distinguisher 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 1, 2, 1, // Peer Address (10.1.2.1) 0, 0, 0, 222, // Peer AS = 222 diff --git a/protocols/bmp/packet/per_peer_header.go b/protocols/bmp/packet/per_peer_header.go index 7723d826f7a81b17f716d31a1b3bbdc3889bf3d1..6fdfedbe3e15abbca30dd9f044d13b489febea62 100644 --- a/protocols/bmp/packet/per_peer_header.go +++ b/protocols/bmp/packet/per_peer_header.go @@ -60,8 +60,14 @@ func decodePerPeerHeader(buf *bytes.Buffer) (*PerPeerHeader, error) { // GetIPVersion gets the IP version of the BGP session func (p *PerPeerHeader) GetIPVersion() uint8 { - if p.PeerFlags>>7 == 1 { + if p.PeerFlags&0b10000000 == 0b10000000 { return 6 } + return 4 } + +// GetAFlag checks if the A flag is set +func (p *PerPeerHeader) GetAFlag() bool { + return p.PeerFlags&0b00100000 == 0b00100000 +} diff --git a/protocols/bmp/packet/per_peer_header_test.go b/protocols/bmp/packet/per_peer_header_test.go index 8c51ee6c9b286faf927893d20ab093fe792ce0ce..17f48724cb3362b5929d70f9d1bb6a0baec2c8b2 100644 --- a/protocols/bmp/packet/per_peer_header_test.go +++ b/protocols/bmp/packet/per_peer_header_test.go @@ -125,28 +125,28 @@ func TestGetIPVersion(t *testing.T) { { name: "IPv4", p: &PerPeerHeader{ - PeerFlags: 0, + PeerFlags: 0b00000000, }, expected: 4, }, { name: "IPv4 #2", p: &PerPeerHeader{ - PeerFlags: 127, + PeerFlags: 0b01000000, }, expected: 4, }, { name: "IPv6", p: &PerPeerHeader{ - PeerFlags: 128, + PeerFlags: 0b10000000, }, expected: 6, }, { name: "IPv6 #2", p: &PerPeerHeader{ - PeerFlags: 129, + PeerFlags: 0b11000000, }, expected: 6, }, @@ -154,6 +154,33 @@ func TestGetIPVersion(t *testing.T) { for _, test := range tests { v := test.p.GetIPVersion() - assert.Equal(t, test.expected, v) + assert.Equal(t, test.expected, v, test.name) + } +} + +func TestGetAFlag(t *testing.T) { + tests := []struct { + name string + input *PerPeerHeader + expected bool + }{ + { + name: "Test #1", + input: &PerPeerHeader{ + PeerFlags: 0b11011111, + }, + expected: false, + }, + { + name: "Test #2", + input: &PerPeerHeader{ + PeerFlags: 0b00100000, + }, + expected: true, + }, + } + + for _, test := range tests { + assert.Equal(t, test.expected, test.input.GetAFlag()) } } diff --git a/util/decode/decode.go b/util/decode/decode.go index d32e2933272b85d7291ebeb2a8b9e9afd0d33220..5283ba79cce7302c99e02ba92f67905f02cbdffc 100644 --- a/util/decode/decode.go +++ b/util/decode/decode.go @@ -19,6 +19,7 @@ func Decode(buf *bytes.Buffer, fields []interface{}) error { return nil } +// DecodeUint8 decodes an uint8 func DecodeUint8(buf *bytes.Buffer, x *uint8) error { y, err := buf.ReadByte() if err != nil { @@ -29,6 +30,7 @@ func DecodeUint8(buf *bytes.Buffer, x *uint8) error { return nil } +// DecodeUint16 decodes an uint16 func DecodeUint16(buf *bytes.Buffer, x *uint16) error { a, err := buf.ReadByte() if err != nil { @@ -40,10 +42,11 @@ func DecodeUint16(buf *bytes.Buffer, x *uint16) error { return err } - *x = uint16(a)*256 + uint16(b) + *x = uint16(a)<<8 + uint16(b) return nil } +// DecodeUint32 decodes an uint32 func DecodeUint32(buf *bytes.Buffer, x *uint32) error { a, err := buf.ReadByte() if err != nil { @@ -65,6 +68,6 @@ func DecodeUint32(buf *bytes.Buffer, x *uint32) error { return err } - *x = uint32(a)*256*256*256 + uint32(b)*256*256*256 + uint32(c)*256 + uint32(d) + *x = uint32(a)<<24 + uint32(b)<<16 + uint32(c)<<8 + uint32(d) return nil }