From 21b9a65c7d96aba159d9014a5ae2afd775c8f623 Mon Sep 17 00:00:00 2001
From: Manuel Kieweg <manuel.kieweg@h-da.de>
Date: Mon, 26 Apr 2021 18:12:38 +0200
Subject: [PATCH] basic tests for change

---
 build/ci/.test.yml                     |   4 +-
 nucleus/controller_test.go             |   3 +
 nucleus/http.go                        |  14 +-
 nucleus/http_test.go                   |  33 ++++-
 nucleus/initialise_test.go             |  19 ---
 nucleus/pnd/change.go                  |  83 ++++++++----
 nucleus/pnd/change_test.go             | 176 +++++++++++++++++++++++++
 nucleus/pnd/initialise_test.go         |  22 ++++
 nucleus/principalNetworkDomain_test.go |   1 +
 9 files changed, 303 insertions(+), 52 deletions(-)
 create mode 100644 nucleus/pnd/change_test.go
 create mode 100644 nucleus/pnd/initialise_test.go

diff --git a/build/ci/.test.yml b/build/ci/.test.yml
index 225356ea6..3379d59ff 100644
--- a/build/ci/.test.yml
+++ b/build/ci/.test.yml
@@ -41,8 +41,8 @@ unit-test:
     - go test -short -race $(go list ./... | grep -v /forks/ | grep -v /api/ | grep -v /mocks ) -v -coverprofile=coverage.out
   <<: *test
 
-http-api-test:
+controller-test:
   script:
     - cd ./nucleus
-    - go test -race -v -run Test_httpApi -coverprofile=coverage.out
+    - go test -race -v -run TestRun -coverprofile=../coverage.out
   <<: *test
