Completed
Pull Request — master (#33)
by
unknown
51s
created

src.redis_lock.InterruptableThread.wait_for_exit_request()   A

Complexity

Conditions 2

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 13
rs 9.4285
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
import weakref
6
7
from redis import StrictRedis
8
from redis.exceptions import NoScriptError
9
10
__version__ = "3.0.0"
11
12
logger = getLogger(__name__)
13
14
# Check if the id match. If not, return an error code.
15
UNLOCK_SCRIPT = b"""
16
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
17
        return 1
18
    else
19
        redis.call("del", KEYS[2])
20
        redis.call("lpush", KEYS[2], 1)
21
        redis.call("expire", KEYS[2], 1)
22
        redis.call("del", KEYS[1])
23
        return 0
24
    end
25
"""
26
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
27
28
# Covers both cases when key doesn't exist and doesn't equal to lock's id
29
EXTEND_SCRIPT = b"""
30
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
31
        return 1
32
    elseif redis.call("ttl", KEYS[1]) < 0 then
33
        return 2
34
    else
35
        redis.call("expire", KEYS[1], ARGV[1])
36
        return 0
37
    end
38
"""
39
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
40
41
RESET_SCRIPT = b"""
42
    redis.call('del', KEYS[2])
43
    redis.call('lpush', KEYS[2], 1)
44
    redis.call('expire', KEYS[2], 1)
45
    return redis.call('del', KEYS[1])
46
"""
47
48
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
49
50
RESET_ALL_SCRIPT = b"""
51
    local locks = redis.call('keys', 'lock:*')
52
    local signal
53
    for _, lock in pairs(locks) do
54
        signal = 'lock-signal:' .. string.sub(lock, 6)
55
        redis.call('del', signal)
56
        redis.call('lpush', signal, 1)
57
        redis.call('expire', signal, 1)
58
        redis.call('del', lock)
59
    end
60
    return #locks
61
"""
62
63
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
64
65
66
class AlreadyAcquired(RuntimeError):
67
    pass
68
69
70
class NotAcquired(RuntimeError):
71
    pass
72
73
74
class AlreadyStarted(RuntimeError):
75
    pass
76
77
78
class TimeoutNotUsable(RuntimeError):
79
    pass
80
81
82
class InvalidTimeout(RuntimeError):
83
    pass
84
85
86
class TimeoutTooLarge(RuntimeError):
87
    pass
88
89
90
class NotExpirable(RuntimeError):
91
    pass
92
93
94
((UNLOCK, _, _,   # noqa
95
  EXTEND, _, _,
96
  RESET, _, _,
97
  RESET_ALL, _, _),
98
 SCRIPTS) = zip(*enumerate([
99
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
100
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
101
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
102
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT'
103
]))
104
105
106
def _eval_script(redis, script_id, *keys, **kwargs):
107
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
108
    regular ``EVAL`` with the `script`.
109
    """
110
    args = kwargs.pop('args', ())
111
    if kwargs:
112
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
113
    try:
114
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
115
    except NoScriptError:
116
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
117
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
118
119
120
class Lock(object):
121
    """
122
    A Lock context manager implemented via redis SETNX/BLPOP.
123
    """
124
125
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
126
        """
127
        :param redis_client:
128
            An instance of :class:`~StrictRedis`.
129
        :param name:
130
            The name (redis key) the lock should have.
131
        :param expire:
132
            The lock expiry time in seconds. If left at the default (None)
133
            the lock will not expire.
134
        :param id:
135
            The ID (redis value) the lock should have. A random value is
136
            generated when left at the default.
137
138
            Note that if you specify this then the lock is marked as "held". Acquires
139
            won't be possible.
140
        :param auto_renewal:
141
            If set to True, Lock will automatically renew the lock so that it
142
            doesn't expire for as long as the lock is held (acquire() called
143
            or running in a context manager).
144
145
            Implementation note: Renewal will happen using a daemon thread with
146
            an interval of expire*2/3. If wishing to use a different renewal
147
            time, subclass Lock, call super().__init__() then set
148
            self._lock_renewal_interval to your desired interval.
149
        """
150
        assert isinstance(redis_client, StrictRedis)
151
        if auto_renewal and expire is None:
152
            raise ValueError("Expire may not be None when auto_renewal is set")
153
154
        self._client = redis_client
155
        self._expire = expire if expire is None else int(expire)
156
        self._id = urandom(16) if id is None else id
157
        self._held = id is not None
158
        self._name = 'lock:'+name
159
        self._signal = 'lock-signal:'+name
160
        self._lock_renewal_interval = expire*2/3 if auto_renewal else None
161
        self._lock_renewal_thread = None
162
163
    def reset(self):
164
        """
165
        Forcibly deletes the lock. Use this with care.
166
        """
167
        _eval_script(self._client, RESET, self._name, self._signal)
168
169
    @property
170
    def id(self):
171
        return self._id
172
173
    def get_owner_id(self):
174
        return self._client.get(self._name)
175
176
    def acquire(self, blocking=True, timeout=None):
177
        """
