Completed
Push — master ( 95d8d5...80ebcd )
by Ionel Cristian
6s
created

Lock.extend()   B

Complexity

Conditions 6

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 23
rs 7.6949
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
import weakref
6
7
from redis import StrictRedis
8
from redis.exceptions import NoScriptError
9
10
__version__ = "3.0.0"
11
12
logger = getLogger(__name__)
13
14
# Check if the id match. If not, return an error code.
15
UNLOCK_SCRIPT = b"""
16
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
17
        return 1
18
    else
19
        redis.call("del", KEYS[2])
20
        redis.call("lpush", KEYS[2], 1)
21
        redis.call("del", KEYS[1])
22
        return 0
23
    end
24
"""
25
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
26
27
# Covers both cases when key doesn't exist and doesn't equal to lock's id
28
EXTEND_SCRIPT = b"""
29
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
30
        return 1
31
    elseif redis.call("ttl", KEYS[1]) < 0 then
32
        return 2
33
    else
34
        redis.call("expire", KEYS[1], ARGV[1])
35
        return 0
36
    end
37
"""
38
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
39
40
RESET_SCRIPT = b"""
41
    redis.call('del', KEYS[2])
42
    redis.call('lpush', KEYS[2], 1)
43
    return redis.call('del', KEYS[1])
44
"""
45
46
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
47
48
RESET_ALL_SCRIPT = b"""
49
    local locks = redis.call('keys', 'lock:*')
50
    local signal
51
    for _, lock in pairs(locks) do
52
        signal = 'lock-signal:' .. string.sub(lock, 6)
53
        redis.call('del', signal)
54
        redis.call('lpush', signal, 1)
55
        redis.call('del', lock)
56
    end
57
    return #locks
58
"""
59
60
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
61
62
DELETE_ALL_SIGNAL_KEYS_SCRIPT = b"""
63
    local signals = redis.call('keys', 'lock-signal:*')
64
    for _, signal in pairs(signals) do
65
        redis.call('del', signal)
66
    end
67
    return #signals
68
"""
69
70
DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH = (sha1(DELETE_ALL_SIGNAL_KEYS_SCRIPT)
71
                                      .hexdigest())
72
73
74
class AlreadyAcquired(RuntimeError):
75
    pass
76
77
78
class NotAcquired(RuntimeError):
79
    pass
80
81
82
class AlreadyStarted(RuntimeError):
83
    pass
84
85
86
class TimeoutNotUsable(RuntimeError):
87
    pass
88
89
90
class InvalidTimeout(RuntimeError):
91
    pass
92
93
94
class TimeoutTooLarge(RuntimeError):
95
    pass
96
97
98
class NotExpirable(RuntimeError):
99
    pass
100
101
102
((UNLOCK, _, _,   # noqa
103
  EXTEND, _, _,
104
  RESET, _, _,
105
  RESET_ALL, _, _,
106
  DELETE_ALL_SIGNAL_KEYS, _, _),
107
 SCRIPTS) = zip(*enumerate([
108
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
109
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
110
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
111
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT',
112
    DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH, DELETE_ALL_SIGNAL_KEYS_SCRIPT,
113
    'DELETE_ALL_SIGNAL_KEYS_SCRIPT'
114
]))
115
116
117
def _eval_script(redis, script_id, *keys, **kwargs):
118
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
119
    regular ``EVAL`` with the `script`.
120
    """
121
    args = kwargs.pop('args', ())
122
    if kwargs:
123
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
124
    try:
125
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
126
    except NoScriptError:
127
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
128
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
129
130
131
class Lock(object):
132
    """
133
    A Lock context manager implemented via redis SETNX/BLPOP.
134
    """
135
136
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
137
        """
138
        :param redis_client:
139
            An instance of :class:`~StrictRedis`.
140
        :param name:
141
            The name (redis key) the lock should have.
142
        :param expire:
143
            The lock expiry time in seconds. If left at the default (None)
144
            the lock will not expire.
145
        :param id:
146
            The ID (redis value) the lock should have. A random value is
147
            generated when left at the default.
148
149
            Note that if you specify this then the lock is marked as "held". Acquires
150
            won't be possible.
151
        :param auto_renewal:
152
            If set to True, Lock will automatically renew the lock so that it
153
            doesn't expire for as long as the lock is held (acquire() called
154
            or running in a context manager).
155
156
            Implementation note: Renewal will happen using a daemon thread with
157
            an interval of expire*2/3. If wishing to use a different renewal
158
            time, subclass Lock, call super().__init__() then set
159
            self._lock_renewal_interval to your desired interval.
160
        """
161
        assert isinstance(redis_client, StrictRedis)
162
        if auto_renewal and expire is None:
163
            raise ValueError("Expire may not be None when auto_renewal is set")
164
165
        self._client = redis_client
166
        self._expire = expire if expire is None else int(expire)
167
        self._id = urandom(16) if id is None else id
168
        self._held = id is not None
169
        self._name = 'lock:'+name
170
        self._signal = 'lock-signal:'+name
171
        self._lock_renewal_interval = (float(expire)*2/3
172
                                       if auto_renewal
173
                                       else None)
174
        self._lock_renewal_thread = None
175
176
    def reset(self):
177
        """
178
        Forcibly deletes the lock. Use this with care.
