Recently I've encountered a scenario where I needed to cache the results
of calls to an I/O-related function f(). f() took a few seconds to execute,
and I needed to take into account a scenario where the call to f() could come from several threads at once.
The builtin functools module provides the lru_cache decorator, which fixes
half of my problem: once the first call to an lru_cache decorated function
is complete, any subsequent calls with the same arguments from any other
thread will use the cached result. In a multi-threaded scenario, this is
not quite enough - once thread A called f(), any other thread B could have
made an identical call to f'() while f() was still in flight. f'() would then
be executed since the cache does not have the result of f() yet.
The desired behaviour would be to make f'() wait until f() is done and
then use the cached result. To achieve this, I've built a simple threadsafe
LRU cache decorator which handles both caching and thread synchronization.
The implementation is below, followed by a brief explanation on how it works
and why it is threadsafe. For clarity, I've omitted the @wraps decorator,
and also the typed and max_size optional arguments which are present in the
original lru_cache implementation.
import threading
from collections import defaultdict
from functools import lru_cache, _make_key
def threadsafe_lru(func):
func = lru_cache()(func)
lock_dict = defaultdict(threading.Lock)
def _thread_lru(*args, **kwargs):
key = _make_key(args, kwargs, typed=False)
with lock_dict[key]:
return func(*args, **kwargs)
return _thread_lru
So, the first thing that you can see here, is that no wheels are reinvented.
I'm still using the lru_cache and _make_key functions provided by functools.
All that is left to do is to block the execution of a function if an identical call
is already in flight. Deciding whether two funciton calls are identical is done in the same way lru_cache() does it
- by using the _make_key function to create a unique key based on the
function arguments. The corresponding lock in lock_dict is then acquired
(creating it, if it does not yet exist), making all other calls with the
same arguments wait until the result of the first call is cached and ready.
Any other calls with different arguments, will have a different 'key' and
will consequentally acquire a different lock. This way, calls with different
argument values do not interfere with each other.
An important nuance comes up at this point - in CPython dict lookups and
insertions are threadsafe. However, this is a defaultdict - while the lookups
and insertions are still threadsafe, the callable defaultdict uses to create
a missing key might not be, causing a race condition where two different
threads attempt to create a new threading.Lock and then assign it to the same key.
However, in this particular case and as long as you are running CPython,
defaultdict(threading.Lock) is threadsafe - (see here for an explanation).
The following example submits 10 jobs to a threadpool in two duplicate
batches - B1 and B2. B2 is submitted with a slight delay while B1 is
still in-flight. Note both the run timings and the 'Executing..' messages.
In a naive cache implementation the calls in B2 would be executed anyway
since the lru_cache would not be ready yet.
import threading
from functools import lru_cache, _make_key
from collections import defaultdict
from time import sleep
from random import randint, seed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
sleep_length = 5
seed(123)
arg_range = 3
num_tasks = 5
def threadsafe_lru(func):
func = lru_cache()(func)
lock_dict = defaultdict(threading.Lock)
def _thread_lru(*args, **kwargs):
key = _make_key(args, kwargs, typed=False)
with lock_dict[key]:
return func(*args, **kwargs)
return _thread_lru
"""
comment out this decorator (or replace it with @lru_cache)
and take a look at the results
"""
@threadsafe_lru
def long_running_function(x, y):
print("Executing with x={}, y={}".format(x, y))
sleep(sleep_length)
func_list = [partial(long_running_function, x=randint(0, arg_range)
, y=randint(0, arg_range)) for _ in range(num_tasks)]
def create_timing_callback(fn_id, *args, **kwargs):
start_time = datetime.now()
def timer(future):
exec_time = round((datetime.now() - start_time).total_seconds())
print("Total run time of {} seconds for job #{} called with {}, {}"
.format(exec_time, fn_id, args, kwargs))
return timer
with ThreadPoolExecutor(max_workers=10) as executor:
fn_id = 0
for batch in [0, 1]:
print("Batch {}".format(batch))
for fn in func_list:
print("Submitting job #{} with arguments {}"
.format(fn_id, fn.keywords))
fut = executor.submit(fn)
fut.add_done_callback(create_timing_callback(fn_id, fn.keywords))
fn_id += 1
sleep(1)
print("Waiting for results..")
executor.shutdown(wait=True)
Running this code with (tested with python3.5.1) should produce:
Batch 0
Submitting function #0 with arguments {'x': 0, 'y': 2}
Executing with x=0, y=2
Submitting function #1 with arguments {'x': 0, 'y': 3}
Executing with x=0, y=3
Submitting function #2 with arguments {'x': 2, 'y': 0}
Executing with x=2, y=0
Submitting function #3 with arguments {'x': 0, 'y': 3}
Submitting function #4 with arguments {'x': 2, 'y': 2}
Executing with x=2, y=2
Batch 1
Submitting function #5 with arguments {'x': 0, 'y': 2}
Submitting function #6 with arguments {'x': 0, 'y': 3}
Submitting function #7 with arguments {'x': 2, 'y': 0}
Submitting function #8 with arguments {'x': 0, 'y': 3}
Submitting function #9 with arguments {'x': 2, 'y': 2}
Waiting for results..
Total run time of 5 seconds for function #1 called with ({'x': 0, 'y': 3},), {}
Total run time of 5 seconds for function #3 called with ({'x': 0, 'y': 3},), {}
Total run time of 4 seconds for function #6 called with ({'x': 0, 'y': 3},), {}
Total run time of 4 seconds for function #8 called with ({'x': 0, 'y': 3},), {}
Total run time of 5 seconds for function #2 called with ({'x': 2, 'y': 0},), {}
Total run time of 4 seconds for function #7 called with ({'x': 2, 'y': 0},), {}
Total run time of 5 seconds for function #0 called with ({'x': 0, 'y': 2},), {}
Total run time of 4 seconds for function #5 called with ({'x': 0, 'y': 2},), {}
Total run time of 5 seconds for function #4 called with ({'x': 2, 'y': 2},), {}
Total run time of 4 seconds for function #9 called with ({'x': 2, 'y': 2},), {}
In the first batch, there are two calls to long_running_function(x=0, y=3)
The output shows that only one them is actually executed while the other
will wait for five seconds for the result to become available in the cache.
The second batch of jobs is launched a second later and has four seconds
to wait until the cache is ready.
Why you probably shouldn't use this code as-is
Adding the missing 'typed' decorator argument (when typed=True, _make_key
treats 3 and 3.0 differently) is relatively straightforward - I just wanted
to avoid here the extra level of indentation that comes when implementing
a parametrized decorator.
The trickier part is implementing 'max_size'. When len(lock_dict)>max_size,
old Locks need to be discarded. Therefore, we need to know which locks are
'old' (an OrderedDict might help) and we need to check len(lock_dict)
each time we insert a new key/threadlock pair. Since several threads might
attempt to add several different keys, another lock is needed to make access
to lock_dict threadsafe. This might not affect performance in a significant
way in most scenarios, but bugs me on a personal level since function calls
might briefly wait on each other even when they are not duplicates and when
a cached result is already available. A variation of a reader-writer lock
might bring the time each thread spends holding the lock unnecessarily down to a minimum.