Completed
Push — master ( c5f31b...4822be )
by Ionel Cristian
03:52
created

src.redis_lock.Lock   B

Complexity

Total Complexity 42

Size/Duplication

Total Lines 178
Duplicated Lines 0 %
Metric Value
dl 0
loc 178
rs 8.2951
wmc 42

12 Methods

Rating   Name   Duplication   Size   Complexity  
A id() 0 3 1
A __enter__() 0 4 2
A release() 0 9 1
A get_owner_id() 0 2 1
A reset() 0 5 1
A _stop_lock_renewer() 0 13 3
A __exit__() 0 10 4
F acquire() 0 41 15
A _start_lock_renewer() 0 18 2
C __init__() 0 34 7
A _lock_renewer() 0 10 2
A extend() 0 13 3

How to fix   Complexity   

Complex Class

Complex classes like src.redis_lock.Lock often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
6
from redis import StrictRedis
7
from redis.exceptions import NoScriptError
8
9
__version__ = "2.3.0"
10
11
logger = getLogger(__name__)
12
13
UNLOCK_SCRIPT = b"""
14
    if redis.call("get", KEYS[1]) == ARGV[1] then
15
        redis.call("del", KEYS[2])
16
        redis.call("lpush", KEYS[2], 1)
17
        redis.call("expire", KEYS[2], 1)
18
        return redis.call("del", KEYS[1])
19
    else
20
        return 0
21
    end
22
"""
23
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
24
25
RESET_SCRIPT = b"""
26
    redis.call('del', KEYS[2])
27
    redis.call('lpush', KEYS[2], 1)
28
    redis.call('expire', KEYS[2], 1)
29
    return redis.call('del', KEYS[1])
30
"""
31
32
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
33
34
RESET_ALL_SCRIPT = b"""
35
    local locks = redis.call('keys', 'lock:*')
36
    local signal
37
    for _, lock in pairs(locks) do
38
        signal = 'lock-signal:' .. string.sub(lock, 6)
39
        redis.call('del', signal)
40
        redis.call('lpush', signal, 1)
41
        redis.call('expire', signal, 1)
42
        redis.call('del', lock)
43
    end
44
    return #locks
45
"""
46
47
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
48
49
50
class AlreadyAcquired(RuntimeError):
51
    pass
52
53
54
class NotAcquired(RuntimeError):
55
    pass
56
57
58
class AlreadyStarted(RuntimeError):
59
    pass
60
61
62
class TimeoutNotUsable(RuntimeError):
63
    pass
64
65
66
class InvalidTimeout(RuntimeError):
67
    pass
68
69
70
class TimeoutTooLarge(RuntimeError):
71
    pass
72
73
74
class NotExpirable(RuntimeError):
75
    pass
76
77
78
(UNLOCK, _, _, RESET, _, _, RESET_ALL, _, _), SCRIPTS = zip(*enumerate([
79
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
80
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
81
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT'
82
]))
83
84
85
def _eval_script(redis, script_id, *args, **kwargs):
86
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
87
    regular ``EVAL`` with the `script`.
88
    """
89
    try:
90
        return redis.evalsha(SCRIPTS[script_id], *args, **kwargs)
91
    except NoScriptError:
92
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
93
        return redis.eval(SCRIPTS[script_id + 1], *args, **kwargs)
94
95
96
class Lock(object):
97
    """
98
    A Lock context manager implemented via redis SETNX/BLPOP.
99
    """
100
101
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
102
        """
103
        :param redis_client:
104
            An instance of :class:`~StrictRedis`.
105
        :param name:
106
            The name (redis key) the lock should have.
107
        :param expire:
108
            The lock expiry time in seconds. If left at the default (None)
109
            the lock will not expire.
110
        :param id:
111
            The ID (redis value) the lock should have. A random value is
112
            generated when left at the default.
113
        :param auto_renewal:
114
            If set to True, Lock will automatically renew the lock so that it
115
            doesn't expire for as long as the lock is held (acquire() called
116
            or running in a context manager).
117
118
            Implementation note: Renewal will happen using a daemon thread with
119
            an interval of expire*2/3. If wishing to use a different renewal
120
            time, subclass Lock, call super().__init__() then set
121
            self._lock_renewal_interval to your desired interval.
122
        """
123
        assert isinstance(redis_client, StrictRedis)
124
        if auto_renewal and expire is None:
125
            raise ValueError("Expire may not be None when auto_renewal is set")
126
127
        self._client = redis_client
128
        self._expire = expire if expire is None else int(expire)
129
        self._id = urandom(16) if id is None else id
130
        self._held = False
131
        self._name = 'lock:'+name
132
        self._signal = 'lock-signal:'+name
133
        self._lock_renewal_interval = expire*2/3 if auto_renewal else None
134
        self._lock_renewal_thread = None
135
136
    def reset(self):
137
        """
138
        Forcibly deletes the lock. Use this with care.
139
        """
140
        _eval_script(self._client, RESET, 2, self._name, self._signal)
141
142
    @property
143
    def id(self):
144
        return self._id
145
146
    def get_owner_id(self):
147
        return self._client.get(self._name)
148
149
    def acquire(self, blocking=True, timeout=None):
150
        """
151
        :param blocking:
152
            Boolean value specifying whether lock should be blocking or not.
153
        :param timeout:
154
            An integer value specifying the maximum number of seconds to block.
