diff --git a/nucleus/pnd/change_test.go b/nucleus/pnd/change_test.go index fca3646186d3916566ec5f6dfbdcafa66fa52ae3..67b97ee4e5b90ca668b8f8c8eab95bf5f4c9024c 100644 --- a/nucleus/pnd/change_test.go +++ b/nucleus/pnd/change_test.go @@ -27,65 +27,76 @@ var rollbackDevice = &exampleoc.Device{ }, } -func TestChange_Commit(t *testing.T) { +func TestChange_CommitRollback(t *testing.T) { + wantErr := false + want := rollback callback := make(chan string) - tests := []struct { - name string - want string - wantErr bool - }{ - { - name: commit, - want: commit, - wantErr: false, + c := &Change{ + cuid: changeUUID, + duid: did, + timestamp: time.Now(), + previousState: rollbackDevice, + intendedState: commitDevice, + callback: func(first ygot.GoStruct, second ygot.GoStruct) error { + hostname := *first.(*exampleoc.Device).System.Hostname + t.Logf("hostname: %v", hostname) + switch hostname { + case rollback: + callback <- rollback + } + return nil }, - { - name: rollback, - want: rollback, - wantErr: false, + lock: sync.RWMutex{}, + } + go func() { + time.Sleep(time.Millisecond * 10) + if err := c.Commit(); (err != nil) != wantErr { + t.Errorf("Commit() error = %v, wantErr %v", err, wantErr) + } + timeout, err := time.ParseDuration(os.Getenv("GOSDN_CHANGE_TIMEOUT")) + if err != nil { + t.Error(err) + } + time.Sleep(timeout) + }() + got := <-callback + if !reflect.DeepEqual(got, want) { + t.Errorf("Commit() = %v, want %v", got, want) + } + close(callback) +} + +func TestChange_Commit(t *testing.T) { + wantErr := false + want := commit + callback := make(chan string) + + c := &Change{ + cuid: changeUUID, + duid: did, + timestamp: time.Now(), + previousState: rollbackDevice, + intendedState: commitDevice, + callback: func(first ygot.GoStruct, second ygot.GoStruct) error { + hostname := *first.(*exampleoc.Device).System.Hostname + t.Logf("hostname: %v", hostname) + callback <- hostname + return nil }, + lock: sync.RWMutex{}, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &Change{ - cuid: changeUUID, - duid: did, - timestamp: time.Now(), - previousState: rollbackDevice, - intendedState: commitDevice, - callback: func(first ygot.GoStruct, second ygot.GoStruct) error { - hostname := *first.(*exampleoc.Device).System.Hostname - t.Logf("hostname: %v", hostname) - callback <- hostname - return nil - }, - lock: sync.RWMutex{}, - } - go func() { - time.Sleep(time.Millisecond * 10) - if err := c.Commit(); (err != nil) != tt.wantErr { - t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.name == "rollback" { - timeout, err := time.ParseDuration(os.Getenv("GOSDN_CHANGE_TIMEOUT")) - if err != nil { - t.Error(err) - } - time.Sleep(timeout) - } - }() - var got string - switch tt.name { - case commit: - got = <-callback - case rollback: - _ = <-callback - got = <-callback - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Commit() = %v, want %v", got, tt.want) - } - }) + go func() { + time.Sleep(time.Millisecond * 10) + if err := c.Commit(); (err != nil) != wantErr { + t.Errorf("Commit() error = %v, wantErr %v", err, wantErr) + } + if err := c.Confirm(); err != nil { + t.Errorf("Commit() error = %v", err) + } + }() + got := <-callback + if !reflect.DeepEqual(got, want) { + t.Errorf("Commit() = %v, want %v", got, want) } close(callback) } @@ -102,8 +113,8 @@ func TestChange_Confirm(t *testing.T) { committed bool } tests := []struct { - name string - fields fields + name string + fields fields wantErr bool }{ { @@ -114,8 +125,8 @@ func TestChange_Confirm(t *testing.T) { wantErr: false, }, { - name: "uncommitted", - fields: fields{}, + name: "uncommitted", + fields: fields{}, wantErr: true, }, }