Lock.release()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 4
c 5
b 0
f 0
dl 0
loc 20
rs 9.2
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
import weakref
6
7
from redis import StrictRedis
8
from redis.exceptions import NoScriptError
9
10
__version__ = "3.2.0"
11
12
logger = getLogger(__name__)
13
14
# Check if the id match. If not, return an error code.
15
UNLOCK_SCRIPT = b"""
16
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
17
        return 1
18
    else
19
        redis.call("del", KEYS[2])
20
        redis.call("lpush", KEYS[2], 1)
21
        redis.call("del", KEYS[1])
22
        return 0
23
    end
24
"""
25
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
26
27
# Covers both cases when key doesn't exist and doesn't equal to lock's id
28
EXTEND_SCRIPT = b"""
29
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
30
        return 1
31
    elseif redis.call("ttl", KEYS[1]) < 0 then
32
        return 2
33
    else
34
        redis.call("expire", KEYS[1], ARGV[1])
35
        return 0
36
    end
37
"""
38
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
39
40
RESET_SCRIPT = b"""
41
    redis.call('del', KEYS[2])
42
    redis.call('lpush', KEYS[2], 1)
43
    return redis.call('del', KEYS[1])
44
"""
45
46
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
47
48
RESET_ALL_SCRIPT = b"""
49
    local locks = redis.call('keys', 'lock:*')
50
    local signal
51
    for _, lock in pairs(locks) do
52
        signal = 'lock-signal:' .. string.sub(lock, 6)
53
        redis.call('del', signal)
54
        redis.call('lpush', signal, 1)
55
        redis.call('del', lock)
56
    end
57
    return #locks
58
"""
59
60
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
61
62
DELETE_ALL_SIGNAL_KEYS_SCRIPT = b"""
63
    local signals = redis.call('keys', 'lock-signal:*')
64
    for _, signal in pairs(signals) do
65
        redis.call('del', signal)
66
    end
67
    return #signals
68
"""
69
70
DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH = (sha1(DELETE_ALL_SIGNAL_KEYS_SCRIPT)
71
                                      .hexdigest())
72
73
74
class AlreadyAcquired(RuntimeError):
75
    pass
76
77
78
class NotAcquired(RuntimeError):
79
    pass
80
81
82
class AlreadyStarted(RuntimeError):
83
    pass
84
85
86
class TimeoutNotUsable(RuntimeError):
87
    pass
88
89
90
class InvalidTimeout(RuntimeError):
91
    pass
92
93
94
class TimeoutTooLarge(RuntimeError):
95
    pass
96
97
98
class NotExpirable(RuntimeError):
99
    pass
100
101
102
((UNLOCK, _, _,   # noqa
103
  EXTEND, _, _,
104
  RESET, _, _,
105
  RESET_ALL, _, _,
106
  DELETE_ALL_SIGNAL_KEYS, _, _),
107
 SCRIPTS) = zip(*enumerate([
108
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
109
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
110
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
111
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT',
112
    DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH, DELETE_ALL_SIGNAL_KEYS_SCRIPT,
113
    'DELETE_ALL_SIGNAL_KEYS_SCRIPT'
114
]))
115
116
117
def _eval_script(redis, script_id, *keys, **kwargs):
118
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
119
    regular ``EVAL`` with the `script`.
120
    """
121
    args = kwargs.pop('args', ())
122
    if kwargs:
123
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
124
    try:
125
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
126
    except NoScriptError:
127
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
128
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
129
130
131
class Lock(object):
132
    """
133
    A Lock context manager implemented via redis SETNX/BLPOP.
134
    """
135
136
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True):
137
        """
138
        :param redis_client:
139
            An instance of :class:`~StrictRedis`.
140
        :param name:
141
            The name (redis key) the lock should have.
142
        :param expire:
143
            The lock expiry time in seconds. If left at the default (None)
144
            the lock will not expire.
145
        :param id:
146
            The ID (redis value) the lock should have. A random value is
147
            generated when left at the default.
148
149
            Note that if you specify this then the lock is marked as "held". Acquires
150
            won't be possible.
151
        :param auto_renewal:
152
            If set to ``True``, Lock will automatically renew the lock so that it
153
            doesn't expire for as long as the lock is held (acquire() called
154
            or running in a context manager).
155
156
            Implementation note: Renewal will happen using a daemon thread with
157
            an interval of ``expire*2/3``. If wishing to use a different renewal
158
            time, subclass Lock, call ``super().__init__()`` then set
159
            ``self._lock_renewal_interval`` to your desired interval.
160
        :param strict:
161
            If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
162
        """
163
        if strict and not isinstance(redis_client, StrictRedis):
164
            raise ValueError("redis_client must be instance of StrictRedis. "
165
                             "Use strict=False if you know what you're doing.")
166
        if auto_renewal and expire is None:
167
            raise ValueError("Expire may not be None when auto_renewal is set")
168
169
        self._client = redis_client
170
        self._expire = expire if expire is None else int(expire)
171
        if id is None:
172
            self._id = urandom(16)
173
        elif isinstance(id, bytes):
174
            self._id = id
175
        else:
176
            raise TypeError("Incorrect type for `id`. Must be bytes not %s." % type(id))
177
        self._name = 'lock:'+name
178
        self._signal = 'lock-signal:'+name
179
        self._lock_renewal_interval = (float(expire)*2/3
180
                                       if auto_renewal
181
                                       else None)
182
        self._lock_renewal_thread = None
183
184
    @property
185
    def _held(self):
186
        return self.id == self.get_owner_id()
187
188
    def reset(self):
189
        """
