A simple threadsafe caching decorator

Recently I've encountered a scenario where I needed to cache the results of calls to an I/O-related function f(). f() took a few seconds to execute, and I needed to take into account a scenario where the call to f() could come from several threads at once. The builtin functools module provides the lru_cache decorator, which fixes half of my problem: once the first call to an lru_cache decorated function is complete, any subsequent calls with the same arguments from any other thread will use the cached result. In a multi-threaded scenario, this is not quite enough - once thread A called f(), any other thread B could have made an identical call to f'() while f() was still in flight. f'() would then be executed since the cache does not have the result of f() yet. The desired behaviour would be to make f'() wait until f() is done and then use the cached result. To achieve this, I've built a simple threadsafe LRU cache decorator which handles both caching and thread synchronization. The implementation is below, followed by a brief explanation on how it works and why it is threadsafe. For clarity, I've omitted the @wraps decorator, and also the typed and max_size optional arguments which are present in the original lru_cache implementation.

import threading
from collections import defaultdict
from functools import lru_cache, _make_key

def threadsafe_lru(func):
    func = lru_cache()(func)
    lock_dict = defaultdict(threading.Lock)

    def _thread_lru(*args, **kwargs):
        key = _make_key(args, kwargs, typed=False)  
        with lock_dict[key]:
            return func(*args, **kwargs)

    return _thread_lru

So, the first thing that you can see here, is that no wheels are reinvented.
I'm still using the lru_cache and _make_key functions provided by functools. All that is left to do is to block the execution of a function if an identical call is already in flight. Deciding whether two funciton calls are identical is done in the same way lru_cache() does it - by using the _make_key function to create a unique key based on the function arguments. The corresponding lock in lock_dict is then acquired (creating it, if it does not yet exist), making all other calls with the same arguments wait until the result of the first call is cached and ready. Any other calls with different arguments, will have a different 'key' and will consequentally acquire a different lock. This way, calls with different argument values do not interfere with each other.

An important nuance comes up at this point - in CPython dict lookups and insertions are threadsafe. However, this is a defaultdict - while the lookups and insertions are still threadsafe, the callable defaultdict uses to create a missing key might not be, causing a race condition where two different threads attempt to create a new threading.Lock and then assign it to the same key. However, in this particular case and as long as you are running CPython, defaultdict(threading.Lock) is threadsafe - (see here for an explanation).

The following example submits 10 jobs to a threadpool in two duplicate batches - B1 and B2. B2 is submitted with a slight delay while B1 is still in-flight. Note both the run timings and the 'Executing..' messages. In a naive cache implementation the calls in B2 would be executed anyway since the lru_cache would not be ready yet.

import threading
from functools import lru_cache, _make_key
from collections import defaultdict
from time import sleep
from random import randint, seed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial

sleep_length = 5
seed(123)
arg_range = 3
num_tasks = 5

def threadsafe_lru(func):
    func = lru_cache()(func)
    lock_dict = defaultdict(threading.Lock)

    def _thread_lru(*args, **kwargs):
        key = _make_key(args, kwargs, typed=False)  
        with lock_dict[key]:
            return func(*args, **kwargs)

    return _thread_lru




"""
comment out this decorator (or replace it with @lru_cache) 
and take a look at the results
"""
@threadsafe_lru  
def long_running_function(x, y):
    print("Executing with x={}, y={}".format(x, y))
    sleep(sleep_length)



func_list = [partial(long_running_function, x=randint(0, arg_range)
                     , y=randint(0, arg_range)) for _ in range(num_tasks)]


def create_timing_callback(fn_id, *args, **kwargs):
    start_time = datetime.now()

    def timer(future):
        exec_time = round((datetime.now() - start_time).total_seconds())
        print("Total run time of {} seconds for job  #{} called with {}, {}"
            .format(exec_time, fn_id, args, kwargs))

    return timer


with ThreadPoolExecutor(max_workers=10) as executor:
    fn_id = 0
    for batch in [0, 1]:
        print("Batch {}".format(batch))
        for fn in func_list:
            print("Submitting job #{} with arguments {}"
            .format(fn_id, fn.keywords))
            fut = executor.submit(fn)
            fut.add_done_callback(create_timing_callback(fn_id, fn.keywords))
            fn_id += 1
        sleep(1)
    print("Waiting for results..")
    executor.shutdown(wait=True)

Running this code with (tested with python3.5.1) should produce:

Batch 0
Submitting function #0 with arguments {'x': 0, 'y': 2}
Executing with x=0, y=2
Submitting function #1 with arguments {'x': 0, 'y': 3}
Executing with x=0, y=3
Submitting function #2 with arguments {'x': 2, 'y': 0}
Executing with x=2, y=0
Submitting function #3 with arguments {'x': 0, 'y': 3}
Submitting function #4 with arguments {'x': 2, 'y': 2}
Executing with x=2, y=2
Batch 1
Submitting function #5 with arguments {'x': 0, 'y': 2}
Submitting function #6 with arguments {'x': 0, 'y': 3}
Submitting function #7 with arguments {'x': 2, 'y': 0}
Submitting function #8 with arguments {'x': 0, 'y': 3}
Submitting function #9 with arguments {'x': 2, 'y': 2}
Waiting for results..
Total run time of 5 seconds for function  #1 called with ({'x': 0, 'y': 3},), {}
Total run time of 5 seconds for function  #3 called with ({'x': 0, 'y': 3},), {}
Total run time of 4 seconds for function  #6 called with ({'x': 0, 'y': 3},), {}
Total run time of 4 seconds for function  #8 called with ({'x': 0, 'y': 3},), {}
Total run time of 5 seconds for function  #2 called with ({'x': 2, 'y': 0},), {}
Total run time of 4 seconds for function  #7 called with ({'x': 2, 'y': 0},), {}
Total run time of 5 seconds for function  #0 called with ({'x': 0, 'y': 2},), {}
Total run time of 4 seconds for function  #5 called with ({'x': 0, 'y': 2},), {}
Total run time of 5 seconds for function  #4 called with ({'x': 2, 'y': 2},), {}
Total run time of 4 seconds for function  #9 called with ({'x': 2, 'y': 2},), {}

In the first batch, there are two calls to long_running_function(x=0, y=3) The output shows that only one them is actually executed while the other will wait for five seconds for the result to become available in the cache. The second batch of jobs is launched a second later and has four seconds to wait until the cache is ready.

Why you probably shouldn't use this code as-is

Adding the missing 'typed' decorator argument (when typed=True, _make_key treats 3 and 3.0 differently) is relatively straightforward - I just wanted to avoid here the extra level of indentation that comes when implementing a parametrized decorator. The trickier part is implementing 'max_size'. When len(lock_dict)>max_size, old Locks need to be discarded. Therefore, we need to know which locks are 'old' (an OrderedDict might help) and we need to check len(lock_dict) each time we insert a new key/threadlock pair. Since several threads might attempt to add several different keys, another lock is needed to make access to lock_dict threadsafe. This might not affect performance in a significant way in most scenarios, but bugs me on a personal level since function calls might briefly wait on each other even when they are not duplicates and when a cached result is already available. A variation of a reader-writer lock might bring the time each thread spends holding the lock unnecessarily down to a minimum.

blogroll

social