178
        :param blocking:
179
            Boolean value specifying whether lock should be blocking or not.
180
        :param timeout:
181
            An integer value specifying the maximum number of seconds to block.
182
        """
183
        logger.debug("Getting %r ...", self._name)
184
185
        if self._held:
186
            raise AlreadyAcquired("Already acquired from this Lock instance.")
187
188
        if not blocking and timeout is not None:
189
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
190
191
        timeout = timeout if timeout is None else int(timeout)
192
        if timeout is not None and timeout <= 0:
193
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
194
195
        if timeout and self._expire and timeout > self._expire:
196
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
197
198
        busy = True
199
        blpop_timeout = timeout or self._expire or 0
200
        timed_out = False
201
        while busy:
202
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
203
            if busy:
204
                if timed_out:
205
                    return False
206
                elif blocking:
207
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
208
                else:
209
                    logger.debug("Failed to get %r.", self._name)
210
                    return False
211
212
        logger.debug("Got lock for %r.", self._name)
213
        self._held = True
214
        if self._lock_renewal_interval is not None:
215
            self._start_lock_renewer()
216
        return True
217
218
    def extend(self, expire=None):
219
        """Extends expiration time of the lock.
220
221
        :param expire:
222
            New expiration time. If ``None`` - `expire` provided during
223
            lock initialization will be taken.
224
        """
225
        if expire is None:
226
            if self._expire is not None:
227
                expire = self._expire
228
            else:
229
                raise TypeError(
230
                    "To extend a lock 'expire' must be provided as an "
231
                    "argument to extend() method or at initialization time."
232
                )
233
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
234
        if error == 1:
235
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
236
        elif error == 2:
237
            raise NotExpirable("Lock %s has no assigned expiration time" %
238
                               self._name)
239
        elif error:
240
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
241
242
    @staticmethod
243
    def _lock_renewer(lockref, interval, stop):
244
        """
245
        Renew the lock key in redis every `interval` seconds for as long
246
        as `self._lock_renewal_thread.should_exit` is False.
247
        """
248
        log = getLogger("%s.lock_refresher" % __name__)
249
        while not stop.wait(timeout=interval):
250
            log.debug("Refreshing lock")
251
            try:
252
                lockref().extend(expire=lockref()._expire)
253
            except AttributeError:
254
                log.debug("The lock no longer exists, "
255
                          "stopping lock refreshing")
256
                break
257
        log.debug("Exit requested, stopping lock refreshing")
258
259
    def _start_lock_renewer(self):
260
        """
261
        Starts the lock refresher thread.
262
        """
263
        if self._lock_renewal_thread is not None:
264
            raise AlreadyStarted("Lock refresh thread already started")
265
266
        logger.debug(
267
            "Starting thread to refresh lock every %s seconds",
268
            self._lock_renewal_interval
269
        )
270
        self._lock_renewal_stop = threading.Event()
271
        self._lock_renewal_thread = threading.Thread(
272
            group=None,
273
            target=self._lock_renewer,
274
            kwargs={'lockref': weakref.ref(self),
275
                    'interval': self._lock_renewal_interval,
276
                    'stop': self._lock_renewal_stop}
277
        )
278
        self._lock_renewal_thread.setDaemon(True)
279
        self._lock_renewal_thread.start()
280
281
    def _stop_lock_renewer(self):
282
        """
283
        Stop the lock renewer.
284
285
        This signals the renewal thread and waits for its exit.
286
        """
287
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
288
            return
289
        logger.debug("Signalling the lock refresher to stop")
290
        self._lock_renewal_stop.set()
291
        self._lock_renewal_thread.join()
292
        self._lock_renewal_thread = None
293
        logger.debug("Lock refresher has stopped")
294
295
    def __enter__(self):
296
        acquired = self.acquire(blocking=True)
297
        assert acquired, "Lock wasn't acquired, but blocking=True"
298
        return self
299
300
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
301
        self.release()
302
303
    def release(self):
304
        """Releases the lock, that was acquired with the same object.
305
306
        .. note::
307
308
            If you want to release a lock that you acquired in a different place you have two choices:
309
310
            * Use ``Lock("name", id=id_from_other_place).release()``
311
            * Use ``Lock("name").reset()``
312
        """
313
        if not self._held:
314
            raise NotAcquired("This Lock instance didn't acquire the lock.")
315
        if self._lock_renewal_thread is not None:
316
            self._stop_lock_renewer()
317
        logger.debug("Releasing %r.", self._name)
318
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
319
        if error == 1:
320
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
321
        elif error:
322
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
323
        else:
324
            self._held = False
325
326
    def __del__(self):
327
        if self._lock_renewal_thread is not None:
328
            self._stop_lock_renewer()
329
330
331
def reset_all(redis_client):
332
    """
333
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
334
    """
335
    _eval_script(redis_client, RESET_ALL)
336