package proto

import (
	"os"
	"reflect"
	"strings"
	"testing"

	gpb "github.com/openconfig/gnmi/proto/gnmi"
	pb "google.golang.org/protobuf/proto"
)

func TestMain(m *testing.M) {
	gnmiPaths = map[string]pb.Message{
		"../../../test/proto/cap-resp-arista-ceos":                  &gpb.CapabilityResponse{},
		"../../../test/proto/req-full-node":                         &gpb.GetRequest{},
		"../../../test/proto/req-full-node-arista-ceos":             &gpb.GetRequest{},
		"../../../test/proto/req-interfaces-arista-ceos":            &gpb.GetRequest{},
		"../../../test/proto/req-interfaces-interface-arista-ceos":  &gpb.GetRequest{},
		"../../../test/proto/req-interfaces-wildcard":               &gpb.GetRequest{},
		"../../../test/proto/resp-full-node":                        &gpb.GetResponse{},
		"../../../test/proto/resp-full-node-arista-ceos":            &gpb.GetResponse{},
		"../../../test/proto/resp-interfaces-arista-ceos":           &gpb.GetResponse{},
		"../../../test/proto/resp-interfaces-interface-arista-ceos": &gpb.GetResponse{},
		"../../../test/proto/resp-interfaces-wildcard":              &gpb.GetResponse{},
		"../../../test/proto/resp-set-system-config-hostname":       &gpb.SetResponse{},
	}
	os.Exit(m.Run())
}

var gnmiPaths map[string]pb.Message

func TestRead(t *testing.T) {
	type args struct {
		filename string
		message  pb.Message
	}
	type test struct {
		name    string
		args    args
		want    reflect.Type
		wantErr bool
	}
	var tests []test
	for k, v := range gnmiPaths {
		name := strings.Split(k, "/")[5]
		tests = append(tests, test{
			name: name,
			args: args{
				filename: k,
				message:  v,
			},
			want:    reflect.TypeOf(v),
			wantErr: false,
		})
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := Read(tt.args.filename, tt.args.message); (err != nil) != tt.wantErr {
				t.Errorf("Read() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			got := reflect.TypeOf(tt.args.message)
			if got != tt.want {
				t.Errorf("Read() got Type %v, want Type %v", got, tt.want)
			}
		})
	}
}

func TestWrite(t *testing.T) {
	for k, v := range gnmiPaths {
		if err := Read(k, v); err != nil {
			t.Error(err)
		}
	}
	type args struct {
		message  pb.Message
		filename string
	}
	type test struct {
		name    string
		args    args
		want    pb.Message
		wantErr bool
	}
	var tests []test
	for k, v := range gnmiPaths {
		name := strings.Split(k, "/")[5]
		tests = append(tests, test{
			name: name,
			args: args{
				message:  v,
				filename: name + "_test",
			},
			want:    v,
			wantErr: false,
		})
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := Write(tt.args.message, tt.args.filename); (err != nil) != tt.wantErr {
				t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
			}
			var got pb.Message
			switch tt.want.(type) {
			case *gpb.GetResponse:
				got = &gpb.GetResponse{}
			case *gpb.GetRequest:
				got = &gpb.GetRequest{}
			case *gpb.SetResponse:
				got = &gpb.SetResponse{}
			case *gpb.SetRequest:
				got = &gpb.SetRequest{}
			case *gpb.CapabilityResponse:
				got = &gpb.CapabilityResponse{}
			default:
				t.Error("no test case for message type")
				return
			}
			err := Read(tt.args.filename, got)
			if err != nil {
				t.Error(err)
			}
			if reflect.DeepEqual(got, tt.want) {
				t.Errorf("Write() got %v, want %v", got, tt.want)
			}
		})
	}
}
