Completed
Push — master ( 15749c...80666e )
by Ionel Cristian
58s
created

src.redis_lock.Lock   B

Complexity

Total Complexity 46

Size/Duplication

Total Lines 196
Duplicated Lines 0 %
Metric Value
dl 0
loc 196
rs 8.4
wmc 46

12 Methods

Rating   Name   Duplication   Size   Complexity  
A id() 0 3 1
A __enter__() 0 4 2
B release() 0 22 5
A get_owner_id() 0 2 1
A reset() 0 5 1
A _stop_lock_renewer() 0 13 3
A __exit__() 0 2 1
F acquire() 0 41 15
A _start_lock_renewer() 0 18 2
C __init__() 0 37 7
A _lock_renewer() 0 10 2
B extend() 0 23 6

How to fix   Complexity   

Complex Class

Complex classes like src.redis_lock.Lock often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
6
from redis import StrictRedis
7
from redis.exceptions import NoScriptError
8
9
__version__ = "2.3.0"
10
11
logger = getLogger(__name__)
12
13
# Check if the id match. If not, return an error code.
14
UNLOCK_SCRIPT = b"""
15
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
16
        return 1
17
    else
18
        redis.call("del", KEYS[2])
19
        redis.call("lpush", KEYS[2], 1)
20
        redis.call("expire", KEYS[2], 1)
21
        redis.call("del", KEYS[1])
22
        return 0
23
    end
24
"""
25
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
26
27
# Covers both cases when key doesn't exist and doesn't equal to lock's id
28
EXTEND_SCRIPT = b"""
29
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
30
        return 1
31
    elseif redis.call("ttl", KEYS[1]) < 0 then
32
        return 2
33
    else
34
        redis.call("expire", KEYS[1], ARGV[1])
35
        return 0
36
    end
37
"""
38
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
39
40
RESET_SCRIPT = b"""
41
    redis.call('del', KEYS[2])
42
    redis.call('lpush', KEYS[2], 1)
43
    redis.call('expire', KEYS[2], 1)
44
    return redis.call('del', KEYS[1])
45
"""
46
47
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
48
49
RESET_ALL_SCRIPT = b"""
50
    local locks = redis.call('keys', 'lock:*')
51
    local signal
52
    for _, lock in pairs(locks) do
53
        signal = 'lock-signal:' .. string.sub(lock, 6)
54
        redis.call('del', signal)
55
        redis.call('lpush', signal, 1)
56
        redis.call('expire', signal, 1)
57
        redis.call('del', lock)
58
    end
59
    return #locks
60
"""
61
62
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
63
64
65
class AlreadyAcquired(RuntimeError):
66
    pass
67
68
69
class NotAcquired(RuntimeError):
70
    pass
71
72
73
class AlreadyStarted(RuntimeError):
74
    pass
75
76
77
class TimeoutNotUsable(RuntimeError):
78
    pass
79
80
81
class InvalidTimeout(RuntimeError):
82
    pass
83
84
85
class TimeoutTooLarge(RuntimeError):
86
    pass
87
88
89
class NotExpirable(RuntimeError):
90
    pass
91
92
93
((UNLOCK, _, _,   # noqa
94
  EXTEND, _, _,
95
  RESET, _, _,
96
  RESET_ALL, _, _),
97
 SCRIPTS) = zip(*enumerate([
98
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
99
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
100
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
101
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT'
102
]))
103
104
105
def _eval_script(redis, script_id, *keys, **kwargs):
106
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
107
    regular ``EVAL`` with the `script`.
108
    """
109
    args = kwargs.pop('args', ())
110
    if kwargs:
111
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
112
    try:
113
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
114
    except NoScriptError:
115
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
116
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
117
118
119
class Lock(object):
120
    """
121
    A Lock context manager implemented via redis SETNX/BLPOP.
122
    """
123
124
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
125
        """
126
        :param redis_client:
127
            An instance of :class:`~StrictRedis`.
128
        :param name:
129
            The name (redis key) the lock should have.
130
        :param expire:
131
            The lock expiry time in seconds. If left at the default (None)
132
            the lock will not expire.
133
        :param id:
134
            The ID (redis value) the lock should have. A random value is
135
            generated when left at the default.
136
137
            Note that if you specify this then the lock is marked as "held". Acquires
138
            won't be possible.
139
        :param auto_renewal:
140
            If set to True, Lock will automatically renew the lock so that it
141
            doesn't expire for as long as the lock is held (acquire() called
142
            or running in a context manager).
143
144
            Implementation note: Renewal will happen using a daemon thread with
145
            an interval of expire*2/3. If wishing to use a different renewal
146
            time, subclass Lock, call super().__init__() then set
147
            self._lock_renewal_interval to your desired interval.
148
        """
149
        assert isinstance(redis_client, StrictRedis)
150
        if auto_renewal and expire is None:
151
            raise ValueError("Expire may not be None when auto_renewal is set")
152
153
        self._client = redis_client
154
        self._expire = expire if expire is None else int(expire)
155
        self._id = urandom(16) if id is None else id
156
        self._held = id is not None
157
        self._name = 'lock:'+name
158
        self._signal = 'lock-signal:'+name
159
        self._lock_renewal_interval = expire*2/3 if auto_renewal else None
160
        self._lock_renewal_thread = None
161
162
    def reset(self):
163
        """
164
        Forcibly deletes the lock. Use this with care.
165
        """
166
        _eval_script(self._client, RESET, self._name, self._signal)
167
168
    @property
169
    def id(self):
170
        return self._id
171
172
    def get_owner_id(self):
173
        return self._client.get(self._name)
174
175
    def acquire(self, blocking=True, timeout=None):
176
        """