190
        Forcibly deletes the lock. Use this with care.
191
        """
192
        _eval_script(self._client, RESET, self._name, self._signal)
193
        self._delete_signal()
194
195
    @property
196
    def id(self):
197
        return self._id
198
199
    def get_owner_id(self):
200
        return self._client.get(self._name)
201
202
    def acquire(self, blocking=True, timeout=None):
203
        """
204
        :param blocking:
205
            Boolean value specifying whether lock should be blocking or not.
206
        :param timeout:
207
            An integer value specifying the maximum number of seconds to block.
208
        """
209
        logger.debug("Getting %r ...", self._name)
210
211
        if self._held:
212
            raise AlreadyAcquired("Already acquired from this Lock instance.")
213
214
        if not blocking and timeout is not None:
215
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
216
217
        timeout = timeout if timeout is None else int(timeout)
218
        if timeout is not None and timeout <= 0:
219
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
220
221
        if timeout and self._expire and timeout > self._expire:
222
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
223
224
        busy = True
225
        blpop_timeout = timeout or self._expire or 0
226
        timed_out = False
227
        while busy:
228
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
229
            if busy:
230
                if timed_out:
231
                    return False
232
                elif blocking:
233
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
234
                else:
235
                    logger.debug("Failed to get %r.", self._name)
236
                    return False
237
238
        logger.debug("Got lock for %r.", self._name)
239
        if self._lock_renewal_interval is not None:
240
            self._start_lock_renewer()
241
        return True
242
243
    def extend(self, expire=None):
244
        """Extends expiration time of the lock.
245
246
        :param expire:
247
            New expiration time. If ``None`` - `expire` provided during
248
            lock initialization will be taken.
249
        """
250
        if expire is None:
251
            if self._expire is not None:
252
                expire = self._expire
253
            else:
254
                raise TypeError(
255
                    "To extend a lock 'expire' must be provided as an "
256
                    "argument to extend() method or at initialization time."
257
                )
258
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
259
        if error == 1:
260
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
261
        elif error == 2:
262
            raise NotExpirable("Lock %s has no assigned expiration time" %
263
                               self._name)
264
        elif error:
265
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
266
267
    @staticmethod
268
    def _lock_renewer(lockref, interval, stop):
269
        """
270
        Renew the lock key in redis every `interval` seconds for as long
271
        as `self._lock_renewal_thread.should_exit` is False.
272
        """
273
        log = getLogger("%s.lock_refresher" % __name__)
274
        while not stop.wait(timeout=interval):
275
            log.debug("Refreshing lock")
276
            lock = lockref()
277
            if lock is None:
278
                log.debug("The lock no longer exists, "
279
                          "stopping lock refreshing")
280
                break
281
            lock.extend(expire=lock._expire)
282
            del lock
283
        log.debug("Exit requested, stopping lock refreshing")
284
285
    def _start_lock_renewer(self):
286
        """
287
        Starts the lock refresher thread.
288
        """
289
        if self._lock_renewal_thread is not None:
290
            raise AlreadyStarted("Lock refresh thread already started")
291
292
        logger.debug(
293
            "Starting thread to refresh lock every %s seconds",
294
            self._lock_renewal_interval
295
        )
296
        self._lock_renewal_stop = threading.Event()
297
        self._lock_renewal_thread = threading.Thread(
298
            group=None,
299
            target=self._lock_renewer,
300
            kwargs={'lockref': weakref.ref(self),
301
                    'interval': self._lock_renewal_interval,
302
                    'stop': self._lock_renewal_stop}
303
        )
304
        self._lock_renewal_thread.setDaemon(True)
305
        self._lock_renewal_thread.start()
306
307
    def _stop_lock_renewer(self):
308
        """
309
        Stop the lock renewer.
310
311
        This signals the renewal thread and waits for its exit.
312
        """
313
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
314
            return
315
        logger.debug("Signalling the lock refresher to stop")
316
        self._lock_renewal_stop.set()
317
        self._lock_renewal_thread.join()
318
        self._lock_renewal_thread = None
319
        logger.debug("Lock refresher has stopped")
320
321
    def __enter__(self):
322
        acquired = self.acquire(blocking=True)
323
        assert acquired, "Lock wasn't acquired, but blocking=True"
324
        return self
325
326
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
327
        self.release()
328
329
    def release(self):
330
        """Releases the lock, that was acquired with the same object.
331
332
        .. note::
333
334
            If you want to release a lock that you acquired in a different place you have two choices:
335
336
            * Use ``Lock("name", id=id_from_other_place).release()``
337
            * Use ``Lock("name").reset()``
338
        """
339
        if self._lock_renewal_thread is not None:
340
            self._stop_lock_renewer()
341
        logger.debug("Releasing %r.", self._name)
342
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
343
        if error == 1:
344
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
345
        elif error:
346
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
347
        else:
348
            self._delete_signal()
349
350
    def _delete_signal(self):
351
        self._client.delete(self._signal)
352
353
354
def reset_all(redis_client):
355
    """
356
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
357
    """
358
    _eval_script(redis_client, RESET_ALL)
359
    _eval_script(redis_client, DELETE_ALL_SIGNAL_KEYS)
360