Newer
Older
"net"
"net/url"
"path"
"sort"
"strings"
"text/template"
"libvirt.org/libvirt-go"
libvirtxml "libvirt.org/libvirt-go-xml"
"slrz.net/runtopo/topology"
)
// Runner implements the topology.Runner interface using libvirt/qemu.
type Runner struct {
conn *libvirt.Connect
devices map[string]*device
domains map[string]*libvirt.Domain
baseImages map[string]*libvirt.StorageVol
sshConfigOut io.Writer
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// fields below are immutable after initialization
uri string // libvirt connection URI
namePrefix string
tunnelIP net.IP
macBase net.HardwareAddr
portBase int
portGap int
storagePool string
authorizedKeys []string
}
// A RunnerOption may be passed to NewRunner to customize the Runner's
// behaviour.
type RunnerOption func(*Runner)
// WithConnectionURI sets the connection URI used to connect to libvirtd.
// Defaults to "qemu:///system".
func WithConnectionURI(uri string) RunnerOption {
return func(r *Runner) {
r.uri = uri
}
}
// WithNamePrefix configures the prefix to use when naming resources like guest
// domains. The default is "runtopo-".
func WithNamePrefix(prefix string) RunnerOption {
return func(r *Runner) {
r.namePrefix = prefix
}
}
// WithTunnelIP sets the default IP address for libvirt UDP tunnels. This is
// used only for devices that do not have an explicit address configured
// (tunnelip node attribute).
func WithTunnelIP(ip net.IP) RunnerOption {
return func(r *Runner) {
r.tunnelIP = ip
}
}
// WithMACAddressBase determines the starting address for automatically
// assigned MAC addresses. Explicitly configured MAC addresses
// (left_mac/right_mac edge attributes) are unaffected by this option.
func WithMACAddressBase(mac net.HardwareAddr) RunnerOption {
return func(r *Runner) {
r.macBase = mac
}
}
// WithPortBase specifies the starting port for allocating UDP tunnel ports.
func WithPortBase(port int) RunnerOption {
return func(r *Runner) {
r.portBase = port
}
}
// WithPortGap sets the gap left between local and remote port. It limits
// the maximum number of connections supported in a topology.
func WithPortGap(delta int) RunnerOption {
return func(r *Runner) {
r.portGap = delta
}
}
// WithStoragePool sets the libvirt storage pool where we create volumes.
func WithStoragePool(pool string) RunnerOption {
return func(r *Runner) {
r.storagePool = pool
}
}
// WithAuthorizedKeys adds the provided SSH public keys to authorized_keys for
// all started VMs.
func WithAuthorizedKeys(keys ...string) RunnerOption {
return func(r *Runner) {
r.authorizedKeys = keys
}
}
// WriteSSHConfig configures the Runner to write an OpenSSH client
// configuration file to w. See ssh_config(5) for a description of its format.
func WriteSSHConfig(w io.Writer) RunnerOption {
return func(r *Runner) {
r.sshConfigOut = w
}
}
// WithConfigFS specifies a filesystem implementation for loading config
// snippets requested with the node attribute config.
func WithConfigFS(fsys fs.FS) RunnerOption {
return func(r *Runner) {
r.configFS = fsys
}
}
// NewRunner constructs a runner configured with the specified options.
func NewRunner(opts ...RunnerOption) *Runner {
r := &Runner{
uri: "qemu:///system",
namePrefix: "runtopo-",
tunnelIP: net.IPv4(127, 0, 0, 1),
// BUG(ls): The default MAC address range matches the one used
// by topology_converter. It belongs to Cumulus though and we
macBase: mustParseMAC("44:38:39:00:00:00"),
portBase: 1e4,
portGap: 1e3,
storagePool: "default",
devices: make(map[string]*device),
domains: make(map[string]*libvirt.Domain),
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
}
for _, opt := range opts {
opt(r)
}
return r
}
// Run starts up the topology described by t.
func (r *Runner) Run(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("libvirt.(*Runner).Run: %w", err)
}
}()
if n := len(r.macBase); n != 6 {
return fmt.Errorf("got base MAC of len %d, want len 6", n)
}
if err := r.buildInventory(t); err != nil {
return err
}
c, err := libvirt.NewConnect(r.uri)
if err != nil {
return err
}
r.conn = c
defer func() {
if err != nil {
c.Close()
r.conn = nil
}
}()
if err := r.downloadBaseImages(ctx, t); err != nil {
defer func() {
if err != nil {
for _, v := range r.baseImages {
v.Free()
}
r.baseImages = nil
}
}()
if err := r.createVolumes(ctx, t); err != nil {
defer func() {
if err != nil {
r.deleteVolumes(ctx, t)
if err := r.defineDomains(ctx, t); err != nil {
return err
}
defer func() {
if err != nil {
r.undefineDomains(ctx, t)
for _, d := range r.domains {
d.Free()
}
r.domains = nil
}
}()
if err := r.customizeDomains(ctx, t); err != nil {
return err
}
if err := r.startDomains(ctx, t); err != nil {
return err
}
if r.sshConfigOut != nil {
// Caller asked us to write out an ssh_config.
if err := r.writeSSHConfig(ctx, t); err != nil {
return err
}
}
// Destroy destroys any resources created by a previous Run invocation. Destroy
// may be called on a different Runner instance than Run as long as the
// instance was created using the same same of RunnerOptions.
func (r *Runner) Destroy(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("libvirt.(*Runner).Destroy: %w", err)
}
}()
if err := r.buildInventory(t); err != nil {
return err
}
if r.conn == nil {
c, err := libvirt.NewConnect(r.uri)
if err != nil {
return err
}
r.conn = c
}
if err := r.undefineDomains(ctx, t); err != nil {
return err
}
for _, d := range r.domains {
d.Free()
}
r.domains = nil
if err := r.deleteVolumes(ctx, t); err != nil {
for _, v := range r.baseImages {
v.Free()
}
r.baseImages = nil
return nil
}
func (r *Runner) buildInventory(t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("buildInventory: %w", err)
}
}()
var macInt uint64
for _, b := range r.macBase {
macInt = macInt<<8 | uint64(b)
}
allocateMAC := func() net.HardwareAddr {
mac := macAddrFromUint64(macInt)
macInt++
return mac
}
for _, topoDev := range t.Devices() {
if topoDev.Function() == topology.Fake {
continue
}
tunnelIP := r.tunnelIP
if s := topoDev.Attr("tunnelip"); s != "" {
if tunnelIP = net.ParseIP(s); tunnelIP == nil {
return fmt.Errorf(
"device %s: cannot parse tunnelip %q",
topoDev.Name, s)
}
}
var config []byte
if file := topoDev.Attr("config"); file != "" && r.configFS != nil {
p, err := fs.ReadFile(r.configFS, file)
if err != nil {
return fmt.Errorf("device %s: %w",
topoDev.Name, err)
}
config = p
}
name: r.namePrefix + topoDev.Name,
tunnelIP: tunnelIP,
pool: r.storagePool,
config: config,
topoDev: topoDev,
}
}
nextPort := uint(r.portBase)
for _, l := range t.Links() {
fromTunnelIP := r.tunnelIP
fromTunnelIP = from.tunnelIP
mac, hasMAC := l.FromMAC()
if !hasMAC {
mac = allocateMAC()
}
if (l.From == "oob-mgmt-server" || l.From == "oob-mgmt-switch") &&
l.To == "" {
// XXX
from.interfaces = append(from.interfaces, iface{
name: l.FromPort,
mac: mac,
network: "default",
})
continue
}
toTunnelIP := r.tunnelIP
if to := r.devices[l.To]; to != nil {
toTunnelIP = to.tunnelIP
}
name: l.FromPort,
mac: mac,
port: nextPort,
localPort: nextPort + uint(r.portGap),
remoteTunnelIP: toTunnelIP,
})
}
if to := r.devices[l.To]; to != nil {
mac, hasMAC := l.ToMAC()
if !hasMAC {
mac = allocateMAC()
}
to.interfaces = append(to.interfaces, iface{
name: l.ToPort,
mac: mac,
port: nextPort + uint(r.portGap),
localPort: nextPort,
remoteTunnelIP: fromTunnelIP,
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
})
}
nextPort++
}
for _, d := range r.devices {
sort.Slice(d.interfaces, func(i, j int) bool {
di, dj := d.interfaces[i], d.interfaces[j]
if di.name == "eth0" && dj.name != "eth0" {
return true
}
return natCompare(di.name, dj.name) < 0
})
}
return nil
}
func (r *Runner) downloadBaseImages(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("downloadBaseImages: %w", err)
}
}()
pool, err := r.conn.LookupStoragePoolByName(r.storagePool)
if err != nil {
return err
}
defer pool.Free()
wantImages := make(map[string]struct{})
haveImages := make(map[string]*libvirt.StorageVol)
for _, d := range r.devices {
u, err := url.Parse(d.topoDev.OSImage())
if err != nil {
return err
}
vol, err := pool.LookupStorageVolByName(path.Base(u.Path))
if err == nil {
// skip over already present volumes
haveImages[d.topoDev.OSImage()] = vol
continue
}
wantImages[d.topoDev.OSImage()] = struct{}{}
}
type result struct {
vol *libvirt.StorageVol
url string
err error
}
ch := make(chan result)
fetchCtx, cancel := context.WithCancel(ctx)
defer cancel()
numStarted := 0
for sourceURL := range wantImages {
sourceURL := sourceURL
go func() {
vol, err := createVolumeFromURL(fetchCtx, r.conn, pool, sourceURL)
if err != nil {
ch <- result{err: err, url: sourceURL}
return
}
ch <- result{vol: vol, url: sourceURL}
}()
numStarted++
}
for i := 0; i < numStarted; i++ {
res := <-ch
if res.err == nil {
haveImages[res.url] = res.vol
continue
}
if res.err != nil {
cancel() // tell other goroutines to shut down
if err == nil {
err = res.err
}
}
}
if err != nil {
for _, v := range haveImages {
v.Free()
}
return err
}
r.baseImages = haveImages
return nil
}
func (r *Runner) createVolumes(ctx context.Context, t *topology.T) (err error) {
var created []*libvirt.StorageVol
defer func() {
for _, vol := range created {
if err != nil {
vol.Free()
}
if err != nil {
err = fmt.Errorf("createVolumes: %w", err)
}
}()
pool, err := r.conn.LookupStoragePoolByName(r.storagePool)
if err != nil {
return err
}
defer pool.Free()
for _, d := range r.devices {
var backing *libvirtxml.StorageVolumeBackingStore
var baseInfo *libvirt.StorageVolInfo
if osImage := d.topoDev.OSImage(); osImage != "none" {
base := r.baseImages[osImage]
if base == nil {
// we should've failed earlier already
panic("unexpected missing base image: " +
osImage)
}
baseInfo, err = base.GetInfo()
if err != nil {
return fmt.Errorf("get-info: %w (bvol: %s)",
err, osImage)
}
backing, err = newBackingStoreFromVol(base)
if err != nil {
return err
}
}
xmlVol := newVolume(d.name, int64(baseInfo.Capacity))
xmlVol.BackingStore = backing
xmlStr, err := xmlVol.Marshal()
if err != nil {
return err
}
vol, err := pool.StorageVolCreateXML(xmlStr, 0)
if err != nil {
return fmt.Errorf("vol-create: %w", err)
}
created = append(created, vol)
d.pool = r.storagePool
}
return nil
}
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
// DeleteVolumes deletes any created volumes from the libvirt storage pool
// r.storagePool.
func (r *Runner) deleteVolumes(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("deleteVolumes: %w", err)
}
}()
pool, err := r.conn.LookupStoragePoolByName(r.storagePool)
if err != nil {
return err
}
defer pool.Free()
for _, d := range r.devices {
v, lerr := pool.LookupStorageVolByName(d.name)
if lerr != nil {
continue
}
_ = v.Delete(0)
v.Free()
}
return nil
}
func (r *Runner) defineDomains(ctx context.Context, t *topology.T) (err error) {
var defined []*libvirt.Domain
defer func() {
if err != nil {
for _, dom := range defined {
dom.Undefine()
dom.Free()
}
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
err = fmt.Errorf("defineDomains: %w", err)
}
}()
tmpl, err := template.New("").
Funcs(templateFuncs).
Parse(domainTemplateText)
if err != nil {
return err
}
var buf bytes.Buffer
for _, d := range r.devices {
if err := tmpl.Execute(&buf, d.templateArgs()); err != nil {
return fmt.Errorf("domain %s: %w", d.name, err)
}
domXML := buf.String()
buf.Reset()
dom, err := r.conn.DomainDefineXMLFlags(
domXML, libvirt.DOMAIN_DEFINE_VALIDATE)
if err != nil {
return fmt.Errorf("define domain %s: %w", d.name, err)
}
defined = append(defined, dom)
r.domains[d.name] = dom
}
return nil
}
// UndefineDomains deletes any created libvirt guest domains, shutting them
// down if running.
func (r *Runner) undefineDomains(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("undefineDomains: %w", err)
}
}()
for _, d := range r.devices {
dom, lerr := r.conn.LookupDomainByName(d.name)
if lerr != nil {
continue
}
_ = dom.Destroy()
_ = dom.Undefine()
dom.Free()
}
return nil
}
func (r *Runner) customizeDomains(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("customizeDomains: %w", err)
}
}()
var buf bytes.Buffer
ch := make(chan error)
numStarted := 0
customizeCtx, cancel := context.WithCancel(ctx)
defer cancel()
for _, d := range r.devices {
user := "root"
if hasCumulusFunction(d) {
fmt.Fprintf(&buf, "write /etc/ptm.d/topology.dot:%s\n",
bytes.Replace(t.DOT(), []byte("\n"),
[]byte("\\\n"), -1))
}
for _, k := range r.authorizedKeys {
fmt.Fprintf(&buf, "ssh-inject %s:string:%s\n", user, k)
if user != "root" {
fmt.Fprintf(&buf, "ssh-inject root:string:%s\n", k)
}
if d.topoDev.Function() == topology.OOBServer {
hosts := gatherHosts(ctx, r, t)
for _, h := range hosts {
fmt.Fprintf(&buf, "append-line /etc/hosts:%s %s\n",
h.ip, h.name)
dnsmasqHosts := generateDnsmasqHostsFile(hosts)
fmt.Fprintf(&buf, "write /etc/dnsmasq.hostsfile:%s\n",
bytes.Replace(dnsmasqHosts, []byte("\n"),
[]byte("\\\n"), -1))
}
extra := strings.NewReader(buf.String())
buf.Reset()
d := d
go func() {
ch <- customizeDomain(customizeCtx, r.uri, d, extra)
}()
numStarted++
}
for i := 0; i < numStarted; i++ {
res := <-ch
if res != nil {
cancel() // tell other goroutines to shut down
if err == nil {
err = res
}
}
}
return err
}
func (r *Runner) startDomains(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("startDomains: %w", err)
}
}()
ds := t.Devices()
sort.Slice(ds, func(i, j int) bool {
return ds[i].Function() < ds[j].Function()
})
var started []*libvirt.Domain
defer func() {
if err != nil {
for _, d := range started {
d.Destroy()
}
}
}()
for _, d := range ds {
if d.Function() == topology.Fake {
continue
}
dom := r.domains[r.namePrefix+d.Name]
if err := dom.Create(); err != nil {
return fmt.Errorf("domain %s: %w",
r.namePrefix+d.Name, err)
}
started = append(started, dom)
}
return nil
}
// WriteSSHConfig genererates an OpenSSH client config and writes it to r.sshConfigOut.
func (r *Runner) writeSSHConfig(ctx context.Context, t *topology.T) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("writeSSHConfig: %w", err)
}
}()
// Retrieve the mgmt server's DHCP lease as we're going to use
// it as a jump host.
ip, err := waitForLease(ctx, r.domains[r.namePrefix+"oob-mgmt-server"])
if err != nil {
return err
}
w := bufio.NewWriter(r.sshConfigOut)
fmt.Fprintf(w, `Host oob-mgmt-server
Hostname %s
User root
UserKnownHostsFile /dev/null
StrictHostKeyChecking no
`, ip)
for _, d := range t.Devices() {
if topology.HasFunction(&d, topology.OOBServer, topology.OOBSwitch) {
continue
}
user := "root"
if hasCumulusFunction(&device{topoDev: d}) {
user = "cumulus"
}
fmt.Fprintf(w, `Host %s
User %s
ProxyJump oob-mgmt-server
UserKnownHostsFile /dev/null
StrictHostKeyChecking no
`, d.Name, user)
}
return w.Flush()
}
// internal representation for a device
type device struct {
name string
tunnelIP net.IP
interfaces []iface
pool string
topoDev topology.Device
}
func (d *device) templateArgs() *domainTemplateArgs {
args := &domainTemplateArgs{
Name: d.name,
VCPUs: d.topoDev.VCPUs(),
Memory: d.topoDev.Memory() >> 10, // libvirt wants KiB
Pool: d.pool,
PXEBoot: false, // set below if enabled for an interface
}
for _, intf := range d.interfaces {
typ := "udp"
netSrc, udpSrc := intf.network, udpSource{
Address: intf.remoteTunnelIP.String(),
Port: intf.port,
LocalAddress: d.tunnelIP.String(),
LocalPort: intf.localPort,
}
if intf.network != "" {
typ = "network"
}
args.Interfaces = append(args.Interfaces, domainInterface{
Type: typ,
MACAddr: intf.mac.String(),
TargetDev: intf.name,
Model: "virtio",
if intf.pxe {
args.PXEBoot = true
}
}
return args
}
// internal representation for an interface
type iface struct {
name string
mac net.HardwareAddr
network string
port uint
localPort uint
remoteTunnelIP net.IP