177
        :param blocking:
178
            Boolean value specifying whether lock should be blocking or not.
179
        :param timeout:
180
            An integer value specifying the maximum number of seconds to block.
181
        """
182
        logger.debug("Getting %r ...", self._name)
183
184
        if self._held:
185
            raise AlreadyAcquired("Already acquired from this Lock instance.")
186
187
        if not blocking and timeout is not None:
188
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
189
190
        timeout = timeout if timeout is None else int(timeout)
191
        if timeout is not None and timeout <= 0:
192
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
193
194
        if timeout and self._expire and timeout > self._expire:
195
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
196
197
        busy = True
198
        blpop_timeout = timeout or self._expire or 0
199
        timed_out = False
200
        while busy:
201
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
202
            if busy:
203
                if timed_out:
204
                    return False
205
                elif blocking:
206
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
207
                else:
208
                    logger.debug("Failed to get %r.", self._name)
209
                    return False
210
211
        logger.debug("Got lock for %r.", self._name)
212
        self._held = True
213
        if self._lock_renewal_interval is not None:
214
            self._start_lock_renewer()
215
        return True
216
217
    def extend(self, expire=None):
218
        """Extends expiration time of the lock.
219
220
        :param expire:
221
            New expiration time. If ``None`` - `expire` provided during
222
            lock initialization will be taken.
223
        """
224
        if expire is None:
225
            if self._expire is not None:
226
                expire = self._expire
227
            else:
228
                raise TypeError(
229
                    "To extend a lock 'expire' must be provided as an "
230
                    "argument to extend() method or at initialization time."
231
                )
232
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
233
        if error == 1:
234
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
235
        elif error == 2:
236
            raise NotExpirable("Lock %s has no assigned expiration time" %
237
                               self._name)
238
        elif error:
239
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
240
241
    def _lock_renewer(self, interval):
242
        """
243
        Renew the lock key in redis every `interval` seconds for as long
244
        as `self._lock_renewal_thread.should_exit` is False.
245
        """
246
        log = getLogger("%s.lock_refresher" % __name__)
247
        while not self._lock_renewal_thread.wait_for_exit_request(timeout=interval):
248
            log.debug("Refreshing lock")
249
            self.extend(expire=self._expire)
250
        log.debug("Exit requested, stopping lock refreshing")
251
252
    def _start_lock_renewer(self):
253
        """
254
        Starts the lock refresher thread.
255
        """
256
        if self._lock_renewal_thread is not None:
257
            raise AlreadyStarted("Lock refresh thread already started")
258
259
        logger.debug(
260
            "Starting thread to refresh lock every %s seconds",
261
            self._lock_renewal_interval
262
        )
263
        self._lock_renewal_thread = InterruptableThread(
264
            group=None,
265
            target=self._lock_renewer,
266
            kwargs={'interval': self._lock_renewal_interval}
267
        )
268
        self._lock_renewal_thread.setDaemon(True)
269
        self._lock_renewal_thread.start()
270
271
    def _stop_lock_renewer(self):
272
        """
273
        Stop the lock renewer.
274
275
        This signals the renewal thread and waits for its exit.
276
        """
277
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
278
            return
279
        logger.debug("Signalling the lock refresher to stop")
280
        self._lock_renewal_thread.request_exit()
281
        self._lock_renewal_thread.join()
282
        self._lock_renewal_thread = None
283
        logger.debug("Lock refresher has stopped")
284
285
    def __enter__(self):
286
        acquired = self.acquire(blocking=True)
287
        assert acquired, "Lock wasn't acquired, but blocking=True"
288
        return self
289
290
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
291
        self.release()
292
293
    def release(self):
294
        """Releases the lock, that was acquired with the same object.
295
296
        .. note::
297
298
            If you want to release a lock that you acquired in a different place you have two choices:
299
300
            * Use ``Lock("name", id=id_from_other_place).release()``
301
            * Use ``Lock("name").reset()``
302
        """
303
        if not self._held:
304
            raise NotAcquired("This Lock instance didn't acquire the lock.")
305
        if self._lock_renewal_thread is not None:
306
            self._stop_lock_renewer()
307
        logger.debug("Releasing %r.", self._name)
308
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
309
        if error == 1:
310
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
311
        elif error:
312
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
313
        else:
314
            self._held = False
315
316
317
class InterruptableThread(threading.Thread):
318
    """
319
    A Python thread that can be requested to stop by calling request_exit()
320
    on it.
321
322
    Code running inside this thread should periodically check the
323
    `should_exit` property (or use wait_for_exit_request) on the thread
324
    object and stop further processing once it returns True.
325
    """
326
    def __init__(self, *args, **kwargs):
327
        self._should_exit = threading.Event()
328
        super(InterruptableThread, self).__init__(*args, **kwargs)
329
330
    def request_exit(self):
331
        """
332
        Signal the thread that it should stop performing more work and exit.
333
        """
334
        self._should_exit.set()
335
336
    @property
337
    def should_exit(self):
338
        return self._should_exit.isSet()
339
340
    def wait_for_exit_request(self, timeout=None):
341
        """
342
        Wait until the thread has been signalled to exit.
343
344
        If timeout is specified (as a float of seconds to wait) then wait
345
        up to this many seconds before returning the value of `should_exit`.
346
        """
347
        should_exit = self._should_exit.wait(timeout)
348
        if should_exit is None:
349
            # Python 2.6 compatibility which doesn't return self.__flag when
350
            # calling Event.wait()
351
            should_exit = self.should_exit
352
        return should_exit
353
354
355
def reset_all(redis_client):
356
    """
357
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
358
    """
359
    _eval_script(redis_client, RESET_ALL)
360