diff --git a/protocols/bgp/server/fsm.go b/protocols/bgp/server/fsm.go index 991f1b60853631a7b88d7fbb38c0e81ff551cc6e..a5b48ef195df148ad69139c638601e0dae935586 100644 --- a/protocols/bgp/server/fsm.go +++ b/protocols/bgp/server/fsm.go @@ -1,7 +1,9 @@ package server import ( + "context" "fmt" + "io" "net" "sync" "time" @@ -70,6 +72,8 @@ type FSM struct { stateMu sync.RWMutex reason string active bool + + connectionCancelFunc context.CancelFunc } // NewPassiveFSM2 initiates a new passive FSM @@ -105,8 +109,11 @@ func newFSM2(peer *peer) *FSM { } func (fsm *FSM) start() { + ctx, cancel := context.WithCancel(context.Background()) + fsm.connectionCancelFunc = cancel + go fsm.run() - go fsm.tcpConnector() + go fsm.tcpConnector(ctx) return } @@ -115,6 +122,8 @@ func (fsm *FSM) activate() { } func (fsm *FSM) run() { + defer fsm.cancelRunningGoRoutines() + next, reason := fsm.state.run() for { newState := stateName(next) @@ -141,6 +150,12 @@ func (fsm *FSM) run() { } } +func (fsm *FSM) cancelRunningGoRoutines() { + if fsm.connectionCancelFunc != nil { + fsm.connectionCancelFunc() + } +} + func stateName(s state) string { switch s.(type) { case *idleState: @@ -166,7 +181,7 @@ func (fsm *FSM) cease() { fsm.eventCh <- Cease } -func (fsm *FSM) tcpConnector() error { +func (fsm *FSM) tcpConnector(ctx context.Context) error { for { select { case <-fsm.initiateCon: @@ -186,6 +201,8 @@ func (fsm *FSM) tcpConnector() error { case <-time.NewTimer(time.Second * 30).C: c.Close() continue + case <-ctx.Done(): + return nil } } } @@ -271,6 +288,23 @@ func (fsm *FSM) sendKeepalive() error { return nil } +func recvMsg(c net.Conn) (msg []byte, err error) { + buffer := make([]byte, packet.MaxLen) + _, err = io.ReadFull(c, buffer[0:packet.MinLen]) + if err != nil { + return nil, fmt.Errorf("Read failed: %v", err) + } + + l := int(buffer[16])*256 + int(buffer[17]) + toRead := l + _, err = io.ReadFull(c, buffer[packet.MinLen:toRead]) + if err != nil { + return nil, fmt.Errorf("Read failed: %v", err) + } + + return buffer, nil +} + func stopTimer(t *time.Timer) { if !t.Stop() { select { diff --git a/protocols/bgp/server/server.go b/protocols/bgp/server/server.go index db6ae24f310cf027114e5e95aee7fee3211aac8b..a27e522276e3fd61791c7f8926db87b2041d2d3f 100644 --- a/protocols/bgp/server/server.go +++ b/protocols/bgp/server/server.go @@ -2,13 +2,11 @@ package server import ( "fmt" - "io" "net" "strings" "sync" "github.com/bio-routing/bio-rd/config" - "github.com/bio-routing/bio-rd/protocols/bgp/packet" "github.com/bio-routing/bio-rd/routingtable/locRIB" log "github.com/sirupsen/logrus" ) @@ -125,20 +123,3 @@ func (b *bgpServer) AddPeer(c config.Peer, rib *locRIB.LocRIB) error { return nil } - -func recvMsg(c net.Conn) (msg []byte, err error) { - buffer := make([]byte, packet.MaxLen) - _, err = io.ReadFull(c, buffer[0:packet.MinLen]) - if err != nil { - return nil, fmt.Errorf("Read failed: %v", err) - } - - l := int(buffer[16])*256 + int(buffer[17]) - toRead := l - _, err = io.ReadFull(c, buffer[packet.MinLen:toRead]) - if err != nil { - return nil, fmt.Errorf("Read failed: %v", err) - } - - return buffer, nil -}