When building a service, I want to acquire a lock by mutual exclusion, but I want to use Redis because it is not enough to use RDS. I searched for a library that implements "Design pattern: Locking with SETNX" in SETNX – Redis in Python, but I couldn't find what I was looking for, so I implemented it. I tried to. The Japanese translation of SETNX – Redis is string type --redis 2.0.3 documentation. See .html # command-SETNX).
Source Code
mutex.py
from datetime import datetime
import time
from functools import wraps
from .exception import (DuplicateLockError,
HasNotLockError,
ExpiredLockError,
SetnxError,
LockError)
class Mutex(object):
def __init__(self, client, key,
expire=10,
retry_count=6, # retry_count * retry_sleep_sec =Maximum waiting time
retry_setnx_count=100,
retry_sleep_sec=0.25):
self._lock = None
self._r = client
self._key = key
self._expire = expire
self._retry_count = retry_count
self._retry_setnx_count = retry_setnx_count
self._retry_sleep_sec = retry_sleep_sec
def _get_now(self):
return float(datetime.now().strftime('%s.%f'))
def lock(self):
if self._lock:
raise DuplicateLockError(self._key)
self._do_lock()
def _do_lock(self):
for n in xrange(0, self._retry_count):
is_set, old_expire = self._setnx()
if is_set:
self._lock = self._get_now()
return
if self._need_retry(old_expire):
continue
if not self._need_retry(self._getset()):
self._lock = self._get_now()
return
raise LockError(self._key)
def _setnx(self):
for n in xrange(0, self._retry_setnx_count):
is_set = self._r.setnx(self._key, self._get_now() + self._expire)
if is_set:
return True, 0
old_expire = self._r.get(self._key)
if old_expire is not None:
return False, float(old_expire)
raise SetnxError(self._key)
def _need_retry(self, expire):
if expire < self._get_now():
return False
time.sleep(self._retry_sleep_sec)
return True
def _getset(self):
old_expire = self._r.getset(self._key, self._get_now() + self._expire)
if old_expire is None:
return 0
return float(old_expire)
def unlock(self):
if not self._lock:
raise HasNotLockError(self._key)
elapsed_time = self._get_now() - self._lock
if self._expire <= elapsed_time:
raise ExpiredLockError(self._key, elapsed_time)
self._r.delete(self._key)
self._lock = None
def __enter__(self):
self.lock()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self._lock:
self.unlock()
return True if exc_type is None else False
def __call__(self, func):
@wraps(func)
def inner(*args, **kwargs):
with self:
return func(*args, **kwargs)
return inner
exception.py
class MutexError(Exception):
pass
class DuplicateLockError(MutexError):
"""
Already locked()Lock on a run Mutex object()Occurs when re-running.
one time, unlock()Do you run,Need to create another Mutex object.
"""
pass
class HasNotLockError(MutexError):
"""
yet, lock()Unlock on Mutex objects that are not running()Occurs when you run.
lock()Need to be done later.
"""
pass
class ExpiredLockError(MutexError):
"""
lock()After execution,unlock with the lock released by expire()Occurs when you run.
"""
pass
class SetnxError(MutexError):
pass
class LockError(MutexError):
pass
The rough flow of rock is as follows.
The usage is as follows.
usage.py
>>> from mutex import Mutex
>>> with Mutex(':'.join(['EmitAccessToken', user_id]):
>>> # do something ...
>>> pass
>>> @Mutex(':'.join(['EmitAccessToken', user_id]):
>>> def emit_access_token():
>>> # do something ...
>>> pass
>>> mutex = Mutex(':'.join(['EmitAccessToken', user_id])
>>> mutex.lock()
>>> # do something ...
>>> mutex.unlock()
test.py
import unittest
import redis
import time
from multiprocessing import Process
from .mutex import Mutex
from .exception import (DuplicateLockError,
HasNotLockError,
ExpiredLockError,
LockError)
class TestMutex(unittest.TestCase):
def setUp(self):
self.key = 'spam'
self.r = redis.StrictRedis()
self.mutex = Mutex(self.r, self.key)
def tearDown(self):
mutex = self.mutex
if mutex._lock:
mutex.unlock()
mutex._r.delete('ham')
def test_lock(self):
mutex = self.mutex
mutex.lock()
self.assertIsNotNone(mutex._r.get(mutex._key))
with self.assertRaises(DuplicateLockError):
mutex.lock()
def test_unlock(self):
self.test_lock()
mutex = self.mutex
self.mutex.unlock()
self.assertIsNone(mutex._r.get(mutex._key))
with self.assertRaises(HasNotLockError):
mutex.unlock()
self.test_lock()
time.sleep(10.5)
with self.assertRaises(ExpiredLockError):
mutex.unlock()
mutex._lock = None #Forced initialization
def test_expire(self):
mutex1 = self.mutex
mutex2 = Mutex(self.r, self.key, expire=2)
mutex2.lock() #Keep Locking for 2 seconds
with self.assertRaises(LockError):
mutex1.lock() #retry 6 times* sleep 0.25 seconds= 1.5 seconds
time.sleep(0.6) #bonus
mutex1.lock()
self.assertIsNotNone(mutex1._r.get(mutex1._key))
def test_with(self):
mutex1 = self.mutex
with mutex1:
self.assertIsNotNone(mutex1._r.get(mutex1._key))
self.assertIsNone(mutex1._r.get(mutex1._key))
mutex2 = Mutex(self.r, self.key, expire=2)
mutex2.lock() #Keep Locking for 2 seconds
with self.assertRaises(LockError):
with mutex1: #retry 6 times* sleep 0.25 seconds= 1.5 seconds
pass
mutex2.unlock()
with mutex1:
with self.assertRaises(DuplicateLockError):
with mutex1:
pass
def test_decorator(self):
mutex = self.mutex
@mutex
def egg():
self.assertIsNotNone(mutex._r.get(mutex._key))
egg()
self.assertIsNone(mutex._r.get(mutex._key))
def test_multi_process(self):
procs = 20
counter = 100
def incr():
mutex = Mutex(redis.StrictRedis(), self.key, retry_count=100)
for n in xrange(0, counter):
mutex.lock()
ham = mutex._r.get('ham') or 0
mutex._r.set('ham', int(ham) + 1)
mutex.unlock()
ps = [Process(target=incr) for n in xrange(0, procs)]
for p in ps:
p.start()
for p in ps:
p.join()
self.assertEqual(int(self.mutex._r.get('ham')), counter * procs)
Recommended Posts