Skip to content

Commit

Permalink
[Bug] Fix CloseAllAsync exception. Fixes #33 (#36)
Browse files Browse the repository at this point in the history
Fix the issue with the CloseAllAsync method that was throwing a
cancelation exception because we were canceling the token rather than
closing the channels.

Fixes #33
  • Loading branch information
mandel-macaque authored Aug 21, 2024
1 parent 1670a22 commit 01fb9f5
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 12 deletions.
38 changes: 37 additions & 1 deletion Marille.Tests/CancellationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public async Task CloseAllWorkersNoEvents ()

var topic2 = "topic2";
var tcs2 = new TaskCompletionSource<bool> ();
var worker2 = new BlockingWorker(tcs1);
var worker2 = new BlockingWorker(tcs2);
await _hub.CreateAsync<WorkQueuesEvent> (topic2, configuration);
await _hub.RegisterAsync (topic2, worker2);

Expand Down Expand Up @@ -154,4 +154,40 @@ public async Task MultithreadedClose ()
}
Assert.False (finalResult);
}

[Fact]
public async Task CloseAllChannelsAsync ()
{
// build several channels and then close them all, this should ensure that all the workers
// have consume all the messages
var eventCount = 100;
var list = new List<Task> (200);

configuration.Mode = ChannelDeliveryMode.AtLeastOnce;
var topic1 = "topic1";
var tcs1 = new TaskCompletionSource<bool> ();
var worker1 = new BlockingWorker(tcs1);
await _hub.CreateAsync<WorkQueuesEvent> (topic1, configuration);
await _hub.RegisterAsync (topic1, worker1);

var topic2 = "topic2";
var tcs2 = new TaskCompletionSource<bool> ();
var worker2 = new BlockingWorker(tcs2);
await _hub.CreateAsync<WorkQueuesEvent> (topic2, configuration);
await _hub.RegisterAsync (topic2, worker2);

for (var index = 0; index < eventCount; index++) {
await _hub.Publish (topic1, new WorkQueuesEvent($"myID{index}"));
await _hub.Publish (topic2, new WorkQueuesEvent($"myID{index}"));
}

// we are blocking the consume of the channels
Assert.True(tcs1.TrySetResult(true));
Assert.True(tcs2.TrySetResult(true));

// close the hub, should throw no cancellation token exceptions and events should have been processed
await _hub.CloseAllAsync ();
Assert.Equal (100, worker1.ConsumedCount);
Assert.Equal (100, worker2.ConsumedCount);
}
}
12 changes: 7 additions & 5 deletions Marille/Hub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,14 @@ public async Task CloseAllAsync ()
let tasks = topic.ConsumerTasks
from task in tasks select task;

var cancellationTasks = from topic in topics.Values
let cancellationTokens = topic.CancellationTokenSources
from source in cancellationTokens select source.CancelAsync ();
var topicInfos = from topic in topics.Values
let channels = topic.Channels
from ch in channels select ch;

foreach (var topicInfo in topicInfos) {
topicInfo.CloseChannel ();
}

// we could do a nested Task.WhenAll but we want to ensure that the cancellation tasks are done before
await Task.WhenAll (cancellationTasks);
await Task.WhenAll (consumingTasks);
} finally {
semaphoreSlim.Release ();
Expand Down
2 changes: 1 addition & 1 deletion Marille/Marille.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<PropertyGroup>
<Title>Marille</Title>
<PackageId>Marille</PackageId>
<Version>0.4.2</Version>
<Version>0.4.3</Version>
<Authors>Manuel de la Peña Saenz</Authors>
<Owners>Manuel de la Peña Saenz</Owners>
<Copyright>Manuel de la Peña Saenz</Copyright>
Expand Down
6 changes: 2 additions & 4 deletions Marille/Topic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ internal class Topic (string name) {
where info.ConsumerTask is not null
select info.ConsumerTask;

public IEnumerable<CancellationTokenSource> CancellationTokenSources => from info in channels.Values
where info.CancellationTokenSource is not null
select info.CancellationTokenSource;
public IEnumerable<TopicInfo> Channels => channels.Values;

public bool TryGetChannel<T> ([NotNullWhen (true)] out TopicInfo<T>? channel) where T : struct
{
Expand Down Expand Up @@ -46,7 +44,7 @@ public void CloseChannel<T> () where T : struct
if (!TryGetChannel<T> (out var chInfo))
return;

chInfo.Channel.Writer.Complete ();
chInfo.CloseChannel ();
channels.Remove (typeof (T));
}

Expand Down
9 changes: 8 additions & 1 deletion Marille/TopicInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

namespace Marille;

internal record TopicInfo (TopicConfiguration Configuration){
internal abstract record TopicInfo (TopicConfiguration Configuration){
public CancellationTokenSource? CancellationTokenSource { get; set; }
public Task? ConsumerTask { get; set; }

public abstract void CloseChannel ();
}

internal record TopicInfo<T> (TopicConfiguration Configuration, Channel<Message<T>> Channel) : TopicInfo (Configuration)
Expand All @@ -16,4 +18,9 @@ public TopicInfo (TopicConfiguration configuration, Channel<Message<T>> channel,
{
Workers.AddRange (workers);
}

public override void CloseChannel ()
{
Channel.Writer.Complete ();
}
}

0 comments on commit 01fb9f5

Please sign in to comment.