Skip to content

Commit

Permalink
Use channel instead of semaphore for locks
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk committed Jan 23, 2024
1 parent d607449 commit 7b87f22
Showing 1 changed file with 91 additions and 28 deletions.
119 changes: 91 additions & 28 deletions src/NATS.Client.Core/Commands/CommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Net.Sockets;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
using NATS.Client.Core.Internal;

Expand All @@ -28,9 +29,9 @@ internal sealed class CommandWriter : IAsyncDisposable
private readonly Action<PingCommand> _enqueuePing;
private readonly NatsOpts _opts;
private readonly ProtocolWriter _protocolWriter;
private readonly SemaphoreSlim _semLock;
private readonly Task _readerLoopTask;
private readonly HeaderWriter _headerWriter;
private readonly Channel<int> _channelLock;
private ISocketConnection? _socketConnection;
private PipeReader? _pipeReader;
private PipeWriter? _pipeWriter;
Expand All @@ -45,7 +46,7 @@ public CommandWriter(ObjectPool pool, NatsOpts opts, ConnectionStatsCounter coun
_enqueuePing = enqueuePing;
_opts = opts;
_protocolWriter = new ProtocolWriter(opts.SubjectEncoding);
_semLock = new SemaphoreSlim(1);
_channelLock = Channel.CreateBounded<int>(1);
_headerWriter = new HeaderWriter(_opts.HeaderEncoding);
_cts = new CancellationTokenSource();
_readerLoopTask = Task.Run(ReaderLoopAsync);
Expand Down Expand Up @@ -76,7 +77,11 @@ public async ValueTask DisposeAsync()
await _cts.CancelAsync().ConfigureAwait(false);
#endif

await _semLock.WaitAsync().ConfigureAwait(false);
while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync().ConfigureAwait(false);
}

try
{
if (_disposed)
Expand All @@ -95,14 +100,22 @@ public async ValueTask DisposeAsync()
}
finally
{
_semLock.Release();
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync().ConfigureAwait(false);
}
}
}

public async ValueTask ConnectAsync(ClientOpts connectOpts, CancellationToken cancellationToken)
{
Interlocked.Add(ref _counter.PendingMessages, 1);
await _semLock.WaitAsync(cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _counter.PendingMessages);

while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false);
}

try
{
if (_disposed)
Expand All @@ -118,20 +131,29 @@ public async ValueTask ConnectAsync(ClientOpts connectOpts, CancellationToken ca
bw = _pipeWriter!;
}

_protocolWriter.WriteConnect(bw, connectOpts!);
_protocolWriter.WriteConnect(bw, connectOpts);
await bw.FlushAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
_semLock.Release();
Interlocked.Add(ref _counter.PendingMessages, -1);
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
}
}

public async ValueTask PingAsync(PingCommand pingCommand, CancellationToken cancellationToken)
{
Interlocked.Add(ref _counter.PendingMessages, 1);
await _semLock.WaitAsync(cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _counter.PendingMessages);

while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false);
}

try
{
if (_disposed)
Expand All @@ -154,15 +176,24 @@ public async ValueTask PingAsync(PingCommand pingCommand, CancellationToken canc
}
finally
{
_semLock.Release();
Interlocked.Add(ref _counter.PendingMessages, -1);
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
}
}

public async ValueTask PongAsync(CancellationToken cancellationToken = default)
{
Interlocked.Add(ref _counter.PendingMessages, 1);
await _semLock.WaitAsync(cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _counter.PendingMessages);

while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false);
}

try
{
if (_disposed)
Expand All @@ -183,7 +214,12 @@ public async ValueTask PongAsync(CancellationToken cancellationToken = default)
}
finally
{
_semLock.Release();
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
}
}

Expand All @@ -208,8 +244,13 @@ public ValueTask PublishAsync<T>(string subject, T? value, NatsHeaders? headers,

public async ValueTask SubscribeAsync(int sid, string subject, string? queueGroup, int? maxMsgs, CancellationToken cancellationToken)
{
Interlocked.Add(ref _counter.PendingMessages, 1);
await _semLock.WaitAsync(cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _counter.PendingMessages);

while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false);
}

try
{
if (_disposed)
Expand All @@ -230,15 +271,24 @@ public async ValueTask SubscribeAsync(int sid, string subject, string? queueGrou
}
finally
{
_semLock.Release();
Interlocked.Add(ref _counter.PendingMessages, -1);
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
}
}

public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken cancellationToken)
{
Interlocked.Add(ref _counter.PendingMessages, 1);
await _semLock.WaitAsync(cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _counter.PendingMessages);

while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false);
}

try
{
if (_disposed)
Expand All @@ -259,8 +309,12 @@ public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken
}
finally
{
_semLock.Release();
Interlocked.Add(ref _counter.PendingMessages, -1);
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
}
}

Expand All @@ -270,8 +324,13 @@ public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken
[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))]
private async ValueTask PublishLockedAsync(string subject, string? replyTo, NatsPooledBufferWriter<byte> payloadBuffer, NatsPooledBufferWriter<byte>? headersBuffer, CancellationToken cancellationToken)
{
Interlocked.Add(ref _counter.PendingMessages, 1);
await _semLock.WaitAsync(cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _counter.PendingMessages);

while (!_channelLock.Writer.TryWrite(1))
{
await _channelLock.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false);
}

try
{
var payload = payloadBuffer.WrittenMemory;
Expand Down Expand Up @@ -305,8 +364,12 @@ private async ValueTask PublishLockedAsync(string subject, string? replyTo, Nat
}
finally
{
_semLock.Release();
Interlocked.Add(ref _counter.PendingMessages, -1);
while (!_channelLock.Reader.TryRead(out _))
{
await _channelLock.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

Interlocked.Decrement(ref _counter.PendingMessages);
}
}

Expand Down

0 comments on commit 7b87f22

Please sign in to comment.