Completed
Push — master ( dc0da3...0600e0 )
by Ionel Cristian
8s
created

Lock._held()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
import weakref
6
7
from redis import StrictRedis
8
from redis.exceptions import NoScriptError
9
10
__version__ = "3.1.0"
11
12
logger = getLogger(__name__)
13
14
# Check if the id match. If not, return an error code.
15
UNLOCK_SCRIPT = b"""
16
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
17
        return 1
18
    else
19
        redis.call("del", KEYS[2])
20
        redis.call("lpush", KEYS[2], 1)
21
        redis.call("del", KEYS[1])
22
        return 0
23
    end
24
"""
25
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
26
27
# Covers both cases when key doesn't exist and doesn't equal to lock's id
28
EXTEND_SCRIPT = b"""
29
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
30
        return 1
31
    elseif redis.call("ttl", KEYS[1]) < 0 then
32
        return 2
33
    else
34
        redis.call("expire", KEYS[1], ARGV[1])
35
        return 0
36
    end
37
"""
38
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
39
40
RESET_SCRIPT = b"""
41
    redis.call('del', KEYS[2])
42
    redis.call('lpush', KEYS[2], 1)
43
    return redis.call('del', KEYS[1])
44
"""
45
46
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
47
48
RESET_ALL_SCRIPT = b"""
49
    local locks = redis.call('keys', 'lock:*')
50
    local signal
51
    for _, lock in pairs(locks) do
52
        signal = 'lock-signal:' .. string.sub(lock, 6)
53
        redis.call('del', signal)
54
        redis.call('lpush', signal, 1)
55
        redis.call('del', lock)
56
    end
57
    return #locks
58
"""
59
60
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
61
62
DELETE_ALL_SIGNAL_KEYS_SCRIPT = b"""
63
    local signals = redis.call('keys', 'lock-signal:*')
64
    for _, signal in pairs(signals) do
65
        redis.call('del', signal)
66
    end
67
    return #signals
68
"""
69
70
DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH = (sha1(DELETE_ALL_SIGNAL_KEYS_SCRIPT)
71
                                      .hexdigest())
72
73
74
class AlreadyAcquired(RuntimeError):
75
    pass
76
77
78
class NotAcquired(RuntimeError):
79
    pass
80
81
82
class AlreadyStarted(RuntimeError):
83
    pass
84
85
86
class TimeoutNotUsable(RuntimeError):
87
    pass
88
89
90
class InvalidTimeout(RuntimeError):
91
    pass
92
93
94
class TimeoutTooLarge(RuntimeError):
95
    pass
96
97
98
class NotExpirable(RuntimeError):
99
    pass
100
101
102
((UNLOCK, _, _,   # noqa
103
  EXTEND, _, _,
104
  RESET, _, _,
105
  RESET_ALL, _, _,
106
  DELETE_ALL_SIGNAL_KEYS, _, _),
107
 SCRIPTS) = zip(*enumerate([
108
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
109
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
110
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
111
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT',
112
    DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH, DELETE_ALL_SIGNAL_KEYS_SCRIPT,
113
    'DELETE_ALL_SIGNAL_KEYS_SCRIPT'
114
]))
115
116
117
def _eval_script(redis, script_id, *keys, **kwargs):
118
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
119
    regular ``EVAL`` with the `script`.
120
    """
121
    args = kwargs.pop('args', ())
122
    if kwargs:
123
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
124
    try:
125
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
126
    except NoScriptError:
127
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
128
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
129
130
131
class Lock(object):
132
    """
133
    A Lock context manager implemented via redis SETNX/BLPOP.
134
    """
135
136
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
137
        """
138
        :param redis_client:
139
            An instance of :class:`~StrictRedis`.
140
        :param name:
141
            The name (redis key) the lock should have.
142
        :param expire:
143
            The lock expiry time in seconds. If left at the default (None)
144
            the lock will not expire.
145
        :param id:
146
            The ID (redis value) the lock should have. A random value is
147
            generated when left at the default.
148
149
            Note that if you specify this then the lock is marked as "held". Acquires
150
            won't be possible.
151
        :param auto_renewal:
152
            If set to True, Lock will automatically renew the lock so that it
153
            doesn't expire for as long as the lock is held (acquire() called
154
            or running in a context manager).
155
156
            Implementation note: Renewal will happen using a daemon thread with
157
            an interval of expire*2/3. If wishing to use a different renewal
158
            time, subclass Lock, call super().__init__() then set
159
            self._lock_renewal_interval to your desired interval.
160
        """
161
        assert isinstance(redis_client, StrictRedis)
162
        if auto_renewal and expire is None:
163
            raise ValueError("Expire may not be None when auto_renewal is set")
164
165
        self._client = redis_client
166
        self._expire = expire if expire is None else int(expire)
167
        if id is None:
168
            self._id = urandom(16)
169
        elif isinstance(id, bytes):
170
            self._id = id
171
        else:
172
            raise TypeError("Incorrect type for `id`. Must be bytes not %s." % type(id))
173
        self._name = 'lock:'+name
174
        self._signal = 'lock-signal:'+name
175
        self._lock_renewal_interval = (float(expire)*2/3
176
                                       if auto_renewal
177
                                       else None)
178
        self._lock_renewal_thread = None
179
180
    @property
181
    def _held(self):
182
        return self.id == self.get_owner_id()
183
184
    def reset(self):
185
        """
