The problem at hand
I needed to fetch rows from the database, and then process them. In theory, the process looks like this
async for v in read_values():
await process(v)
Now I don’t necessarily need to wait for one value to be processed before moving on to the next one.
This makes it the perfect candidate for concurrent processing.
The promise of async python
Async python promsies to make concurrency easier. This is somewhat true. It makes certain things easier to do concurrently, but makes certain things more difficult
The most common way to run async co-routines concurrently is to use python’s asyncio.gather
helper. Gather makes it easy to run a list of co-routines concurrently.
If I added asyncio.gather
to the earlier snippet, it would look something like this
async for value in read_values():
# Collect all the tasks
tasks.append(process(value))
# Process them concurrently
await asyncio.gather(*tasks)
And it does make the code run significantly faster. But I wonder if it can go even faster.
Making it go faster
The primary bottleneck here is that I have to wait for all the values to be read, before I can start processing them. So even though the processing is parallel (or rather concurrent), the reads are still happening serially.
What if it was possible to just start the processing, without waiting for all the reads to be complete. Luckily, python offers something that makes that possible. And it’s called asyncio.create_task
.
With asyncio.create_task
, the task is scheduled immediately as soon as it’s created. So the reads don’t have to be completed, before we can start the processing.
Making the change to the snippet,
async for value in read_values():
# Schedule the task as soon as the value is read
asyncio.create_task(process(value))
And we get another nice speedup. This becomes most apparent, if the reads take a lot of time to complete.
Bugs lurking in the background
But there’s a problem here. Sometimes a few values at the end don’t get processed. That’s because nothing is being awaited
here. So as soon as the read is complete, the function exits, without waiting for any tasks that may still be pending
The fix for this is simple. We can track if any tasks are pending, and make sure that if there are pending tasks, we wait for them to finish first.
pending_tasks = set()
async for value in read_values():
task = asyncio.create_task(process(value))
# Add task to the pending set
pending_tasks.add(task)
# Remove the task if it's completed
task.add_done_callback(pending_tasks.discard)
# Wait for any tasks that are still pending
await asyncio.gather(*pending_tasks)
And that’s it! The processing can run concurrently without waiting for all the reads to complete. And if all the reads are done, the function waits for any processing tasks that are still pending. The entire process is fully concurrent.
Stuffing it into an abstraction
The pattern is helpful enough that it warrants an abstraction that can be re-used. Enter the AsyncExecutor
. Taking the bits that we want into a separate class, the snippet becomes
class AsyncExecutor:
def __init__(self):
self.pending_tasks = set()
async def submit(self, task)
task = asyncio.create_task(process(value))
self.pending_tasks.add(task)
task.add_done_callback(self.pending_tasks.discard)
async def join(self, task)
await asyncio.gather(*self.pending_tasks)
executor = AsyncExecutor()
async for value in read_values():
await executor.submit(process(value))
await executor.join()
And that gives us a nice AsyncExecutor
that we can re-use.
Making it more pythonic
Manually having to call .join()
at the end of every executor call doesn’t really feel very pythonic. Luckily python’s async world provides another something to make this process automatic.
Enter the AsyncContextManager
. Similar to the standard ContextManager
, but more automatic
class AsyncExecutor:
def __init__(self):
self.pending_tasks = set()
async def submit(self, task)
task = asyncio.create_task(process(value))
self.pending_tasks.add(task)
task.add_done_callback(self.pending_tasks.discard)
async def join(self, task)
await asyncio.gather(*self.pending_tasks)
# AsyncContextManager interface implementation
async def __aenter__(self):
pass
async def __aexit__(self):
await self.join()
async with AsyncExecutor() as executor:
async for value in read_values():
await executor.submit(process(value))
Now that looks more pythonic!
Bonus: Concurrency Control
I’ve had quite a few situations where I needed to set a limit on the number of concurrent tasks.
With asyncio.gather
this usually means batching. Unfortunately it also means that the entire batch needs to be processed at once. There’s a lot of wasted time just waiting for the batch to be built.
Luckily the AsyncExecutor
can easily be modified to support concurrency control, using asyncio
’s built-in semaphore
class AsyncExecutor:
def __init__(self, max_concurrency=None):
self.pending_tasks = set()
# Create a semaphore only if a max_concurrency is specified
if max_concurrency:
self.sem = asyncio.semaphore(max_concurrency)
async def submit(self, task)
if self.sem:
# Don't submit a new task until the semaphore is available
await self.semaphore.acquire()
task = asyncio.create_task(process(value))
self.track(task)
def track(self, task):
self.pending_tasks.add(task)
task.add_done_callback(on_task_complete)
def on_task_complete(self, task):
self.pending_tasks.discard(task)
if self.sem:
# Release the semaphore for another task
self.sem.release()
async def join(self):
await asyncio.gather(*self.pending_tasks)
async def __aenter__(self):
pass
async def __aexit__(self):
await self.join()
async with AsyncExecutor(max_concurrency=50) as executor:
async for value in read_values():
await executor.submit(process(value))
No more than max_concurrency
tasks are executed at any time. This also means that the task is started as soon as it it ready to be executed (no need to wait for an entire batch).
Further enhancements
I have plans to add some other things to the AsyncExecutor
(such as timeouts and cancellation). But for now this AsyncExecutor
has served me well.