diff --git a/examples/device/main.go b/examples/device/main.go new file mode 100644 index 0000000000000000000000000000000000000000..354817a18b59480ed24fa84b7ab234d601b938de --- /dev/null +++ b/examples/device/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "fmt" + "os" + + "github.com/bio-routing/bio-rd/protocols/device" + log "github.com/sirupsen/logrus" +) + +// Client is a device protocol client +type Client struct { +} + +// DeviceUpdate is a callback to get updated device information +func (c *Client) DeviceUpdate(d *device.Device) { + fmt.Printf("Device Update! %s\n", d.Name) + fmt.Printf("New State: %v\n", d.OperState) +} + +func main() { + s, err := device.New() + if err != nil { + log.Errorf("%v", err) + os.Exit(1) + } + + err = s.Start() + if err != nil { + log.Errorf("%v", err) + os.Exit(1) + } + + c := &Client{} + s.Subscribe(c, "virbr0") + + select {} +} diff --git a/protocols/device/device.go b/protocols/device/device.go new file mode 100644 index 0000000000000000000000000000000000000000..8939ca531a6530cb576aea4b93c25f7f14057327 --- /dev/null +++ b/protocols/device/device.go @@ -0,0 +1,67 @@ +package device + +import ( + "net" + "sync" + + bnet "github.com/bio-routing/bio-rd/net" +) + +// Device represents a network device +type Device struct { + Name string + Index uint64 + MTU uint16 + HardwareAddr net.HardwareAddr + Flags net.Flags + OperState uint8 + Addrs []bnet.Prefix + l sync.RWMutex +} + +func newDevice() *Device { + return &Device{ + Addrs: make([]bnet.Prefix, 0), + } +} + +func (d *Device) addAddr(pfx bnet.Prefix) { + d.l.Lock() + defer d.l.Unlock() + + d.Addrs = append(d.Addrs, pfx) +} + +func (d *Device) delAddr(del bnet.Prefix) { + d.l.Lock() + defer d.l.Unlock() + + for i, pfx := range d.Addrs { + if !pfx.Equal(del) { + continue + } + + d.Addrs = append(d.Addrs[:i], d.Addrs[i+1:]...) + } +} + +func (d *Device) copy() *Device { + d.l.RLock() + defer d.l.RUnlock() + + n := &Device{ + Name: d.Name, + Index: d.Index, + MTU: d.MTU, + Flags: d.Flags, + OperState: d.OperState, + Addrs: make([]bnet.Prefix, len(d.Addrs)), + } + + copy(n.HardwareAddr, d.HardwareAddr) + for i, a := range d.Addrs { + n.Addrs[i] = a + } + + return n +} diff --git a/protocols/device/device_linux.go b/protocols/device/device_linux.go new file mode 100644 index 0000000000000000000000000000000000000000..fc802a79f99373ff62f8f2ee02d660c0dc00dfa5 --- /dev/null +++ b/protocols/device/device_linux.go @@ -0,0 +1,22 @@ +package device + +import "github.com/vishvananda/netlink" + +func (d *Device) updateLink(attrs *netlink.LinkAttrs) { + d.l.Lock() + defer d.l.Unlock() + + d.MTU = uint16(attrs.MTU) + d.Name = attrs.Name + copy(d.HardwareAddr, attrs.HardwareAddr) + d.Flags = attrs.Flags + d.OperState = uint8(attrs.OperState) +} + +func (d *Device) notify(clients []Client) { + n := d.copy() + + for _, c := range clients { + c.DeviceUpdate(n) + } +} diff --git a/protocols/device/device_test.go b/protocols/device/device_test.go new file mode 100644 index 0000000000000000000000000000000000000000..59ee490be0fdb74b98d72772c73728fe2e615bb6 --- /dev/null +++ b/protocols/device/device_test.go @@ -0,0 +1,117 @@ +package device + +import ( + "testing" + + bnet "github.com/bio-routing/bio-rd/net" + "github.com/stretchr/testify/assert" +) + +func TestDeviceCopy(t *testing.T) { + tests := []struct { + name string + dev *Device + expected *Device + }{ + { + name: "Test #1", + dev: &Device{ + Name: "Foo", + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + }, + }, + expected: &Device{ + Name: "Foo", + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + }, + }, + }, + } + + for _, test := range tests { + copy := test.dev.copy() + test.dev.addAddr(bnet.NewPfx(bnet.IPv4(200), 8)) + assert.Equalf(t, test.expected, copy, "Test %q", test.name) + } +} + +func TestDeviceDelAddr(t *testing.T) { + tests := []struct { + name string + dev *Device + delete bnet.Prefix + expected *Device + }{ + { + name: "Test #1", + dev: &Device{ + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + bnet.NewPfx(bnet.IPv4(200), 8), + bnet.NewPfx(bnet.IPv4(300), 8), + }, + }, + delete: bnet.NewPfx(bnet.IPv4(200), 8), + expected: &Device{ + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + bnet.NewPfx(bnet.IPv4(300), 8), + }, + }, + }, + { + name: "Test #2", + dev: &Device{ + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + bnet.NewPfx(bnet.IPv4(200), 8), + bnet.NewPfx(bnet.IPv4(300), 8), + }, + }, + delete: bnet.NewPfx(bnet.IPv4(100), 8), + expected: &Device{ + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(200), 8), + bnet.NewPfx(bnet.IPv4(300), 8), + }, + }, + }, + } + + for _, test := range tests { + test.dev.delAddr(test.delete) + assert.Equalf(t, test.expected, test.dev, "Test %q", test.name) + } +} + +func TestDeviceAddAddr(t *testing.T) { + tests := []struct { + name string + dev *Device + input bnet.Prefix + expected *Device + }{ + { + name: "Test #1", + dev: &Device{ + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + }, + }, + input: bnet.NewPfx(bnet.IPv4(200), 8), + expected: &Device{ + Addrs: []bnet.Prefix{ + bnet.NewPfx(bnet.IPv4(100), 8), + bnet.NewPfx(bnet.IPv4(200), 8), + }, + }, + }, + } + + for _, test := range tests { + test.dev.addAddr(test.input) + assert.Equalf(t, test.expected, test.dev, "Test %q", test.name) + } +} diff --git a/protocols/device/server.go b/protocols/device/server.go new file mode 100644 index 0000000000000000000000000000000000000000..22eeb29d3914d9cf0ec100b4970ae576c406af1e --- /dev/null +++ b/protocols/device/server.go @@ -0,0 +1,118 @@ +package device + +import ( + "fmt" + "sync" +) + +// Server represents a device server +type Server struct { + devices map[uint64]*Device + devicesMu sync.RWMutex + clientsByDevice map[string][]Client + clientsByDeviceMu sync.RWMutex + osAdapter osAdapter + done chan struct{} +} + +// Client represents a client of the device server +type Client interface { + DeviceUpdate(*Device) +} + +type osAdapter interface { + start() error +} + +// New creates a new device server +func New() (*Server, error) { + srv := newWithAdapter(nil) + err := srv.loadAdapter() + if err != nil { + return nil, fmt.Errorf("Unable to create OS adapter: %v", err) + } + + return srv, nil +} + +func newWithAdapter(a osAdapter) *Server { + return &Server{ + devices: make(map[uint64]*Device), + clientsByDevice: make(map[string][]Client), + osAdapter: a, + done: make(chan struct{}), + } +} + +// Start starts the device server +func (ds *Server) Start() error { + err := ds.osAdapter.start() + if err != nil { + return fmt.Errorf("Unable to start osAdapter: %v", err) + } + + return nil +} + +// Stop stops the device server +func (ds *Server) Stop() { + close(ds.done) +} + +// Subscribe allows a client to subscribe for status updates on interface `devName` +func (ds *Server) Subscribe(client Client, devName string) { + d := ds.getLinkState(devName) + if d != nil { + client.DeviceUpdate(d) + } + + ds.clientsByDeviceMu.RLock() + defer ds.clientsByDeviceMu.RUnlock() + + if _, ok := ds.clientsByDevice[devName]; !ok { + ds.clientsByDevice[devName] = make([]Client, 0) + } + + ds.clientsByDevice[devName] = append(ds.clientsByDevice[devName], client) +} + +func (ds *Server) addDevice(d *Device) { + ds.devicesMu.Lock() + defer ds.devicesMu.Unlock() + + ds.devices[d.Index] = d +} + +func (ds *Server) delDevice(index uint64) { + delete(ds.devices, index) +} + +func (ds *Server) getLinkState(name string) *Device { + ds.devicesMu.RLock() + defer ds.devicesMu.RUnlock() + + for _, d := range ds.devices { + if d.Name != name { + continue + } + + return d.copy() + } + + return nil +} + +func (ds *Server) notify(index uint64) { + ds.clientsByDeviceMu.RLock() + defer ds.clientsByDeviceMu.RUnlock() + + for i, d := range ds.devices { + if i != index { + continue + } + + for _, c := range ds.clientsByDevice[d.Name] { + c.DeviceUpdate(d.copy()) + } + } +} diff --git a/protocols/device/server_darwin.go b/protocols/device/server_darwin.go new file mode 100644 index 0000000000000000000000000000000000000000..364843f7dd729f08a114d6e24e6b62a432024ec0 --- /dev/null +++ b/protocols/device/server_darwin.go @@ -0,0 +1,14 @@ +package device + +import "fmt" + +type osAdapterDarwin struct { +} + +func newOSAdapterDarwin(srv *Server) (*osAdapterDarwin, error) { + return nil, nil +} + +func (o *osAdapterDarwin) start() error { + return fmt.Errorf("Not implemented") +} diff --git a/protocols/device/server_linux.go b/protocols/device/server_linux.go new file mode 100644 index 0000000000000000000000000000000000000000..fa344d817aba04c01af0c433996c14b04c5ec729 --- /dev/null +++ b/protocols/device/server_linux.go @@ -0,0 +1,162 @@ +package device + +import ( + "fmt" + + bnet "github.com/bio-routing/bio-rd/net" + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" +) + +func (ds *Server) loadAdapter() error { + a, err := newOSAdapterLinux(ds) + if err != nil { + return fmt.Errorf("Unable to create linux adapter: %v", err) + } + + ds.osAdapter = a + return nil +} + +type osAdapterLinux struct { + srv *Server + handle *netlink.Handle + done chan struct{} +} + +func newOSAdapterLinux(srv *Server) (*osAdapterLinux, error) { + o := &osAdapterLinux{ + srv: srv, + } + + h, err := netlink.NewHandle() + if err != nil { + return nil, fmt.Errorf("Failed to create netlink handle: %v", err) + } + + o.handle = h + return o, nil +} + +func (o *osAdapterLinux) start() error { + chLU := make(chan netlink.LinkUpdate) + err := netlink.LinkSubscribe(chLU, o.done) + if err != nil { + return fmt.Errorf("Unable to subscribe for link updates: %v", err) + } + + chAU := make(chan netlink.AddrUpdate) + err = netlink.AddrSubscribe(chAU, o.done) + if err != nil { + return fmt.Errorf("Unable to subscribe for address updates: %v", err) + } + + err = o.init() + if err != nil { + return fmt.Errorf("Init failed: %v", err) + } + + go o.monitorLinks(chLU) + go o.monitorAddrs(chAU) + + return nil +} + +func (o *osAdapterLinux) init() error { + links, err := o.handle.LinkList() + if err != nil { + return fmt.Errorf("Unable to get links: %v", err) + } + + for _, l := range links { + d := linkUpdateToDevice(l.Attrs()) + + for _, f := range []int{4, 6} { + addrs, err := o.handle.AddrList(l, f) + if err != nil { + return fmt.Errorf("Unable to get addresses for interface %s: %v", d.Name, err) + } + + for _, addr := range addrs { + d.Addrs = append(d.Addrs, bnet.NewPfxFromIPNet(addr.IPNet)) + } + } + + o.srv.addDevice(d) + } + + return nil +} + +func (o *osAdapterLinux) monitorAddrs(chAU chan netlink.AddrUpdate) { + for { + select { + case <-o.done: + return + case au := <-chAU: + o.processAddrUpdate(&au) + } + } +} + +func (o *osAdapterLinux) monitorLinks(chLU chan netlink.LinkUpdate) { + for { + select { + case <-o.done: + return + case lu := <-chLU: + o.processLinkUpdate(&lu) + } + } + + return +} + +func linkUpdateToDevice(attrs *netlink.LinkAttrs) *Device { + return &Device{ + Index: uint64(attrs.Index), + MTU: uint16(attrs.MTU), + Name: attrs.Name, + HardwareAddr: attrs.HardwareAddr, + Flags: attrs.Flags, + OperState: uint8(attrs.OperState), + } +} + +func (o *osAdapterLinux) processAddrUpdate(au *netlink.AddrUpdate) { + o.srv.devicesMu.RLock() + defer o.srv.devicesMu.RUnlock() + + if _, ok := o.srv.devices[uint64(au.LinkIndex)]; !ok { + log.Warningf("Received address update for non existent device index %d", au.LinkIndex) + return + } + + d := o.srv.devices[uint64(au.LinkIndex)] + if au.NewAddr { + d.addAddr(bnet.NewPfxFromIPNet(&au.LinkAddress)) + return + } + + d.delAddr(bnet.NewPfxFromIPNet(&au.LinkAddress)) +} + +func (o *osAdapterLinux) processLinkUpdate(lu *netlink.LinkUpdate) { + attrs := lu.Attrs() + + o.srv.devicesMu.Lock() + defer o.srv.devicesMu.Unlock() + + if _, ok := o.srv.devices[uint64(attrs.Index)]; !ok { + d := newDevice() + d.Index = uint64(attrs.Index) + o.srv.addDevice(d) + } + + o.srv.devices[uint64(attrs.Index)].updateLink(attrs) + o.srv.notify(uint64(attrs.Index)) + if attrs.OperState == netlink.OperNotPresent { + o.srv.delDevice(uint64(attrs.Index)) + return + } +} diff --git a/protocols/device/server_test.go b/protocols/device/server_test.go new file mode 100644 index 0000000000000000000000000000000000000000..87a0f70e2bbaff47682b964e787a0865c560a08b --- /dev/null +++ b/protocols/device/server_test.go @@ -0,0 +1,119 @@ +package device + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockAdapter struct { + started bool + startFail bool +} + +func (m *mockAdapter) start() error { + m.started = true + if m.startFail { + return fmt.Errorf("Fail") + } + + return nil +} + +func (m *mockAdapter) loadAdapter() error { + return nil +} + +func TestStart(t *testing.T) { + tests := []struct { + name string + adapter *mockAdapter + wantFail bool + expected *mockAdapter + }{ + { + name: "Test with failure", + adapter: &mockAdapter{ + startFail: true, + }, + wantFail: true, + }, + { + name: "Test with success", + adapter: &mockAdapter{}, + wantFail: false, + expected: &mockAdapter{ + started: true, + }, + }, + } + + for _, test := range tests { + s := newWithAdapter(test.adapter) + err := s.Start() + if err != nil { + if test.wantFail { + continue + } + + t.Errorf("Unexpected failure for test %q: %v", test.name, err) + continue + } + + if test.wantFail { + t.Errorf("Unexpected success for test %q", test.name) + continue + } + + assert.Equalf(t, test.expected, test.adapter, "Test %q", test.name) + } +} + +func TestStop(t *testing.T) { + a := &mockAdapter{} + s := newWithAdapter(a) + s.Stop() + + // This will cause a timeout if channel was not closed + <-s.done +} + +type mockClient struct { + deviceUpdateCalled uint +} + +func (m *mockClient) DeviceUpdate(d *Device) { + m.deviceUpdateCalled++ +} + +func TestNotify(t *testing.T) { + mc := &mockClient{} + a := &mockAdapter{} + s := newWithAdapter(a) + + s.addDevice(&Device{ + Name: "eth0", + Index: 100, + }) + s.addDevice(&Device{ + Name: "eth1", + Index: 101, + }) + s.addDevice(&Device{ + Name: "eth2", + Index: 102, + }) + + s.Subscribe(mc, "eth1") + assert.Equal(t, uint(1), mc.deviceUpdateCalled) + s.notify(100) + assert.Equal(t, uint(1), mc.deviceUpdateCalled) + + s.notify(101) + assert.Equal(t, uint(2), mc.deviceUpdateCalled) + + s.delDevice(101) + s.notify(101) + assert.Equal(t, uint(2), mc.deviceUpdateCalled) +}