Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non buffered body server responses. #1657

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions bytesconv.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,23 @@ func appendQuotedPath(dst, src []byte) []byte {
}
return dst
}

// countHexDigits returns the number of hex digits required to represent n when using writeHexInt
func countHexDigits(n int) int {
if n < 0 {
// developer sanity-check
panic("BUG: int must be positive")
}

if n == 0 {
return 1
}

count := 0
for n > 0 {
n = n >> 4
count++
}

return count
}
3 changes: 3 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ type Response struct {
raddr net.Addr
// Local TCPAddr from concurrently net.Conn
laddr net.Addr

headersWritten bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this state needs to be kept here, you can just add a headersWritten bool to UnbufferedWriterHttp1.

}

// SetHost sets host for the request.
Expand Down Expand Up @@ -1122,6 +1124,7 @@ func (resp *Response) Reset() {
resp.laddr = nil
resp.ImmediateHeaderFlush = false
resp.StreamBody = false
resp.headersWritten = false
}

func (resp *Response) resetSkipHeader() {
Expand Down
104 changes: 104 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,47 @@ type RequestCtx struct {
hijackHandler HijackHandler
hijackNoResponse bool
formValueFunc FormValueFunc

disableBuffering bool // disables buffered response body
writer *bufio.Writer // used to send response in non-buffered mode
bytesSent int // number of bytes sent to client (in non-buffered mode)
bodyChunkStarted bool // true if the response body chunk has been started
bodyLastChunkSent bool // true if the last chunk of the response body has been sent
}

// DisableBuffering modifies fasthttp to disable body buffering for this request.
// This is useful for requests that return large data or stream data.
//
// When buffering is disabled you must:
// 1. Set response status and header values before writing body
// 2. Set ContentLength is optional. If not set, the server will use chunked encoding.
// 3. Write body data using methods like ctx.Write or io.Copy(ctx,src), etc.
// 4. Optionally call CloseResponse to finalize the response.
//
// CLosing the response will finalize the response and send the last chunk.
// If the handler does not finish the response, it will be called automatically after handler returns.
// Closing the response will also set BytesSent with the correct number of total bytes sent.
func (ctx *RequestCtx) DisableBuffering() {
ctx.disableBuffering = true
}

// CloseResponse finalizes non-buffered response dispatch.
// This method must be called after performing non-buffered responses
// If the handler does not finish the response, it will be called automatically
// after the handler function returns.
func (ctx *RequestCtx) CloseResponse() {
if !ctx.disableBuffering || !ctx.bodyChunkStarted || ctx.bodyLastChunkSent {
return
}
if ctx.writer != nil {
// finalize chunks
if ctx.bodyChunkStarted && ctx.Response.Header.IsHTTP11() && !ctx.bodyLastChunkSent {
_, _ = ctx.writer.Write([]byte("0\r\n\r\n"))
_ = ctx.writer.Flush()
ctx.bytesSent += 5
}
ctx.bodyLastChunkSent = true
}
}

// HijackHandler must process the hijacked connection c.
Expand Down Expand Up @@ -822,6 +863,12 @@ func (ctx *RequestCtx) reset() {

ctx.hijackHandler = nil
ctx.hijackNoResponse = false

ctx.writer = nil
ctx.disableBuffering = false
ctx.bytesSent = 0
ctx.bodyChunkStarted = false
ctx.bodyLastChunkSent = false
}

type firstByteReader struct {
Expand Down Expand Up @@ -1443,10 +1490,58 @@ func (ctx *RequestCtx) NotFound() {

// Write writes p into response body.
func (ctx *RequestCtx) Write(p []byte) (int, error) {
if ctx.disableBuffering {
return ctx.writeDirect(p)
}

ctx.Response.AppendBody(p)
return len(p), nil
}

// writeDirect writes p to underlying connection bypassing any buffering.
func (ctx *RequestCtx) writeDirect(p []byte) (int, error) {
// Non buffered response
if ctx.writer == nil {
ctx.writer = acquireWriter(ctx)
}

// Write headers if not written yet
if !ctx.Response.headersWritten {
if ctx.Response.Header.contentLength == 0 && ctx.Response.Header.IsHTTP11() {
ctx.Response.Header.SetContentLength(-1) // means Transfer-Encoding = chunked
}
h := ctx.Response.Header.Header()
n, err := ctx.writer.Write(h)
if err != nil {
return 0, err
}
ctx.bytesSent += n
ctx.Response.headersWritten = true
}

// Write body. In chunks if content length is not set.
if ctx.Response.Header.contentLength == -1 && ctx.Response.Header.IsHTTP11() {
ctx.bodyChunkStarted = true
err := writeChunk(ctx.writer, p)
if err != nil {
return 0, err
}
ctx.bytesSent += len(p) + 4 + countHexDigits(len(p))
return len(p), nil
}

n, err := ctx.writer.Write(p)
ctx.bytesSent += n

return n, err
}

// BytesSent returns the number of bytes sent to the client after non buffered operation.
// Includes headers and body length.
func (ctx *RequestCtx) BytesSent() int {
return ctx.bytesSent
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this function?

Copy link
Author

@pablolagos pablolagos Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After streaming a response, it can be necessary to know the number of bytes sent. That number is known only if we a serving local resources, but unknown if we are proxying from external sources. Bytes sent can represent a cost related to data-transfer. It can be useful for logging and other analysis.

We could return that value in ctx.CloseReponse()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I would remove this. This isn't supported with normal responses either so it's weird to only add it for this case now.


// WriteString appends s to response body.
func (ctx *RequestCtx) WriteString(s string) (int, error) {
ctx.Response.AppendBodyString(s)
Expand Down Expand Up @@ -2359,6 +2454,15 @@ func (s *Server) serveConn(c net.Conn) (err error) {
s.Handler(ctx)
}

if ctx.disableBuffering {
ctx.CloseResponse()
if ctx.writer != nil {
releaseWriter(s, ctx.writer)
ctx.writer = nil
}
break
}

timeoutResponse = ctx.timeoutResponse
if timeoutResponse != nil {
// Acquire a new ctx because the old one will still be in use by the timeout out handler.
Expand Down
50 changes: 50 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4237,6 +4237,56 @@ func TestServerChunkedResponse(t *testing.T) {
}
}

func TestServerDisableBuffering(t *testing.T) {
t.Parallel()

expectedBody := bytes.Repeat([]byte("a"), 4096)

s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.DisableBuffering()
ctx.SetStatusCode(StatusOK)
ctx.SetContentType("text/html; charset=utf-8")
reader := bytes.NewReader(expectedBody)
_, err := io.Copy(ctx, reader)
if err != nil {
t.Fatalf("Unexpected error when copying body: %v", err)
}
pablolagos marked this conversation as resolved.
Show resolved Hide resolved
if len(ctx.Response.Body()) > 0 {
t.Fatalf("Body was populated when buffer was disabled")
}
},
}

ln := fasthttputil.NewInmemoryListener()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()

conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET /index.html HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)

var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.Header.ContentLength() != -1 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), -1)
}
if !bytes.Equal(resp.Body(), expectedBody) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}
}

func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response {
var resp Response
if err := resp.Read(r); err != nil {
Expand Down