From 8607de0eea54f5f6503bacb2342815eb2a73edad Mon Sep 17 00:00:00 2001
From: Manuel Kieweg <mail@manuelkieweg.de>
Date: Mon, 15 Mar 2021 16:14:39 +0000
Subject: [PATCH] updated interface signature. first set integration test

---
 mocks/Transport.go             | 12 ++---
 nucleus/errors.go              |  9 ++++
 nucleus/gnmi_transport.go      | 37 ++++++++++++++-
 nucleus/gnmi_transport_test.go |  2 +-
 nucleus/integration_test.go    | 87 +++++++++++++++++++++++++++-------
 nucleus/transport.go           |  2 +-
 6 files changed, 119 insertions(+), 30 deletions(-)

diff --git a/mocks/Transport.go b/mocks/Transport.go
index 3d284039a..c83ad1949 100644
--- a/mocks/Transport.go
+++ b/mocks/Transport.go
@@ -60,18 +60,14 @@ func (_m *Transport) ProcessResponse(resp interface{}, root interface{}, models
 }
 
 // Set provides a mock function with given fields: ctx, params
-func (_m *Transport) Set(ctx context.Context, params ...string) (interface{}, error) {
-	_va := make([]interface{}, len(params))
-	for _i := range params {
-		_va[_i] = params[_i]
-	}
+func (_m *Transport) Set(ctx context.Context, params ...interface{}) (interface{}, error) {
 	var _ca []interface{}
 	_ca = append(_ca, ctx)
-	_ca = append(_ca, _va...)
+	_ca = append(_ca, params...)
 	ret := _m.Called(_ca...)
 
 	var r0 interface{}
-	if rf, ok := ret.Get(0).(func(context.Context, ...string) interface{}); ok {
+	if rf, ok := ret.Get(0).(func(context.Context, ...interface{}) interface{}); ok {
 		r0 = rf(ctx, params...)
 	} else {
 		if ret.Get(0) != nil {
@@ -80,7 +76,7 @@ func (_m *Transport) Set(ctx context.Context, params ...string) (interface{}, er
 	}
 
 	var r1 error
-	if rf, ok := ret.Get(1).(func(context.Context, ...string) error); ok {
+	if rf, ok := ret.Get(1).(func(context.Context, ...interface{}) error); ok {
 		r1 = rf(ctx, params...)
 	} else {
 		r1 = ret.Error(1)
diff --git a/nucleus/errors.go b/nucleus/errors.go
index 93097feda..9656b7327 100644
--- a/nucleus/errors.go
+++ b/nucleus/errors.go
@@ -47,3 +47,12 @@ type ErrNotYetImplemented struct{}
 func (e ErrNotYetImplemented) Error() string {
 	return fmt.Sprintf("function not yet implemented")
 }
+
+type ErrInvalidParameters struct {
+	f interface{}
+	r interface{}
+}
+
+func (e ErrInvalidParameters) Error() string {
+	return fmt.Sprintf("invalid parameters for %v: %v", e.f, e.r)
+}
\ No newline at end of file
diff --git a/nucleus/gnmi_transport.go b/nucleus/gnmi_transport.go
index 39fb4e6ba..1519721e2 100644
--- a/nucleus/gnmi_transport.go
+++ b/nucleus/gnmi_transport.go
@@ -47,11 +47,44 @@ func (g *Gnmi) Get(ctx context.Context, params ...string) (interface{}, error) {
 	paths := gnmi.SplitPaths(params)
 	return g.get(ctx, paths, "")
 }
-func (g *Gnmi) Set(ctx context.Context, params ...string) (interface{}, error) {
+
+// Set takes a slice of params. This slice must contain at least one operation.
+// It can contain an additional arbitrary amount of operations and extensions.
+func (g *Gnmi) Set(ctx context.Context, params ...interface{}) (interface{}, error) {
 	if g.client == nil {
 		return nil, &ErrNilClient{}
 	}
-	return nil, nil
+	if len(params) == 0 {
+		return nil, &ErrInvalidParameters{
+			f: "gnmi.Set()",
+			r: "no parameters provided",
+		}
+	}
+
+	// Loop over params and create ops and exts
+	// Invalid params cause unhealable error
+	ops := make([]*gnmi.Operation, 0)
+	exts := make([]*gnmi_ext.Extension, 0)
+	for _,p := range params{
+		switch p.(type) {
+		case *gnmi.Operation:
+			ops = append(ops, p.(*gnmi.Operation))
+		case *gnmi_ext.Extension:
+			exts = append(exts, p.(*gnmi_ext.Extension))
+		default:
+			return nil, &ErrInvalidParameters{
+				f: "gnmi.Set()",
+				r: "params contain invalid type",
+			}
+		}
+	}
+	if len(ops) == 0 {
+		return nil, &ErrInvalidParameters{
+			f: "gnmi.Set()",
+			r: "no operations provided",
+		}
+	}
+	return g.set(ctx, ops, exts...)
 }
 
 func (g *Gnmi) Subscribe(ctx context.Context, params ...string) error {
diff --git a/nucleus/gnmi_transport_test.go b/nucleus/gnmi_transport_test.go
index 5c1215d1d..d604ecc65 100644
--- a/nucleus/gnmi_transport_test.go
+++ b/nucleus/gnmi_transport_test.go
@@ -353,7 +353,7 @@ func TestGnmi_Set(t *testing.T) {
 		transport *Gnmi
 	}
 	type args struct {
-		params      []string
+		params      []interface{}
 		runEndpoint bool
 	}
 	tests := []struct {
diff --git a/nucleus/integration_test.go b/nucleus/integration_test.go
index caed6e9d9..b8633c1c0 100644
--- a/nucleus/integration_test.go
+++ b/nucleus/integration_test.go
@@ -31,22 +31,73 @@ func TestGnmi_SetIntegration(t *testing.T) {
 	if testing.Short() {
 		t.Skip("skipping integration test")
 	}
-	t.Run("Test GNMI Set", func(t *testing.T) {
-		transport, err := NewGnmiTransport(cfg)
-		if err != nil {
-			t.Error(err)
-		}
-		p := []string{"/interfaces/interface"}
-		resp, err := transport.Set(context.Background(), p...)
-		if err != nil {
-			t.Error(err)
-		}
-		if resp == nil {
-			t.Error("resp is nil")
-		}
-	})
+	type fields struct {
+		config *gnmi.Config
+	}
+	type args struct {
+		ctx    context.Context
+		params []interface{}
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		want    interface{}
+		wantErr bool
+	}{
+		{
+			name: "destination unreachable",
+			fields: fields{config: &gnmi.Config{
+				Addr: "203.0.113.10:6030",
+			},
+			},
+			args:    args{
+				ctx: context.Background(),
+				params: []interface{}{&gnmi.Operation{}},
+			},
+			want:    nil,
+			wantErr: true,
+		},
+		{
+			name:    "valid update",
+			fields:  fields{config: cfg},
+			args:    args{
+				ctx:    context.Background(),
+				params: []interface{}{
+					&gnmi.Operation{
+						Type:   "update",
+						Origin: "",
+						Target: "",
+						Path: []string{
+							"system",
+							"config",
+							"hostname",
+						},
+						Val: "ceos3000",
+					},
+				},
+			},
+			want:    &gpb.SetResponse{},
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			g, err := NewGnmiTransport(cfg)
+			if err != nil {
+				t.Error(err)
+			}
+			got, err := g.Set(tt.args.ctx, tt.args.params...)
+			if (err != nil) != tt.wantErr {
+				t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("Set() got = %v, want %v", got, tt.want)
+			}
+		})
+	}
 }
-
 func TestGnmi_GetIntegration(t *testing.T) {
 	if testing.Short() {
 		t.Skip("skipping integration test")
@@ -147,10 +198,10 @@ func TestGnmi_CapabilitiesIntegration(t *testing.T) {
 			wantErr: false,
 		},
 		{
-			name:	 "destination unreachable",
-			fields:  fields{config: &gnmi.Config{
+			name: "destination unreachable",
+			fields: fields{config: &gnmi.Config{
 				Addr: "203.0.113.10:6030",
-				},
+			},
 			},
 			args:    args{ctx: context.Background()},
 			want:    nil,
diff --git a/nucleus/transport.go b/nucleus/transport.go
index eb0aead92..5724cf86e 100644
--- a/nucleus/transport.go
+++ b/nucleus/transport.go
@@ -12,7 +12,7 @@ import (
 // or gnmi
 type Transport interface {
 	Get(ctx context.Context, params ...string) (interface{}, error)
-	Set(ctx context.Context, params ...string) (interface{}, error)
+	Set(ctx context.Context, params ...interface{}) (interface{}, error)
 	Subscribe(ctx context.Context, params ...string) error
 	Type() string
 	ProcessResponse(resp interface{}, root interface{}, models *ytypes.Schema) error
-- 
GitLab