Completed
Pull Request — master (#46)
by Ionel Cristian
57s
created

Lock.extend()   B

Complexity

Conditions 6

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 23
rs 7.6949
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
import weakref
6
7
from redis import StrictRedis
8
from redis.exceptions import NoScriptError
9
10
__version__ = "3.0.0"
11
12
logger = getLogger(__name__)
13
14
# Check if the id match. If not, return an error code.
15
UNLOCK_SCRIPT = b"""
16
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
17
        return 1
18
    else
19
        redis.call("del", KEYS[2])
20
        redis.call("lpush", KEYS[2], 1)
21
        redis.call("expire", KEYS[2], 1)
22
        redis.call("del", KEYS[1])
23
        return 0
24
    end
25
"""
26
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
27
28
# Covers both cases when key doesn't exist and doesn't equal to lock's id
29
EXTEND_SCRIPT = b"""
30
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
31
        return 1
32
    elseif redis.call("ttl", KEYS[1]) < 0 then
33
        return 2
34
    else
35
        redis.call("expire", KEYS[1], ARGV[1])
36
        return 0
37
    end
38
"""
39
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
40
41
RESET_SCRIPT = b"""
42
    redis.call('del', KEYS[2])
43
    redis.call('lpush', KEYS[2], 1)
44
    redis.call('expire', KEYS[2], 1)
45
    return redis.call('del', KEYS[1])
46
"""
47
48
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
49
50
RESET_ALL_SCRIPT = b"""
51
    local locks = redis.call('keys', 'lock:*')
52
    local signal
53
    for _, lock in pairs(locks) do
54
        signal = 'lock-signal:' .. string.sub(lock, 6)
55
        redis.call('del', signal)
56
        redis.call('lpush', signal, 1)
57
        redis.call('expire', signal, 1)
58
        redis.call('del', lock)
59
    end
60
    return #locks
61
"""
62
63
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
64
65
66
class AlreadyAcquired(RuntimeError):
67
    pass
68
69
70
class NotAcquired(RuntimeError):
71
    pass
72
73
74
class AlreadyStarted(RuntimeError):
75
    pass
76
77
78
class TimeoutNotUsable(RuntimeError):
79
    pass
80
81
82
class InvalidTimeout(RuntimeError):
83
    pass
84
85
86
class TimeoutTooLarge(RuntimeError):
87
    pass
88
89
90
class NotExpirable(RuntimeError):
91
    pass
92
93
94
((UNLOCK, _, _,   # noqa
95
  EXTEND, _, _,
96
  RESET, _, _,
97
  RESET_ALL, _, _),
98
 SCRIPTS) = zip(*enumerate([
99
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
100
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
101
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
102
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT'
103
]))
104
105
106
def _eval_script(redis, script_id, *keys, **kwargs):
107
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
108
    regular ``EVAL`` with the `script`.
109
    """
110
    args = kwargs.pop('args', ())
111
    if kwargs:
112
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
113
    try:
114
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
115
    except NoScriptError:
116
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
117
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
118
119
120
class Lock(object):
121
    """
122
    A Lock context manager implemented via redis SETNX/BLPOP.
123
    """
124
125
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
126
        """
127
        :param redis_client:
128
            An instance of :class:`~StrictRedis`.
129
        :param name:
130
            The name (redis key) the lock should have.
131
        :param expire:
132
            The lock expiry time in seconds. If left at the default (None)
133
            the lock will not expire.
134
        :param id:
135
            The ID (redis value) the lock should have. A random value is
136
            generated when left at the default.
137
138
            Note that if you specify this then the lock is marked as "held". Acquires
139
            won't be possible.
140
        :param auto_renewal:
141
            If set to True, Lock will automatically renew the lock so that it
142
            doesn't expire for as long as the lock is held (acquire() called
143
            or running in a context manager).
144
145
            Implementation note: Renewal will happen using a daemon thread with
146
            an interval of expire*2/3. If wishing to use a different renewal
147
            time, subclass Lock, call super().__init__() then set
148
            self._lock_renewal_interval to your desired interval.
149
        """
150
        assert isinstance(redis_client, StrictRedis)
151
        if auto_renewal and expire is None:
152
            raise ValueError("Expire may not be None when auto_renewal is set")
153
154
        self._client = redis_client
155
        self._expire = expire if expire is None else int(expire)
156
        self._id = urandom(16) if id is None else id
157
        self._name = 'lock:'+name
158
        self._signal = 'lock-signal:'+name
159
        self._lock_renewal_interval = (float(expire)*2/3
160
                                       if auto_renewal
161
                                       else None)
162
        self._lock_renewal_thread = None
163
164
    @property
165
    def _held(self):
166
        return self.id == self.get_owner_id()
167
168
    def reset(self):
169
        """
170
        Forcibly deletes the lock. Use this with care.
