Skip to content

Commit

Permalink
add: DisableIPv4, DisableIPv6
Browse files Browse the repository at this point in the history
  • Loading branch information
CorentinB committed Sep 24, 2024
1 parent 7c0b7a6 commit 9495206
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 22 deletions.
4 changes: 3 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type HTTPClientSettings struct {
FullOnDisk bool
VerifyCerts bool
RandomLocalIP bool
DisableIPv4 bool
DisableIPv6 bool
}

type CustomHTTPClient struct {
Expand Down Expand Up @@ -147,7 +149,7 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient
httpClient.TLSHandshakeTimeout = HTTPClientSettings.TLSHandshakeTimeout

// Configure custom dialer / transport
customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout)
customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6)
if err != nil {
return nil, err
}
Expand Down
207 changes: 207 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package warc

import (
"context"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -1267,6 +1268,212 @@ func TestHTTPClientWithZStandardDictionary(t *testing.T) {
}
}

func TestHTTPClientConnectionClosedEarly(t *testing.T) {
var (
rotatorSettings = NewRotatorSettings()
errWg sync.WaitGroup
err error
)

// init test HTTP endpoint
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fileBytes, err := os.ReadFile(path.Join("testdata", "2MB.jpg"))
if err != nil {
t.Fatal(err)
}

time.Sleep(2 * time.Second)

w.WriteHeader(http.StatusOK)
w.Write(fileBytes)
}))
defer server.Close()

rotatorSettings.OutputDirectory, err = os.MkdirTemp("", "warc-tests-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rotatorSettings.OutputDirectory)

rotatorSettings.Prefix = "TESTEARLYCLOSE"
rotatorSettings.Compression = "GZIP"

// init the HTTP client responsible for recording HTTP(s) requests / responses
httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{
RotatorSettings: rotatorSettings,
TempDir: rotatorSettings.OutputDirectory + "/temp",
})
if err != nil {
t.Fatalf("Unable to init WARC writing HTTP client: %s", err)
}

httpClient.Timeout = 1 * time.Second

errWg.Add(1)
go func() {
defer errWg.Done()
for err := range httpClient.ErrChan {
t.Errorf("Error writing to WARC: %s", err.Err.Error())
}
}()

req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

resp, err := httpClient.Do(req)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()

io.Copy(io.Discard, resp.Body)

httpClient.Close()

// Check if there are any file left in the temp directory
files, err := filepath.Glob(rotatorSettings.OutputDirectory + "/temp/*")
if err != nil {
t.Fatal(err)
}

if len(files) > 0 {
t.Fatalf("Expected no files in temp directory, got %d", len(files))
}

files, err = filepath.Glob(rotatorSettings.OutputDirectory + "/*")
if err != nil {
t.Fatal(err)
}

for _, path := range files {
testFileSingleHashCheck(t, path, "sha1:UIRWL5DFIPQ4MX3D3GFHM2HCVU3TZ6I3", []string{"26872"}, 1)
}
}

func setupIPv4Server(t *testing.T) (string, func()) {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to set up IPv4 server: %v", err)
}

server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("IPv4 Server"))
}),
}

go server.Serve(listener)

return "http://" + listener.Addr().String(), func() {
server.Shutdown(context.Background())
}
}

func setupIPv6Server(t *testing.T) (string, func()) {
listener, err := net.Listen("tcp6", "[::1]:0")
if err != nil {
t.Fatalf("Failed to set up IPv6 server: %v", err)
}

server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("IPv6 Server"))
}),
}

go server.Serve(listener)

return "http://" + listener.Addr().String(), func() {
server.Shutdown(context.Background())
}
}

func TestHTTPClientWithIPv4Disabled(t *testing.T) {
defer goleak.VerifyNone(t)

ipv4URL, closeIPv4 := setupIPv4Server(t)
defer closeIPv4()

ipv6URL, closeIPv6 := setupIPv6Server(t)
defer closeIPv6()

rotatorSettings := NewRotatorSettings()
rotatorSettings.OutputDirectory, _ = os.MkdirTemp("", "warc-tests-")
defer os.RemoveAll(rotatorSettings.OutputDirectory)
rotatorSettings.Prefix = "TESTIPv6Only"

httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{
RotatorSettings: rotatorSettings,
DisableIPv4: true,
})
if err != nil {
t.Fatalf("Unable to init WARC writing HTTP client: %s", err)
}
defer httpClient.Close()

// Try IPv4 - should fail
_, err = httpClient.Get(ipv4URL)
if err == nil {
t.Fatalf("Expected error when connecting to IPv4 server, but got none")
}

// Try IPv6 - should succeed
resp, err := httpClient.Get(ipv6URL)
if err != nil {
t.Fatalf("Failed to connect to IPv6 server: %v", err)
}
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if string(body) != "IPv6 Server" {
t.Fatalf("Unexpected response from IPv6 server: %s", string(body))
}
}

func TestHTTPClientWithIPv6Disabled(t *testing.T) {
defer goleak.VerifyNone(t)

ipv4URL, closeIPv4 := setupIPv4Server(t)
defer closeIPv4()

ipv6URL, closeIPv6 := setupIPv6Server(t)
defer closeIPv6()

rotatorSettings := NewRotatorSettings()
rotatorSettings.OutputDirectory, _ = os.MkdirTemp("", "warc-tests-")
defer os.RemoveAll(rotatorSettings.OutputDirectory)
rotatorSettings.Prefix = "TESTIPv4Only"

httpClient, err := NewWARCWritingHTTPClient(HTTPClientSettings{
RotatorSettings: rotatorSettings,
DisableIPv6: true,
})
if err != nil {
t.Fatalf("Unable to init WARC writing HTTP client: %s", err)
}
defer httpClient.Close()

// Try IPv6 - should fail
_, err = httpClient.Get(ipv6URL)
if err == nil {
t.Fatalf("Expected error when connecting to IPv6 server, but got none")
}

// Try IPv4 - should succeed
resp, err := httpClient.Get(ipv4URL)
if err != nil {
t.Fatalf("Failed to connect to IPv4 server: %v", err)
}
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if string(body) != "IPv4 Server" {
t.Fatalf("Unexpected response from IPv4 server: %s", string(body))
}
}

func BenchmarkConcurrentUnder2MB(b *testing.B) {
var (
rotatorSettings = NewRotatorSettings()
Expand Down
77 changes: 56 additions & 21 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ import (
type customDialer struct {
proxyDialer proxy.Dialer
client *CustomHTTPClient
disableIPv4 bool
disableIPv6 bool
net.Dialer
}

func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration) (d *customDialer, err error) {
func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration, disableIPv4, disableIPv6 bool) (d *customDialer, err error) {
d = new(customDialer)

d.Timeout = DialTimeout
d.client = httpClient
d.disableIPv4 = disableIPv4
d.disableIPv6 = disableIPv6

if proxyURL != "" {
u, err := url.Parse(proxyURL)
Expand Down Expand Up @@ -87,59 +91,65 @@ func (d *customDialer) wrapConnection(c net.Conn, scheme string) net.Conn {
}

func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err error) {
// Determine the network based on IPv4/IPv6 settings
network = d.getNetworkType(network)
if network == "" {
return nil, errors.New("no supported network type available")
}

if d.proxyDialer != nil {
conn, err = d.proxyDialer.Dial(network, address)
if err != nil {
return nil, err
}
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
if localAddr != nil {
if network == "tcp" {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
} else if network == "udp" {
} else if network == "udp" || network == "udp4" || network == "udp6" {
d.LocalAddr = localAddr.(*net.UDPAddr)
}
}
}

conn, err = d.Dial(network, address)
if err != nil {
return nil, err
}
}

if err != nil {
return nil, err
}

return d.wrapConnection(conn, "http"), nil
}

func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) {
var (
plainConn net.Conn
err error
)
// Determine the network based on IPv4/IPv6 settings
network = d.getNetworkType(network)
if network == "" {
return nil, errors.New("no supported network type available")
}

var plainConn net.Conn
var err error

if d.proxyDialer != nil {
plainConn, err = d.proxyDialer.Dial(network, address)
if err != nil {
return nil, err
}
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
if localAddr != nil {
if network == "tcp" {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
} else if network == "udp" {
} else if network == "udp" || network == "udp4" || network == "udp6" {
d.LocalAddr = localAddr.(*net.UDPAddr)
}
}
}

plainConn, err = d.Dial(network, address)
if err != nil {
return nil, err
}
}

if err != nil {
return nil, err
}

cfg := new(tls.Config)
Expand Down Expand Up @@ -171,6 +181,31 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error)
return d.wrapConnection(tlsConn, "https"), nil
}

func (d *customDialer) getNetworkType(network string) string {
switch network {
case "tcp", "udp":
if d.disableIPv4 && !d.disableIPv6 {
return network + "6"
}
if !d.disableIPv4 && d.disableIPv6 {
return network + "4"
}
return network // Both enabled or both disabled, use default
case "tcp4", "udp4":
if d.disableIPv4 {
return ""
}
return network
case "tcp6", "udp6":
if d.disableIPv6 {
return ""
}
return network
default:
return "" // Unsupported network type
}
}

func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, scheme string, conn net.Conn) {
defer d.client.WaitGroup.Done()

Expand Down

0 comments on commit 9495206

Please sign in to comment.