186
        Forcibly deletes the lock. Use this with care.
187
        """
188
        _eval_script(self._client, RESET, self._name, self._signal)
189
        self._delete_signal()
190
191
    @property
192
    def id(self):
193
        return self._id
194
195
    def get_owner_id(self):
196
        return self._client.get(self._name)
197
198
    def acquire(self, blocking=True, timeout=None):
199
        """
200
        :param blocking:
201
            Boolean value specifying whether lock should be blocking or not.
202
        :param timeout:
203
            An integer value specifying the maximum number of seconds to block.
204
        """
205
        logger.debug("Getting %r ...", self._name)
206
207
        if self._held:
208
            raise AlreadyAcquired("Already acquired from this Lock instance.")
209
210
        if not blocking and timeout is not None:
211
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
212
213
        timeout = timeout if timeout is None else int(timeout)
214
        if timeout is not None and timeout <= 0:
215
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
216
217
        if timeout and self._expire and timeout > self._expire:
218
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
219
220
        busy = True
221
        blpop_timeout = timeout or self._expire or 0
222
        timed_out = False
223
        while busy:
224
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
225
            if busy:
226
                if timed_out:
227
                    return False
228
                elif blocking:
229
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
230
                else:
231
                    logger.debug("Failed to get %r.", self._name)
232
                    return False
233
234
        logger.debug("Got lock for %r.", self._name)
235
        if self._lock_renewal_interval is not None:
236
            self._start_lock_renewer()
237
        return True
238
239
    def extend(self, expire=None):
240
        """Extends expiration time of the lock.
241
242
        :param expire:
243
            New expiration time. If ``None`` - `expire` provided during
244
            lock initialization will be taken.
245
        """
246
        if expire is None:
247
            if self._expire is not None:
248
                expire = self._expire
249
            else:
250
                raise TypeError(
251
                    "To extend a lock 'expire' must be provided as an "
252
                    "argument to extend() method or at initialization time."
253
                )
254
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
255
        if error == 1:
256
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
257
        elif error == 2:
258
            raise NotExpirable("Lock %s has no assigned expiration time" %
259
                               self._name)
260
        elif error:
261
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
262
263
    @staticmethod
264
    def _lock_renewer(lockref, interval, stop):
265
        """
266
        Renew the lock key in redis every `interval` seconds for as long
267
        as `self._lock_renewal_thread.should_exit` is False.
268
        """
269
        log = getLogger("%s.lock_refresher" % __name__)
270
        while not stop.wait(timeout=interval):
271
            log.debug("Refreshing lock")
272
            lock = lockref()
273
            if lock is None:
274
                log.debug("The lock no longer exists, "
275
                          "stopping lock refreshing")
276
                break
277
            lock.extend(expire=lock._expire)
278
            del lock
279
        log.debug("Exit requested, stopping lock refreshing")
280
281
    def _start_lock_renewer(self):
282
        """
283
        Starts the lock refresher thread.
284
        """
285
        if self._lock_renewal_thread is not None:
286
            raise AlreadyStarted("Lock refresh thread already started")
287
288
        logger.debug(
289
            "Starting thread to refresh lock every %s seconds",
290
            self._lock_renewal_interval
291
        )
292
        self._lock_renewal_stop = threading.Event()
293
        self._lock_renewal_thread = threading.Thread(
294
            group=None,
295
            target=self._lock_renewer,
296
            kwargs={'lockref': weakref.ref(self),
297
                    'interval': self._lock_renewal_interval,
298
                    'stop': self._lock_renewal_stop}
299
        )
300
        self._lock_renewal_thread.setDaemon(True)
301
        self._lock_renewal_thread.start()
302
303
    def _stop_lock_renewer(self):
304
        """
305
        Stop the lock renewer.
306
307
        This signals the renewal thread and waits for its exit.
308
        """
309
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
310
            return
311
        logger.debug("Signalling the lock refresher to stop")
312
        self._lock_renewal_stop.set()
313
        self._lock_renewal_thread.join()
314
        self._lock_renewal_thread = None
315
        logger.debug("Lock refresher has stopped")
316
317
    def __enter__(self):
318
        acquired = self.acquire(blocking=True)
319
        assert acquired, "Lock wasn't acquired, but blocking=True"
320
        return self
321
322
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
323
        self.release()
324
325
    def release(self):
326
        """Releases the lock, that was acquired with the same object.
327
328
        .. note::
329
330
            If you want to release a lock that you acquired in a different place you have two choices:
331
332
            * Use ``Lock("name", id=id_from_other_place).release()``
333
            * Use ``Lock("name").reset()``
334
        """
335
        if not self._held:
336
            raise NotAcquired("This Lock instance didn't acquire the lock.")
337
        if self._lock_renewal_thread is not None:
338
            self._stop_lock_renewer()
339
        logger.debug("Releasing %r.", self._name)
340
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
341
        if error == 1:
342
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
343
        elif error:
344
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
345
        else:
346
            self._delete_signal()
347
348
    def _delete_signal(self):
349
        self._client.delete(self._signal)
350
351
352
def reset_all(redis_client):
353
    """
354
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
355
    """
356
    _eval_script(redis_client, RESET_ALL)
357
    _eval_script(redis_client, DELETE_ALL_SIGNAL_KEYS)
358