171
        """
172
        _eval_script(self._client, RESET, self._name, self._signal)
173
174
    @property
175
    def id(self):
176
        return self._id
177
178
    def get_owner_id(self):
179
        return self._client.get(self._name)
180
181
    def acquire(self, blocking=True, timeout=None):
182
        """
183
        :param blocking:
184
            Boolean value specifying whether lock should be blocking or not.
185
        :param timeout:
186
            An integer value specifying the maximum number of seconds to block.
187
        """
188
        logger.debug("Getting %r ...", self._name)
189
190
        if self._held:
191
            raise AlreadyAcquired("Already acquired from this Lock instance.")
192
193
        if not blocking and timeout is not None:
194
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
195
196
        timeout = timeout if timeout is None else int(timeout)
197
        if timeout is not None and timeout <= 0:
198
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
199
200
        if timeout and self._expire and timeout > self._expire:
201
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
202
203
        busy = True
204
        blpop_timeout = timeout or self._expire or 0
205
        timed_out = False
206
        while busy:
207
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
208
            if busy:
209
                if timed_out:
210
                    return False
211
                elif blocking:
212
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
213
                else:
214
                    logger.debug("Failed to get %r.", self._name)
215
                    return False
216
217
        logger.debug("Got lock for %r.", self._name)
218
        if self._lock_renewal_interval is not None:
219
            self._start_lock_renewer()
220
        return True
221
222
    def extend(self, expire=None):
223
        """Extends expiration time of the lock.
224
225
        :param expire:
226
            New expiration time. If ``None`` - `expire` provided during
227
            lock initialization will be taken.
228
        """
229
        if expire is None:
230
            if self._expire is not None:
231
                expire = self._expire
232
            else:
233
                raise TypeError(
234
                    "To extend a lock 'expire' must be provided as an "
235
                    "argument to extend() method or at initialization time."
236
                )
237
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
238
        if error == 1:
239
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
240
        elif error == 2:
241
            raise NotExpirable("Lock %s has no assigned expiration time" %
242
                               self._name)
243
        elif error:
244
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
245
246
    @staticmethod
247
    def _lock_renewer(lockref, interval, stop):
248
        """
249
        Renew the lock key in redis every `interval` seconds for as long
250
        as `self._lock_renewal_thread.should_exit` is False.
251
        """
252
        log = getLogger("%s.lock_refresher" % __name__)
253
        while not stop.wait(timeout=interval):
254
            log.debug("Refreshing lock")
255
            lock = lockref()
256
            if lock is None:
257
                log.debug("The lock no longer exists, "
258
                          "stopping lock refreshing")
259
                break
260
            lock.extend(expire=lock._expire)
261
            del lock
262
        log.debug("Exit requested, stopping lock refreshing")
263
264
    def _start_lock_renewer(self):
265
        """
266
        Starts the lock refresher thread.
267
        """
268
        if self._lock_renewal_thread is not None:
269
            raise AlreadyStarted("Lock refresh thread already started")
270
271
        logger.debug(
272
            "Starting thread to refresh lock every %s seconds",
273
            self._lock_renewal_interval
274
        )
275
        self._lock_renewal_stop = threading.Event()
276
        self._lock_renewal_thread = threading.Thread(
277
            group=None,
278
            target=self._lock_renewer,
279
            kwargs={'lockref': weakref.ref(self),
280
                    'interval': self._lock_renewal_interval,
281
                    'stop': self._lock_renewal_stop}
282
        )
283
        self._lock_renewal_thread.setDaemon(True)
284
        self._lock_renewal_thread.start()
285
286
    def _stop_lock_renewer(self):
287
        """
288
        Stop the lock renewer.
289
290
        This signals the renewal thread and waits for its exit.
291
        """
292
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
293
            return
294
        logger.debug("Signalling the lock refresher to stop")
295
        self._lock_renewal_stop.set()
296
        self._lock_renewal_thread.join()
297
        self._lock_renewal_thread = None
298
        logger.debug("Lock refresher has stopped")
299
300
    def __enter__(self):
301
        acquired = self.acquire(blocking=True)
302
        assert acquired, "Lock wasn't acquired, but blocking=True"
303
        return self
304
305
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
306
        self.release()
307
308
    def release(self):
309
        """Releases the lock, that was acquired with the same object.
310
311
        .. note::
312
313
            If you want to release a lock that you acquired in a different place you have two choices:
314
315
            * Use ``Lock("name", id=id_from_other_place).release()``
316
            * Use ``Lock("name").reset()``
317
        """
318
        if not self._held:
319
            raise NotAcquired("This Lock instance didn't acquire the lock.")
320
        if self._lock_renewal_thread is not None:
321
            self._stop_lock_renewer()
322
        logger.debug("Releasing %r.", self._name)
323
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
324
        if error == 1:
325
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
326
        elif error:
327
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
328
329
330
def reset_all(redis_client):
331
    """
332
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
333
    """
334
    _eval_script(redis_client, RESET_ALL)
335