diff --git a/resolver.go b/resolver.go index 555def0..689ab19 100644 --- a/resolver.go +++ b/resolver.go @@ -23,6 +23,12 @@ func (e ResolvError) Error() string { return errmsg } +type RResp struct { + msg *dns.Msg + nameserver string + rtt time.Duration +} + type Resolver struct { servers []string domain_server *suffixTreeNode @@ -126,7 +132,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error qname := req.Question[0].Name - res := make(chan *dns.Msg, 1) + res := make(chan *RResp, 1) var wg sync.WaitGroup L := func(nameserver string) { defer wg.Done() @@ -145,11 +151,10 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error if r.Rcode == dns.RcodeServerFailure { return } - } else { - logger.Debug("%s resolv on %s (%s) ttl: %v", UnFqdn(qname), nameserver, net, rtt) } + re := &RResp{r, nameserver, rtt} select { - case res <- r: + case res <- re: default: } } @@ -163,9 +168,9 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error go L(nameserver) // but exit early, if we have an answer select { - case r := <-res: - // logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), nameserver, rtt) - return r, nil + case re := <-res: + logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), re.nameserver, re.rtt) + return re.msg, nil case <-ticker.C: continue } @@ -173,9 +178,9 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error // wait for all the namservers to finish wg.Wait() select { - case r := <-res: - // logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), nameserver, rtt) - return r, nil + case re := <-res: + logger.Debug("%s resolv on %s rtt: %v", UnFqdn(qname), re.nameserver, re.rtt) + return re.msg, nil default: return nil, ResolvError{qname, net, nameservers} } @@ -190,10 +195,12 @@ func (r *Resolver) Nameservers(qname string) []string { ns := []string{} if v, found := r.domain_server.search(queryKeys); found { - logger.Debug("found upstream: %v", v) + logger.Debug("%s be found in domain server list, upstream: %v", qname, v) server := v - nameserver := server + ":53" + nameserver := net.JoinHostPort(server, "53") ns = append(ns, nameserver) + //Ensure query the specific upstream nameserver in async Lookup() function. + return ns } for _, nameserver := range r.servers {