diff --git a/internal/martian/proxy_handler.go b/internal/martian/proxy_handler.go index 0a509a9f..982e3802 100644 --- a/internal/martian/proxy_handler.go +++ b/internal/martian/proxy_handler.go @@ -236,7 +236,7 @@ func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Requ cc = []copier{ {"outbound " + name, crw, req.Body}, - {"inbound " + name, writeFlusher{rw, rc}, crw}, + {"inbound " + name, makeH2Writer(rw, rc, req), crw}, } default: err := fmt.Errorf("unsupported protocol version: %d", req.ProtoMajor) @@ -337,16 +337,25 @@ func (p proxyHandler) handleRequest(rw http.ResponseWriter, req *http.Request) { } } -type writeFlusher struct { - rw io.Writer - rc *http.ResponseController +type h2Writer struct { + w io.Writer + flush func() error + close func() error } -func (w writeFlusher) Write(p []byte) (n int, err error) { - n, err = w.rw.Write(p) +func makeH2Writer(rw http.ResponseWriter, rc *http.ResponseController, req *http.Request) *h2Writer { + return &h2Writer{ + w: rw, + flush: rc.Flush, + close: req.Body.Close, + } +} + +func (w h2Writer) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) if n > 0 { - if err := w.rc.Flush(); err != nil { + if err := w.flush(); err != nil { log.Errorf(context.TODO(), "got error while flushing response back to client: %v", err) } } @@ -354,10 +363,10 @@ func (w writeFlusher) Write(p []byte) (n int, err error) { return } -func (w writeFlusher) CloseWrite() error { - // This is a nop implementation of closeWriter. - // It avoids printing the error log "cannot close write side of inbound CONNECT tunnel". - return nil +func (w h2Writer) CloseWrite() error { + // Close request body to signal the end of the request. + // This results RST_STREAM frame with error code NO_ERROR to be sent to the other side. + return w.close() } func (p proxyHandler) writeErrorResponse(rw http.ResponseWriter, req *http.Request, err error) {