diff --git a/protocols/bgp/server/update_helper_test.go b/protocols/bgp/server/update_helper_test.go index 81792015b17fcdec0597b686084b2979221ef0b7..5752edded4aba03b66acd88950796e21c2c82620 100644 --- a/protocols/bgp/server/update_helper_test.go +++ b/protocols/bgp/server/update_helper_test.go @@ -1,25 +1,52 @@ package server import ( + "io" "testing" "bytes" "github.com/bio-routing/bio-rd/protocols/bgp/packet" + "errors" + "github.com/bio-routing/bio-rd/net" "github.com/stretchr/testify/assert" ) +type failingUpdate struct{} + +func (f *failingUpdate) SerializeUpdate() ([]byte, error) { + return nil, errors.New("general error") +} + +type WriterByter interface { + Bytes() []byte + io.Writer +} + +type failingReadWriter struct { +} + +func (f *failingReadWriter) Write(p []byte) (n int, err error) { + return 0, errors.New("general error") +} + +func (f *failingReadWriter) Bytes() []byte { + return []byte{} +} + func TestSerializeAndSendUpdate(t *testing.T) { tests := []struct { name string + buf WriterByter err error testUpdate serializeAbleUpdate expected []byte }{ { name: "normal bgp update", + buf: bytes.NewBuffer(nil), err: nil, testUpdate: &packet.BGPUpdate{ WithdrawnRoutesLen: 5, @@ -40,14 +67,37 @@ func TestSerializeAndSendUpdate(t *testing.T) { 0, 5, 8, 10, 16, 192, 168, 0, 0, // 2 withdraws }, }, + { + name: "failed serialization", + buf: bytes.NewBuffer(nil), + err: nil, + testUpdate: &failingUpdate{}, + expected: nil, + }, + { + name: "failed connection", + buf: &failingReadWriter{}, + err: errors.New("Failed sending Update: general error"), + testUpdate: &packet.BGPUpdate{ + WithdrawnRoutesLen: 5, + WithdrawnRoutes: &packet.NLRI{ + IP: strAddr("10.0.0.0"), + Pfxlen: 8, + Next: &packet.NLRI{ + IP: strAddr("192.168.0.0"), + Pfxlen: 16, + }, + }, + }, + expected: []byte{}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - buf := bytes.NewBuffer(nil) - err := serializeAndSendUpdate(buf, test.testUpdate) + err := serializeAndSendUpdate(test.buf, test.testUpdate) assert.Equal(t, test.err, err) - assert.Equal(t, test.expected, buf.Bytes()) + assert.Equal(t, test.expected, test.buf.Bytes()) }) }