Time-Based Memoization

Memoization is a technique for caching the result of a function and returning it when the function is called again with the same inputs. It’s primarily used when a call is time-consuming or complex and, given the same inputs, the result doesn’t change. Time-based memoization is an extension of this idea, where the result is only cached for a brief period of time.

One very useful application is for wrapping functions that make requests to other microservices, especially when we a) might make the same call multiple times from different functions, and b) don’t want to preserve or manage the response in some sort of context in the calling function(s).

Basic usage:

@memoize(until=timedelta(seconds=30))
def get_features(customer_id, flag=None):
  ...

Behavior:

  • If memoization is disabled, call the function directly.
  • Do we have a cached result for this set of inputs? Is it fresh? If so, return it.
  • If no result was found, call the function and register the result and expiration.
  • Clean up any stale results on our way out.

Testing code that uses this decorator could be problematic, because a cached result would persist between tests. To address this, the decorator itself has a set of management methods, each called on the wrapped function:

get_features.clear()  # Clear all cached results
get_features.enable_memoization()
get_features.disable_memoization()

The implementation:

from collections import OrderedDict
from datetime import datetime, timedelta
from functools import wraps


def memoize(fn, until=None):
    """Time-based memoization decorator

    When the wrapped function is called, the result is recorded and
    timestamped. Any subsequent calls within the `until` window *with the same
    function parameters* will return the cached result, rather than executing
    the function again.

    Arguments:
        fn (function): Function to wrap
        until (timedelta): How long a result is valid
    Returns:
        function: The wrapper function
    """
    expirations = OrderedDict()
    results = {}
    state = {
        'enabled': True
    }

    @wraps(fn)
    def wrapper(*args, **kwargs):
        
        # Pass-through to function if not enabled
        if not state['enabled']:
            return fn(*args, **kwargs)

        now = datetime.utcnow()

        # Check if this combination of parameters has a memoized result
        call_key = (args, tuple(kwargs.items()))
        result, expiration = results.get(call_key, (None, None))

        # If no cached result exists, or if it's expired
        if result is None or (expiration is not None and now >= expiration):
            # call out to the original function
            result = fn(*args, **kwargs)
            new_expiration = now + until if until else None

            # save the result and an expiration entry for easy filtering later
            results[call_key] = (result, new_expiration)
            expirations[new_expiration] = call_key

            # remove the old expiration record, if there was one
            if expiration:
                del expirations[expiration]

        # Clean up all expired results to avoid hogging memory
        removed = []
        for expiration, result_key in expirations.items():
            if expiration < now:
                del results[result_key]

                # Mark them for removal, then do the actual removal after
                # the look to avoid modifying the same collection we're
                # iterating over
                removed.append(expiration)
            else:
                # Due to keys being ordered, when we discovered an unexpired
                # record, we know we're finished
                break

        for key in removed:
            del expirations[key]

        return result
    
    def clear():
        """Clear all cached results. Callable on the original memoized
        function.
        """
        results.clear()
        expirations.clear()

    def enable_memoization():
        """Enable memoization for the associated function
        """
        state['enabled'] = True

    def disable_memoization():
        """Disable memoization for the associated function
        """
        state['enabled'] = False

    wrapper.clear = clear
    wrapper.enable_memoization = enable_memoization
    wrapper.disable_memoization = disable_memoization

    return wrapper

Final notes:

  • If until=None, results are memoized without the time-based component.
  • Keyword arguments aren’t sorted, so if kwargs.items() returns a different order the function will be run anew.
  • While memoization is enabled, expired results are removed after every lookup to avoid excessive memory usage. Because expirations are sorted, we only do the minimum amount of tidying necessary.
  • I really like using this with an optional argument decorator, where the decorator can be used without being called if expiration isn’t needed, e.g. @memoize instead of @memoize().

Lastly, you can visit the gist to see the most up to date implementation and leave feedback!