summaryrefslogtreecommitdiff
path: root/addr.go
blob: c99a94c91af2f4015e049ff7b858c9f32754b223 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package main

import (
	"net"
	"github.com/vishvananda/netlink"
	"fmt"
	"log"
	"io"
)

type AddrSet struct {
	linkAttrs	netlink.LinkAttrs
	linkChan	chan netlink.LinkUpdate
	addrChan	chan netlink.AddrUpdate

	addrs	map[string]net.IP
}

func (addrs *AddrSet) String() string {
	return fmt.Sprintf("AddrSet iface=%v", addrs.linkAttrs.Name)
}

func (addrs *AddrSet) testFlag(flag net.Flags) bool {
	return addrs.linkAttrs.Flags & flag != 0
}

func (addrs *AddrSet) Up() bool {
	return addrs.testFlag(net.FlagUp)
}

func InterfaceAddrs(iface string, family Family) (*AddrSet, error) {
	var addrs AddrSet

	link, err := netlink.LinkByName(iface)
	if err != nil {
		return nil, fmt.Errorf("netlink.LinkByName %v: %v", iface, err)
	} else {
		addrs.linkAttrs = *link.Attrs()
	}

	// list
	if addrList, err := netlink.AddrList(link, int(family)); err != nil {
		return nil, fmt.Errorf("netlink.AddrList %v: %v", link, err)
	} else {
		addrs.addrs = make(map[string]net.IP)

		for _, addr := range addrList {
			addrs.updateAddr(addr, true)
		}
	}

	// update
	addrs.linkChan = make(chan netlink.LinkUpdate)
	addrs.addrChan = make(chan netlink.AddrUpdate)

	if err := netlink.LinkSubscribe(addrs.linkChan, nil); err != nil {
		return nil, fmt.Errorf("netlink.LinkSubscribe: %v", err)
	}

	if err := netlink.AddrSubscribe(addrs.addrChan, nil); err != nil {
		return nil, fmt.Errorf("netlink.AddrSubscribe: %v", err)
	}

	return &addrs, nil
}

func (addrs *AddrSet) Read() error {
	for {
		select {
		case linkUpdate, ok := <-addrs.linkChan:
			if !ok {
				return io.EOF
			}

			linkAttrs := linkUpdate.Attrs()

			if linkAttrs.Index != addrs.linkAttrs.Index {
				continue
			}

			// update state
			addrs.updateLink(*linkAttrs)

		case addrUpdate, ok := <-addrs.addrChan:
			if !ok {
				return io.EOF
			}

			if addrUpdate.LinkIndex != addrs.linkAttrs.Index {
				continue
			}

			// XXX: scope and other filters?
			addrs.updateAddr(addrUpdate.Addr, addrUpdate.NewAddr)

			return nil
		}
	}
}

// Update state for address
func (addrs *AddrSet) updateAddr(addr netlink.Addr, up bool) {
	if addr.Scope >= int(netlink.SCOPE_LINK) {
		return
	}

	ip := addr.IP

	if up {
		log.Printf("%v: up %v", addrs, ip)

		addrs.addrs[ip.String()] = ip

	} else {
		log.Printf("%v: down %v", addrs, ip)

		delete(addrs.addrs, ip.String())
	}
}

func (addrs *AddrSet) updateLink(linkAttrs netlink.LinkAttrs) {
	addrs.linkAttrs = linkAttrs

	if !addrs.Up() {
		log.Printf("%v: down", addrs)
	}
}

func (addrs *AddrSet) Each(visitFunc func(net.IP)) {
	if !addrs.Up() {
		// link down has no up addrs
		return
	}

	for _, ip := range addrs.addrs {
		visitFunc(ip)
	}
}