package node import ( "encoding/binary" "fmt" "lux/host" "net" "net/netip" "strings" ) const LUX_DNS_TTL uint32 = 60 * 1 // 1 minute type LuxDnsServer struct { node *LuxNode stopChan chan bool } func NewLuxDnsServer(node *LuxNode) *LuxDnsServer { return &LuxDnsServer{ node: node, stopChan: make(chan bool), } } var NO = binary.BigEndian type dnsEntry struct { fullName string labels []string offset int dnsType uint16 dnsClass uint16 } // host // host.lux // host.wan.lux // host.re0.lux // service1.host.re0.lux type luxDomainInfo struct { hostname string netif string tld string } func (entry *dnsEntry) ParseDomainInfo() luxDomainInfo { switch len(entry.labels) { case 1: return luxDomainInfo{ hostname: entry.labels[0], netif: "wan", tld: "lux", } case 2: return luxDomainInfo{ hostname: entry.labels[0], netif: "wan", tld: entry.labels[1], } case 3: return luxDomainInfo{ hostname: entry.labels[0], netif: entry.labels[1], tld: entry.labels[2], } default: return luxDomainInfo{ hostname: entry.labels[len(entry.labels)-3], netif: entry.labels[len(entry.labels)-2], tld: entry.labels[len(entry.labels)-1], } } } const dnsTypeA = 1 const dnsTypeAAAA = 28 func (sv *LuxDnsServer) HandleRequest(req []byte) []byte { id := NO.Uint16(req[0:2]) flags := NO.Uint16(req[2:4]) qdcount := NO.Uint16(req[4:6]) ancount := NO.Uint16(req[6:8]) nscount := NO.Uint16(req[8:10]) arcount := NO.Uint16(req[10:12]) data := req[12:] qr := flags >> 15 opcode := (flags >> 11) & 0b1111 var rcode uint16 offset := 12 entries := make([]dnsEntry, qdcount) if qr != 0 || opcode != 0 { // throw not supported goto not_imp } // decode questions for i := 0; i < int(qdcount); i++ { thisOff := offset labels := make([]string, 0) // decode labels if req[offset]&0xC0 != 0 { // label is referencing previous label via offset labelOff := NO.Uint16(req[offset : offset+2]) found := false for _, entry := range entries { if entry.offset == int(labelOff) { labels = entry.labels found = true } } if !found { goto decode_error } // bump offset offset += 2 } else { // decode labels for int(req[offset]) != 0 { labelLen := int(req[offset]) offset++ labels = append(labels, string(req[offset:offset+labelLen])) offset += labelLen } // bump null label offset++ } dnsType := NO.Uint16(req[offset : offset+2]) offset += 2 dnsClass := NO.Uint16(req[offset : offset+2]) offset += 2 entries[i] = dnsEntry{ fullName: strings.Join(labels, "."), labels: labels, offset: thisOff, dnsType: dnsType, dnsClass: dnsClass, } } // start forming response. first, we just gonna put questions // from request into response fmt.Println(offset) data = req[12:offset] // answer questions for _, entry := range entries { if entry.dnsType != dnsTypeA && entry.dnsType != dnsTypeAAAA { continue // only A and AAAA are supported } info := entry.ParseDomainInfo() if info.tld != "lux" { goto refused } // find luxHost luxHost, ok := sv.node.GetHostByName(info.hostname) if !ok { goto not_found } // get IP by netif state := &luxHost.State var addr netip.Addr switch info.netif { case "wan": wan, ok := state.Options[host.LuxOptionTypeWAN] if !ok { // no WAN info goto not_found } if entry.dnsType == dnsTypeA { // IPv4 addr = wan.(*host.LuxOptionWAN).Addr4 } else { // IPv6 addr = wan.(*host.LuxOptionWAN).Addr6 } if addr.IsUnspecified() { goto not_found } default: // look for netif opt, ok := state.Options[host.LuxOptionTypeNetIf] if !ok { goto not_found } netif, ok := opt.(*host.LuxOptionNetIf).Interfaces[info.netif] if !ok { goto not_found } ipFound := false for _, netAddr := range netif.Addrs { if entry.dnsType == dnsTypeA { if netAddr.Type == host.LuxNetAddrType4 { addr = netAddr.Addr ipFound = true break } } else { // look either for IPv6 global unicast or local unicast. // but we prefer global over local unicast if netAddr.Type == host.LuxNetAddrType6GUA { addr = netAddr.Addr ipFound = true break } else if netAddr.Type == host.LuxNetAddrType6ULA { addr = netAddr.Addr ipFound = true // we dont break here since next addr may be global unicast, // which we prefer over ula } } } if !ipFound { goto not_found } } // write IP to response answer := make([]byte, 12) NO.PutUint16(answer[0:2], uint16(entry.offset)|0xC000) // name NO.PutUint16(answer[2:4], entry.dnsType) // type NO.PutUint16(answer[4:6], entry.dnsClass) // class NO.PutUint32(answer[6:10], LUX_DNS_TTL) // DNS TTL if addr.Is6() { NO.PutUint16(answer[10:12], 16) // Data Length - IPv6 octets := addr.As16() answer = append(answer, octets[:]...) } else { NO.PutUint16(answer[10:12], 4) // Data Length - IPv4 octets := addr.As4() answer = append(answer, octets[:]...) } data = append(data, answer...) ancount++ } rcode = 0 nscount = 0 arcount = 0 flags |= (1 << 10) // Authoritative Answer goto reply decode_error: rcode = 2 goto reply not_found: rcode = 3 goto reply not_imp: rcode = 4 goto reply refused: rcode = 5 reply: qr = 1 flags |= (qr << 15) | rcode hdr := make([]byte, 12) NO.PutUint16(hdr[0:2], id) NO.PutUint16(hdr[2:4], flags) NO.PutUint16(hdr[4:6], qdcount) NO.PutUint16(hdr[6:8], ancount) NO.PutUint16(hdr[8:10], nscount) NO.PutUint16(hdr[10:12], arcount) return append(hdr, data...) } type udpPacket struct { req []byte addr *net.UDPAddr } func (sv *LuxDnsServer) CreateFrontend(udpListen string) error { listenAddr, err := net.ResolveUDPAddr("udp", udpListen) if err != nil { return err } conn, err := net.ListenUDP("udp", listenAddr) if err != nil { return err } go func(sv *LuxDnsServer) { defer conn.Close() packetChan := make(chan udpPacket) go func() { buf := make([]byte, 512) for { n, addr, err := conn.ReadFromUDP(buf) if err != nil { log.Debugf("failed to recv dns udp: %v\n", err) return } packetChan <- udpPacket{ req: buf[:n], addr: addr, } } }() for { select { case <-sv.stopChan: return case packet := <-packetChan: res := sv.HandleRequest(packet.req) if len(res) > 0 { _, err := conn.WriteToUDP(res, packet.addr) if err != nil { log.Debugf("failed to send dns reply: %v\n", err) } } } } }(sv) return nil } func (sv *LuxDnsServer) Stop() { sv.stopChan <- true }