155
        """
156
        logger.debug("Getting %r ...", self._name)
157
158
        if self._held:
159
            raise AlreadyAcquired("Already acquired from this Lock instance.")
160
161
        if not blocking and timeout is not None:
162
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
163
164
        timeout = timeout if timeout is None else int(timeout)
165
        if timeout is not None and timeout <= 0:
166
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
167
168
        if timeout and self._expire and timeout > self._expire:
169
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
170
171
        busy = True
172
        blpop_timeout = timeout or self._expire or 0
173
        timed_out = False
174
        while busy:
175
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
176
            if busy:
177
                if timed_out:
178
                    return False
179
                elif blocking:
180
                    timed_out = not self._client.blpop(self._signal, blpop_timeout)
181
                else:
182
                    logger.debug("Failed to get %r.", self._name)
183
                    return False
184
185
        logger.debug("Got lock for %r.", self._name)
186
        self._held = True
187
        if self._lock_renewal_interval is not None:
188
            self._start_lock_renewer()
189
        return True
190
191
    def extend(self, expire=None):
192
        """Extends expiration time of the lock.
193
194
        :param expire:
195
            New expiration time. If ``None`` - `expire` provided during
196
            lock initialization will be taken.
197
        """
198
        if self._expire is None:
199
            raise NotExpirable('The lock has no expiry time, so extending it '
200
                               'makes no sense.')
201
        if expire is None:
202
            expire = self._expire
203
        self._client.set(self._name, self._id, xx=True, ex=expire)
204
205
    def _lock_renewer(self, interval):
206
        """
207
        Renew the lock key in redis every `interval` seconds for as long
208
        as `self._lock_renewal_thread.should_exit` is False.
209
        """
210
        log = getLogger("%s.lock_refresher" % __name__)
211
        while not self._lock_renewal_thread.wait_for_exit_request(timeout=interval):
212
            log.debug("Refreshing lock")
213
            self.extend(expire=self._expire)
214
        log.debug("Exit requested, stopping lock refreshing")
215
216
    def _start_lock_renewer(self):
217
        """
218
        Starts the lock refresher thread.
219
        """
220
        if self._lock_renewal_thread is not None:
221
            raise AlreadyStarted("Lock refresh thread already started")
222
223
        logger.debug(
224
            "Starting thread to refresh lock every %s seconds",
225
            self._lock_renewal_interval
226
        )
227
        self._lock_renewal_thread = InterruptableThread(
228
            group=None,
229
            target=self._lock_renewer,
230
            kwargs={'interval': self._lock_renewal_interval}
231
        )
232
        self._lock_renewal_thread.setDaemon(True)
233
        self._lock_renewal_thread.start()
234
235
    def _stop_lock_renewer(self):
236
        """
237
        Stop the lock renewer.
238
239
        This signals the renewal thread and waits for its exit.
240
        """
241
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
242
            return
243
        logger.debug("Signalling the lock refresher to stop")
244
        self._lock_renewal_thread.request_exit()
245
        self._lock_renewal_thread.join()
246
        self._lock_renewal_thread = None
247
        logger.debug("Lock refresher has stopped")
248
249
    def __enter__(self):
250
        acquired = self.acquire(blocking=True)
251
        assert acquired, "Lock wasn't acquired, but blocking=True"
252
        return self
253
254
    def __exit__(self, exc_type=None, exc_value=None, traceback=None, force=False):
255
        if not (self._held or force):
256
            raise NotAcquired("This Lock instance didn't acquire the lock.")
257
        if self._lock_renewal_thread is not None:
258
            self._stop_lock_renewer()
259
        logger.debug("Releasing %r.", self._name)
260
        _eval_script(self._client, UNLOCK,
261
                     2, self._name, self._signal, self._id)
262
263
        self._held = False
264
265
    def release(self, force=False):
266
        """Releases the lock, that was acquired in the same Python context.
267
268
        :param force:
269
            If ``False`` - fail with exception if this instance was not in
270
            acquired state in the same Python context.
271
            If ``True`` - fail silently.
272
        """
273
        return self.__exit__(force=force)
274
275
276
class InterruptableThread(threading.Thread):
277
    """
278
    A Python thread that can be requested to stop by calling request_exit()
279
    on it.
280
281
    Code running inside this thread should periodically check the
282
    `should_exit` property (or use wait_for_exit_request) on the thread
283
    object and stop further processing once it returns True.
284
    """
285
    def __init__(self, *args, **kwargs):
286
        self._should_exit = threading.Event()
287
        super(InterruptableThread, self).__init__(*args, **kwargs)
288
289
    def request_exit(self):
290
        """
291
        Signal the thread that it should stop performing more work and exit.
292
        """
293
        self._should_exit.set()
294
295
    @property
296
    def should_exit(self):
297
        return self._should_exit.isSet()
298
299
    def wait_for_exit_request(self, timeout=None):
300
        """
301
        Wait until the thread has been signalled to exit.
302
303
        If timeout is specified (as a float of seconds to wait) then wait
304
        up to this many seconds before returning the value of `should_exit`.
305
        """
306
        should_exit = self._should_exit.wait(timeout)
307
        if should_exit is None:
308
            # Python 2.6 compatibility which doesn't return self.__flag when
309
            # calling Event.wait()
310
            should_exit = self.should_exit
311
        return should_exit
312
313
314
def reset_all(redis_client):
315
    """
316
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
317
    """
318
    _eval_script(redis_client, RESET_ALL, 0)
319