Skip to content

Commit

Permalink
support waitRead in windows
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Dec 21, 2023
1 parent 630fbf1 commit dfadc7d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tunnel/copy_stub.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:build !unix
//go:build !unix && !windows

package tunnel

Expand Down
98 changes: 98 additions & 0 deletions tunnel/copy_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package tunnel

import (
"io"
"log"
"syscall"
"unsafe"

"golang.org/x/sys/windows"
)

//go:linkname modws2_32 golang.org/x/sys/windows.modws2_32
var modws2_32 *windows.LazyDLL

var procrecv = modws2_32.NewProc("recv")

//go:linkname errnoErr golang.org/x/sys/windows.errnoErr
func errnoErr(e syscall.Errno) error

func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}

func syscallCopy(src io.Reader, srcRaw syscall.RawConn, dst io.Writer) (handed bool, written int64, err error) {
log.Printf("syscallCopy %T %T", src, dst)
handed = true
var sysErr error = nil
var buf []byte
getBuf := func() []byte {
if buf == nil {
buf = BufPool.Get().([]byte)
}
return buf
}
putBuf := func() {
if buf != nil {
BufPool.Put(buf)
buf = nil
}
}
for {
var rn int
var wn int
hasData := false
err = srcRaw.Read(func(fd uintptr) (done bool) {
if !hasData {
hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
return false
}
n, err := recv(windows.Handle(fd), getBuf(), 0)
rn = int(n)
if n <= 0 {
putBuf()
}
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case err == windows.WSAEWOULDBLOCK || err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR:
return false
//sysErr = nil
default:
sysErr = err
}
hasData = false
return true
})
if err == nil {
err = sysErr
}
if err != nil {
if err == io.EOF {
err = nil
}
return
}
wn, err = dst.Write(buf[:rn])
putBuf()
written += int64(wn)
if rn != wn {
err = io.ErrShortWrite
return
}
if err != nil {
return
}
}
}
70 changes: 59 additions & 11 deletions udp/packet_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,76 @@ package udp
import (
"net"
"net/netip"
"strconv"
"syscall"

"golang.org/x/sys/windows"
)

type enhanceUDPConn struct {
*net.UDPConn
rawConn syscall.RawConn
}

func newEnhancePacketConn(udpConn *net.UDPConn) EnhancePacketConn {
return &enhanceUDPConn{UDPConn: udpConn}
c := &enhanceUDPConn{UDPConn: udpConn}
c.rawConn, _ = udpConn.SyscallConn()
return c
}

func (c *enhanceUDPConn) WaitReadFrom() (data []byte, put func(), addr netip.AddrPort, err error) {
readBuf := BufPool.Get().([]byte)
put = func() {
BufPool.Put(readBuf)
if c.rawConn == nil {
c.rawConn, _ = c.UDPConn.SyscallConn()
}
var readErr error
hasData := false
err = c.rawConn.Read(func(fd uintptr) (done bool) {
if !hasData {
hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
return false
}
readBuf := BufPool.Get().([]byte)
put = func() {
BufPool.Put(readBuf)
}
var readFrom windows.Sockaddr
var readN int
readN, readFrom, readErr = windows.Recvfrom(windows.Handle(fd), readBuf, 0)
if readN > 0 {
data = readBuf[:readN]
} else {
put()
put = nil
data = nil
}
if readErr == windows.WSAEWOULDBLOCK {
return false
}
if readFrom != nil {
switch from := readFrom.(type) {
case *windows.SockaddrInet4:
ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes
addr = netip.AddrPortFrom(netip.AddrFrom4(ip), uint16(from.Port))
case *windows.SockaddrInet6:
ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
addr = netip.AddrPortFrom(netip.AddrFrom16(ip).WithZone(strconv.FormatInt(int64(from.ZoneId), 10)), uint16(from.Port))
}
}
// udp should not convert readN == 0 to io.EOF
//if readN == 0 {
// readErr = io.EOF
//}
hasData = false
return true
})
if err != nil {
return
}
var readN int
readN, addr, err = c.UDPConn.ReadFromUDPAddrPort(readBuf)
if readN > 0 {
data = readBuf[:readN]
} else {
put()
put = nil
if readErr != nil {
err = readErr
return
}
return
}

0 comments on commit dfadc7d

Please sign in to comment.