Skip to content

Commit

Permalink
support mtproxy
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Dec 21, 2023
1 parent f48d164 commit 630fbf1
Show file tree
Hide file tree
Showing 11 changed files with 526 additions and 289 deletions.
286 changes: 11 additions & 275 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
package client

import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/wwqgtxx/wstunnel/common"
"github.com/wwqgtxx/wstunnel/config"
"github.com/wwqgtxx/wstunnel/fallback"
"github.com/wwqgtxx/wstunnel/listener"
"github.com/wwqgtxx/wstunnel/proxy"
"github.com/wwqgtxx/wstunnel/tunnel"
"github.com/wwqgtxx/wstunnel/utils"
)

const DialTimeout = 8 * time.Second
Expand Down Expand Up @@ -75,197 +65,23 @@ func (c *client) GetServerWSPath() string {
return c.serverWSPath
}

type wsClientImpl struct {
header http.Header
wsUrl *url.URL
tlsConfig *tls.Config
dialer proxy.ContextDialer
ed uint32
proxy string
v2rayHttpUpgrade bool
}

type tcpClientImpl struct {
targetAddress string
dialer proxy.ContextDialer
proxy string
}

func (c *wsClientImpl) Target() string {
return c.wsUrl.String()
}

func (c *wsClientImpl) Proxy() string {
return c.proxy
}

func (c *wsClientImpl) Handle(tcp net.Conn) {
defer tcp.Close()
log.Println("Incoming --> ", tcp.RemoteAddr(), " --> ", c.Target(), c.Proxy())
edBuf, err := utils.PrepareXray0rtt(tcp, c.ed)
if err != nil {
log.Println(err)
return
}
conn, err := c.Dial(edBuf, nil)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
conn.TunnelTcp(tcp)
}

func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientConn, error) {
var header http.Header
if len(inHeader) > 0 {
// copy from inHeader
header = inHeader.Clone()
// don't use inHeader's `Host`
header.Del("Host")

// merge from c.header
for k, vs := range c.header {
header[k] = vs
}

// duplicate header is not allowed, remove
header.Del("Upgrade")
header.Del("Connection")
header.Del("Sec-Websocket-Key")
header.Del("Sec-Websocket-Version")
header.Del("Sec-Websocket-Extensions")
header.Del("Sec-WebSocket-Protocol")

// force use inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws
if secProtocol := inHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
if c.ed > 0 {
header.Set("Sec-WebSocket-Protocol", secProtocol)
edBuf = nil
} else {
edBuf, _ = utils.DecodeEd(secProtocol)
}
}
} else {
// copy from c.header
header = c.header.Clone()
if header == nil {
header = http.Header{}
}
}
if c.ed > 0 && len(edBuf) > 0 {
header.Set("Sec-WebSocket-Protocol", utils.EncodeEd(edBuf))
edBuf = nil
}

ctx, cancel := context.WithTimeout(context.Background(), DialTimeout)
defer cancel()
conn, respHeader, err := utils.ClientWebsocketDial(ctx, *c.wsUrl, header, c.dialer, c.tlsConfig, c.v2rayHttpUpgrade)
log.Println("Dial to", c.Target(), c.Proxy(), "with", header, "response", respHeader)
if err != nil {
return nil, err
}

if len(edBuf) > 0 {
_, err = conn.Write(edBuf)
if err != nil {
return nil, err
}
}
if wsConn, ok := conn.(*utils.WebsocketConn); ok {
return &wsClientConn{wsConn: wsConn}, err
} else {
return &tcpClientConn{tcp: conn}, err
}
}

type wsClientConn struct {
wsConn *utils.WebsocketConn
close sync.Once
}

func (c *wsClientConn) Close() {
c.close.Do(func() {
_ = c.wsConn.Close()
})
}

func (c *wsClientConn) TunnelTcp(tcp net.Conn) {
tunnel.Tunnel(tcp, c.wsConn)
}

func (c *wsClientConn) TunnelWs(wsConn *utils.WebsocketConn) {
if wsConn.ReaderReplaceable() == c.wsConn.ReaderReplaceable() {
// fastpath for direct tunnel underlying ws connection
tunnel.Tunnel(wsConn.Conn, c.wsConn.Conn)
} else {
tunnel.Tunnel(wsConn, c.wsConn)
}
}

func (c *tcpClientImpl) Target() string {
return c.targetAddress
}

func (c *tcpClientImpl) Proxy() string {
return c.proxy
}

