diff --git a/protocols/bgp/packet/path_attributes.go b/protocols/bgp/packet/path_attributes.go index 52fad7ed2129c8527b73eafb7d2842607c8d4420..3b3192034590bd357b29a13833a2bcaaa632f0cf 100644 --- a/protocols/bgp/packet/path_attributes.go +++ b/protocols/bgp/packet/path_attributes.go @@ -180,41 +180,15 @@ func (pa *PathAttribute) decodeASPath(buf *bytes.Buffer) error { } func (pa *PathAttribute) decodeNextHop(buf *bytes.Buffer) error { - addr := [4]byte{} - - p := uint16(0) - n, err := buf.Read(addr[:]) - if err != nil { - return err - } - if n != 4 { - return fmt.Errorf("Unable to read next hop: buf.Read read %d bytes", n) - } - - pa.Value = fourBytesToUint32(addr) - p += 4 - - return dumpNBytes(buf, pa.Length-p) + return pa.decodeUint32(buf, "next hop") } func (pa *PathAttribute) decodeMED(buf *bytes.Buffer) error { - med, err := pa.decodeUint32(buf) - if err != nil { - return fmt.Errorf("Unable to decode MED: %v", err) - } - - pa.Value = uint32(med) - return nil + return pa.decodeUint32(buf, "MED") } func (pa *PathAttribute) decodeLocalPref(buf *bytes.Buffer) error { - lpref, err := pa.decodeUint32(buf) - if err != nil { - return fmt.Errorf("Unable to decode local pref: %v", err) - } - - pa.Value = uint32(lpref) - return nil + return pa.decodeUint32(buf, "local pref") } func (pa *PathAttribute) decodeAggregator(buf *bytes.Buffer) error { @@ -252,7 +226,7 @@ func (pa *PathAttribute) decodeCommunities(buf *bytes.Buffer) error { coms := make([]uint32, count) for i := uint16(0); i < count; i++ { - v, err := read4BytesAsUin32(buf) + v, err := read4BytesAsUint32(buf) if err != nil { return err } @@ -274,19 +248,19 @@ func (pa *PathAttribute) decodeLargeCommunities(buf *bytes.Buffer) error { for i := uint16(0); i < count; i++ { com := LargeCommunity{} - v, err := read4BytesAsUin32(buf) + v, err := read4BytesAsUint32(buf) if err != nil { return err } com.GlobalAdministrator = v - v, err = read4BytesAsUin32(buf) + v, err = read4BytesAsUint32(buf) if err != nil { return err } com.DataPart1 = v - v, err = read4BytesAsUin32(buf) + v, err = read4BytesAsUint32(buf) if err != nil { return err } @@ -300,22 +274,27 @@ func (pa *PathAttribute) decodeLargeCommunities(buf *bytes.Buffer) error { } func (pa *PathAttribute) decodeAS4Path(buf *bytes.Buffer) error { - as4Path, err := pa.decodeUint32(buf) + return pa.decodeUint32(buf, "AS4Path") +} + +func (pa *PathAttribute) decodeAS4Aggregator(buf *bytes.Buffer) error { + return pa.decodeUint32(buf, "AS4Aggregator") +} + +func (pa *PathAttribute) decodeUint32(buf *bytes.Buffer, attrName string) error { + v, err := read4BytesAsUint32(buf) if err != nil { - return fmt.Errorf("Unable to decode AS4Path: %v", err) + return fmt.Errorf("Unable to decode %s: %v", attrName, err) } - pa.Value = as4Path - return nil -} + pa.Value = v -func (pa *PathAttribute) decodeAS4Aggregator(buf *bytes.Buffer) error { - as4Aggregator, err := pa.decodeUint32(buf) + p := uint16(4) + err = dumpNBytes(buf, pa.Length-p) if err != nil { - return fmt.Errorf("Unable to decode AS4Aggregator: %v", err) + return fmt.Errorf("dumpNBytes failed: %v", err) } - pa.Value = as4Aggregator return nil } @@ -339,24 +318,6 @@ func (pa *PathAttribute) setLength(buf *bytes.Buffer) (int, error) { return bytesRead, nil } -func (pa *PathAttribute) decodeUint32(buf *bytes.Buffer) (uint32, error) { - var v uint32 - - p := uint16(0) - err := decode(buf, []interface{}{&v}) - if err != nil { - return 0, err - } - - p += 4 - err = dumpNBytes(buf, pa.Length-p) - if err != nil { - return 0, fmt.Errorf("dumpNBytes failed: %v", err) - } - - return v, nil -} - func (pa *PathAttribute) ASPathString() (ret string) { for _, p := range pa.Value.(ASPath) { if p.Type == ASSet { @@ -722,14 +683,14 @@ func fourBytesToUint32(address [4]byte) uint32 { return uint32(address[0])<<24 + uint32(address[1])<<16 + uint32(address[2])<<8 + uint32(address[3]) } -func read4BytesAsUin32(buf *bytes.Buffer) (uint32, error) { +func read4BytesAsUint32(buf *bytes.Buffer) (uint32, error) { b := [4]byte{} n, err := buf.Read(b[:]) if err != nil { return 0, err } if n != 4 { - return 0, fmt.Errorf("Unable to read as uint32: buf.Read read %d bytes", n) + return 0, fmt.Errorf("Unable to read as uint32. Expected 4 bytes but got only %d", n) } return fourBytesToUint32(b), nil diff --git a/protocols/bgp/packet/path_attributes_test.go b/protocols/bgp/packet/path_attributes_test.go index 75915d34c81f3895b92dfcdfdc43c91af094f971..c0da61a3e1c38ee310f1eb738ff75bd095a5e3a2 100644 --- a/protocols/bgp/packet/path_attributes_test.go +++ b/protocols/bgp/packet/path_attributes_test.go @@ -783,7 +783,7 @@ func TestSetLength(t *testing.T) { } } -func TestDecodeUint32(t *testing.T) { +func TestRead4BytesAsUint32(t *testing.T) { tests := []struct { name string input []byte @@ -822,7 +822,7 @@ func TestDecodeUint32(t *testing.T) { pa := &PathAttribute{ Length: l, } - res, err := pa.decodeUint32(bytes.NewBuffer(test.input)) + err := pa.decodeUint32(bytes.NewBuffer(test.input), "test") if test.wantFail { if err != nil { @@ -837,7 +837,7 @@ func TestDecodeUint32(t *testing.T) { continue } - assert.Equal(t, test.expected, res) + assert.Equal(t, test.expected, pa.Value) } }