\ No newline at end of file
diff --git a/nucleus/controller_test.go b/nucleus/controller_test.go
index 47fa1026c..1dc5b8bd0 100644
--- a/nucleus/controller_test.go
+++ b/nucleus/controller_test.go
@@ -9,6 +9,9 @@ import (
 )
 
 func TestRun(t *testing.T) {
+	if testing.Short() {
+		t.Skip("this test is executed separately")
+	}
 	type args struct {
 		request string
 	}
diff --git a/nucleus/http.go b/nucleus/http.go
index 621d0d30c..fe644c7a1 100644
--- a/nucleus/http.go
+++ b/nucleus/http.go
@@ -199,11 +199,9 @@ func httpHandler(writer http.ResponseWriter, request *http.Request) {
 	case "change-list":
 		changes := pnd.ListCommitted()
 		writeIDs(writer, "Tentative changes", changes)
-		writer.WriteHeader(http.StatusOK)
 	case "change-list-pending":
 		changes := pnd.ListPending()
 		writeIDs(writer, "Pending changes", changes)
-		writer.WriteHeader(http.StatusOK)
 	case "change-commit":
 		cuid, err := uuid.Parse(query.Get("cuid"))
 		if err != nil {
@@ -215,7 +213,11 @@ func httpHandler(writer http.ResponseWriter, request *http.Request) {
 			handleServerError(writer, err)
 			return
 		}
-		change.(*Change).Commit()
+		err = change.(*Change).Commit()
+		if err != nil {
+			handleServerError(writer, err)
+			return
+		}
 		writer.WriteHeader(http.StatusAccepted)
 	case "change-confirm":
 		cuid, err := uuid.Parse(query.Get("cuid"))
@@ -228,7 +230,11 @@ func httpHandler(writer http.ResponseWriter, request *http.Request) {
 			handleServerError(writer, err)
 			return
 		}
-		change.(*Change).Confirm()
+		err = change.(*Change).Confirm()
+		if err != nil {
+			handleServerError(writer, err)
+			return
+		}
 		writer.WriteHeader(http.StatusAccepted)
 	default:
 		writer.WriteHeader(http.StatusBadRequest)
diff --git a/nucleus/http_test.go b/nucleus/http_test.go
index b7728d9ff..f6b20983a 100644
--- a/nucleus/http_test.go
+++ b/nucleus/http_test.go
@@ -39,9 +39,6 @@ func testSetupHTTP() {
 }
 
 func Test_httpApi(t *testing.T) {
-	if testing.Short() {
-		t.Skip("this test is executed separately")
-	}
 	tests := []struct {
 		name    string
 		request string
@@ -139,6 +136,36 @@ func Test_httpApi(t *testing.T) {
 			want:    &http.Response{StatusCode: http.StatusOK},
 			wantErr: false,
 		},
+		{
+			name:    "change list",
+			request: apiEndpoint + "/api?q=change-list" + args + "&path=/system/config/hostname",
+			want:    &http.Response{StatusCode: http.StatusOK},
+			wantErr: false,
+		},
+		{
+			name:    "change list pending",
+			request: apiEndpoint + "/api?q=change-list-pending" + args + "&path=/system/config/hostname",
+			want:    &http.Response{StatusCode: http.StatusOK},
+			wantErr: false,
+		},
+		{
+			name:    "change commit",
+			request: apiEndpoint + "/api?q=change-commit" + args + "&cuid=" + uuid.New().String(),
+			want:    &http.Response{StatusCode: http.StatusOK},
+			wantErr: false,
+		},
+		{
+			name:    "change confirm",
+			request: apiEndpoint + "/api?q=change-confirm" + args + "&cuid="  + uuid.New().String(),
+			want:    &http.Response{StatusCode: http.StatusOK},
+			wantErr: false,
+		},
+		{
+			name:    "bad request",
+			request: apiEndpoint + "/api?q=bad-request",
+			want:    &http.Response{StatusCode: http.StatusBadRequest},
+			wantErr: false,
+		},
 		{
 			name:    "internal server errror: wrong pnd",
 			request: apiEndpoint + "/api?pnd=" + uuid.New().String(),
diff --git a/nucleus/initialise_test.go b/nucleus/initialise_test.go
index 932adb147..7fb435337 100644
--- a/nucleus/initialise_test.go
+++ b/nucleus/initialise_test.go
@@ -108,29 +108,10 @@ func newGnmiTransportOptions() *GnmiTransportOptions {
 func readTestUUIDs() {
 	var err error
 	did, err = uuid.Parse("4d8246f8-e884-41d6-87f5-c2c784df9e44")
-	if err != nil {
-		log.Fatal(err)
-	}
-
 	mdid, err = uuid.Parse("688a264e-5f85-40f8-bd13-afc42fcd5c7a")
-	if err != nil {
-		log.Fatal(err)
-	}
-
 	defaultSbiID, err = uuid.Parse("b70c8425-68c7-4d4b-bb5e-5586572bd64b")
-	if err != nil {
-		log.Fatal(err)
-	}
-
 	defaultPndID, err = uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad")
-	if err != nil {
-		log.Fatal(err)
-	}
-
 	ocUUID, err = uuid.Parse("5e252b70-38f2-4c99-a0bf-1b16af4d7e67")
-	if err != nil {
-		log.Fatal(err)
-	}
 	iid, err = uuid.Parse("8495a8ac-a1e8-418e-b787-10f5878b2690")
 	altIid, err = uuid.Parse("edc5de93-2d15-4586-b2a7-fb1bc770986b")
 	if err != nil {
diff --git a/nucleus/pnd/change.go b/nucleus/pnd/change.go
index 8e67b431b..553e0368e 100644
--- a/nucleus/pnd/change.go
+++ b/nucleus/pnd/change.go
@@ -1,13 +1,35 @@
 package pnd
 
 import (
+	"errors"
 	"github.com/google/uuid"
 	"github.com/openconfig/ygot/ygot"
 	log "github.com/sirupsen/logrus"
+	"golang.org/x/net/context"
+	"os"
 	"sync"
 	"time"
 )
 
+var changeTimeout time.Duration
+
+func init() {
+	timeout, err := time.ParseDuration(os.Getenv("GOSDN_CHANGE_TIMEOUT"))
+	if err != nil {
+		log.Fatal(err)
+	}
+	if timeout != time.Duration(0) {
+		changeTimeout = timeout
+	} else {
+		var err error
+		changeTimeout, err = time.ParseDuration("10m")
+		if err != nil {
+			log.Fatal()
+		}
+	}
+	log.Debugf("change timeout set to %v", changeTimeout)
+}
+
 func NewChange(device uuid.UUID, currentState ygot.GoStruct, change ygot.GoStruct, callback func(ygot.GoStruct, ygot.GoStruct) error) *Change {
 	return &Change{
 		cuid:          uuid.New(),
@@ -21,7 +43,7 @@ func NewChange(device uuid.UUID, currentState ygot.GoStruct, change ygot.GoStruc
 	}
 }
 
-// Change is an intended change to a OND. It is unique and immutable.
+// Change is an intended change to an OND. It is unique and immutable.
 // It has a cuid, a timestamp, and holds both the previous and the new
 // state. It keeps track if the state is committed and confirmed. A callback
 // exists to acess the proper transport for the changed OND
@@ -35,53 +57,66 @@ type Change struct {
 	confirmed     bool
 	callback      func(ygot.GoStruct, ygot.GoStruct) error
 	lock          sync.RWMutex
+	cancelFunc    context.CancelFunc
 }
 
 func (c *Change) ID() uuid.UUID {
 	return c.cuid
 }
 
-func (c *Change) Commit() {
+func (c *Change) Commit() error {
 	c.committed = true
 	if err := c.callback(c.intendedState, c.previousState); err != nil {
-		log.WithFields(log.Fields{
-			"change uuid": c.cuid,
-			"device uuid": c.duid,
-		}).Error(err)
-		log.WithFields(log.Fields{
-			"change uuid": c.cuid,
-			"device uuid": c.duid,
-		}).Debug("change commited")
+		return err
 	}
-	go func() {
-		time.Sleep(time.Minute * 10)
+	log.WithFields(log.Fields{
+		"change uuid": c.cuid,
+		"device uuid": c.duid,
+	}).Debug("change commited")
+	ctx, cancel := context.WithCancel(context.Background())
+	c.cancelFunc = cancel
+	go c.rollbackHandler(ctx)
+	return nil
+}
+
+func (c *Change) rollbackHandler(ctx context.Context) {
+	select {
+	case <-ctx.Done():
+		return
+	case <-time.Tick(changeTimeout):
 		c.lock.RLock()
 		defer c.lock.RUnlock()
 		if !c.confirmed {
-			c.Rollback()
+			err := c.callback(c.previousState, c.intendedState)
+			if err != nil {
+				log.WithFields(log.Fields{
+					"change uuid": c.cuid,
+					"device uuid": c.duid,
+					"error":       err,
+				}).Error("rollback error")
+			}
 			log.WithFields(log.Fields{
 				"change uuid": c.cuid,
 				"device uuid": c.duid,
-			}).Error("change timed out")
+			}).Info("change timed out")
 		}
-	}()
-}
-
-func (c *Change) Rollback() {
-	if err := c.callback(c.previousState, c.intendedState); err != nil {
-		log.WithFields(log.Fields{
-			"change uuid": c.cuid,
-			"device uuid": c.duid,
-		}).Error(err)
 	}
 }
 
-func (c *Change) Confirm() {
+func (c *Change) Confirm() error {
+	c.lock.RLock()
+	if !c.committed {
+		defer c.lock.RUnlock()
+		return errors.New("cannot confirm uncommitted change")
+	}
+	c.lock.RUnlock()
 	c.lock.Lock()
 	defer c.lock.Unlock()
 	c.confirmed = true
+	c.cancelFunc()
 	log.WithFields(log.Fields{
 		"change uuid": c.cuid,
 		"device uuid": c.duid,
 	}).Info("change confirmed")
+	return nil
 }
diff --git a/nucleus/pnd/change_test.go b/nucleus/pnd/change_test.go
new file mode 100644
index 000000000..fca364618
--- /dev/null
+++ b/nucleus/pnd/change_test.go
@@ -0,0 +1,176 @@
+package pnd
+
+import (
+	"context"
+	"github.com/google/uuid"
+	"github.com/openconfig/ygot/exampleoc"
+	"github.com/openconfig/ygot/ygot"
+	"os"
+	"reflect"
+	"sync"
+	"testing"
+	"time"
+)
+
+var commit = "commit"
+var rollback = "rollback"
+
+var commitDevice = &exampleoc.Device{
+	System: &exampleoc.System{
+		Hostname: &commit,
+	},
+}
+
+var rollbackDevice = &exampleoc.Device{
+	System: &exampleoc.System{
+		Hostname: &rollback,
+	},
+}
+
+func TestChange_Commit(t *testing.T) {
+	callback := make(chan string)
+	tests := []struct {
+		name    string
+		want    string
+		wantErr bool
+	}{
+		{
+			name: commit,
+			want:    commit,
+			wantErr: false,
+		},
+		{
+			name: rollback,
+			want:    rollback,
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			c := &Change{
+				cuid:          changeUUID,
+				duid:          did,
+				timestamp:     time.Now(),
+				previousState: rollbackDevice,
+				intendedState: commitDevice,
+				callback:      func(first ygot.GoStruct, second ygot.GoStruct) error {
+					hostname := *first.(*exampleoc.Device).System.Hostname
+					t.Logf("hostname: %v", hostname)
+					callback <- hostname
+					return nil
+				},
+				lock:          sync.RWMutex{},
+			}
+			go func() {
+				time.Sleep(time.Millisecond * 10)
+				if err := c.Commit(); (err != nil) != tt.wantErr {
+					t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr)
+				}
+				if tt.name == "rollback" {
+					timeout, err := time.ParseDuration(os.Getenv("GOSDN_CHANGE_TIMEOUT"))
+					if err != nil {
+						t.Error(err)
+					}
+					time.Sleep(timeout)
+				}
+			}()
+			var got string
+			switch tt.name {
+			case commit:
+				got = <-callback
+			case rollback:
+				_ = <-callback
+				got = <-callback
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("Commit() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+	close(callback)
+}
+
+func TestChange_Confirm(t *testing.T) {
+	_, cancel := context.WithCancel(context.Background())
+	type fields struct {
+		cuid          uuid.UUID
+		duid          uuid.UUID
+		timestamp     time.Time
+		previousState ygot.GoStruct
+		intendedState ygot.GoStruct
+		callback      func(ygot.GoStruct, ygot.GoStruct) error
+		committed     bool
+	}
+	tests := []struct {
+		name   string
+		fields fields
+		wantErr bool
+	}{
+		{
+			name: "committed",
+			fields: fields{
+				committed: true,
+			},
+			wantErr: false,
+		},
+		{
+			name:   "uncommitted",
+			fields: fields{},
+			wantErr: true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			c := &Change{
+				cuid:      tt.fields.cuid,
+				duid:      tt.fields.duid,
+				timestamp: tt.fields.timestamp,
+				previousState: &exampleoc.Device{
+					System: &exampleoc.System{
+						Hostname: &rollback,
+					},
+				},
+				intendedState: &exampleoc.Device{
+					System: &exampleoc.System{
+						Hostname: &commit,
+					},
+				},
+				callback:   tt.fields.callback,
+				committed:  tt.fields.committed,
+				cancelFunc: cancel,
+				lock:       sync.RWMutex{},
+			}
+			if err := c.Confirm(); (err != nil) != tt.wantErr {
+				t.Errorf("Confirm() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+	cancel()
+}
+
+func TestChange_ID(t *testing.T) {
+	type fields struct {
+		cuid uuid.UUID
+	}
+	tests := []struct {
+		name   string
+		fields fields
+		want   uuid.UUID
+	}{
+		{
+			name:   "default",
+			fields: fields{cuid: changeUUID},
+			want:   changeUUID,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			c := &Change{
+				cuid: tt.fields.cuid,
+			}
+			if got := c.ID(); !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("ID() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
diff --git a/nucleus/pnd/initialise_test.go b/nucleus/pnd/initialise_test.go
new file mode 100644
index 000000000..ff7dcbf9f
--- /dev/null
+++ b/nucleus/pnd/initialise_test.go
@@ -0,0 +1,22 @@
+package pnd
+
+import (
+	"github.com/google/uuid"
+	log "github.com/sirupsen/logrus"
+	"os"
+	"testing"
+)
+
+// UUIDs for test cases
+var changeUUID uuid.UUID
+var did uuid.UUID
+
+func TestMain(m *testing.M) {
+	var err error
+	changeUUID, err = uuid.Parse("cfbb96cd-ecad-45d1-bebf-1851760f5087")
+	did, err = uuid.Parse("4d8246f8-e884-41d6-87f5-c2c784df9e44")
+	if err != nil {
+		log.Fatal(err)
+	}
+	os.Exit(m.Run())
+}
\ No newline at end of file
diff --git a/nucleus/principalNetworkDomain_test.go b/nucleus/principalNetworkDomain_test.go
index 4b885cdf5..ef4207edc 100644
--- a/nucleus/principalNetworkDomain_test.go
+++ b/nucleus/principalNetworkDomain_test.go
@@ -524,6 +524,7 @@ func Test_pndImplementation_RequestAll(t *testing.T) {
 }
 
 func Test_pndImplementation_ChangeOND(t *testing.T) {
+	t.Fail()
 	type fields struct {
 		name             string
 		description      string
-- 
GitLab