diff --git a/protocols/bgp/server/fsm.go b/protocols/bgp/server/fsm.go index aa10948ea4e6af7437960503a2b18ce85bb32301..1e614ab1a45315f7ba13df63a18531c5937a92e8 100644 --- a/protocols/bgp/server/fsm.go +++ b/protocols/bgp/server/fsm.go @@ -1,7 +1,9 @@ package server import ( + "context" "fmt" + "io" "net" "sync" "time" @@ -66,6 +68,8 @@ type FSM struct { stateMu sync.RWMutex reason string active bool + + connectionCancelFunc context.CancelFunc } // NewPassiveFSM2 initiates a new passive FSM @@ -110,8 +114,11 @@ func newFSM2(peer *peer) *FSM { } func (fsm *FSM) start() { + ctx, cancel := context.WithCancel(context.Background()) + fsm.connectionCancelFunc = cancel + go fsm.run() - go fsm.tcpConnector() + go fsm.tcpConnector(ctx) return } @@ -120,6 +127,8 @@ func (fsm *FSM) activate() { } func (fsm *FSM) run() { + defer fsm.cancelRunningGoRoutines() + next, reason := fsm.state.run() for { newState := stateName(next) @@ -146,6 +155,12 @@ func (fsm *FSM) run() { } } +func (fsm *FSM) cancelRunningGoRoutines() { + if fsm.connectionCancelFunc != nil { + fsm.connectionCancelFunc() + } +} + func stateName(s state) string { switch s.(type) { case *idleState: @@ -171,7 +186,7 @@ func (fsm *FSM) cease() { fsm.eventCh <- Cease } -func (fsm *FSM) tcpConnector() error { +func (fsm *FSM) tcpConnector(ctx context.Context) { for { select { case <-fsm.initiateCon: @@ -191,6 +206,8 @@ func (fsm *FSM) tcpConnector() error { case <-time.NewTimer(time.Second * 30).C: c.Close() continue + case <-ctx.Done(): + return } } } @@ -276,6 +293,23 @@ func (fsm *FSM) sendKeepalive() error { return nil } +func recvMsg(c net.Conn) (msg []byte, err error) { + buffer := make([]byte, packet.MaxLen) + _, err = io.ReadFull(c, buffer[0:packet.MinLen]) + if err != nil { + return nil, fmt.Errorf("Read failed: %v", err) + } + + l := int(buffer[16])*256 + int(buffer[17]) + toRead := l + _, err = io.ReadFull(c, buffer[packet.MinLen:toRead]) + if err != nil { + return nil, fmt.Errorf("Read failed: %v", err) + } + + return buffer, nil +} + func stopTimer(t *time.Timer) { if !t.Stop() { select { diff --git a/protocols/bgp/server/server.go b/protocols/bgp/server/server.go index 1bccb7b64e617223a8445346630f016e8aa764ab..170ca0a2e327ea0c49045c88d1070d4e5f6ca005 100644 --- a/protocols/bgp/server/server.go +++ b/protocols/bgp/server/server.go @@ -2,13 +2,11 @@ package server import ( "fmt" - "io" "net" "strings" "sync" "github.com/bio-routing/bio-rd/config" - "github.com/bio-routing/bio-rd/protocols/bgp/packet" log "github.com/sirupsen/logrus" ) @@ -124,20 +122,3 @@ func (b *bgpServer) AddPeer(c config.Peer) error { return nil } - -func recvMsg(c net.Conn) (msg []byte, err error) { - buffer := make([]byte, packet.MaxLen) - _, err = io.ReadFull(c, buffer[0:packet.MinLen]) - if err != nil { - return nil, fmt.Errorf("Read failed: %v", err) - } - - l := int(buffer[16])*256 + int(buffer[17]) - toRead := l - _, err = io.ReadFull(c, buffer[packet.MinLen:toRead]) - if err != nil { - return nil, fmt.Errorf("Read failed: %v", err) - } - - return buffer, nil -}