Completed
Push — master ( 949a86...c5f0ee )
by Ionel Cristian
01:04
created

src.redis_lock.Lock.extend()   B

Complexity

Conditions 6

Size

Total Lines 24

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 24
rs 7.6129
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
6
from redis import StrictRedis
7
from redis.exceptions import NoScriptError
8
9
__version__ = "2.3.0"
10
11
logger = getLogger(__name__)
12
13
UNLOCK_SCRIPT = b"""
14
    if redis.call("get", KEYS[1]) == ARGV[1] then
15
        redis.call("del", KEYS[2])
16
        redis.call("lpush", KEYS[2], 1)
17
        redis.call("expire", KEYS[2], 1)
18
        return redis.call("del", KEYS[1])
19
    else
20
        return 0
21
    end
22
"""
23
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
24
25
EXTEND_SCRIPT = b"""
26
    -- Covers both cases when key doesn't exist and doesn't equal to lock's id
27
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
28
        return -1
29
    elseif redis.call("ttl", KEYS[1]) < 0 then
30
        return -2
31
    else
32
        redis.call("expire", KEYS[1], ARGV[1])
33
        return 0
34
    end
35
"""
36
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
37
38
RESET_SCRIPT = b"""
39
    redis.call('del', KEYS[2])
40
    redis.call('lpush', KEYS[2], 1)
41
    redis.call('expire', KEYS[2], 1)
42
    return redis.call('del', KEYS[1])
43
"""
44
45
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
46
47
RESET_ALL_SCRIPT = b"""
48
    local locks = redis.call('keys', 'lock:*')
49
    local signal
50
    for _, lock in pairs(locks) do
51
        signal = 'lock-signal:' .. string.sub(lock, 6)
52
        redis.call('del', signal)
53
        redis.call('lpush', signal, 1)
54
        redis.call('expire', signal, 1)
55
        redis.call('del', lock)
56
    end
57
    return #locks
58
"""
59
60
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
61
62
63
class AlreadyAcquired(RuntimeError):
64
    pass
65
66
67
class NotAcquired(RuntimeError):
68
    pass
69
70
71
class AlreadyStarted(RuntimeError):
72
    pass
73
74
75
class TimeoutNotUsable(RuntimeError):
76
    pass
77
78
79
class InvalidTimeout(RuntimeError):
80
    pass
81
82
83
class TimeoutTooLarge(RuntimeError):
84
    pass
85
86
87
class NotExpirable(RuntimeError):
88
    pass
89
90
91
((UNLOCK, _, _,   # noqa
92
  EXTEND, _, _,
93
  RESET, _, _,
94
  RESET_ALL, _, _),
95
 SCRIPTS) = zip(*enumerate([
96
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
97
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
98
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
99
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT'
100
]))
101
102
103
def _eval_script(redis, script_id, *args, **kwargs):
104
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
105
    regular ``EVAL`` with the `script`.
106
    """
107
    try:
108
        return redis.evalsha(SCRIPTS[script_id], *args, **kwargs)
109
    except NoScriptError:
110
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
111
        return redis.eval(SCRIPTS[script_id + 1], *args, **kwargs)
112
113
114
class Lock(object):
115
    """
116
    A Lock context manager implemented via redis SETNX/BLPOP.
117
    """
118
119
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
120
        """
121
        :param redis_client:
122
            An instance of :class:`~StrictRedis`.
123
        :param name:
124
            The name (redis key) the lock should have.
125
        :param expire:
126
            The lock expiry time in seconds. If left at the default (None)
127
            the lock will not expire.
128
        :param id:
129
            The ID (redis value) the lock should have. A random value is
130
            generated when left at the default.
131
        :param auto_renewal:
132
            If set to True, Lock will automatically renew the lock so that it
133
            doesn't expire for as long as the lock is held (acquire() called
134
            or running in a context manager).
135
136
            Implementation note: Renewal will happen using a daemon thread with
137
            an interval of expire*2/3. If wishing to use a different renewal
138
            time, subclass Lock, call super().__init__() then set
139
            self._lock_renewal_interval to your desired interval.
140
        """
141
        assert isinstance(redis_client, StrictRedis)
142
        if auto_renewal and expire is None:
143
            raise ValueError("Expire may not be None when auto_renewal is set")
144
145
        self._client = redis_client
146
        self._expire = expire if expire is None else int(expire)
147
        self._id = urandom(16) if id is None else id
148
        self._held = False
149
        self._name = 'lock:'+name
150
        self._signal = 'lock-signal:'+name
151
        self._lock_renewal_interval = expire*2/3 if auto_renewal else None
152
        self._lock_renewal_thread = None
153
154
    def reset(self):
155
        """
156
        Forcibly deletes the lock. Use this with care.
157
        """
158
        _eval_script(self._client, RESET, 2, self._name, self._signal)
159
160
    @property
161
    def id(self):
162
        return self._id
163
164
    def get_owner_id(self):
165
        return self._client.get(self._name)
166
167
    def acquire(self, blocking=True, timeout=None):
168
        """
169
        :param blocking:
170
            Boolean value specifying whether lock should be blocking or not.
171
        :param timeout:
172
            An integer value specifying the maximum number of seconds to block.
