Skip to content
Snippets Groups Projects
Commit 29d24891 authored by takt's avatar takt Committed by Daniel Czerwonk
Browse files

Protocol Device (#147)

parent 99207738
No related branches found
No related tags found
No related merge requests found
package main
import (
"fmt"
"os"
"github.com/bio-routing/bio-rd/protocols/device"
log "github.com/sirupsen/logrus"
)
// Client is a device protocol client
type Client struct {
}
// DeviceUpdate is a callback to get updated device information
func (c *Client) DeviceUpdate(d *device.Device) {
fmt.Printf("Device Update! %s\n", d.Name)
fmt.Printf("New State: %v\n", d.OperState)
}
func main() {
s, err := device.New()
if err != nil {
log.Errorf("%v", err)
os.Exit(1)
}
err = s.Start()
if err != nil {
log.Errorf("%v", err)
os.Exit(1)
}
c := &Client{}
s.Subscribe(c, "virbr0")
select {}
}
package device
import (
"net"
"sync"
bnet "github.com/bio-routing/bio-rd/net"
)
// Device represents a network device
type Device struct {
Name string
Index uint64
MTU uint16
HardwareAddr net.HardwareAddr
Flags net.Flags
OperState uint8
Addrs []bnet.Prefix
l sync.RWMutex
}
func newDevice() *Device {
return &Device{
Addrs: make([]bnet.Prefix, 0),
}
}
func (d *Device) addAddr(pfx bnet.Prefix) {
d.l.Lock()
defer d.l.Unlock()
d.Addrs = append(d.Addrs, pfx)
}
func (d *Device) delAddr(del bnet.Prefix) {
d.l.Lock()
defer d.l.Unlock()
for i, pfx := range d.Addrs {
if !pfx.Equal(del) {
continue
}
d.Addrs = append(d.Addrs[:i], d.Addrs[i+1:]...)
}
}
func (d *Device) copy() *Device {
d.l.RLock()
defer d.l.RUnlock()
n := &Device{
Name: d.Name,
Index: d.Index,
MTU: d.MTU,
Flags: d.Flags,
OperState: d.OperState,
Addrs: make([]bnet.Prefix, len(d.Addrs)),
}
copy(n.HardwareAddr, d.HardwareAddr)
for i, a := range d.Addrs {
n.Addrs[i] = a
}
return n
}
package device
import "github.com/vishvananda/netlink"
func (d *Device) updateLink(attrs *netlink.LinkAttrs) {
d.l.Lock()
defer d.l.Unlock()
d.MTU = uint16(attrs.MTU)
d.Name = attrs.Name
copy(d.HardwareAddr, attrs.HardwareAddr)
d.Flags = attrs.Flags
d.OperState = uint8(attrs.OperState)
}
func (d *Device) notify(clients []Client) {
n := d.copy()
for _, c := range clients {
c.DeviceUpdate(n)
}
}
package device
import (
"testing"
bnet "github.com/bio-routing/bio-rd/net"
"github.com/stretchr/testify/assert"
)
func TestDeviceCopy(t *testing.T) {
tests := []struct {
name string
dev *Device
expected *Device
}{
{
name: "Test #1",
dev: &Device{
Name: "Foo",
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
},
},
expected: &Device{
Name: "Foo",
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
},
},
},
}
for _, test := range tests {
copy := test.dev.copy()
test.dev.addAddr(bnet.NewPfx(bnet.IPv4(200), 8))
assert.Equalf(t, test.expected, copy, "Test %q", test.name)
}
}
func TestDeviceDelAddr(t *testing.T) {
tests := []struct {
name string
dev *Device
delete bnet.Prefix
expected *Device
}{
{
name: "Test #1",
dev: &Device{
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
bnet.NewPfx(bnet.IPv4(200), 8),
bnet.NewPfx(bnet.IPv4(300), 8),
},
},
delete: bnet.NewPfx(bnet.IPv4(200), 8),
expected: &Device{
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
bnet.NewPfx(bnet.IPv4(300), 8),
},
},
},
{
name: "Test #2",
dev: &Device{
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
bnet.NewPfx(bnet.IPv4(200), 8),
bnet.NewPfx(bnet.IPv4(300), 8),
},
},
delete: bnet.NewPfx(bnet.IPv4(100), 8),
expected: &Device{
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(200), 8),
bnet.NewPfx(bnet.IPv4(300), 8),
},
},
},
}
for _, test := range tests {
test.dev.delAddr(test.delete)
assert.Equalf(t, test.expected, test.dev, "Test %q", test.name)
}
}
func TestDeviceAddAddr(t *testing.T) {
tests := []struct {
name string
dev *Device
input bnet.Prefix
expected *Device
}{
{
name: "Test #1",
dev: &Device{
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
},
},
input: bnet.NewPfx(bnet.IPv4(200), 8),
expected: &Device{
Addrs: []bnet.Prefix{
bnet.NewPfx(bnet.IPv4(100), 8),
bnet.NewPfx(bnet.IPv4(200), 8),
},
},
},
}
for _, test := range tests {
test.dev.addAddr(test.input)
assert.Equalf(t, test.expected, test.dev, "Test %q", test.name)
}
}
package device
import (
"fmt"
"sync"
)
// Server represents a device server
type Server struct {
devices map[uint64]*Device
devicesMu sync.RWMutex
clientsByDevice map[string][]Client
clientsByDeviceMu sync.RWMutex
osAdapter osAdapter
done chan struct{}
}
// Client represents a client of the device server
type Client interface {
DeviceUpdate(*Device)
}
type osAdapter interface {
start() error
}
// New creates a new device server
func New() (*Server, error) {
srv := newWithAdapter(nil)
err := srv.loadAdapter()
if err != nil {
return nil, fmt.Errorf("Unable to create OS adapter: %v", err)
}
return srv, nil
}
func newWithAdapter(a osAdapter) *Server {
return &Server{
devices: make(map[uint64]*Device),
clientsByDevice: make(map[string][]Client),
osAdapter: a,
done: make(chan struct{}),
}
}
// Start starts the device server
func (ds *Server) Start() error {
err := ds.osAdapter.start()
if err != nil {
return fmt.Errorf("Unable to start osAdapter: %v", err)
}
return nil
}
// Stop stops the device server
func (ds *Server) Stop() {
close(ds.done)
}
// Subscribe allows a client to subscribe for status updates on interface `devName`
func (ds *Server) Subscribe(client Client, devName string) {
d := ds.getLinkState(devName)
if d != nil {
client.DeviceUpdate(d)
}
ds.clientsByDeviceMu.RLock()
defer ds.clientsByDeviceMu.RUnlock()
if _, ok := ds.clientsByDevice[devName]; !ok {
ds.clientsByDevice[devName] = make([]Client, 0)
}
ds.clientsByDevice[devName] = append(ds.clientsByDevice[devName], client)
}
func (ds *Server) addDevice(d *Device) {
ds.devicesMu.Lock()
defer ds.devicesMu.Unlock()
ds.devices[d.Index] = d
}
func (ds *Server) delDevice(index uint64) {
delete(ds.devices, index)
}
func (ds *Server) getLinkState(name string) *Device {
ds.devicesMu.RLock()
defer ds.devicesMu.RUnlock()
for _, d := range ds.devices {
if d.Name != name {
continue
}
return d.copy()
}
return nil
}
func (ds *Server) notify(index uint64) {
ds.clientsByDeviceMu.RLock()
defer ds.clientsByDeviceMu.RUnlock()
for i, d := range ds.devices {
if i != index {
continue
}
for _, c := range ds.clientsByDevice[d.Name] {
c.DeviceUpdate(d.copy())
}
}
}
package device
import "fmt"
type osAdapterDarwin struct {
}
func newOSAdapterDarwin(srv *Server) (*osAdapterDarwin, error) {
return nil, nil
}
func (o *osAdapterDarwin) start() error {
return fmt.Errorf("Not implemented")
}
package device
import (
"fmt"
bnet "github.com/bio-routing/bio-rd/net"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
func (ds *Server) loadAdapter() error {
a, err := newOSAdapterLinux(ds)
if err != nil {
return fmt.Errorf("Unable to create linux adapter: %v", err)
}
ds.osAdapter = a
return nil
}
type osAdapterLinux struct {
srv *Server
handle *netlink.Handle
done chan struct{}
}
func newOSAdapterLinux(srv *Server) (*osAdapterLinux, error) {
o := &osAdapterLinux{
srv: srv,
}
h, err := netlink.NewHandle()
if err != nil {
return nil, fmt.Errorf("Failed to create netlink handle: %v", err)
}
o.handle = h
return o, nil
}
func (o *osAdapterLinux) start() error {
chLU := make(chan netlink.LinkUpdate)
err := netlink.LinkSubscribe(chLU, o.done)
if err != nil {
return fmt.Errorf("Unable to subscribe for link updates: %v", err)
}
chAU := make(chan netlink.AddrUpdate)
err = netlink.AddrSubscribe(chAU, o.done)
if err != nil {
return fmt.Errorf("Unable to subscribe for address updates: %v", err)
}
err = o.init()
if err != nil {
return fmt.Errorf("Init failed: %v", err)
}
go o.monitorLinks(chLU)
go o.monitorAddrs(chAU)
return nil
}
func (o *osAdapterLinux) init() error {
links, err := o.handle.LinkList()
if err != nil {
return fmt.Errorf("Unable to get links: %v", err)
}
for _, l := range links {
d := linkUpdateToDevice(l.Attrs())
for _, f := range []int{4, 6} {
addrs, err := o.handle.AddrList(l, f)
if err != nil {
return fmt.Errorf("Unable to get addresses for interface %s: %v", d.Name, err)
}
for _, addr := range addrs {
d.Addrs = append(d.Addrs, bnet.NewPfxFromIPNet(addr.IPNet))
}
}
o.srv.addDevice(d)
}
return nil
}
func (o *osAdapterLinux) monitorAddrs(chAU chan netlink.AddrUpdate) {
for {
select {
case <-o.done:
return
case au := <-chAU:
o.processAddrUpdate(&au)
}
}
}
func (o *osAdapterLinux) monitorLinks(chLU chan netlink.LinkUpdate) {
for {
select {
case <-o.done:
return
case lu := <-chLU:
o.processLinkUpdate(&lu)
}
}
return
}
func linkUpdateToDevice(attrs *netlink.LinkAttrs) *Device {
return &Device{
Index: uint64(attrs.Index),
MTU: uint16(attrs.MTU),
Name: attrs.Name,
HardwareAddr: attrs.HardwareAddr,
Flags: attrs.Flags,
OperState: uint8(attrs.OperState),
}
}
func (o *osAdapterLinux) processAddrUpdate(au *netlink.AddrUpdate) {
o.srv.devicesMu.RLock()
defer o.srv.devicesMu.RUnlock()
if _, ok := o.srv.devices[uint64(au.LinkIndex)]; !ok {
log.Warningf("Received address update for non existent device index %d", au.LinkIndex)
return
}
d := o.srv.devices[uint64(au.LinkIndex)]
if au.NewAddr {
d.addAddr(bnet.NewPfxFromIPNet(&au.LinkAddress))
return
}
d.delAddr(bnet.NewPfxFromIPNet(&au.LinkAddress))
}
func (o *osAdapterLinux) processLinkUpdate(lu *netlink.LinkUpdate) {
attrs := lu.Attrs()
o.srv.devicesMu.Lock()
defer o.srv.devicesMu.Unlock()
if _, ok := o.srv.devices[uint64(attrs.Index)]; !ok {
d := newDevice()
d.Index = uint64(attrs.Index)
o.srv.addDevice(d)
}
o.srv.devices[uint64(attrs.Index)].updateLink(attrs)
o.srv.notify(uint64(attrs.Index))
if attrs.OperState == netlink.OperNotPresent {
o.srv.delDevice(uint64(attrs.Index))
return
}
}
package device
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
type mockAdapter struct {
started bool
startFail bool
}
func (m *mockAdapter) start() error {
m.started = true
if m.startFail {
return fmt.Errorf("Fail")
}
return nil
}
func (m *mockAdapter) loadAdapter() error {
return nil
}
func TestStart(t *testing.T) {
tests := []struct {
name string
adapter *mockAdapter
wantFail bool
expected *mockAdapter
}{
{
name: "Test with failure",
adapter: &mockAdapter{
startFail: true,
},
wantFail: true,
},
{
name: "Test with success",
adapter: &mockAdapter{},
wantFail: false,
expected: &mockAdapter{
started: true,
},
},
}
for _, test := range tests {
s := newWithAdapter(test.adapter)
err := s.Start()
if err != nil {
if test.wantFail {
continue
}
t.Errorf("Unexpected failure for test %q: %v", test.name, err)
continue
}
if test.wantFail {
t.Errorf("Unexpected success for test %q", test.name)
continue
}
assert.Equalf(t, test.expected, test.adapter, "Test %q", test.name)
}
}
func TestStop(t *testing.T) {
a := &mockAdapter{}
s := newWithAdapter(a)
s.Stop()
// This will cause a timeout if channel was not closed
<-s.done
}
type mockClient struct {
deviceUpdateCalled uint
}
func (m *mockClient) DeviceUpdate(d *Device) {
m.deviceUpdateCalled++
}
func TestNotify(t *testing.T) {
mc := &mockClient{}
a := &mockAdapter{}
s := newWithAdapter(a)
s.addDevice(&Device{
Name: "eth0",
Index: 100,
})
s.addDevice(&Device{
Name: "eth1",
Index: 101,
})
s.addDevice(&Device{
Name: "eth2",
Index: 102,
})
s.Subscribe(mc, "eth1")
assert.Equal(t, uint(1), mc.deviceUpdateCalled)
s.notify(100)
assert.Equal(t, uint(1), mc.deviceUpdateCalled)
s.notify(101)
assert.Equal(t, uint(2), mc.deviceUpdateCalled)
s.delDevice(101)
s.notify(101)
assert.Equal(t, uint(2), mc.deviceUpdateCalled)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment