Skip to content

Commit

Permalink
ws client support v2ray-http-upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Dec 19, 2023
1 parent 8d74452 commit 2e32c6a
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 91 deletions.
70 changes: 41 additions & 29 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"crypto/tls"
"fmt"
"log"
Expand Down Expand Up @@ -75,17 +76,18 @@ func (c *client) GetServerWSPath() string {
}

type wsClientImpl struct {
header http.Header
wsUrl *url.URL
tlsConfig *tls.Config
netDial proxy.NetDialerFunc
ed uint32
proxy string
header http.Header
wsUrl *url.URL
tlsConfig *tls.Config
dialer proxy.ContextDialer
ed uint32
proxy string
v2rayHttpUpgrade bool
}

type tcpClientImpl struct {
targetAddress string
netDial proxy.NetDialerFunc
dialer proxy.ContextDialer
proxy string
}

Expand Down Expand Up @@ -147,25 +149,34 @@ func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientCo
} else {
// copy from c.header
header = c.header.Clone()
if header == nil {
header = http.Header{}
}
}
if c.ed > 0 && len(edBuf) > 0 {
header.Set("Sec-WebSocket-Protocol", utils.EncodeEd(edBuf))
edBuf = nil
}

wsConn, header, err := utils.ClientWebsocketDial(*c.wsUrl, header, c.netDial, c.tlsConfig, DialTimeout)
log.Println("Dial to", c.Target(), c.Proxy(), "with", header)
ctx, cancel := context.WithTimeout(context.Background(), DialTimeout)
defer cancel()
conn, respHeader, err := utils.ClientWebsocketDial(ctx, *c.wsUrl, header, c.dialer, c.tlsConfig, c.v2rayHttpUpgrade)
log.Println("Dial to", c.Target(), c.Proxy(), "with", header, "response", respHeader)
if err != nil {
return nil, err
}

if len(edBuf) > 0 {
_, err = wsConn.Write(edBuf)
_, err = conn.Write(edBuf)
if err != nil {
return nil, err
}
}
return &wsClientConn{wsConn: wsConn}, err
if wsConn, ok := conn.(*utils.WebsocketConn); ok {
return &wsClientConn{wsConn: wsConn}, err
} else {
return &tcpClientConn{tcp: conn}, err
}
}

type wsClientConn struct {
Expand Down Expand Up @@ -213,7 +224,9 @@ func (c *tcpClientImpl) Handle(tcp net.Conn) {
}

func (c *tcpClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientConn, error) {
tcp, err := c.netDial("tcp", c.Target())
ctx, cancel := context.WithTimeout(context.Background(), DialTimeout)
defer cancel()
tcp, err := c.dialer.DialContext(ctx, "tcp", c.Target())
if err == nil && len(edBuf) > 0 {
_, err = tcp.Write(edBuf)
if err != nil {
Expand Down Expand Up @@ -279,25 +292,23 @@ func parseProxy(proxyString string) (proxyUrl *url.URL, proxyStr string) {
return
}

func getNetDial(proxyUrl *url.URL) (netDial proxy.NetDialerFunc) {
tcpDialer := &net.Dialer{
Timeout: DialTimeout,
}
netDial = tcpDialer.Dial
func getDialer(proxyUrl *url.URL) proxy.ContextDialer {
tcpDialer := &net.Dialer{}

proxyDialer := proxy.FromEnvironment()
if proxyUrl != nil {
dialer, err := proxy.FromURL(proxyUrl, netDial)
dialer, err := proxy.FromURL(proxyUrl, tcpDialer)
if err != nil {
log.Println(err)
} else {
proxyDialer = dialer
}
}
if proxyDialer != proxy.Direct {
netDial = proxyDialer.Dial
return proxy.NewContextDialer(proxyDialer)
} else {
return tcpDialer
}
return
}

func NewClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
Expand All @@ -310,18 +321,18 @@ func NewClientImpl(clientConfig config.ClientConfig) common.ClientImpl {

func NewTcpClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)
netDial := getNetDial(proxyUrl)
dialer := getDialer(proxyUrl)

return &tcpClientImpl{
targetAddress: clientConfig.TargetAddress,
netDial: netDial,
dialer: dialer,
proxy: proxyStr,
}
}

func NewWsClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)
netDial := getNetDial(proxyUrl)
netDial := getDialer(proxyUrl)

header := http.Header{}
if len(clientConfig.WSHeaders) != 0 {
Expand All @@ -347,12 +358,13 @@ func NewWsClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
}

return &wsClientImpl{
header: header,
wsUrl: u,
netDial: netDial,
tlsConfig: tlsConfig,
ed: ed,
proxy: proxyStr,
header: header,
wsUrl: u,
dialer: netDial,
tlsConfig: tlsConfig,
ed: ed,
proxy: proxyStr,
v2rayHttpUpgrade: clientConfig.V2rayHttpUpgrade,
}
}

Expand Down
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import (
)

type ClientConfig struct {
ListenerConfig `yaml:",inline"`
ProxyConfig `yaml:",inline"`
TargetAddress string `yaml:"target-address"`
WSUrl string `yaml:"ws-url"`
WSHeaders map[string]string `yaml:"ws-headers"`
SkipCertVerify bool `yaml:"skip-cert-verify"`
ServerName string `yaml:"servername"`
ServerWSPath string `yaml:"server-ws-path"`
ListenerConfig `yaml:",inline"`
ProxyConfig `yaml:",inline"`
TargetAddress string `yaml:"target-address"`
WSUrl string `yaml:"ws-url"`
WSHeaders map[string]string `yaml:"ws-headers"`
V2rayHttpUpgrade bool `yaml:"v2ray-http-upgrade"`
SkipCertVerify bool `yaml:"skip-cert-verify"`
ServerName string `yaml:"servername"`
ServerWSPath string `yaml:"server-ws-path"`
}

type ServerConfig struct {
Expand Down
31 changes: 31 additions & 0 deletions proxy/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package proxy

import (
"context"
"net"
)

// SetupContextForConn is a helper function that starts connection I/O interrupter goroutine.
func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) {
var (
quit = make(chan struct{})
interrupt = make(chan error, 1)
)
go func() {
select {
case <-quit:
interrupt <- nil
case <-ctx.Done():
// Close the connection, discarding the error
_ = conn.Close()
interrupt <- ctx.Err()
}
}()
return func(inputErr *error) {
close(quit)
if ctxErr := <-interrupt; ctxErr != nil && inputErr != nil {
// Return context error to user.
inputErr = &ctxErr
}
}
}
58 changes: 41 additions & 17 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package proxy

import (
"bufio"
"context"
"encoding/base64"
"errors"
"net"
Expand All @@ -16,14 +17,26 @@ import (

func init() {
RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer Dialer) (Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
return &httpProxyDialer{proxyURL: proxyURL, forwardDialer: NewContextDialer(forwardDialer)}, nil
})
}

type NetDialerFunc func(network, addr string) (net.Conn, error)
func NewContextDialer(d Dialer) ContextDialer {
if xd, ok := d.(ContextDialer); ok {
return xd
}
return contextDialer{d}
}

type contextDialer struct {
Dialer
}

func (fn NetDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
func (d contextDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if ctx.Done() != nil {
return dialContext(ctx, d, network, addr)
}
return d.Dial(network, addr)
}

func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
Expand All @@ -45,15 +58,19 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
}

type httpProxyDialer struct {
proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error)
proxyURL *url.URL
forwardDialer ContextDialer
}

func (hpd *httpProxyDialer) Dial(network string, addr string) (conn net.Conn, err error) {
return hpd.DialContext(context.Background(), network, addr)
}

func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (conn net.Conn, err error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort)
conn, err = hpd.forwardDialer.DialContext(ctx, network, hostPort)
if err != nil {
return nil, err
return
}

connectHeader := make(http.Header)
Expand All @@ -72,24 +89,31 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
Header: connectHeader,
}

if err := connectReq.Write(conn); err != nil {
conn.Close()
return nil, err
done := SetupContextForConn(ctx, conn)
defer done(&err)

if err = connectReq.Write(conn); err != nil {
_ = conn.Close()
conn = nil
return
}

// Read response. It's OK to use and discard buffered reader here becaue
// the remote server does not speak until spoken to.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
_ = conn.Close()
conn = nil
return
}

if resp.StatusCode != 200 {
conn.Close()
_ = conn.Close()
conn = nil
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
err = errors.New(f[1])
return
}
return conn, nil
return
}
Loading

0 comments on commit 2e32c6a

Please sign in to comment.