diff --git a/common/build.go b/common/build.go index afebf09195982f25d736fb33d6591413897216b5..2c7c0432265c56821b0f6e53e823bc9f29307fce 100644 --- a/common/build.go +++ b/common/build.go @@ -352,10 +352,16 @@ func (b *Build) waitForTerminal(ctx context.Context, timeout time.Duration) { return } + expiryTime, _ := ctx.Deadline() + + if expiryTime.Before(time.Now().Add(timeout)) { + timeout = expiryTime.Sub(time.Now()) + } + b.logger.Infoln( fmt.Sprintf( "Terminal is connected, will time out in %s...", - timeout, + helpers.RoundDuration(timeout, time.Second), ), ) @@ -369,7 +375,7 @@ func (b *Build) waitForTerminal(ctx context.Context, timeout time.Duration) { case <-time.After(timeout): err := fmt.Errorf( "Terminal session timed out (maximum time allowed - %s)", - timeout, + helpers.RoundDuration(timeout, time.Second), ) b.logger.Infoln(err.Error()) b.Log().WithError(err).Debugln("Connection closed") diff --git a/common/build_test.go b/common/build_test.go index 24b8760b348bad9795f2778fb97a122cc9da36b8..d0157692012f85ca77784f9fcd1b69791ffc1edc 100644 --- a/common/build_test.go +++ b/common/build_test.go @@ -1,15 +1,26 @@ package common import ( + "bytes" + "context" "errors" "fmt" + "net/http" + "net/http/httptest" + "net/url" "os" + "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-runner/session" + "gitlab.com/gitlab-org/gitlab-runner/session/terminal" ) func init() { @@ -661,3 +672,167 @@ func TestIsFeatureFlagOn(t *testing.T) { } } + +func TestWaitForTerminal_CancelBuild(t *testing.T) { + build, out, hook, cancelFn, mockConn := bootstrapWaitForTerminal(t, 1*time.Hour, 1800*time.Second) + + mockConn.On("Close").Return(nil).Once() + defer mockConn.AssertExpectations(t) + + cancelFn() + + buildSessionDisconnect := func() bool { + return !build.Session.Connected() + } + + waitFor(5*time.Second, buildSessionDisconnect) + + assert.Contains(t, out.String(), "Terminal is connected, will time out in 30m0s") + assert.Contains(t, hook.LastEntry().Message, "Build cancelled, killing session") +} + +func TestWaitForTerminal_BuildTimeout(t *testing.T) { + build, out, _, _, _ := bootstrapWaitForTerminal(t, 1*time.Hour, 5*time.Second) + + <-build.Session.TimeoutCh + + assert.Contains(t, out.String(), "Terminal is connected, will time out") + assert.Contains(t, out.String(), "Terminal session timed out") +} + +func TestWaitForTerminal_UserDisconnect(t *testing.T) { + build, out, _, _, _ := bootstrapWaitForTerminal(t, 1*time.Hour, 1800*time.Second) + + build.Session.DisconnectCh <- errors.New("user disconnect") + + assert.Contains(t, out.String(), "Terminal is connected, will time out") + + terminalDisconnectedMessage := func() bool { + return strings.Contains(out.String(), "Terminal disconnected") + } + + waitFor(5*time.Second, terminalDisconnectedMessage) + + assert.Contains(t, out.String(), "Terminal disconnected") +} + +func TestWaitForTerminal_SystemInterrupt(t *testing.T) { + build, out, _, _, mockConn := bootstrapWaitForTerminal(t, 1*time.Hour, 1800*time.Second) + + mockConn.On("Close").Return(nil).Once() + defer mockConn.AssertExpectations(t) + + build.SystemInterrupt <- os.Interrupt + + assert.Contains(t, out.String(), "Terminal is connected, will time out") + terminalDisconnectedMessage := func() bool { + return strings.Contains(out.String(), "Terminal disconnected") + } + + waitFor(5*time.Second, terminalDisconnectedMessage) + assert.Contains(t, out.String(), "Terminal disconnected") +} + +func TestWaitForTerminal_ContextTimeoutShorterThenTerminalTimeout(t *testing.T) { + build, out, _, _, _ := bootstrapWaitForTerminal(t, 5*time.Second, 1800*time.Second) + + <-build.Session.TimeoutCh + + assert.NotContains(t, out.String(), "Terminal is connected, will time out in 30m") + assert.Contains(t, out.String(), "Terminal is connected, will time out") + assert.Contains(t, out.String(), "Terminal session timed out (maximum time allowed - 5s") +} + +func bootstrapWaitForTerminal(t *testing.T, buildTimeout, sessionTimeout time.Duration) (*Build, *bytes.Buffer, *test.Hook, context.CancelFunc, *terminal.MockConn) { + hook := test.NewGlobal() + e := MockExecutor{} + defer e.AssertExpectations(t) + + p := MockExecutorProvider{} + defer p.AssertExpectations(t) + + p.On("GetDefaultShell").Return("bash").Once() + p.On("CanCreate").Return(true).Once() + p.On("GetFeatures", mock.Anything).Return(nil).Once() + + mockConn := terminal.MockConn{} + defer mockConn.AssertExpectations(t) + mockConn.On("Start", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + upgrader := &websocket.Upgrader{} + r := args[1].(*http.Request) + w := args[0].(http.ResponseWriter) + + _, _ = upgrader.Upgrade(w, r, nil) + + time.Sleep(10 * time.Second) + }).Once() + + mockTerminal := terminal.MockInteractiveTerminal{} + defer mockTerminal.AssertExpectations(t) + mockTerminal.On("Connect").Return(&mockConn, nil).Once() + + // Register unique executor per test, to prevent collisions. + RegisterExecutor(fmt.Sprintf("shell-%s", time.Now().String()), &p) + + build := Build{ + Runner: &RunnerConfig{ + RunnerSettings: RunnerSettings{ + Executor: "shell", + }, + }, + SystemInterrupt: make(chan os.Signal), + } + + var buildOut bytes.Buffer + trace := Trace{Writer: &buildOut} + build.logger = NewBuildLogger(&trace, build.Log()) + sess, err := session.NewSession(nil) + require.NoError(t, err) + build.Session = sess + + build.Session.SetInteractiveTerminal(&mockTerminal) + + srv := httptest.NewServer(build.Session.Mux()) + defer srv.Close() + + u := url.URL{ + Scheme: "ws", + Host: srv.Listener.Addr().String(), + Path: build.Session.Endpoint + "/exec", + } + headers := http.Header{ + "Authorization": []string{build.Session.Token}, + } + + connectToWebsocket := func() bool { + _, _, err = websocket.DefaultDialer.Dial(u.String(), headers) + return err != nil + } + + waitFor(5*time.Second, connectToWebsocket) + + ctx, cancelFn := context.WithTimeout(context.Background(), buildTimeout) + + go build.waitForTerminal(ctx, sessionTimeout) + + terminalToTimeout := func() bool { + return strings.Contains(buildOut.String(), "Terminal is connected,") + } + + waitFor(5*time.Second, terminalToTimeout) + + return &build, &buildOut, hook, cancelFn, &mockConn +} + +// waitFor takes a timeout and a func that returns a bool, it will keep running +// the func until it returns true, or a timeout is elapsed. +func waitFor(timeout time.Duration, fn func() bool) { + started := time.Now() + for time.Since(started) < timeout { + if !fn() { + time.Sleep(50 * time.Millisecond) + } + + return + } +} diff --git a/helpers/time.go b/helpers/time.go new file mode 100644 index 0000000000000000000000000000000000000000..e8aba42e55fa284137990a0deb8805c1c303854c --- /dev/null +++ b/helpers/time.go @@ -0,0 +1,43 @@ +package helpers + +import "time" + +const ( + minDuration time.Duration = -1 << 63 + maxDuration time.Duration = 1<<63 - 1 +) + +// RoundDuration does exactly the same as time.Round in go.1.9+ since we are +// still on go1.8 we do not have this available. You can check the actual +// implementation in +// https://github.com/golang/go/blob/dev.boringcrypto.go1.9/src/time/time.go#L819-L841 +// and the it can be found in go1.9 change log https://golang.org/doc/go1.9 +func RoundDuration(d time.Duration, m time.Duration) time.Duration { + if m <= 0 { + return d + } + r := d % m + if d < 0 { + r = -r + if lessThanHalf(r, m) { + return d + r + } + if d1 := d - m + r; d1 < d { + return d1 + } + return minDuration // overflow + } + if lessThanHalf(r, m) { + return d - r + } + if d1 := d + m - r; d1 > d { + return d1 + } + return maxDuration // overflow +} + +// lessThanHalf reports whether x+x < y but avoids overflow, +// assuming x and y are both positive (Duration is signed). +func lessThanHalf(x, y time.Duration) bool { + return uint64(x)+uint64(x) < uint64(y) +}