Monday 18 May 2020

Multi-path cancellation; a tale of two codependent async enumerators

Disclaimer: I'll be honest: many of the concepts in this post are a bit more advanced - some viewer caution is advised! It touches on concurrent linked async enumerators that share a termination condition by combining multiple CancellationToken.


Something that I've been looking at recently - in the context of gRPC (and protobuf-net.Grpc in particular) - is the complex story of duplex data pipes. A full-duplex connection is a connection between two nodes, but instead of being request-response, either node can send messages at any time. There's still a notional "client" and "server", but that is purely a feature of which node was sat listening for connection attempts vs which node reached out and established a connection. Shaping a duplex API is much more complex than shaping a request-response API, and frankly: a lot of the details around timing are hard.

So: I had the idea that maybe we can reshape everything at the library level, and offer the consumer something more familiar. It makes an interesting (to me, at least) worked example of cancellation in practice. So; let's start with an imaginary transport API (the thing that is happening underneath) - let's say that we have:

  • a client establishes a connection (we're not going to worry about how)
  • there is a SendAsync message that sends a message from the client to the server
  • there is a TryReceiveAsync message that attempts to await a message from the server (this will also report true if a message could be fetched, and false if the server has indicated that it won't ever be sending any more)
  • additionally, the server controls data flow termination; if the server indicates that it has sent the last message, the client should not send any more

something like (where TRequest is the data-type being sent from the client to the server, and TResponse is the data-type expected from the server to the client):

interface ITransport<TRequest, TResponse> : IAsyncDisposable
{
    ValueTask SendAsync(TRequest request,
        CancellationToken cancellationToken);

    ValueTask<(bool Success, TResponse Message)> TryReceiveAsync(
        CancellationToken cancellationToken);
}

This API doesn't look all that complicated - it looks like (if we ignore connection etc for the moment) we can just create a couple of loops, and expose the data via enumerators - presumably starting the SendAsync via Task.Run or similar so it is on a parallel flow:

ITransport<TRequest, TResponse> transport;
public async IAsyncEnumerable<TResponse> ReceiveAsync(
    [EnumeratorCancellation] CancellationToken cancellationToken)
{
    while (true)
    {
        var (success, message) =
            await transport.TryReceiveAsync(cancellationToken);
        if (!success) break;
        yield return message;
    }
}

public async ValueTask SendAsync(
    IAsyncEnumerable<TRequest> data,
    CancellationToken cancellationToken)
{
    await foreach (var message in data
        .WithCancellation(cancellationToken))
    {
        await transport.SendAsync(message, cancellationToken);
    }
}

and it looks like we're all set for cancellation - we can pass in an external cancellation-token to both methods, and we're set. Right?

Well, it is a bit more complex than that, and the above doesn't take into consideration that these two flows are codependent. In particular, a big concern is that we don't want to leave the producer (the thing pumping SendAsync) still running in any scenario where the connection is doomed. There are actually many more cancellation paths than we might think:

  1. we might have supplied an external cancellation-token to both methods, and this token may have triggered
  2. the consumer of ReceiveAsync (the thing iterating it) might have supplied a cancellation-token to GetAsyncEnumerator (via WithCancellation), and this token may have been triggered (we looked at this last time)
  3. we could have faulted in our send/receive code
  4. the consumer of ReceiveAsync may have decided not to take all the data - that might be because of some async simile of Enumerable.Take(), or it could be because they faulted when processing a message they had received
  5. the producer in SendAsync may have faulted

All of these scenarios essentially signify termination of the connection, so we need to be able to encompass all of these scenarios in some way that allows us to communicate the problem between the send and receive path. In a word, we want our own CancellationTokenSource.

There's a lot going on here; more than we can reasonably expect consumers to do each and every time they use the API, so this is a perfect scenario for a library method. Let's imagine that we want to encompass all this complexity in a simple single library API that the consumer can access - something like:

public IAsyncEnumerable<TResponse> Duplex(
    IAsyncEnumerable<TRequest> request,
    CancellationToken cancellationToken = default);

This:

  • allows them to pass in a producer
  • optionally allows them to pass in an external cancellation-token
  • makes an async feed of responses available to them

Their usage might be something like:

await foreach (MyResponse item in client.Duplex(ProducerAsync()))
{
    Console.WriteLine(item);
}

where their ProducerAsync() method is (just "because"):

async IAsyncEnumerable<MyRequest> ProducerAsync(
    [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
    for (int i = 0; i < 100; i++)
    {
        yield return new MyRequest(i);
        await Task.Delay(100, cancellationToken);
    }
}

As I discussed in The anatomy of async iterators (aka await, foreach, yield), our call to ProducerAsync() doesn't actually do much yet - this just hands a place-holder that can be enumerated later, and it is the act of enumerating it that actually invokes the code. Very important point, that.

So; what can our Duplex code do? It already needs to think about at least 2 different kinds of cancellation:

  • the external token that was passed into cancellationToken
  • the potentially different token that could be passed into GetAsyncEnumerator() when it is consumed

but we know from our thoughts earler that we also have a bunch of other ways of cancelling. We can do something clever here. Recall how the compiler usually combines the above two tokens for us? Well, if we do that ourselves, then instead of getting just a CancellationToken, we find ourselves with a CancellationTokenSource, which gives us lots of control:

public IAsyncEnumerable<TResponse> Duplex(
    IAsyncEnumerable<TRequest> request,
    CancellationToken cancellationToken = default)
    => DuplexImpl(transport, request, cancellationToken);

private async static IAsyncEnumerable<TResponse> DuplexImpl(
    ITransport<TRequest, TResponse> transport,
    IAsyncEnumerable<TRequest> request,
    CancellationToken externalToken,
    [EnumeratorCancellation] CancellationToken enumeratorToken = default)
{
    using var allDone = CancellationTokenSource.CreateLinkedTokenSource(
            externalToken, enumeratorToken);
    // ... todo
}

Our DuplexImpl method here allows the enumerator cancellation to be provided, but (importantly) kept separate from the original external token; this means that it won't yet be combined, and we can do that ourselves using CancellationTokenSource.CreateLinkedTokenSource - much like the compiler would have done for us, but: now we have a CancellationTokenSource that we can cancel when we choose. This means that we can use allDone.Token in all the places we want to ask "are we done yet?", and we're considering everything.

For starters, let's handle the scenario where the consumer doesn't take all the data (out of choice, or because of a fault). We want to trigger allDone however we exit DuplexImpl. Fortunately, the way that iterator blocks are implemented makes this simple (and we're already using it here, via using): recall (from the previous blog post) that foreach and await foreach both (usually) include a using block that invokes Dispose/DisposeAsync on the enumerator instance? Well: anything we put in a finally essentially relocates to that Dispose/DisposeAsync. The upshot of this is that triggering the cancellation token when the consumer is done with us is trivial:

using var allDone = CancellationTokenSource.CreateLinkedTokenSource(
        externalToken, enumeratorToken);
try
{
    // ... todo
}
finally
{   // cancel allDone however we exit
    allDone.Cancel();
}

The next step is to get our producer working - that's our SendAsync code. Because this is duplex, it doesn't have any bearing on the incoming messages, so we'll start that as a completely separate code-path via Task.Run, but we can make it such that if the producer or send faults, it stops the entire show; so if we look just at our // ... todo code, we can add:

var send = Task.Run(async () =>
{
    try
    {
        await foreach (var message in
            request.WithCancellation(allDone.Token))
        {
            await transport.SendAsync(message, allDone.Token);
        }
    }
    catch
    {   // trigger cancellation if send faults
        allDone.Cancel();
        throw;
    }
}, allDone.Token);

// ... todo: receive

await send; // observe send outcome

This starts a parallel operation that consumes the data from our producer, but notice that we're using allDone.Token to pass our combined cancellation knowledge to the producer. This is very subtle, because it represents a cancellation state that didn't even conceptually exist at the time ProducerAsync() was originall invoked. The fact that GetAsyncEnumerator is deferred has allowed us to give it something much more useful, and as long as ProducerAsync() uses the cancellation-token appropriately, it can now be fully aware of the life-cycle of the composite duplex operation.

This just leaves our receive code, which is more or less like it was originally, but again: using allDone.Token:

while (true)
{
    var (success, message) = await transport.TryReceiveAsync(allDone.Token);
    if (!success) break;
    yield return message;
}

// the server's last message stops everything
allDone.Cancel();

Putting all this together gives us a non-trivial libray function:

private async static IAsyncEnumerable<TResponse> DuplexImpl(
    ITransport<TRequest, TResponse> transport,
    IAsyncEnumerable<TRequest> request,
    CancellationToken externalToken,
    [EnumeratorCancellation] CancellationToken enumeratorToken = default)
{
    using var allDone = CancellationTokenSource.CreateLinkedTokenSource(
        externalToken, enumeratorToken);
    try
    {
        var send = Task.Run(async () =>
        {
            try
            {
                await foreach (var message in
                    request.WithCancellation(allDone.Token))
                {
                    await transport.SendAsync(message, allDone.Token);
                }
            }
            catch
            {   // trigger cancellation if send faults
                allDone.Cancel();
                throw;
            }
        }, allDone.Token);

        while (true)
        {
            var (success, message) = await transport.TryReceiveAsync(allDone.Token);
            if (!success) break;
            yield return message;
        }

        // the server's last message stops everything
        allDone.Cancel();

        await send; // observe send outcome
    }
    finally
    {   // cancel allDone however we exit
        allDone.Cancel();
    }
}

The key points here being:

  • both the external token and the enumerator token contribute to allDone
  • the transport-level send and receive code uses allDone.Token
  • the producer enumeration uses allDone.Token
  • however we exit our enumerator, allDone is cancelled
    • if transport-receive faults, allDone is cancelled
    • if the consumer terminates early, allDone is cancelled
  • when we receive the last message from the server, allDone is cancelled
  • if the producer or transport-send faults, allDone is cancelled

The one thing it doesn't support well is people using GetAsyncEnumerator() directly and not disposing it. That comes under the heading of "using the API incorrectly", and is self-inflicted.

A side note on ConfigureAwait(false); by default await includes a check on SynchronizationContext.Current; in addition to meaning an extra context-switch, in the case of UI applications this may mean running code on the UI thread that does not need to run on the UI thread. Library code usually does not require this (it isn't as though we're updating form controls here, so we don't need thread-affinity). As such, in library code, it is common to use .ConfigureAwait(false) basically everywhere that you see an await - which bypasses this mechanism. I have not included that in the code above, for readability, but: you should imagine it being there :) By contrast, in application code, you should usually default to just using await without ConfigureAwait, unless you know you're writing something that doesn't need sync-context.

I hope this has been a useful delve into some of the more complex things you can do with cancellation-tokens, and how you can combine them to represent codependent exit conditions.