173
        """
174
        logger.debug("Getting %r ...", self._name)
175
176
        if self._held:
177
            raise AlreadyAcquired("Already acquired from this Lock instance.")
178
179
        if not blocking and timeout is not None:
180
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
181
182
        timeout = timeout if timeout is None else int(timeout)
183
        if timeout is not None and timeout <= 0:
184
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
185
186
        if timeout and self._expire and timeout > self._expire:
187
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
188
189
        busy = True
190
        blpop_timeout = timeout or self._expire or 0
191
        timed_out = False
192
        while busy:
193
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
194
            if busy:
195
                if timed_out:
196
                    return False
197
                elif blocking:
198
                    timed_out = not self._client.blpop(self._signal, blpop_timeout)
199
                else:
200
                    logger.debug("Failed to get %r.", self._name)
201
                    return False
202
203
        logger.debug("Got lock for %r.", self._name)
204
        self._held = True
205
        if self._lock_renewal_interval is not None:
206
            self._start_lock_renewer()
207
        return True
208
209
    def extend(self, expire=None):
210
        """Extends expiration time of the lock.
211
212
        :param expire:
213
            New expiration time. If ``None`` - `expire` provided during
214
            lock initialization will be taken.
215
        """
216
        if expire is None:
217
            if self._expire is not None:
218
                expire = self._expire
219
            else:
220
                raise TypeError(
221
                    "To extend a lock 'expire' must be provided as an "
222
                    "argument to extend() method or at initialization time."
223
                )
224
        result = _eval_script(self._client, EXTEND, 1,
225
                              self._name, expire, self._id)
226
        if result == 0:
227
            return result
228
        elif result == -1:
229
            raise NotAcquired("Lock %s is not acquired" % self._name)
230
        elif result == -2:
231
            raise NotExpirable("Lock %s has no assigned expiration time" %
232
                               self._name)
233
234
    def _lock_renewer(self, interval):
235
        """
236
        Renew the lock key in redis every `interval` seconds for as long
237
        as `self._lock_renewal_thread.should_exit` is False.
238
        """
239
        log = getLogger("%s.lock_refresher" % __name__)
240
        while not self._lock_renewal_thread.wait_for_exit_request(timeout=interval):
241
            log.debug("Refreshing lock")
242
            self.extend(expire=self._expire)
243
        log.debug("Exit requested, stopping lock refreshing")
244
245
    def _start_lock_renewer(self):
246
        """
247
        Starts the lock refresher thread.
248
        """
249
        if self._lock_renewal_thread is not None:
250
            raise AlreadyStarted("Lock refresh thread already started")
251
252
        logger.debug(
253
            "Starting thread to refresh lock every %s seconds",
254
            self._lock_renewal_interval
255
        )
256
        self._lock_renewal_thread = InterruptableThread(
257
            group=None,
258
            target=self._lock_renewer,
259
            kwargs={'interval': self._lock_renewal_interval}
260
        )
261
        self._lock_renewal_thread.setDaemon(True)
262
        self._lock_renewal_thread.start()
263
264
    def _stop_lock_renewer(self):
265
        """
266
        Stop the lock renewer.
267
268
        This signals the renewal thread and waits for its exit.
269
        """
270
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
271
            return
272
        logger.debug("Signalling the lock refresher to stop")
273
        self._lock_renewal_thread.request_exit()
274
        self._lock_renewal_thread.join()
275
        self._lock_renewal_thread = None
276
        logger.debug("Lock refresher has stopped")
277
278
    def __enter__(self):
279
        acquired = self.acquire(blocking=True)
280
        assert acquired, "Lock wasn't acquired, but blocking=True"
281
        return self
282
283
    def __exit__(self, exc_type=None, exc_value=None, traceback=None, force=False):
284
        if not (self._held or force):
285
            raise NotAcquired("This Lock instance didn't acquire the lock.")
286
        if self._lock_renewal_thread is not None:
287
            self._stop_lock_renewer()
288
        logger.debug("Releasing %r.", self._name)
289
        _eval_script(self._client, UNLOCK,
290
                     2, self._name, self._signal, self._id)
291
292
        self._held = False
293
294
    def release(self, force=False):
295
        """Releases the lock, that was acquired in the same Python context.
296
297
        :param force:
298
            If ``False`` - fail with exception if this instance was not in
299
            acquired state in the same Python context.
300
            If ``True`` - fail silently.
301
        """
302
        return self.__exit__(force=force)
303
304
305
class InterruptableThread(threading.Thread):
306
    """
307
    A Python thread that can be requested to stop by calling request_exit()
308
    on it.
309
310
    Code running inside this thread should periodically check the
311
    `should_exit` property (or use wait_for_exit_request) on the thread
312
    object and stop further processing once it returns True.
313
    """
314
    def __init__(self, *args, **kwargs):
315
        self._should_exit = threading.Event()
316
        super(InterruptableThread, self).__init__(*args, **kwargs)
317
318
    def request_exit(self):
319
        """
320
        Signal the thread that it should stop performing more work and exit.
321
        """
322
        self._should_exit.set()
323
324
    @property
325
    def should_exit(self):
326
        return self._should_exit.isSet()
327
328
    def wait_for_exit_request(self, timeout=None):
329
        """
330
        Wait until the thread has been signalled to exit.
331
332
        If timeout is specified (as a float of seconds to wait) then wait
333
        up to this many seconds before returning the value of `should_exit`.
334
        """
335
        should_exit = self._should_exit.wait(timeout)
336
        if should_exit is None:
337
            # Python 2.6 compatibility which doesn't return self.__flag when
338
            # calling Event.wait()
339
            should_exit = self.should_exit
340
        return should_exit
341
342
343
def reset_all(redis_client):
344
    """
345
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
346
    """
347
    _eval_script(redis_client, RESET_ALL, 0)
348