179
        """
180
        _eval_script(self._client, RESET, self._name, self._signal)
181
        self._delete_signal()
182
183
    @property
184
    def id(self):
185
        return self._id
186
187
    def get_owner_id(self):
188
        return self._client.get(self._name)
189
190
    def acquire(self, blocking=True, timeout=None):
191
        """
192
        :param blocking:
193
            Boolean value specifying whether lock should be blocking or not.
194
        :param timeout:
195
            An integer value specifying the maximum number of seconds to block.
196
        """
197
        logger.debug("Getting %r ...", self._name)
198
199
        if self._held:
200
            raise AlreadyAcquired("Already acquired from this Lock instance.")
201
202
        if not blocking and timeout is not None:
203
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
204
205
        timeout = timeout if timeout is None else int(timeout)
206
        if timeout is not None and timeout <= 0:
207
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
208
209
        if timeout and self._expire and timeout > self._expire:
210
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
211
212
        busy = True
213
        blpop_timeout = timeout or self._expire or 0
214
        timed_out = False
215
        while busy:
216
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
217
            if busy:
218
                if timed_out:
219
                    return False
220
                elif blocking:
221
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
222
                else:
223
                    logger.debug("Failed to get %r.", self._name)
224
                    return False
225
226
        logger.debug("Got lock for %r.", self._name)
227
        self._held = True
228
        if self._lock_renewal_interval is not None:
229
            self._start_lock_renewer()
230
        return True
231
232
    def extend(self, expire=None):
233
        """Extends expiration time of the lock.
234
235
        :param expire:
236
            New expiration time. If ``None`` - `expire` provided during
237
            lock initialization will be taken.
238
        """
239
        if expire is None:
240
            if self._expire is not None:
241
                expire = self._expire
242
            else:
243
                raise TypeError(
244
                    "To extend a lock 'expire' must be provided as an "
245
                    "argument to extend() method or at initialization time."
246
                )
247
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
248
        if error == 1:
249
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
250
        elif error == 2:
251
            raise NotExpirable("Lock %s has no assigned expiration time" %
252
                               self._name)
253
        elif error:
254
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
255
256
    @staticmethod
257
    def _lock_renewer(lockref, interval, stop):
258
        """
259
        Renew the lock key in redis every `interval` seconds for as long
260
        as `self._lock_renewal_thread.should_exit` is False.
261
        """
262
        log = getLogger("%s.lock_refresher" % __name__)
263
        while not stop.wait(timeout=interval):
264
            log.debug("Refreshing lock")
265
            lock = lockref()
266
            if lock is None:
267
                log.debug("The lock no longer exists, "
268
                          "stopping lock refreshing")
269
                break
270
            lock.extend(expire=lock._expire)
271
            del lock
272
        log.debug("Exit requested, stopping lock refreshing")
273
274
    def _start_lock_renewer(self):
275
        """
276
        Starts the lock refresher thread.
277
        """
278
        if self._lock_renewal_thread is not None:
279
            raise AlreadyStarted("Lock refresh thread already started")
280
281
        logger.debug(
282
            "Starting thread to refresh lock every %s seconds",
283
            self._lock_renewal_interval
284
        )
285
        self._lock_renewal_stop = threading.Event()
286
        self._lock_renewal_thread = threading.Thread(
287
            group=None,
288
            target=self._lock_renewer,
289
            kwargs={'lockref': weakref.ref(self),
290
                    'interval': self._lock_renewal_interval,
291
                    'stop': self._lock_renewal_stop}
292
        )
293
        self._lock_renewal_thread.setDaemon(True)
294
        self._lock_renewal_thread.start()
295
296
    def _stop_lock_renewer(self):
297
        """
298
        Stop the lock renewer.
299
300
        This signals the renewal thread and waits for its exit.
301
        """
302
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
303
            return
304
        logger.debug("Signalling the lock refresher to stop")
305
        self._lock_renewal_stop.set()
306
        self._lock_renewal_thread.join()
307
        self._lock_renewal_thread = None
308
        logger.debug("Lock refresher has stopped")
309
310
    def __enter__(self):
311
        acquired = self.acquire(blocking=True)
312
        assert acquired, "Lock wasn't acquired, but blocking=True"
313
        return self
314
315
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
316
        self.release()
317
318
    def release(self):
319
        """Releases the lock, that was acquired with the same object.
320
321
        .. note::
322
323
            If you want to release a lock that you acquired in a different place you have two choices:
324
325
            * Use ``Lock("name", id=id_from_other_place).release()``
326
            * Use ``Lock("name").reset()``
327
        """
328
        if not self._held:
329
            raise NotAcquired("This Lock instance didn't acquire the lock.")
330
        if self._lock_renewal_thread is not None:
331
            self._stop_lock_renewer()
332
        logger.debug("Releasing %r.", self._name)
333
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
334
        if error == 1:
335
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
336
        elif error:
337
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
338
        else:
339
            self._held = False
340
            self._delete_signal()
341
342
    def _delete_signal(self):
343
        self._client.delete(self._signal)
344
345
346
def reset_all(redis_client):
347
    """
348
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
349
    """
350
    _eval_script(redis_client, RESET_ALL)
351
    _eval_script(redis_client, DELETE_ALL_SIGNAL_KEYS)
352