func (c *tcpClientImpl) Handle(tcp net.Conn) {
defer tcp.Close()
log.Println("Incoming --> ", tcp.RemoteAddr(), " --> ", c.Target(), c.Proxy())
conn, err := c.Dial(nil, nil)
func BuildClient(clientConfig config.ClientConfig) {
_, port, err := net.SplitHostPort(clientConfig.BindAddress)
if err != nil {
log.Println(err)
return
}
defer conn.Close()
conn.TunnelTcp(tcp)
}

func (c *tcpClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), DialTimeout)
defer cancel()
tcp, err := c.dialer.DialContext(ctx, "tcp", c.Target())
if err == nil && len(edBuf) > 0 {
_, err = tcp.Write(edBuf)
if err != nil {
return nil, err
}
}
return &tcpClientConn{tcp: tcp}, err
}

type tcpClientConn struct {
tcp net.Conn
close sync.Once
}

func (c *tcpClientConn) Close() {
c.close.Do(func() {
_ = c.tcp.Close()
})
}

func (c *tcpClientConn) TunnelTcp(tcp net.Conn) {
tunnel.Tunnel(tcp, c.tcp)
}

func (c *tcpClientConn) TunnelWs(wsConn *utils.WebsocketConn) {
tunnel.Tunnel(c.tcp, wsConn)
}
serverWSPath := strings.ReplaceAll(clientConfig.ServerWSPath, "{port}", port)

func BuildClient(clientConfig config.ClientConfig) {
_, port, err := net.SplitHostPort(clientConfig.BindAddress)
clientImpl, err := NewClientImpl(clientConfig)
if err != nil {
log.Println(err)
return
}

serverWSPath := strings.ReplaceAll(clientConfig.ServerWSPath, "{port}", port)

c := &client{
ClientImpl: NewClientImpl(clientConfig),
ClientImpl: clientImpl,
serverWSPath: serverWSPath,
listenerConfig: listener.Config{
ListenerConfig: clientConfig.ListenerConfig,
Expand All @@ -277,97 +93,17 @@ func BuildClient(clientConfig config.ClientConfig) {
common.PortToClient[port] = c
}

func parseProxy(proxyString string) (proxyUrl *url.URL, proxyStr string) {
if len(proxyString) > 0 {
u, err := url.Parse(proxyString)
if err != nil {
log.Println(err)
}
proxyUrl = u

ru := *u
ru.User = nil
proxyStr = ru.String()
}
return
}

func getDialer(proxyUrl *url.URL) proxy.ContextDialer {
tcpDialer := &net.Dialer{}

proxyDialer := proxy.FromEnvironment()
if proxyUrl != nil {
dialer, err := proxy.FromURL(proxyUrl, tcpDialer)
if err != nil {
log.Println(err)
} else {
proxyDialer = dialer
}
}
if proxyDialer != proxy.Direct {
return proxy.NewContextDialer(proxyDialer)
} else {
return tcpDialer
}
}

func NewClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
if len(clientConfig.TargetAddress) > 0 {
func NewClientImpl(clientConfig config.ClientConfig) (common.ClientImpl, error) {
switch {
case len(clientConfig.Mtp) > 0:
return NewMtproxyClientImpl(clientConfig)
case len(clientConfig.TargetAddress) > 0:
return NewTcpClientImpl(clientConfig)
} else {
default:
return NewWsClientImpl(clientConfig)
}
}

func NewTcpClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)
dialer := getDialer(proxyUrl)

return &tcpClientImpl{
targetAddress: clientConfig.TargetAddress,
dialer: dialer,
proxy: proxyStr,
}
}

func NewWsClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)
netDial := getDialer(proxyUrl)

header := http.Header{}
if len(clientConfig.WSHeaders) != 0 {
for key, value := range clientConfig.WSHeaders {
header.Add(key, value)
}
}
tlsConfig := &tls.Config{
ServerName: clientConfig.ServerName,
InsecureSkipVerify: clientConfig.SkipCertVerify,
}
var ed uint32
u, err := url.Parse(clientConfig.WSUrl)
if err != nil {
panic(fmt.Errorf("parse url %s error: %w", clientConfig.WSUrl, err))
}
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
//clientConfig.WSUrl = u.String()
}

return &wsClientImpl{
header: header,
wsUrl: u,
dialer: netDial,
tlsConfig: tlsConfig,
ed: ed,
proxy: proxyStr,
v2rayHttpUpgrade: clientConfig.V2rayHttpUpgrade,
}
}

func StartClients() {
for clientPort, client := range common.PortToClient {
if !strings.HasPrefix(client.Target(), "ws") {
Expand Down
Loading

0 comments on commit 630fbf1

Please sign in to comment.