Skip to content
Snippets Groups Projects
change_test.go 3.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • Manuel Kieweg's avatar
    Manuel Kieweg committed
    package pnd
    
    import (
    	"context"
    	"github.com/google/uuid"
    	"github.com/openconfig/ygot/exampleoc"
    	"github.com/openconfig/ygot/ygot"
    	"os"
    	"reflect"
    	"sync"
    	"testing"
    	"time"
    )
    
    var commit = "commit"
    var rollback = "rollback"
    
    var commitDevice = &exampleoc.Device{
    	System: &exampleoc.System{
    		Hostname: &commit,
    	},
    }
    
    var rollbackDevice = &exampleoc.Device{
    	System: &exampleoc.System{
    		Hostname: &rollback,
    	},
    }
    
    func TestChange_Commit(t *testing.T) {
    	callback := make(chan string)
    	tests := []struct {
    		name    string
    		want    string
    		wantErr bool
    	}{
    		{
    			name: commit,
    			want:    commit,
    			wantErr: false,
    		},
    		{
    			name: rollback,
    			want:    rollback,
    			wantErr: false,
    		},
    	}
    	for _, tt := range tests {
    		t.Run(tt.name, func(t *testing.T) {
    			c := &Change{
    				cuid:          changeUUID,
    				duid:          did,
    				timestamp:     time.Now(),
    				previousState: rollbackDevice,
    				intendedState: commitDevice,
    				callback:      func(first ygot.GoStruct, second ygot.GoStruct) error {
    					hostname := *first.(*exampleoc.Device).System.Hostname
    					t.Logf("hostname: %v", hostname)
    					callback <- hostname
    					return nil
    				},
    				lock:          sync.RWMutex{},
    			}
    			go func() {
    				time.Sleep(time.Millisecond * 10)
    				if err := c.Commit(); (err != nil) != tt.wantErr {
    					t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr)
    				}
    				if tt.name == "rollback" {
    					timeout, err := time.ParseDuration(os.Getenv("GOSDN_CHANGE_TIMEOUT"))
    					if err != nil {
    						t.Error(err)
    					}
    					time.Sleep(timeout)
    				}
    			}()
    			var got string
    			switch tt.name {
    			case commit:
    				got = <-callback
    			case rollback:
    				_ = <-callback
    				got = <-callback
    			}
    			if !reflect.DeepEqual(got, tt.want) {
    				t.Errorf("Commit() = %v, want %v", got, tt.want)
    			}
    		})
    	}
    	close(callback)
    }
    
    func TestChange_Confirm(t *testing.T) {
    	_, cancel := context.WithCancel(context.Background())
    	type fields struct {
    		cuid          uuid.UUID
    		duid          uuid.UUID
    		timestamp     time.Time
    		previousState ygot.GoStruct
    		intendedState ygot.GoStruct
    		callback      func(ygot.GoStruct, ygot.GoStruct) error
    		committed     bool
    	}
    	tests := []struct {
    		name   string
    		fields fields
    		wantErr bool
    	}{
    		{
    			name: "committed",
    			fields: fields{
    				committed: true,
    			},
    			wantErr: false,
    		},
    		{
    			name:   "uncommitted",
    			fields: fields{},
    			wantErr: true,
    		},
    	}
    	for _, tt := range tests {
    		t.Run(tt.name, func(t *testing.T) {
    			c := &Change{
    				cuid:      tt.fields.cuid,
    				duid:      tt.fields.duid,
    				timestamp: tt.fields.timestamp,
    				previousState: &exampleoc.Device{
    					System: &exampleoc.System{
    						Hostname: &rollback,
    					},
    				},
    				intendedState: &exampleoc.Device{
    					System: &exampleoc.System{
    						Hostname: &commit,
    					},
    				},
    				callback:   tt.fields.callback,
    				committed:  tt.fields.committed,
    				cancelFunc: cancel,
    				lock:       sync.RWMutex{},
    			}
    			if err := c.Confirm(); (err != nil) != tt.wantErr {
    				t.Errorf("Confirm() error = %v, wantErr %v", err, tt.wantErr)
    			}
    		})
    	}
    	cancel()
    }
    
    func TestChange_ID(t *testing.T) {
    	type fields struct {
    		cuid uuid.UUID
    	}
    	tests := []struct {
    		name   string
    		fields fields
    		want   uuid.UUID
    	}{
    		{
    			name:   "default",
    			fields: fields{cuid: changeUUID},
    			want:   changeUUID,
    		},
    	}
    	for _, tt := range tests {
    		t.Run(tt.name, func(t *testing.T) {
    			c := &Change{
    				cuid: tt.fields.cuid,
    			}
    			if got := c.ID(); !reflect.DeepEqual(got, tt.want) {
    				t.Errorf("ID() = %v, want %v", got, tt.want)
    			}
    		})
    	}
    }