diff --git a/.gitignore b/.gitignore index 0ff7746cc07e234f4c3325080bf14d2bcc12e169..83d7fcf53ce03d7dcb7af95d784535e225453b15 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ documentation/design/*.pdf restconf/bin/bin test/.terraform.local/ configs/gosdn.toml -debug.test \ No newline at end of file +api/api_test.toml +debug.test diff --git a/api/pnd.go b/api/pnd.go index f396625163114d0454ca10b57ec23908ca2afb6a..976fdb0ac17efa80e945babe8a83f040adaf6080 100644 --- a/api/pnd.go +++ b/api/pnd.go @@ -177,7 +177,7 @@ func (p *PrincipalNetworkDomainAdapter) CommittedChanges() []uuid.UUID { } // GetChange sends an API call to the controller requesting the specified change -func (p *PrincipalNetworkDomainAdapter) GetChange(uuid.UUID, ...int) (change.Change, error) { +func (p *PrincipalNetworkDomainAdapter) GetChange(uuid.UUID) (change.Change, error) { return nil, &errors.ErrNotYetImplemented{} } @@ -204,7 +204,7 @@ func (p *PrincipalNetworkDomainAdapter) Confirm(cuid uuid.UUID) error { func filterChanges(state ppb.Change_State, resp *ppb.GetResponse) []uuid.UUID { changes := make([]uuid.UUID, 0) for _, ch := range resp.Change { - if ch.State == ppb.Change_PENDING { + if ch.State == state { id, _ := uuid.Parse(ch.Id) changes = append(changes, id) } diff --git a/api/pnd_test.go b/api/pnd_test.go index 5f6238157715197eadb7e3b71aa696c2cc06cac3..f10c4d7b364185ffa44ea8d845d7e490a50922e8 100644 --- a/api/pnd_test.go +++ b/api/pnd_test.go @@ -555,8 +555,7 @@ func TestPrincipalNetworkDomainAdapter_GetChange(t *testing.T) { endpoint string } type args struct { - in0 uuid.UUID - in1 []int + in uuid.UUID } tests := []struct { name string @@ -573,7 +572,7 @@ func TestPrincipalNetworkDomainAdapter_GetChange(t *testing.T) { id: tt.fields.id, endpoint: tt.fields.endpoint, } - got, err := p.GetChange(tt.args.in0, tt.args.in1...) + got, err := p.GetChange(tt.args.in) if (err != nil) != tt.wantErr { t.Errorf("PrincipalNetworkDomainAdapter.GetChange() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/initialise_test.go b/initialise_test.go index 984cf81ec228f76245892538aab838209ac981c1..d47822bbd8de7398ab5ebe30f1d5f9e8bc8bff84 100644 --- a/initialise_test.go +++ b/initialise_test.go @@ -1,78 +1,40 @@ package gosdn import ( - "context" "os" "testing" - "code.fbi.h-da.de/cocsn/gosdn/interfaces/device" - "code.fbi.h-da.de/cocsn/gosdn/interfaces/networkdomain" - "code.fbi.h-da.de/cocsn/gosdn/interfaces/southbound" - - "code.fbi.h-da.de/cocsn/gosdn/nucleus/util/proto" "github.com/google/uuid" - gpb "github.com/openconfig/gnmi/proto/gnmi" log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/mock" - pb "google.golang.org/protobuf/proto" ) const apiEndpoint = "http://localhost:8080" // UUIDs for test cases var mdid uuid.UUID -var defaultSbiID uuid.UUID var defaultPndID uuid.UUID var cuid uuid.UUID -var sbi southbound.SouthboundInterface -var httpTestPND networkdomain.NetworkDomain -var gnmiMessages map[string]pb.Message -var httpTestDevice device.Device - -var args string -var argsNotFound string -var argsNotFoundGetDevice string - -var mockContext = mock.MatchedBy(func(ctx context.Context) bool { return true }) - -// TestMain bootstraps all tests. Humongous beast -// TODO: Move somewhere more sensible func TestMain(m *testing.M) { log.SetReportCaller(true) if os.Getenv("GOSDN_LOG") == "nolog" { log.SetLevel(log.PanicLevel) } - - gnmiMessages = map[string]pb.Message{ - "./test/proto/cap-resp-arista-ceos": &gpb.CapabilityResponse{}, - "./test/proto/req-full-node": &gpb.GetRequest{}, - "./test/proto/req-full-node-arista-ceos": &gpb.GetRequest{}, - "./test/proto/req-interfaces-arista-ceos": &gpb.GetRequest{}, - "./test/proto/req-interfaces-interface-arista-ceos": &gpb.GetRequest{}, - "./test/proto/req-interfaces-wildcard": &gpb.GetRequest{}, - "./test/proto/resp-full-node": &gpb.GetResponse{}, - "./test/proto/resp-full-node-arista-ceos": &gpb.GetResponse{}, - "./test/proto/resp-interfaces-arista-ceos": &gpb.GetResponse{}, - "./test/proto/resp-interfaces-interface-arista-ceos": &gpb.GetResponse{}, - "./test/proto/resp-interfaces-wildcard": &gpb.GetResponse{}, - "./test/proto/resp-set-system-config-hostname": &gpb.SetResponse{}, - } - for k, v := range gnmiMessages { - if err := proto.Read(k, v); err != nil { - log.Fatalf("error parsing %v: %v", k, err) - } - } readTestUUIDs() - os.Exit(m.Run()) } func readTestUUIDs() { var err error mdid, err = uuid.Parse("688a264e-5f85-40f8-bd13-afc42fcd5c7a") + if err != nil { + log.Fatal(err) + } defaultPndID, err = uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad") + if err != nil { + log.Fatal(err) + } cuid, err = uuid.Parse("3e8219b0-e926-400d-8660-217f2a25a7c6") if err != nil { log.Fatal(err) diff --git a/interfaces/change/change.go b/interfaces/change/change.go index 9c861b44ad34de54f177e91f2b0eed2a590d6558..e195c02b16a0ac5fe1ef38c61f8a4a437d1bddbe 100644 --- a/interfaces/change/change.go +++ b/interfaces/change/change.go @@ -1,6 +1,9 @@ package change -import "github.com/google/uuid" +import ( + ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" + "github.com/google/uuid" +) // Change is an intended change to an OND. It is unique and immutable. // It has a cuid, a timestamp, and holds both the previous and the new @@ -10,4 +13,5 @@ type Change interface { ID() uuid.UUID Commit() error Confirm() error + State() ppb.Change_State } diff --git a/interfaces/networkdomain/pnd.go b/interfaces/networkdomain/pnd.go index ce032d856f5d3ff64705593d7c2e2cef5c6b7b58..b723d723470af129644b2a0f62f882abf0163614 100644 --- a/interfaces/networkdomain/pnd.go +++ b/interfaces/networkdomain/pnd.go @@ -31,7 +31,7 @@ type NetworkDomain interface { ID() uuid.UUID PendingChanges() []uuid.UUID CommittedChanges() []uuid.UUID - GetChange(uuid.UUID, ...int) (change.Change, error) + GetChange(uuid.UUID) (change.Change, error) Commit(uuid.UUID) error Confirm(uuid.UUID) error } diff --git a/mocks/Change.go b/mocks/Change.go index f59f4749d421b8e6dab4388044cd4c9feb3d6b8a..c211183e76791592c3dcb4274f3fd250df54969f 100644 --- a/mocks/Change.go +++ b/mocks/Change.go @@ -3,8 +3,10 @@ package mocks import ( - uuid "github.com/google/uuid" + pnd "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" mock "github.com/stretchr/testify/mock" + + uuid "github.com/google/uuid" ) // Change is an autogenerated mock type for the Change type @@ -55,3 +57,17 @@ func (_m *Change) ID() uuid.UUID { return r0 } + +// State provides a mock function with given fields: +func (_m *Change) State() pnd.Change_State { + ret := _m.Called() + + var r0 pnd.Change_State + if rf, ok := ret.Get(0).(func() pnd.Change_State); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(pnd.Change_State) + } + + return r0 +} diff --git a/mocks/NetworkDomain.go b/mocks/NetworkDomain.go index c43edae6837df98034f87947c4457ec1c7de7465..86e84069386bcc0247d8d9b10b3b5c8563c32bf7 100644 --- a/mocks/NetworkDomain.go +++ b/mocks/NetworkDomain.go @@ -161,20 +161,13 @@ func (_m *NetworkDomain) Devices() []uuid.UUID { return r0 } -// GetChange provides a mock function with given fields: _a0, _a1 -func (_m *NetworkDomain) GetChange(_a0 uuid.UUID, _a1 ...int) (change.Change, error) { - _va := make([]interface{}, len(_a1)) - for _i := range _a1 { - _va[_i] = _a1[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetChange provides a mock function with given fields: _a0 +func (_m *NetworkDomain) GetChange(_a0 uuid.UUID) (change.Change, error) { + ret := _m.Called(_a0) var r0 change.Change - if rf, ok := ret.Get(0).(func(uuid.UUID, ...int) change.Change); ok { - r0 = rf(_a0, _a1...) + if rf, ok := ret.Get(0).(func(uuid.UUID) change.Change); ok { + r0 = rf(_a0) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(change.Change) @@ -182,8 +175,8 @@ func (_m *NetworkDomain) GetChange(_a0 uuid.UUID, _a1 ...int) (change.Change, er } var r1 error - if rf, ok := ret.Get(1).(func(uuid.UUID, ...int) error); ok { - r1 = rf(_a0, _a1...) + if rf, ok := ret.Get(1).(func(uuid.UUID) error); ok { + r1 = rf(_a0) } else { r1 = ret.Error(1) } diff --git a/mocks/SouthboundInterface.go b/mocks/SouthboundInterface.go index c7936b27df0f42f4c6331711b357635badb1a05b..3528d150a28e19fc062e9d80e8d4439249980c61 100644 --- a/mocks/SouthboundInterface.go +++ b/mocks/SouthboundInterface.go @@ -12,6 +12,8 @@ import ( yang "github.com/openconfig/goyang/pkg/yang" + ygot "github.com/openconfig/ygot/ygot" + ytypes "github.com/openconfig/ygot/ytypes" ) @@ -66,17 +68,22 @@ func (_m *SouthboundInterface) Schema() *ytypes.Schema { return r0 } -// SetNode provides a mock function with given fields: -func (_m *SouthboundInterface) SetNode() func(*yang.Entry, interface{}, *gnmi.Path, interface{}, []ytypes.SetNodeOpt) error { - ret := _m.Called() - - var r0 func(*yang.Entry, interface{}, *gnmi.Path, interface{}, []ytypes.SetNodeOpt) error - if rf, ok := ret.Get(0).(func() func(*yang.Entry, interface{}, *gnmi.Path, interface{}, []ytypes.SetNodeOpt) error); ok { - r0 = rf() +// SetNode provides a mock function with given fields: schema, root, path, val, opts +func (_m *SouthboundInterface) SetNode(schema *yang.Entry, root interface{}, path *gnmi.Path, val interface{}, opts ...ytypes.SetNodeOpt) error { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, schema, root, path, val) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(*yang.Entry, interface{}, *gnmi.Path, interface{}, ...ytypes.SetNodeOpt) error); ok { + r0 = rf(schema, root, path, val, opts...) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(func(*yang.Entry, interface{}, *gnmi.Path, interface{}, []ytypes.SetNodeOpt) error) - } + r0 = ret.Error(0) } return r0 @@ -95,3 +102,24 @@ func (_m *SouthboundInterface) Type() gosdnsouthbound.Type { return r0 } + +// Unmarshal provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *SouthboundInterface) Unmarshal(_a0 []byte, _a1 []string, _a2 ygot.ValidatedGoStruct, _a3 ...ytypes.UnmarshalOpt) error { + _va := make([]interface{}, len(_a3)) + for _i := range _a3 { + _va[_i] = _a3[_i] + } + var _ca []interface{} + _ca = append(_ca, _a0, _a1, _a2) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func([]byte, []string, ygot.ValidatedGoStruct, ...ytypes.UnmarshalOpt) error); ok { + r0 = rf(_a0, _a1, _a2, _a3...) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/nucleus/change.go b/nucleus/change.go index b1a3242717c4736289dcb8ec78e6c9453bf06f4a..75c54d37850780f011ffcc7330b03bbec3573a4d 100644 --- a/nucleus/change.go +++ b/nucleus/change.go @@ -1,9 +1,8 @@ package nucleus import ( - "errors" + "fmt" "os" - "sync" "time" ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" @@ -11,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/openconfig/ygot/ygot" log "github.com/sirupsen/logrus" - "golang.org/x/net/context" ) var changeTimeout time.Duration @@ -24,13 +22,10 @@ func init() { if err != nil { log.Fatal(err) } + log.Debugf("change timeout set to %v", changeTimeout) } else { - changeTimeout, err = time.ParseDuration("10m") - if err != nil { - log.Fatal(err) - } + changeTimeout = time.Minute * 10 } - log.Debugf("change timeout set to %v", changeTimeout) } // NewChange takes a Device UUID, a pair GoStructs (current and intended state) @@ -38,6 +33,7 @@ func init() { // The callback function is used by the Commit() and Confirm() functions. It // must define how the change is carried out. func NewChange(device uuid.UUID, currentState ygot.GoStruct, change ygot.GoStruct, callback func(ygot.GoStruct, ygot.GoStruct) error, errChan chan error) *Change { + commit, confirm, out := stateManager(changeTimeout) return &Change{ cuid: uuid.New(), duid: device, @@ -48,7 +44,9 @@ func NewChange(device uuid.UUID, currentState ygot.GoStruct, change ygot.GoStruc confirmed: false, callback: callback, errChan: errChan, - done: make(chan int), + out: out, + commit: commit, + confirm: confirm, } } @@ -66,10 +64,10 @@ type Change struct { confirmed bool inconsistent bool callback func(ygot.GoStruct, ygot.GoStruct) error - lock sync.RWMutex - cancelFunc context.CancelFunc errChan chan error - done chan int + out <-chan bool + commit chan<- *Change + confirm chan<- *Change } // ID returns the Change's UUID @@ -77,60 +75,36 @@ func (c *Change) ID() uuid.UUID { return c.cuid } -// Commit pushes the cange to the OND using the callback() function +// Commit pushes the change to the OND using the callback() function // and starts the timeout-timer for the Change. If the timer expires // the change is rolled back. func (c *Change) Commit() error { - if err := c.callback(c.intendedState, c.previousState); err != nil { - return err - } - c.committed = true - log.WithFields(log.Fields{ - "change uuid": c.cuid, - "device uuid": c.duid, - }).Debug("change commited") - ctx, cancel := context.WithCancel(context.Background()) - c.cancelFunc = cancel - go c.rollbackHandler(ctx) - return nil -} - -func (c *Change) rollbackHandler(ctx context.Context) { + c.commit <- c select { - case <-ctx.Done(): - return - case <-time.Tick(changeTimeout): - c.lock.RLock() - defer c.lock.RUnlock() - if !c.confirmed { - c.errChan <- c.callback(c.previousState, c.intendedState) - log.WithFields(log.Fields{ - "change uuid": c.cuid, - "device uuid": c.duid, - }).Info("change timed out") + case err := <-c.errChan: + if err != nil { + return err } + case <-c.out: + return nil } + return nil } // Confirm confirms a committed Change and stops the rollback timer. func (c *Change) Confirm() error { - c.lock.RLock() if !c.committed { - defer c.lock.RUnlock() - return errors.New("cannot confirm uncommitted change") + return fmt.Errorf("cannot confirm uncommitted change %v", c.cuid) + } + c.confirm <- c + select { + case err := <-c.errChan: + if err != nil { + return err + } + case <-c.out: + return nil } - c.lock.RUnlock() - c.lock.Lock() - defer c.lock.Unlock() - c.confirmed = true - c.cancelFunc() - close(c.errChan) - c.done <- 0 - close(c.done) - log.WithFields(log.Fields{ - "change uuid": c.cuid, - "device uuid": c.duid, - }).Info("change confirmed") return nil } @@ -149,3 +123,36 @@ func (c *Change) State() ppb.Change_State { return ppb.Change_CONFIRMED } } + +func stateManager(timeout time.Duration) (chan<- *Change, chan<- *Change, <-chan bool) { + commit := make(chan *Change) + confirm := make(chan *Change) + out := make(chan bool) + ticker := time.NewTicker(timeout) + + go func() { + ch := <-commit + err := ch.callback(ch.previousState, ch.intendedState) + if err != nil { + ch.errChan <- err + } + ch.committed = true + out <- true + for { + select { + case <-ticker.C: + err := ch.callback(ch.intendedState, ch.previousState) + if err != nil { + ch.errChan <- err + } + ch.errChan <- fmt.Errorf("change %v timed out", ch.cuid) + break + case <-confirm: + ch.confirmed = true + out <- true + break + } + } + }() + return commit, confirm, out +} diff --git a/nucleus/change_test.go b/nucleus/change_test.go index eca18fa34152fddb2125737a7f6961c7e50e8f8a..4f35b89e4c6c29d746c91bdc5f3bf11994d4d81e 100644 --- a/nucleus/change_test.go +++ b/nucleus/change_test.go @@ -1,38 +1,43 @@ package nucleus import ( - "context" "errors" "reflect" - "sync" "testing" "time" + ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" "github.com/google/uuid" "github.com/openconfig/ygot/exampleoc" "github.com/openconfig/ygot/ygot" ) -var commit = "commit" -var rollback = "rollback" +var commitHostname = "commit" +var rollbackHostname = "rollback" var commitDevice = &exampleoc.Device{ System: &exampleoc.System{ - Hostname: &commit, + Hostname: &commitHostname, }, } var rollbackDevice = &exampleoc.Device{ System: &exampleoc.System{ - Hostname: &rollback, + Hostname: &rollbackHostname, }, } func TestChange_CommitRollback(t *testing.T) { wantErr := false - want := rollback + want := rollbackHostname callback := make(chan string) + errChan := make(chan error, 10) + commit, confirm, out := stateManager(time.Millisecond * 100) c := &Change{ + commit: commit, + confirm: confirm, + errChan: errChan, + out: out, cuid: cuid, duid: did, timestamp: time.Now(), @@ -42,46 +47,47 @@ func TestChange_CommitRollback(t *testing.T) { hostname := *first.(*exampleoc.Device).System.Hostname t.Logf("hostname: %v", hostname) switch hostname { - case rollback: - callback <- rollback + case rollbackHostname: + callback <- rollbackHostname } return nil }, - lock: sync.RWMutex{}, } go func() { time.Sleep(time.Millisecond * 10) if err := c.Commit(); (err != nil) != wantErr { t.Errorf("Commit() error = %v, wantErr %v", err, wantErr) } - time.Sleep(changeTimeout) + time.Sleep(time.Millisecond * 100) }() got := <-callback if !reflect.DeepEqual(got, want) { t.Errorf("Commit() = %v, want %v", got, want) } - close(callback) } func TestChange_CommitRollbackError(t *testing.T) { wantErr := false want := errors.New("this is an expected error") + commit, confirm, out := stateManager(time.Millisecond * 100) c := &Change{ + commit: commit, + confirm: confirm, + out: out, cuid: cuid, duid: did, timestamp: time.Now(), previousState: rollbackDevice, intendedState: commitDevice, callback: func(first ygot.GoStruct, second ygot.GoStruct) error { - hostname := *first.(*exampleoc.Device).System.Hostname + hostname := *second.(*exampleoc.Device).System.Hostname t.Logf("hostname: %v", hostname) switch hostname { - case rollback: + case rollbackHostname: return errors.New("this is an expected error") } return nil }, - lock: sync.RWMutex{}, errChan: make(chan error), } go func() { @@ -89,18 +95,21 @@ func TestChange_CommitRollbackError(t *testing.T) { if err := c.Commit(); (err != nil) != wantErr { t.Errorf("Commit() error = %v, wantErr %v", err, wantErr) } - time.Sleep(changeTimeout) + time.Sleep(time.Millisecond * 100) }() got := <-c.errChan if !reflect.DeepEqual(got, want) { t.Errorf("Commit() = %v, want %v", got, want) } - close(c.errChan) } func TestChange_CommitError(t *testing.T) { wantErr := true + commit, confirm, out := stateManager(time.Millisecond * 100) c := &Change{ + commit: commit, + confirm: confirm, + out: out, cuid: cuid, duid: did, timestamp: time.Now(), @@ -109,7 +118,6 @@ func TestChange_CommitError(t *testing.T) { callback: func(first ygot.GoStruct, second ygot.GoStruct) error { return errors.New("this is an expected error") }, - lock: sync.RWMutex{}, } go func() { time.Sleep(time.Millisecond * 10) @@ -125,10 +133,13 @@ func TestChange_CommitError(t *testing.T) { func TestChange_Commit(t *testing.T) { wantErr := false - want := commit - callback := make(chan string) + want := ppb.Change_COMMITTED + commit, confirm, out := stateManager(time.Millisecond * 100) c := &Change{ + commit: commit, + confirm: confirm, + out: out, cuid: cuid, duid: did, timestamp: time.Now(), @@ -137,84 +148,71 @@ func TestChange_Commit(t *testing.T) { callback: func(first ygot.GoStruct, second ygot.GoStruct) error { hostname := *first.(*exampleoc.Device).System.Hostname t.Logf("hostname: %v", hostname) - callback <- hostname return nil }, - lock: sync.RWMutex{}, - errChan: make(chan error), - done: make(chan int), + errChan: make(chan error, 10), } - go func() { - time.Sleep(time.Millisecond * 10) - if err := c.Commit(); (err != nil) != wantErr { - t.Errorf("Commit() error = %v, wantErr %v", err, wantErr) - } - if err := c.Confirm(); err != nil { - t.Errorf("Commit() error = %v", err) - } - }() - got := <-callback + if err := c.Commit(); (err != nil) != wantErr { + t.Errorf("Commit() error = %v, wantErr %v", err, wantErr) + } + got := c.State() if !reflect.DeepEqual(got, want) { t.Errorf("Commit() = %v, want %v", got, want) } - close(callback) + if err := c.Confirm(); err != nil { + t.Errorf("Confirm() error = %v", err) + } } func TestChange_Confirm(t *testing.T) { - _, cancel := context.WithCancel(context.Background()) - type fields struct { - cuid uuid.UUID - duid uuid.UUID - timestamp time.Time - previousState ygot.GoStruct - intendedState ygot.GoStruct - callback func(ygot.GoStruct, ygot.GoStruct) error - committed bool - } tests := []struct { name string - fields fields wantErr bool }{ { - name: "committed", - fields: fields{ - committed: true, - }, + name: "committed", wantErr: false, }, { name: "uncommitted", - fields: fields{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + errChan := make(chan error, 10) + commit, confirm, out := stateManager(time.Millisecond * 100) c := &Change{ - committed: tt.fields.committed, - timestamp: tt.fields.timestamp, + commit: commit, + confirm: confirm, + errChan: errChan, + out: out, previousState: &exampleoc.Device{ System: &exampleoc.System{ - Hostname: &rollback, + Hostname: &rollbackHostname, }, }, intendedState: &exampleoc.Device{ System: &exampleoc.System{ - Hostname: &commit, + Hostname: &commitHostname, }, }, - cancelFunc: cancel, - lock: sync.RWMutex{}, - errChan: make(chan error), - done: make(chan int, 1), + callback: func(first ygot.GoStruct, second ygot.GoStruct) error { + hostname := *first.(*exampleoc.Device).System.Hostname + t.Logf("hostname: %v", hostname) + return nil + }, + } + if tt.name == "committed" { + if err := c.Commit(); err != nil { + t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr) + } } if err := c.Confirm(); (err != nil) != tt.wantErr { t.Errorf("Confirm() error = %v, wantErr %v", err, tt.wantErr) } }) } - cancel() } func TestChange_ID(t *testing.T) { @@ -243,3 +241,47 @@ func TestChange_ID(t *testing.T) { }) } } + +func TestChange_State(t *testing.T) { + tests := []struct { + name string + want ppb.Change_State + }{ + { + name: "pending", + want: ppb.Change_PENDING, + }, + { + name: "committed", + want: ppb.Change_COMMITTED, + }, + { + name: "confirmed", + want: ppb.Change_CONFIRMED, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + callback := func(first ygot.GoStruct, second ygot.GoStruct) error { + hostname := *first.(*exampleoc.Device).System.Hostname + t.Logf("hostname: %v", hostname) + return nil + } + errChan := make(chan error) + c := NewChange(did, rollbackDevice, commitDevice, callback, errChan) + if tt.name != "pending" { + if err := c.Commit(); err != nil { + t.Errorf("Commit() error = %v", err) + } + } + if tt.name == "confirmed" { + if err := c.Confirm(); err != nil { + t.Errorf("Confirm() error = %v", err) + } + } + if got := c.State(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Change.State() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/nucleus/gnmi_transport.go b/nucleus/gnmi_transport.go index 06611db2a1a0af73141c55e5ea4ebce52f09e5dc..68568f7033cea220059862a228264e481a4a2f6e 100644 --- a/nucleus/gnmi_transport.go +++ b/nucleus/gnmi_transport.go @@ -102,15 +102,20 @@ func (g *Gnmi) Set(ctx context.Context, args ...interface{}) error { } } - if len(args) == 2 { - switch args[0].(type) { + // look at args, unpack any GoStructs present + goStructs := make([]ygot.GoStruct, 0) + for _, arg := range args { + switch a := arg.(type) { case ygot.GoStruct: - return g.applyDiff(ctx, args[0]) + goStructs = append(goStructs, a) default: - } } + if len(goStructs) == 2 { + return g.applyDiff(ctx, goStructs) + } + opts := make([]interface{}, 0) for _, o := range args { attrs, ok := o.([]string) @@ -170,7 +175,7 @@ func (g *Gnmi) Set(ctx context.Context, args ...interface{}) error { return nil } -func (g *Gnmi) applyDiff(ctx context.Context, payload ...interface{}) error { +func (g *Gnmi) applyDiff(ctx context.Context, payload []ygot.GoStruct) error { if len(payload) != 2 { return &errors.ErrInvalidParameters{} } diff --git a/nucleus/initialise_test.go b/nucleus/initialise_test.go index ffa4fa64a19b9fd44929c0287ef846127618f067..c93c3615ff866b3adcfd69e6e82acd185fa44bda 100644 --- a/nucleus/initialise_test.go +++ b/nucleus/initialise_test.go @@ -67,7 +67,6 @@ func TestMain(m *testing.M) { } } readTestUUIDs() - testSetupGnmi() os.Exit(m.Run()) } @@ -105,12 +104,33 @@ func newGnmiTransportOptions() *tpb.TransportOption { func readTestUUIDs() { var err error did, err = uuid.Parse("4d8246f8-e884-41d6-87f5-c2c784df9e44") + if err != nil { + log.Fatal(err) + } mdid, err = uuid.Parse("688a264e-5f85-40f8-bd13-afc42fcd5c7a") + if err != nil { + log.Fatal(err) + } defaultSbiID, err = uuid.Parse("b70c8425-68c7-4d4b-bb5e-5586572bd64b") + if err != nil { + log.Fatal(err) + } defaultPndID, err = uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad") + if err != nil { + log.Fatal(err) + } ocUUID, err = uuid.Parse("5e252b70-38f2-4c99-a0bf-1b16af4d7e67") + if err != nil { + log.Fatal(err) + } iid, err = uuid.Parse("8495a8ac-a1e8-418e-b787-10f5878b2690") + if err != nil { + log.Fatal(err) + } altIid, err = uuid.Parse("edc5de93-2d15-4586-b2a7-fb1bc770986b") + if err != nil { + log.Fatal(err) + } cuid, err = uuid.Parse("3e8219b0-e926-400d-8660-217f2a25a7c6") if err != nil { log.Fatal(err) @@ -129,14 +149,12 @@ func mockDevice() device.Device { func newPnd() pndImplementation { return pndImplementation{ - name: "default", - description: "default test pnd", - sbic: SbiStore{genericStore{}}, - devices: NewDeviceStore(), - pendingChanges: ChangeStore{genericStore{}}, - committedChanges: ChangeStore{genericStore{}}, - confirmedChanges: ChangeStore{genericStore{}}, - id: defaultPndID, - errChans: make(map[uuid.UUID]chan error), + name: "default", + description: "default test pnd", + sbic: SbiStore{genericStore{}}, + devices: NewDeviceStore(), + changes: ChangeStore{genericStore{}}, + id: defaultPndID, + errChans: make(map[uuid.UUID]chan error), } } diff --git a/nucleus/principalNetworkDomain.go b/nucleus/principalNetworkDomain.go index 82912dcbef3cf8af1606480634c6af950751df3c..92b3bdec99b8c97e0a1dd461c7adc44bbe3bd5d9 100644 --- a/nucleus/principalNetworkDomain.go +++ b/nucleus/principalNetworkDomain.go @@ -3,9 +3,10 @@ package nucleus import ( "context" "encoding/json" - "reflect" "time" + "code.fbi.h-da.de/cocsn/gosdn/nucleus/types" + cpb "code.fbi.h-da.de/cocsn/api/go/gosdn/csbi" ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" tpb "code.fbi.h-da.de/cocsn/api/go/gosdn/transport" @@ -28,15 +29,13 @@ import ( // NewPND creates a Principle Network Domain func NewPND(name, description string, id uuid.UUID, sbi southbound.SouthboundInterface, c cpb.CsbiClient, callback func(uuid.UUID, chan DeviceDetails)) (networkdomain.NetworkDomain, error) { pnd := &pndImplementation{ - name: name, - description: description, - sbic: SbiStore{genericStore{}}, - devices: NewDeviceStore(), - pendingChanges: ChangeStore{genericStore{}}, - committedChanges: ChangeStore{genericStore{}}, - confirmedChanges: ChangeStore{genericStore{}}, - id: id, - errChans: make(map[uuid.UUID]chan error), + name: name, + description: description, + sbic: SbiStore{genericStore{}}, + devices: NewDeviceStore(), + changes: ChangeStore{genericStore{}}, + id: id, + errChans: make(map[uuid.UUID]chan error), csbiClient: c, callback: callback, @@ -48,105 +47,48 @@ func NewPND(name, description string, id uuid.UUID, sbi southbound.SouthboundInt } type pndImplementation struct { - name string - description string - sbic SbiStore - devices *DeviceStore - pendingChanges ChangeStore - committedChanges ChangeStore - confirmedChanges ChangeStore - id uuid.UUID - errChans map[uuid.UUID]chan error + name string + description string + sbic SbiStore + devices *DeviceStore + changes ChangeStore + id uuid.UUID + errChans map[uuid.UUID]chan error csbiClient cpb.CsbiClient callback func(uuid.UUID, chan DeviceDetails) } func (pnd *pndImplementation) PendingChanges() []uuid.UUID { - return pnd.pendingChanges.UUIDs() + return pnd.changes.Pending() } func (pnd *pndImplementation) CommittedChanges() []uuid.UUID { - return pnd.committedChanges.UUIDs() + return pnd.changes.Committed() } -func (pnd *pndImplementation) GetChange(cuid uuid.UUID, i ...int) (change.Change, error) { - var index int - if len(i) == 1 { - index = i[0] - } else if len(i) > 1 { - return nil, errors.ErrInvalidParameters{ - Func: pnd.GetChange, - Param: "length of 'i' cannot be greater than '1'", - } - } - stores := []*ChangeStore{ - &pnd.pendingChanges, - &pnd.committedChanges, - &pnd.confirmedChanges, - } - ch, err := stores[index].GetChange(cuid) - index++ - if err != nil { - switch err.(type) { - case *errors.ErrNotFound: - c, err := pnd.GetChange(cuid, index) - if err != nil { - return nil, err - } - var ok bool - ch, ok = c.(*Change) - if !ok { - return nil, &errors.ErrInvalidTypeAssertion{ - Value: c, - Type: reflect.TypeOf(&Change{}), - } - } - - default: - return nil, err - } - } - return ch, err +func (pnd *pndImplementation) ConfirmedChanges() []uuid.UUID { + return pnd.changes.Confirmed() +} + +func (pnd *pndImplementation) GetChange(cuid uuid.UUID) (change.Change, error) { + return pnd.changes.GetChange(cuid) } func (pnd *pndImplementation) Commit(u uuid.UUID) error { - ch, err := pnd.pendingChanges.GetChange(u) + ch, err := pnd.changes.GetChange(u) if err != nil { return err } - if err := ch.Commit(); err != nil { - return err - } - go func() { - for { - select { - case err := <-pnd.errChans[u]: - if err != nil { - handleRollbackError(ch.ID(), err) - } - case <-ch.done: - } - } - }() - if err := pnd.committedChanges.Add(ch); err != nil { - return err - } - return pnd.pendingChanges.Delete(u) + return ch.Commit() } func (pnd *pndImplementation) Confirm(u uuid.UUID) error { - ch, err := pnd.committedChanges.GetChange(u) + ch, err := pnd.changes.GetChange(u) if err != nil { return err } - if err := ch.Confirm(); err != nil { - return err - } - if err := pnd.confirmedChanges.Add(ch); err != nil { - return err - } - return pnd.committedChanges.Delete(u) + return ch.Confirm() } func (pnd *pndImplementation) ID() uuid.UUID { @@ -345,7 +287,6 @@ func (pnd *pndImplementation) ChangeOND(uuid uuid.UUID, operation ppb.ApiOperati Param: value, } } - switch operation { case ppb.ApiOperation_UPDATE, ppb.ApiOperation_REPLACE: typedValue := gnmi.TypedValue(value[0]) @@ -360,8 +301,9 @@ func (pnd *pndImplementation) ChangeOND(uuid uuid.UUID, operation ppb.ApiOperati return &errors.ErrOperationNotSupported{Op: operation} } + ygot.PruneEmptyBranches(cpy) callback := func(state ygot.GoStruct, change ygot.GoStruct) error { - ctx := context.Background() + ctx := context.WithValue(context.Background(), types.CtxKeyOperation, operation) // nolint return d.Transport().Set(ctx, state, change) } @@ -369,7 +311,7 @@ func (pnd *pndImplementation) ChangeOND(uuid uuid.UUID, operation ppb.ApiOperati ch := NewChange(uuid, d.Model(), cpy, callback, errChan) pnd.errChans[ch.ID()] = errChan - return pnd.pendingChanges.Add(ch) + return pnd.changes.Add(ch) } func handleRollbackError(id uuid.UUID, err error) { diff --git a/nucleus/principalNetworkDomain_test.go b/nucleus/principalNetworkDomain_test.go index 1bb567b1a06868343f4ece346812e80d93d3d2b3..5819d0e3b0735957683b41528ebcf7864c223ac9 100644 --- a/nucleus/principalNetworkDomain_test.go +++ b/nucleus/principalNetworkDomain_test.go @@ -8,7 +8,6 @@ import ( ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" spb "code.fbi.h-da.de/cocsn/api/go/gosdn/southbound" tpb "code.fbi.h-da.de/cocsn/api/go/gosdn/transport" - "code.fbi.h-da.de/cocsn/gosdn/interfaces/device" "code.fbi.h-da.de/cocsn/gosdn/interfaces/networkdomain" "code.fbi.h-da.de/cocsn/gosdn/interfaces/southbound" @@ -635,8 +634,8 @@ func Test_pndImplementation_ChangeOND(t *testing.T) { return } if !tt.wantErr { - if len(pnd.pendingChanges.genericStore) != 1 { - t.Errorf("ChangeOND() unexpected change count. got %v, want 1", len(pnd.pendingChanges.genericStore)) + if len(pnd.changes.genericStore) != 1 { + t.Errorf("ChangeOND() unexpected change count. got %v, want 1", len(pnd.changes.genericStore)) } } }) @@ -786,3 +785,75 @@ func Test_pndImplementation_Confirm(t *testing.T) { }) } } + +func Test_pndImplementation_PendingChanges(t *testing.T) { + tests := []struct { + name string + want []uuid.UUID + }{ + { + name: "default", + want: []uuid.UUID{cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pnd := newPnd() + pnd.changes.genericStore[cuid] = &Change{ + cuid: cuid, + } + if got := pnd.PendingChanges(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("pndImplementation.PendingChanges() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_pndImplementation_CommittedChanges(t *testing.T) { + tests := []struct { + name string + want []uuid.UUID + }{ + { + name: "default", + want: []uuid.UUID{cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pnd := newPnd() + pnd.changes.genericStore[cuid] = &Change{ + cuid: cuid, + committed: true, + } + if got := pnd.CommittedChanges(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("pndImplementation.CommittedChanges() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_pndImplementation_ConfirmedChanges(t *testing.T) { + tests := []struct { + name string + want []uuid.UUID + }{ + { + name: "default", + want: []uuid.UUID{cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pnd := newPnd() + pnd.changes.genericStore[cuid] = &Change{ + cuid: cuid, + committed: true, + confirmed: true, + } + if got := pnd.ConfirmedChanges(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("pndImplementation.ConfirmedChanges() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/nucleus/store.go b/nucleus/store.go index c9706303ae82c9080c6aad5de19a23c17f9b1ff7..631da74785b9a381530f485ee47f0931b1001e2f 100644 --- a/nucleus/store.go +++ b/nucleus/store.go @@ -5,6 +5,8 @@ import ( "reflect" "sync" + ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" + "code.fbi.h-da.de/cocsn/gosdn/interfaces/device" "code.fbi.h-da.de/cocsn/gosdn/interfaces/networkdomain" "code.fbi.h-da.de/cocsn/gosdn/interfaces/southbound" @@ -78,9 +80,6 @@ func (s genericStore) Get(id uuid.UUID) (store.Storable, error) { if !s.Exists(id) { return nil, &errors.ErrNotFound{ID: id} } - log.WithFields(log.Fields{ - "uuid": id, - }).Debug("storable was accessed") storeLock.RLock() defer storeLock.RUnlock() return s[id], nil @@ -316,3 +315,31 @@ func (s ChangeStore) GetChange(id uuid.UUID) (*Change, error) { }).Debug("change was accessed") return change, nil } + +// Pending returns the UUIDs of all pending changes +func (s ChangeStore) Pending() []uuid.UUID { + return filterChanges(s, ppb.Change_PENDING) +} + +// Committed returns the UUIDs of all pending changes +func (s ChangeStore) Committed() []uuid.UUID { + return filterChanges(s, ppb.Change_COMMITTED) +} + +// Confirmed returns the UUIDs of all pending changes +func (s ChangeStore) Confirmed() []uuid.UUID { + return filterChanges(s, ppb.Change_CONFIRMED) +} + +func filterChanges(store ChangeStore, state ppb.Change_State) []uuid.UUID { + changes := make([]uuid.UUID, 0) + for _, ch := range store.genericStore { + switch change := ch.(type) { + case *Change: + if change.State() == state { + changes = append(changes, change.cuid) + } + } + } + return changes +} diff --git a/nucleus/store_test.go b/nucleus/store_test.go index 8468a34d458e438b390bb25ce6a57602e32374b7..42f6a95d6677e31c8227c016fba9cfdcac986277 100644 --- a/nucleus/store_test.go +++ b/nucleus/store_test.go @@ -5,6 +5,8 @@ import ( "sort" "testing" + ppb "code.fbi.h-da.de/cocsn/api/go/gosdn/pnd" + "code.fbi.h-da.de/cocsn/gosdn/interfaces/device" "code.fbi.h-da.de/cocsn/gosdn/interfaces/networkdomain" "code.fbi.h-da.de/cocsn/gosdn/interfaces/southbound" @@ -419,3 +421,172 @@ func Test_deviceStore_get(t *testing.T) { }) } } + +func TestChangeStore_Pending(t *testing.T) { + type fields struct { + genericStore genericStore + } + tests := []struct { + name string + fields fields + want []uuid.UUID + }{ + { + name: "default", + fields: fields{ + genericStore: genericStore{ + cuid: &Change{cuid: cuid}, + }, + }, + want: []uuid.UUID{cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := ChangeStore{ + genericStore: tt.fields.genericStore, + } + if got := s.Pending(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Pending() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChangeStore_Committed(t *testing.T) { + type fields struct { + genericStore genericStore + } + tests := []struct { + name string + fields fields + want []uuid.UUID + }{ + { + name: "default", + fields: fields{ + genericStore: genericStore{ + cuid: &Change{ + cuid: cuid, + committed: true, + }, + }, + }, + want: []uuid.UUID{cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := ChangeStore{ + genericStore: tt.fields.genericStore, + } + if got := s.Committed(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Committed() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChangeStore_Confirmed(t *testing.T) { + type fields struct { + genericStore genericStore + } + tests := []struct { + name string + fields fields + want []uuid.UUID + }{ + { + name: "default", + fields: fields{ + genericStore: genericStore{ + cuid: &Change{ + cuid: cuid, + committed: true, + confirmed: true, + }, + }, + }, + want: []uuid.UUID{cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := ChangeStore{ + genericStore: tt.fields.genericStore, + } + if got := s.Confirmed(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Confirmed() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_filterChanges(t *testing.T) { + store := NewChangeStore() + pending := &Change{ + cuid: uuid.New(), + } + committed := &Change{ + cuid: uuid.New(), + committed: true, + } + confirmed := &Change{ + cuid: uuid.New(), + committed: true, + confirmed: true, + } + if err := store.Add(pending); err != nil { + t.Error(err) + return + } + if err := store.Add(committed); err != nil { + t.Error(err) + return + } + if err := store.Add(confirmed); err != nil { + t.Error(err) + return + } + type args struct { + store ChangeStore + state ppb.Change_State + } + tests := []struct { + name string + args args + want []uuid.UUID + }{ + { + name: "pending", + args: args{ + store: *store, + state: ppb.Change_PENDING, + }, + want: []uuid.UUID{pending.cuid}, + }, + { + name: "committed", + args: args{ + store: *store, + state: ppb.Change_COMMITTED, + }, + want: []uuid.UUID{committed.cuid}, + }, + { + name: "confirmed", + args: args{ + store: *store, + state: ppb.Change_CONFIRMED, + }, + want: []uuid.UUID{confirmed.cuid}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := filterChanges(tt.args.store, tt.args.state); !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterChanges() = %v, want %v", got, tt.want) + } + }) + } +}