diff --git a/lib/http/client.go b/lib/http/client.go index f7b05c51..283ecbf8 100644 --- a/lib/http/client.go +++ b/lib/http/client.go @@ -534,8 +534,9 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { for { // For all but the first request, create the next // request hop and replace req. + loc := req.URL.String() if len(reqs) > 0 { - loc := resp.Header.Get("Location") + loc = resp.Header.Get("Location") if loc == "" { return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode)) } @@ -571,14 +572,6 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { req.Header.Set("Referer", ref) } - err = c.checkRedirect(req, resp, reqs) - - // Sentinel error to let users select the - // previous response, without closing its - // body. See Issue 10069. - if err == ErrUseLastResponse { - return resp, nil - } // Close the previous response's body. But // read at least some of the body so if it's @@ -590,16 +583,6 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) } resp.Body.Close() - - if err != nil { - // Special case for Go 1 compatibility: return both the response - // and an error if the CheckRedirect function failed. - // See https://golang.org/issue/3795 - // The resp.Body has already been closed. - ue := uerr(err) - ue.(*url.Error).URL = loc - return resp, ue - } } reqs = append(reqs, req) @@ -614,6 +597,22 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { } return nil, uerr(err) } + err = c.checkRedirect(req, resp, reqs) + + // Sentinel error to let users select the + // previous response, without closing its + // body. See Issue 10069. + if err == ErrUseLastResponse { + return resp, nil + } + if err != nil { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See https://golang.org/issue/3795 + ue := uerr(err) + ue.(*url.Error).URL = loc + return resp, err + } var shouldRedirect bool redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) diff --git a/modules/http/scanner.go b/modules/http/scanner.go index a124b306..6ebf5ed8 100644 --- a/modules/http/scanner.go +++ b/modules/http/scanner.go @@ -37,6 +37,8 @@ var ( // ErrTooManyRedirects is returned when the number of HTTP redirects exceeds // MaxRedirects. ErrTooManyRedirects = errors.New("Too many redirects") + + ErrDoNotRedirect = errors.New("No redirects configured") ) // Flags holds the command-line configuration for the HTTP scan module. @@ -388,6 +390,13 @@ func redirectsToLocalhost(host string) bool { // the redirectToLocalhost and MaxRedirects config func (scan *scan) getCheckRedirect() func(*http.Request, *http.Response, []*http.Request) error { return func(req *http.Request, res *http.Response, via []*http.Request) error { + if scan.scanner.config.MaxRedirects == 0 { + return ErrDoNotRedirect + } + //len-1 because otherwise we'll return a failure on 1 redirect when we specify only 1 redirect. I.e. we are 0 + if len(via)-1 > scan.scanner.config.MaxRedirects { + return ErrTooManyRedirects + } if !scan.scanner.config.FollowLocalhostRedirects && redirectsToLocalhost(req.URL.Hostname()) { return ErrRedirLocalhost } @@ -413,10 +422,6 @@ func (scan *scan) getCheckRedirect() func(*http.Request, *http.Response, []*http } } - if len(via) > scan.scanner.config.MaxRedirects { - return ErrTooManyRedirects - } - return nil } } @@ -529,6 +534,8 @@ func (scan *scan) Grab() *zgrab2.ScanError { } if err != nil { switch err { + case ErrDoNotRedirect: + break case ErrRedirLocalhost: break case ErrTooManyRedirects: