diff --git a/client.go b/client.go index 92a5b07bbb..8f02190e6f 100644 --- a/client.go +++ b/client.go @@ -163,6 +163,11 @@ type Client struct { // User-Agent header to be excluded from the Request. NoDefaultUserAgentHeader bool + // CookieJar stores cookies allowing user to handle cookies easily. + // + // If CookieJar is nil no cookie will be collected. + CookieJar *CookieJar + // Callback for establishing new connections to hosts. // // Default Dial is used if not set. @@ -410,6 +415,7 @@ func (c *Client) Do(req *Request, resp *Response) error { Name: c.Name, NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, Dial: c.Dial, + CookieJar: c.CookieJar, DialDualStack: c.DialDualStack, IsTLS: isTLS, TLSConfig: c.TLSConfig, @@ -525,6 +531,10 @@ type HostClient struct { // Default Dial is used if not set. Dial DialFunc + // CookieJar stores cookies. If CookieJar is nil + // no cookie will be collected. + CookieJar *CookieJar + // Attempt to connect to both ipv4 and ipv6 host addresses // if set to true. // @@ -1101,8 +1111,16 @@ func (c *HostClient) do(req *Request, resp *Response) (bool, error) { resp = AcquireResponse() } + if c.CookieJar != nil { + c.CookieJar.dumpTo(req) + } + ok, err := c.doNonNilReqResp(req, resp) + if c.CookieJar != nil { + c.CookieJar.getFrom(req.Host(), req.URI().Path(), resp) + } + if nilResp { ReleaseResponse(resp) } diff --git a/client_test.go b/client_test.go index 90d1c63128..c55d9a8057 100644 --- a/client_test.go +++ b/client_test.go @@ -50,6 +50,69 @@ func TestClientPostArgs(t *testing.T) { } } +func TestClientCookieJar(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + key, value := "cometo", "thefasthttpcon" + key2, value2 := "hello", "world" + key3, value3 := "coo", "kie" + s := &Server{ + Handler: func(ctx *RequestCtx) { + cookie := AcquireCookie() + cookie.SetKey(key) + cookie.SetValue(value) + ctx.Response.Header.SetCookie(cookie) + + cookie2 := AcquireCookie() + cookie2.SetKey(key2) + cookie2.SetValue(value2) + cookie2.SetExpire(CookieExpireDelete) + ctx.Response.Header.SetCookie(cookie2) + + cookie3:= AcquireCookie() + cookie3.SetKey(key3) + cookie3.SetValue(value3) + cookie3.SetPath("/hello/") + ctx.Response.Header.SetCookie(cookie3) + }, + } + go s.Serve(ln) + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + CookieJar: &CookieJar{}, + } + req := AcquireRequest() + res := AcquireResponse() + req.SetRequestURI("http://fasthttp.con/hello/world") + err := c.Do(req, res) + if err != nil { + t.Fatal(err) + } + + uri := AcquireURI() + uri.SetHost("fasthttp.con") + + cs := c.CookieJar.Get(uri) + if len(cs) != 2 { + t.Fatalf("Unexpected len: %d. Expected: %d", len(cs), 2) + } + + if k := string(cs[0].Key()); k != key { + t.Fatalf("Unexpected key: %s <> %s", k, key) + } + if v := string(cs[0].Value()); v != value { + t.Fatalf("Unexpected value: %s <> %s", v, value) + } + + if k := string(cs[1].Key()); k != key3 { + t.Fatalf("Unexpected key: %s <> %s", k, key3) + } + if v := string(cs[1].Value()); v != value3 { + t.Fatalf("Unexpected value: %s <> %s", v, value3) + } +} + func TestClientRedirectSameSchema(t *testing.T) { listenHTTPS1 := testClientRedirectListener(t, true) diff --git a/cookie.go b/cookie.go index 1d9861d9ea..8137643c24 100644 --- a/cookie.go +++ b/cookie.go @@ -21,6 +21,7 @@ var ( // CookieSameSite is an enum for the mode in which the SameSite flag should be set for the given cookie. // See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. type CookieSameSite int + const ( // CookieSameSiteDisabled removes the SameSite flag CookieSameSiteDisabled CookieSameSite = iota diff --git a/cookiejar.go b/cookiejar.go new file mode 100644 index 0000000000..6646912b9d --- /dev/null +++ b/cookiejar.go @@ -0,0 +1,204 @@ +package fasthttp + +import ( + "bytes" + "net" + "sync" + "time" +) + +// CookieJar manages cookie storage +type CookieJar struct { + m sync.Mutex + hostCookies map[string][]*Cookie +} + +// Get returns the cookies stored from a specific domain. +// +// If there were no cookies related with host returned slice will be nil. +// +// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely. +func (cj *CookieJar) Get(uri *URI) (cookies []*Cookie) { + if uri != nil { + cookies = cj.get(uri.Host(), uri.Path()) + } + return +} + +func (cj *CookieJar) get(host, path []byte) (rcs []*Cookie) { + if cj.hostCookies == nil { + return + } + + var ( + err error + cookies []*Cookie + hostStr = b2s(host) + ) + // port must not be included. + hostStr, _, err = net.SplitHostPort(hostStr) + if err != nil { + hostStr = b2s(host) + } + // get cookies deleting expired ones + cookies = cj.getCookies(hostStr) + + rcs = make([]*Cookie, 0, len(cookies)) + for i := 0; i < len(cookies); i++ { + cookie := cookies[i] + if len(path) > 1 && len(cookie.path) > 1 && !bytes.HasPrefix(cookie.Path(), path) { + continue + } + rcs = append(rcs, cookie) + } + + return +} + +// getCookies returns a cookie slice releasing expired cookies +func (cj *CookieJar) getCookies(hostStr string) (cookies []*Cookie) { + cj.m.Lock() + defer cj.m.Unlock() + + cookies = cj.hostCookies[hostStr] + var ( + t = time.Now() + n = len(cookies) + ) + for i := 0; i < len(cookies); i++ { + c := cookies[i] + if !c.Expire().Equal(CookieExpireUnlimited) && c.Expire().Before(t) { // cookie expired + cookies = append(cookies[:i], cookies[i+1:]...) + ReleaseCookie(c) + i-- + } + } + // has any cookie been deleted? + if n > len(cookies) { + cj.hostCookies[hostStr] = cookies + } + return +} + +// Set sets cookies for a specific host. +// +// The host is get from uri.Host(). +// +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) Set(uri *URI, cookies ...*Cookie) { + if uri != nil { + cj.set(uri.Host(), cookies...) + } +} + +// SetByHost sets cookies for a specific host. +// +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) SetByHost(host []byte, cookies ...*Cookie) { + cj.set(host, cookies...) +} + +func (cj *CookieJar) set(host []byte, cookies ...*Cookie) { + hostStr := b2s(host) + + cj.m.Lock() + defer cj.m.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*Cookie) + } + hcs, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exists in the map then + // we must make a copy for the key. + hostStr = string(host) + } + for _, cookie := range cookies { + c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hcs) + if c == nil { + c = AcquireCookie() + hcs = append(hcs, c) + } + c.CopyTo(cookie) + } + cj.hostCookies[hostStr] = hcs +} + +// SetKeyValue sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValue(host, key, value string) { + cj.SetKeyValueBytes(host, s2b(key), s2b(value)) +} + +// SetKeyValueBytes sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { + cj.setKeyValue(host, key, value) +} + +func (cj *CookieJar) setKeyValue(host string, key, value []byte) { + c := AcquireCookie() + c.SetKeyBytes(key) + c.SetValueBytes(value) + cj.set(s2b(host), c) +} + +func (cj *CookieJar) dumpTo(req *Request) { + uri := req.URI() + cookies := cj.get(uri.Host(), uri.Path()) + for _, cookie := range cookies { + req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) + } +} + +func (cj *CookieJar) getFrom(host, path []byte, res *Response) { + hostStr := b2s(host) + + cj.m.Lock() + defer cj.m.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*Cookie) + } + cookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exists in the map then + // we must make a copy for the key. + hostStr = string(host) + } + t := time.Now() + res.Header.VisitAllCookie(func(key, value []byte) { + created := false + c := searchCookieByKeyAndPath(key, path, cookies) + if c == nil { + c, created = AcquireCookie(), true + } + c.ParseBytes(value) + if c.Expire().Equal(CookieExpireUnlimited) || c.Expire().After(t) { + cookies = append(cookies, c) + } else if created { + ReleaseCookie(c) + } + }) + cj.hostCookies[hostStr] = cookies +} + +func searchCookieByKeyAndPath(key, path []byte, cookies []*Cookie) (cookie *Cookie) { + for _, c := range cookies { + if bytes.Equal(key, c.Key()) { + if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) { + cookie = c + break + } + } + } + return +} diff --git a/cookiejar_test.go b/cookiejar_test.go new file mode 100644 index 0000000000..6b57c042da --- /dev/null +++ b/cookiejar_test.go @@ -0,0 +1,236 @@ +package fasthttp + +import ( + "bytes" + "testing" + "time" +) + +func checkKeyValue(t *testing.T, cj *CookieJar, cookie *Cookie, uri *URI, n int) { + cs := cj.Get(uri) + if len(cs) < n { + t.Fatalf("Unexpected cookie length: %d. Expected %d", len(cs), n) + } + c := cs[n-1] + if c == nil { + t.Fatal("got a nil cookie") + } + if string(c.Key()) != string(cookie.Key()) { + t.Fatalf("key mismatch: %s <> %s", c.Key(), cookie.Key()) + } + if string(c.Value()) != string(cookie.Value()) { + t.Fatalf("value mismatch: %s <> %s", c.Value(), cookie.Value()) + } +} + +func TestCookieJarGet(t *testing.T) { + url := []byte("http://fasthttp.com/") + url1 := []byte("http://fasthttp.com/make") + url11 := []byte("http://fasthttp.com/hola") + url2 := []byte("http://fasthttp.com/make/fasthttp") + url3 := []byte("http://fasthttp.com/make/fasthttp/great") + prefix := []byte("/") + prefix1 := []byte("/make") + prefix2 := []byte("/make/fasthttp") + prefix3 := []byte("/make/fasthttp/great") + cj := &CookieJar{} + + c1 := &Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetPath("/make/") + + c2 := &Cookie{} + c2.SetKey("kk") + c2.SetValue("vv") + c2.SetPath("/make/fasthttp") + + c3 := &Cookie{} + c3.SetKey("kkk") + c3.SetValue("vvv") + c3.SetPath("/make/fasthttp/great") + + uri := AcquireURI() + uri.Parse(nil, url) + + uri1 := AcquireURI() + uri1.Parse(nil, url1) + + uri11 := AcquireURI() + uri11.Parse(nil, url11) + + uri2 := AcquireURI() + uri2.Parse(nil, url2) + + uri3 := AcquireURI() + uri3.Parse(nil, url3) + + cj.Get(uri1) + cj.Get(uri11) + cj.Get(uri2) + cj.Get(uri3) + + cj.Set(uri1, c1, c2, c3) + + cookies := cj.Get(uri1) + if len(cookies) != 3 { + t.Fatalf("Unexpected len. Expected %d. Got %d", 3, len(cookies)) + } + for _, cookie := range cookies { + if !bytes.HasPrefix(cookie.Path(), prefix1) { + t.Fatalf("prefix mismatch: %s<>%s", cookie.Path(), prefix1) + } + } + + cookies = cj.Get(uri11) + if len(cookies) != 0 { + t.Fatalf("Unexpected len. Expected %d. Got %d", 0, len(cookies)) + } + + cookies = cj.Get(uri2) + if len(cookies) != 2 { + t.Fatalf("Unexpected len. Expected %d. Got %d", 2, len(cookies)) + } + for _, cookie := range cookies { + if !bytes.HasPrefix(cookie.Path(), prefix2) { + t.Fatalf("prefix mismatch: %s<>%s", cookie.Path(), prefix2) + } + } + + cookies = cj.Get(uri3) + if len(cookies) != 1 { + t.Fatalf("Unexpected len. Expected %d. Got %d: %v", 1, len(cookies), cookies) + } + for _, cookie := range cookies { + if !bytes.HasPrefix(cookie.Path(), prefix3) { + t.Fatalf("prefix mismatch: %s<>%s", cookie.Path(), prefix3) + } + } + + cookies = cj.Get(uri) + if len(cookies) != 3 { + t.Fatalf("Unexpected len. Expected %d. Got %d", 3, len(cookies)) + } + for _, cookie := range cookies { + if !bytes.HasPrefix(cookie.Path(), prefix) { + t.Fatalf("prefix mismatch: %s<>%s", cookie.Path(), prefix) + } + } +} + +func TestCookieJarGetExpired(t *testing.T) { + url1 := []byte("http://fasthttp.com/make/") + uri1 := AcquireURI() + uri1.Parse(nil, url1) + + c1 := &Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetExpire(time.Now().Add(-time.Hour)) + + cj := &CookieJar{} + cj.Set(uri1, c1) + + cookies := cj.Get(uri1) + if len(cookies) != 0 { + t.Fatalf("unexpected cookie get result. Expected %d. Got %d", 0, len(cookies)) + } +} + +func TestCookieJarSet(t *testing.T) { + url := []byte("http://fasthttp.com/hello/world") + cj := &CookieJar{} + + cookie := &Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + uri := AcquireURI() + uri.Parse(nil, url) + + cj.Set(uri, cookie) + checkKeyValue(t, cj, cookie, uri, 1) +} + +func TestCookieJarSetRepeatedCookieKeys(t *testing.T) { + host := "fast.http" + cj := &CookieJar{} + + uri := AcquireURI() + uri.SetHost(host) + + cookie := &Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + cookie2 := &Cookie{} + cookie2.SetKey("k") + cookie2.SetValue("v2") + + cookie3 := &Cookie{} + cookie3.SetKey("key") + cookie3.SetValue("value") + + cj.Set(uri, cookie, cookie2, cookie3) + + cookies := cj.Get(uri) + if len(cookies) != 2 { + t.Fatalf("error getting cookies. Expected %d. Got %d", 2, len(cookies)) + } + if cookies[0] == cookie2 { + t.Fatalf("Unexpected cookie (%s)", cookies[0]) + } + if !bytes.Equal(cookies[0].Value(), cookie2.Value()) { + t.Fatalf("Unexpected cookie value. Expected %s. Got %s", cookies[0].Value(), cookie2.Value()) + } +} + +func TestCookieJarSetKeyValue(t *testing.T) { + host := "fast.http" + cj := &CookieJar{} + + uri := AcquireURI() + uri.SetHost(host) + + cj.SetKeyValue(host, "k", "v") + cj.SetKeyValue(host, "key", "value") + cj.SetKeyValue(host, "k", "vv") + cj.SetKeyValue(host, "key", "value2") + + cookies := cj.Get(uri) + if len(cookies) != 2 { + t.Fatalf("error getting cookies. Expected %d. Got %d: %v", 2, len(cookies), cookies) + } +} + +func TestCookieJarGetFromResponse(t *testing.T) { + res := AcquireResponse() + host := []byte("fast.http") + uri := AcquireURI() + uri.SetHostBytes(host) + + c := &Cookie{} + c.SetKey("key") + c.SetValue("val") + + c2 := &Cookie{} + c2.SetKey("k") + c2.SetValue("v") + + c3 := &Cookie{} + c3.SetKey("kk") + c3.SetValue("vv") + + res.Header.SetStatusCode(200) + res.Header.SetCookie(c) + res.Header.SetCookie(c2) + res.Header.SetCookie(c3) + + cj := &CookieJar{} + cj.getFrom(host, nil, res) + + cookies := cj.Get(uri) + if len(cookies) != 3 { + t.Fatalf("error cookies length. Expected %d. Got %d", 3, len(cookies)) + } +} diff --git a/header.go b/header.go index 3176f94600..3e4b90c42b 100644 --- a/header.go +++ b/header.go @@ -1253,9 +1253,8 @@ func (h *RequestHeader) peek(key []byte) []byte { case "Cookie": if h.cookiesCollected { return appendRequestCookieBytes(nil, h.cookies) - } else { - return peekArgBytes(h.h, key) } + return peekArgBytes(h.h, key) default: return peekArgBytes(h.h, key) }