lux/node/lux_dns.go
2025-01-29 06:17:14 +02:00

356 lines
6.7 KiB
Go

package node
import (
"encoding/binary"
"fmt"
"lux/host"
"net"
"net/netip"
"strings"
)
const LUX_DNS_TTL uint32 = 60 * 1 // 1 minute
type LuxDnsServer struct {
node *LuxNode
stopChan chan bool
}
func NewLuxDnsServer(node *LuxNode) *LuxDnsServer {
return &LuxDnsServer{
node: node,
stopChan: make(chan bool),
}
}
var NO = binary.BigEndian
type dnsEntry struct {
fullName string
labels []string
offset int
dnsType uint16
dnsClass uint16
}
// host
// host.lux
// host.wan.lux
// host.re0.lux
// service1.host.re0.lux
type luxDomainInfo struct {
hostname string
netif string
tld string
}
func (entry *dnsEntry) ParseDomainInfo() luxDomainInfo {
switch len(entry.labels) {
case 1:
return luxDomainInfo{
hostname: entry.labels[0],
netif: "wan",
tld: "lux",
}
case 2:
return luxDomainInfo{
hostname: entry.labels[0],
netif: "wan",
tld: entry.labels[1],
}
case 3:
return luxDomainInfo{
hostname: entry.labels[0],
netif: entry.labels[1],
tld: entry.labels[2],
}
default:
return luxDomainInfo{
hostname: entry.labels[len(entry.labels)-3],
netif: entry.labels[len(entry.labels)-2],
tld: entry.labels[len(entry.labels)-1],
}
}
}
const dnsTypeA = 1
const dnsTypeAAAA = 28
func (sv *LuxDnsServer) HandleRequest(req []byte) []byte {
id := NO.Uint16(req[0:2])
flags := NO.Uint16(req[2:4])
qdcount := NO.Uint16(req[4:6])
ancount := NO.Uint16(req[6:8])
nscount := NO.Uint16(req[8:10])
arcount := NO.Uint16(req[10:12])
data := req[12:]
qr := flags >> 15
opcode := (flags >> 11) & 0b1111
var rcode uint16
offset := 12
entries := make([]dnsEntry, qdcount)
if qr != 0 || opcode != 0 {
// throw not supported
goto not_imp
}
// decode questions
for i := 0; i < int(qdcount); i++ {
thisOff := offset
labels := make([]string, 0)
// decode labels
if req[offset]&0xC0 != 0 {
// label is referencing previous label via offset
labelOff := NO.Uint16(req[offset : offset+2])
found := false
for _, entry := range entries {
if entry.offset == int(labelOff) {
labels = entry.labels
found = true
}
}
if !found {
goto decode_error
}
// bump offset
offset += 2
} else {
// decode labels
for int(req[offset]) != 0 {
labelLen := int(req[offset])
offset++
labels = append(labels, string(req[offset:offset+labelLen]))
offset += labelLen
}
// bump null label
offset++
}
dnsType := NO.Uint16(req[offset : offset+2])
offset += 2
dnsClass := NO.Uint16(req[offset : offset+2])
offset += 2
entries[i] = dnsEntry{
fullName: strings.Join(labels, "."),
labels: labels,
offset: thisOff,
dnsType: dnsType,
dnsClass: dnsClass,
}
}
// start forming response. first, we just gonna put questions
// from request into response
fmt.Println(offset)
data = req[12:offset]
// answer questions
for _, entry := range entries {
if entry.dnsType != dnsTypeA && entry.dnsType != dnsTypeAAAA {
continue // only A and AAAA are supported
}
info := entry.ParseDomainInfo()
if info.tld != "lux" {
goto refused
}
// find luxHost
luxHost, ok := sv.node.GetHostByName(info.hostname)
if !ok {
goto not_found
}
// get IP by netif
state := &luxHost.State
var addr netip.Addr
switch info.netif {
case "wan":
wan, ok := state.Options[host.LuxOptionTypeWAN]
if !ok {
// no WAN info
goto not_found
}
if entry.dnsType == dnsTypeA {
// IPv4
addr = wan.(*host.LuxOptionWAN).Addr4
} else {
// IPv6
addr = wan.(*host.LuxOptionWAN).Addr6
}
if addr.IsUnspecified() {
goto not_found
}
default:
// look for netif
opt, ok := state.Options[host.LuxOptionTypeNetIf]
if !ok {
goto not_found
}
netif, ok := opt.(*host.LuxOptionNetIf).Interfaces[info.netif]
if !ok {
goto not_found
}
ipFound := false
for _, netAddr := range netif.Addrs {
if entry.dnsType == dnsTypeA {
if netAddr.Type == host.LuxNetAddrType4 {
addr = netAddr.Addr
ipFound = true
break
}
} else {
// look either for IPv6 global unicast or local unicast.
// but we prefer global over local unicast
if netAddr.Type == host.LuxNetAddrType6GUA {
addr = netAddr.Addr
ipFound = true
break
} else if netAddr.Type == host.LuxNetAddrType6ULA {
addr = netAddr.Addr
ipFound = true
// we dont break here since next addr may be global unicast,
// which we prefer over ula
}
}
}
if !ipFound {
goto not_found
}
}
// write IP to response
answer := make([]byte, 12)
NO.PutUint16(answer[0:2], uint16(entry.offset)|0xC000) // name
NO.PutUint16(answer[2:4], entry.dnsType) // type
NO.PutUint16(answer[4:6], entry.dnsClass) // class
NO.PutUint32(answer[6:10], LUX_DNS_TTL) // DNS TTL
if addr.Is6() {
NO.PutUint16(answer[10:12], 16) // Data Length - IPv6
octets := addr.As16()
answer = append(answer, octets[:]...)
} else {
NO.PutUint16(answer[10:12], 4) // Data Length - IPv4
octets := addr.As4()
answer = append(answer, octets[:]...)
}
data = append(data, answer...)
ancount++
}
rcode = 0
nscount = 0
arcount = 0
flags |= (1 << 10) // Authoritative Answer
goto reply
decode_error:
rcode = 2
goto reply
not_found:
rcode = 3
goto reply
not_imp:
rcode = 4
goto reply
refused:
rcode = 5
reply:
qr = 1
flags |= (qr << 15) | rcode
hdr := make([]byte, 12)
NO.PutUint16(hdr[0:2], id)
NO.PutUint16(hdr[2:4], flags)
NO.PutUint16(hdr[4:6], qdcount)
NO.PutUint16(hdr[6:8], ancount)
NO.PutUint16(hdr[8:10], nscount)
NO.PutUint16(hdr[10:12], arcount)
return append(hdr, data...)
}
type udpPacket struct {
req []byte
addr *net.UDPAddr
}
func (sv *LuxDnsServer) CreateFrontend(udpListen string) error {
listenAddr, err := net.ResolveUDPAddr("udp", udpListen)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", listenAddr)
if err != nil {
return err
}
go func(sv *LuxDnsServer) {
defer conn.Close()
packetChan := make(chan udpPacket)
go func() {
buf := make([]byte, 512)
for {
n, addr, err := conn.ReadFromUDP(buf)
if err != nil {
log.Debugf("failed to recv dns udp: %v\n", err)
return
}
packetChan <- udpPacket{
req: buf[:n],
addr: addr,
}
}
}()
for {
select {
case <-sv.stopChan:
return
case packet := <-packetChan:
res := sv.HandleRequest(packet.req)
if len(res) > 0 {
_, err := conn.WriteToUDP(res, packet.addr)
if err != nil {
log.Debugf("failed to send dns reply: %v\n", err)
}
}
}
}
}(sv)
return nil
}
func (sv *LuxDnsServer) Stop() {
sv.stopChan <- true
}