diff --git a/agents/ack_agent.go b/agents/ack_agent.go index d930c4d..42da844 100644 --- a/agents/ack_agent.go +++ b/agents/ack_agent.go @@ -16,7 +16,9 @@ type AckAgent struct { func (a *AckAgent) Run(conn *Connection) { a.BaseAgent.Init("AckAgent", conn.OriginalDestinationCID) a.FrameProducingAgent.InitFPA(conn) - a.DisableAcks = make(map[PNSpace]bool) + if a.DisableAcks == nil { + a.DisableAcks = make(map[PNSpace]bool) + } a.TotalDataAcked = make(map[PNSpace]uint64) incomingPackets := conn.IncomingPackets.RegisterNewChan(1000) diff --git a/agents/base_agent.go b/agents/base_agent.go index c3e2f8b..aea78ff 100644 --- a/agents/base_agent.go +++ b/agents/base_agent.go @@ -21,26 +21,27 @@ type Agent interface { Init(name string, SCID ConnectionID) Run(conn *Connection) Stop() + Restart() Join() } type RequestFrameArgs struct { availableSpace int - level EncryptionLevel - number PacketNumber + level EncryptionLevel + number PacketNumber } // All agents should embed this structure type BaseAgent struct { name string Logger *log.Logger - close chan bool + close chan bool // true if should restart, false otherwise closed chan bool } func (a *BaseAgent) Name() string { return a.name } -// All agents that embed this structure must call InitFPA() as soon as their Run() method is called +// All agents that embed this structure must call Init() as soon as their Run() method is called func (a *BaseAgent) Init(name string, ODCID ConnectionID) { a.name = name a.Logger = log.New(os.Stderr, fmt.Sprintf("[%s/%s] ", hex.EncodeToString(ODCID), a.Name()), log.Lshortfile) @@ -57,6 +58,14 @@ func (a *BaseAgent) Stop() { } } +func (a *BaseAgent) Restart() { + select { + case <-a.close: + default: + a.close <- true + } +} + func (a *BaseAgent) Join() { <-a.closed } @@ -105,6 +114,25 @@ func AttachAgentsToConnection(conn *Connection, agents ...Agent) *ConnectionAgen c.Add(a) } + go func() { + for { + select { + case <-conn.ConnectionRestart: + conn.Logger.Printf("Restarting all agents\n") + for _, a := range agents { + a.Restart() + a.Join() + a.Run(conn) + } + conn.ConnectionRestart = make(chan bool, 1) + close(conn.ConnectionRestarted) + conn.Logger.Printf("Restarting all agents: done\n") + case <-conn.ConnectionClosed: + return + } + } + }() + return &c } @@ -132,7 +160,7 @@ func (c *ConnectionAgents) GetFrameProducingAgents() []FrameProducer { return agents } -func (c *ConnectionAgents) Stop(names... string) { +func (c *ConnectionAgents) Stop(names ...string) { for _, n := range names { c.Get(n).Stop() c.Get(n).Join() diff --git a/agents/closing_agent.go b/agents/closing_agent.go index 91e245e..cc3a1a4 100644 --- a/agents/closing_agent.go +++ b/agents/closing_agent.go @@ -15,7 +15,7 @@ type ClosingAgent struct { IdleTimeout *time.Timer } -func (a *ClosingAgent) Run(conn *Connection) { +func (a *ClosingAgent) Run(conn *Connection) { // TODO: Observe incoming CC and AC a.Init("ClosingAgent", conn.OriginalDestinationCID) a.conn = conn a.IdleDuration = time.Duration(a.conn.TLSTPHandler.IdleTimeout) * time.Second @@ -27,7 +27,6 @@ func (a *ClosingAgent) Run(conn *Connection) { go func() { defer a.Logger.Println("Agent terminated") defer close(a.closed) - defer close(a.conn.ConnectionClosed) for { select { @@ -37,6 +36,7 @@ func (a *ClosingAgent) Run(conn *Connection) { switch p := i.(type) { case Framer: if a.closing && (p.Contains(ConnectionCloseType) || p.Contains(ApplicationCloseType)) { + close(a.conn.ConnectionClosed) return } } @@ -46,8 +46,12 @@ func (a *ClosingAgent) Run(conn *Connection) { case <-a.IdleTimeout.C: a.closing = true a.Logger.Printf("Idle timeout of %v reached, closing\n", a.IdleDuration.String()) + close(a.conn.ConnectionClosed) return - case <-a.close: + case shouldRestart := <-a.close: + if !shouldRestart { + close(a.conn.ConnectionClosed) + } return } } diff --git a/agents/handshake_agent.go b/agents/handshake_agent.go index fdaea73..4927f63 100644 --- a/agents/handshake_agent.go +++ b/agents/handshake_agent.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" . "github.com/QUIC-Tracker/quic-tracker" + "github.com/davecgh/go-spew/spew" "strings" ) @@ -38,13 +39,9 @@ func (a *HandshakeAgent) Run(conn *Connection) { a.sendInitial = make(chan bool, 1) incPackets := conn.IncomingPackets.RegisterNewChan(1000) - outPackets := conn.OutgoingPackets.RegisterNewChan(1000) - tlsStatus := a.TLSAgent.TLSStatus.RegisterNewChan(10) - - socketStatus := make(chan interface{}, 10) - a.SocketAgent.SocketStatus.Register(socketStatus) + socketStatus := a.SocketAgent.SocketStatus.RegisterNewChan(10) firstInitialReceived := false tlsCompleted := false @@ -66,20 +63,17 @@ func (a *HandshakeAgent) Run(conn *Connection) { a.HandshakeStatus.Submit(HandshakeStatus{false, p, err}) return } - conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial) + close(conn.ConnectionRestart) case *RetryPacket: if !a.IgnoreRetry && bytes.Equal(conn.DestinationCID, p.OriginalDestinationCID) && !a.receivedRetry { // TODO: Check the original_connection_id TP too a.receivedRetry = true conn.DestinationCID = p.Header().(*LongHeader).SourceCID - tlsTP := conn.TLSTPHandler - conn.TransitionTo(QuicVersion, QuicALPNToken) + tlsTP, alpn := conn.TLSTPHandler, conn.ALPN + spew.Dump(tlsTP) + conn.TransitionTo(QuicVersion, alpn) conn.TLSTPHandler = tlsTP conn.Token = p.RetryToken - a.TLSAgent.Stop() - a.TLSAgent.Join() - a.TLSAgent.Run(conn) - a.TLSAgent.TLSStatus.Register(tlsStatus) - conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial) + close(conn.ConnectionRestart) } case Framer: if p.Contains(ConnectionCloseType) || p.Contains(ApplicationCloseType) { @@ -122,6 +116,13 @@ func (a *HandshakeAgent) Run(conn *Connection) { if strings.Contains(i.(error).Error(), "connection refused") { a.HandshakeStatus.Submit(HandshakeStatus{false, nil , i.(error)}) } + case <-conn.ConnectionRestarted: + incPackets = conn.IncomingPackets.RegisterNewChan(1000) + outPackets = conn.OutgoingPackets.RegisterNewChan(1000) + tlsStatus = a.TLSAgent.TLSStatus.RegisterNewChan(10) + socketStatus = a.SocketAgent.SocketStatus.RegisterNewChan(10) + conn.ConnectionRestarted = make(chan bool, 1) + conn.SendPacket(conn.GetInitialPacket(), EncryptionLevelInitial) case <-a.close: return } diff --git a/agents/http_agent.go b/agents/http_agent.go index b3913fd..39b461f 100644 --- a/agents/http_agent.go +++ b/agents/http_agent.go @@ -212,10 +212,10 @@ func (a *HTTPAgent) SendRequest(path, method, authority string, headers map[stri } hdrs := []HTTPHeader{ - HTTPHeader{":method", method}, - HTTPHeader{":scheme", "https"}, - HTTPHeader{":authority", authority}, - HTTPHeader{":path", path}, + {":method", method}, + {":scheme", "https"}, + {":authority", authority}, + {":path", path}, } for k, v := range headers { hdrs = append(hdrs, HTTPHeader{k, v}) diff --git a/agents/socket_agent.go b/agents/socket_agent.go index 0e937d4..a25940c 100644 --- a/agents/socket_agent.go +++ b/agents/socket_agent.go @@ -45,6 +45,11 @@ func (a *SocketAgent) Run(conn *Connection) { if err != nil { a.Logger.Println("Closing UDP socket because of error", err.Error()) + select { + case <-recChan: + return + default: + } close(recChan) a.SocketStatus.Submit(err) break @@ -65,6 +70,11 @@ func (a *SocketAgent) Run(conn *Connection) { a.Logger.Printf("Received %d bytes from UDP socket\n", i) payload := make([]byte, i) copy(payload, recBuf[:i]) + select { + case <-recChan: + return + default: + } recChan <- payload } }() @@ -80,9 +90,16 @@ func (a *SocketAgent) Run(conn *Connection) { } conn.IncomingPayloads.Submit(p) - case <-a.close: - conn.UdpConnection.Close() - // TODO: Close this agent gracefully + case shouldRestart := <-a.close: + if !shouldRestart { + conn.UdpConnection.Close() + } + select { + case <-recChan: + return + default: + } + close(recChan) return } } diff --git a/common.go b/common.go index cd595c1..e774835 100644 --- a/common.go +++ b/common.go @@ -49,7 +49,7 @@ const ( MinimumInitialLengthv6 = 1232 MaxUDPPayloadSize = 65507 MaximumVersion = 0xff000012 - MinimumVersion = 0xff000012 + MinimumVersion = 0xff000011 ) // errors diff --git a/connection.go b/connection.go index a0e84f3..ad1e5d9 100644 --- a/connection.go +++ b/connection.go @@ -47,6 +47,8 @@ type Connection struct { StreamInput Broadcaster //type: StreamInput ConnectionClosed chan bool + ConnectionRestart chan bool // Triggered when receiving a Retry or a VN packet + ConnectionRestarted chan bool OriginalDestinationCID ConnectionID SourceCID ConnectionID @@ -290,6 +292,8 @@ func NewConnection(serverName string, version uint32, ALPN string, SCID []byte, c.FrameQueue = NewBroadcaster(1000) c.TransportParameters = NewBroadcaster(10) c.ConnectionClosed = make(chan bool, 1) + c.ConnectionRestart = make(chan bool, 1) + c.ConnectionRestarted = make(chan bool, 1) c.PreparePacket = NewBroadcaster(1000) c.StreamInput = NewBroadcaster(1000)