diff --git a/README.md b/README.md index e656b51..481664a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ discodns A DNS *fowarder* and *nameserver* that first queries an [etcd](http://github.com/coreos/etcd) database of domains and records. It forwards requests it's not authoritative for onto a configured set of upstream nameservers (Google DNS by default). -The authoritative domains are configured using the `-domain` argument to the server, which switches the server from a *forwarder* to a *nameserver* for that domain zone. For example, `-domain=discodns.net.` will mean any domain queries within the `discodns.net.` zone will be served from the local database. +The authoritative domains are configured using the `--domain` argument to the server, which switches the server from a *forwarder* to a *nameserver* for that domain zone. For example, `--domain=discodns.net.` will mean any domain queries within the `discodns.net.` zone will be served from the local database. #### Key Features @@ -49,7 +49,7 @@ It's as simple as launching the binary to start a DNS server listening on port 5 ````shell cd discodns/build/ -sudo ./bin/discodns -domain=discodns.net +sudo ./bin/discodns --domain=discodns.net --ns=8.8.8.8 --ns=8.8.4.4 ```` ### Try it out @@ -72,7 +72,7 @@ curl -L http://127.0.0.1:4001/v2/keys/net/discodns/.A/foobar -XPUT -d value="10. discodns.net. 0 IN A 10.1.1.1 ```` -#### Authority +### Authority If you're not familiar with the DNS specification, to support correct DNS Delegation using `NS` records, each top level domain needs to have it's own `SOA` record (stands for Start Of Authority) to asset it's authority. Since discodns can support multiple authoritative domains, it's up to you to enter this `SOA` record for each domain you use. Here's an example of creating this record for `discodns.net.`. diff --git a/main.go b/main.go index 0089019..f51be56 100644 --- a/main.go +++ b/main.go @@ -2,69 +2,71 @@ package main import ( "github.com/coreos/go-etcd/etcd" - "github.com/miekg/dns" + "github.com/jessevdk/go-flags" "runtime" "os/signal" "os" - "strings" "log" - "flag" "time" ) var ( logger = log.New(os.Stderr, "[discodns] ", log.Ldate|log.Ltime) log_debug = false + + // Define all of the command line arguments + Options struct { + ListenAddress string `short:"l" long:"listen" description:"Listen IP address" default:"0.0.0.0"` + ListenPort int `short:"p" long:"port" description:"Port to listen on" default:"53"` + EtcdHosts []string `short:"e" long:"etcd" description:"host:port for etcd hosts" default:"127.0.0.1:4001"` + Nameservers []string `short:"n" long:"ns" description:"Upstream nameservers for forwarding"` + Timeout string `short:"t" long:"ns-timeout" description:"Default forwarding timeout" default:"1s"` + Domain []string `short:"d" long:"domain" description:"Domain for this server to be authoritative over"` + Debug bool `short:"v" long:"debug" description:"Enable debug logging"` + } ) func main() { - var addr = flag.String("listen", "0.0.0.0", "Listen IP address") - var port = flag.Int("port", 53, "Port to listen on") - var hosts = flag.String("etcd", "0.0.0.0:4001", "List of etcd hosts (comma separated)") - var nameservers = flag.String("ns", "8.8.8.8,8.8.4.4", "Fallback nameservers (comma separated)") - var timeout = flag.String("ns-timeout", "1s", "Default nameserver timeout") - var domain = flag.String("domain", "discodns.local", "Constrain discodns to a domain") - var authority = flag.String("authority", "dns.discodns.local", "Authoritative DNS server hostname") - var debug = flag.Bool("debug", false, "Enable debug logging") - - flag.Parse() + _, err := flags.ParseArgs(&Options, os.Args[1:]) + if err != nil { + os.Exit(1) + } - if *debug { + if Options.Debug { log_debug = true debugMsg("Debug mode enabled") } - // Parse the list of nameservers - ns := strings.Split(*nameservers, ",") - // Parse the timeout string - nsTimeout, err := time.ParseDuration(*timeout) + nsTimeout, err := time.ParseDuration(Options.Timeout) if err != nil { - logger.Fatalf("Failed to parse duration '%s'", timeout) + logger.Fatalf("Failed to parse duration '%s'", Options.Timeout) } - // Create an ETCD client - etcd := etcd.NewClient(strings.Split(*hosts, ",")) + if len(Options.Nameservers) == 0 { + logger.Fatalf("Upstream nameservers are required with -n") + } + // Create an ETCD client + etcd := etcd.NewClient(Options.EtcdHosts) if !etcd.SyncCluster() { logger.Printf("[WARNING] Failed to connect to etcd cluster at launch time") } // Start up the DNS resolver server server := &Server{ - addr: *addr, - port: *port, + addr: Options.ListenAddress, + port: Options.ListenPort, etcd: etcd, rTimeout: nsTimeout, wTimeout: nsTimeout, - domain: dns.Fqdn(*domain), - authority: dns.Fqdn(*authority), - ns: ns} + domains: Options.Domain, + ns: Options.Nameservers} server.Run() - logger.Printf("Listening on %s:%d\n", *addr, *port) + logger.Printf("Listening on %s:%d\n", Options.ListenAddress, Options.ListenPort) sig := make(chan os.Signal) signal.Notify(sig, os.Interrupt) diff --git a/resolver.go b/resolver.go index 66a9f71..ea8b03e 100644 --- a/resolver.go +++ b/resolver.go @@ -14,8 +14,7 @@ import ( type Resolver struct { etcd *etcd.Client dns *dns.Client - domain string - authority string + domains []string nameservers []string rTimeout time.Duration } @@ -46,20 +45,37 @@ func (r *Resolver) GetFromStorage(key string) (nodes []*etcd.Node, err error) { return } -// Authority returns a dns.RR describing this authority (SOA) -func (r *Resolver) Authority() []dns.RR { - domain := dns.Fqdn(r.domain) - authority := &dns.SOA{Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 86400}, - Ns: dns.Fqdn(r.authority), - Mbox: domain, - Serial: uint32(time.Now().Truncate(time.Hour).Unix()), - Refresh: 10000, - Retry: 2400, - Expire: 604800, - Minttl: 60, +// Authority returns a dns.RR describing the know authority for the given +// domain. It will recurse up the domain structure to find an SOA record that +// matches. +func (r *Resolver) Authority(domain string) []dns.RR { + // tree := strings.Split(domain, ".") + // for i, _ := range tree { + // subdomain := strings.Join(tree[i:], ".") + // answers, _ := r.LookupAnswersForType(subdomain, dns.TypeSOA) + + // if len(answers) > 0 { + // for _, answer := range answers { + // answer.(*dns.SOA).Serial = uint32(time.Now().Truncate(time.Hour).Unix()) + // } + + // return answers + // } + // } + + return make([]dns.RR, 0) +} + +// IsAuthoritative will return true if this discodns server is authoritative +// for the given domain name. +func (r *Resolver) IsAuthoritative(name string) (bool) { + for _, domain := range r.domains { + if strings.HasSuffix(strings.ToLower(name), domain) { + return true + } } - return []dns.RR{authority} + return false } // Lookup responds to DNS messages of type Query, with a dns message containing Answers. @@ -84,6 +100,10 @@ func (r *Resolver) Lookup(req *dns.Msg) (msg *dns.Msg) { authorities := make(chan dns.RR) errors := make(chan error) + if q.Qclass == dns.ClassINET { + r.AnswerQuestion(req.RecursionDesired, answers, errors, authorities, q, &wait) + } + // Spawn a goroutine to close the channel as soon as all of the things // are done. This allows us to ensure we'll wait for all workers to finish // but allows us to collect up answers concurrently. @@ -96,10 +116,6 @@ func (r *Resolver) Lookup(req *dns.Msg) (msg *dns.Msg) { close(errors) }() - if q.Qclass == dns.ClassINET { - r.AnswerQuestion(req.RecursionDesired, answers, errors, authorities, q, &wait) - } - // Collect up all of the answers and any errors done := 0 for done < 3 { @@ -128,13 +144,12 @@ func (r *Resolver) Lookup(req *dns.Msg) (msg *dns.Msg) { // If we've not found any answers if len(msg.Answer) == 0 { - msg = nil - // msg.SetRcode(req, dns.RcodeNameError) + msg.SetRcode(req, dns.RcodeNameError) - // // If the domain query was within our authority, we need to send our SOA record - // if strings.HasSuffix(strings.ToLower(q.Name), r.domain) { - // msg.Ns = r.Authority() - // } + // If the domain query was within our authority, we need to send our SOA record + if r.IsAuthoritative(q.Name) { + msg.Ns = r.Authority(q.Name) + } } return @@ -236,7 +251,8 @@ forever: // has been completed. func (r *Resolver) AnswerQuestion(recurse bool, answers chan dns.RR, errors chan error, authorities chan dns.RR, q dns.Question, wg *sync.WaitGroup) { - if strings.HasSuffix(strings.ToLower(q.Name), r.domain) { + debugMsg("Answering question ", q) + if r.IsAuthoritative(q.Name) { if q.Qtype == dns.TypeANY { wg.Add(len(converters)) @@ -277,6 +293,7 @@ func (r *Resolver) AnswerQuestion(recurse bool, answers chan dns.RR, errors chan } else if recurse && q.Qtype != dns.TypeCNAME && q.Qtype != dns.TypeNS { // Check for any aliases + debugMsg("Looking for CNAME records at " + q.Name) cnames, err := r.LookupAnswersForType(q.Name, dns.TypeCNAME) if err != nil { errors <- err @@ -361,7 +378,6 @@ func (r *Resolver) AnswerQuestion(recurse bool, answers chan dns.RR, errors chan func (r *Resolver) LookupAnswersForType(name string, rrType uint16) (answers []dns.RR, err error) { name = strings.ToLower(name) - typeStr := dns.TypeToString[rrType] nodes, err := r.GetFromStorage(nameToKey(name, "/." + typeStr)) diff --git a/server.go b/server.go index db4bf69..7b41054 100644 --- a/server.go +++ b/server.go @@ -12,8 +12,7 @@ type Server struct { port int etcd *etcd.Client ns []string - domain string - authority string + domains []string rTimeout time.Duration wTimeout time.Duration } @@ -51,9 +50,8 @@ func (s *Server) Run() { DialTimeout: s.rTimeout, ReadTimeout: s.rTimeout, WriteTimeout: s.wTimeout}, - domain: s.domain, + domains: s.domains, nameservers: s.ns, - authority: s.authority, rTimeout: s.rTimeout, } } @@ -64,8 +62,6 @@ func (s *Server) Run() { udpHandler := dns.NewServeMux() tcpHandler := dns.NewServeMux() - // TODO(tarnfeld): Perhaps we could move up resolution of "." to here and - // specifically only call our handler for s.domain? tcpHandler.HandleFunc(".", tcpDNShandler.Handle) udpHandler.HandleFunc(".", udpDNShandler.Handle)