Completed
Push — master ( ead540...a8df81 )
by Ionel Cristian
9s
created

src.redis_lock.InterruptableThread   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 36
Duplicated Lines 0 %
Metric Value
dl 0
loc 36
rs 10
wmc 5

1 Method

Rating   Name   Duplication   Size   Complexity  
A src.redis_lock.reset_all() 0 5 1
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
import weakref
6
7
from redis import StrictRedis
8
from redis.exceptions import NoScriptError
9
10
__version__ = "3.0.0"
11
12
logger = getLogger(__name__)
13
14
# Check if the id match. If not, return an error code.
15
UNLOCK_SCRIPT = b"""
16
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
17
        return 1
18
    else
19
        redis.call("del", KEYS[2])
20
        redis.call("lpush", KEYS[2], 1)
21
        redis.call("expire", KEYS[2], 1)
22
        redis.call("del", KEYS[1])
23
        return 0
24
    end
25
"""
26
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
27
28
# Covers both cases when key doesn't exist and doesn't equal to lock's id
29
EXTEND_SCRIPT = b"""
30
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
31
        return 1
32
    elseif redis.call("ttl", KEYS[1]) < 0 then
33
        return 2
34
    else
35
        redis.call("expire", KEYS[1], ARGV[1])
36
        return 0
37
    end
38
"""
39
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
40
41
RESET_SCRIPT = b"""
42
    redis.call('del', KEYS[2])
43
    redis.call('lpush', KEYS[2], 1)
44
    redis.call('expire', KEYS[2], 1)
45
    return redis.call('del', KEYS[1])
46
"""
47
48
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
49
50
RESET_ALL_SCRIPT = b"""
51
    local locks = redis.call('keys', 'lock:*')
52
    local signal
53
    for _, lock in pairs(locks) do
54
        signal = 'lock-signal:' .. string.sub(lock, 6)
55
        redis.call('del', signal)
56
        redis.call('lpush', signal, 1)
57
        redis.call('expire', signal, 1)
58
        redis.call('del', lock)
59
    end
60
    return #locks
61
"""
62
63
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
64
65
66
class AlreadyAcquired(RuntimeError):
67
    pass
68
69
70
class NotAcquired(RuntimeError):
71
    pass
72
73
74
class AlreadyStarted(RuntimeError):
75
    pass
76
77
78
class TimeoutNotUsable(RuntimeError):
79
    pass
80
81
82
class InvalidTimeout(RuntimeError):
83
    pass
84
85
86
class TimeoutTooLarge(RuntimeError):
87
    pass
88
89
90
class NotExpirable(RuntimeError):
91
    pass
92
93
94
((UNLOCK, _, _,   # noqa
95
  EXTEND, _, _,
96
  RESET, _, _,
97
  RESET_ALL, _, _),
98
 SCRIPTS) = zip(*enumerate([
99
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
100
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
101
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
102
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT'
103
]))
104
105
106
def _eval_script(redis, script_id, *keys, **kwargs):
107
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
108
    regular ``EVAL`` with the `script`.
109
    """
110
    args = kwargs.pop('args', ())
111
    if kwargs:
112
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
113
    try:
114
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
115
    except NoScriptError:
116
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
117
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
118
119
120
class Lock(object):
121
    """
122
    A Lock context manager implemented via redis SETNX/BLPOP.
123
    """
124
125
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
126
        """
127
        :param redis_client:
128
            An instance of :class:`~StrictRedis`.
129
        :param name:
130
            The name (redis key) the lock should have.
131
        :param expire:
132
            The lock expiry time in seconds. If left at the default (None)
133
            the lock will not expire.
134
        :param id:
135
            The ID (redis value) the lock should have. A random value is
136
            generated when left at the default.
137
138
            Note that if you specify this then the lock is marked as "held". Acquires
139
            won't be possible.
140
        :param auto_renewal:
141
            If set to True, Lock will automatically renew the lock so that it
142
            doesn't expire for as long as the lock is held (acquire() called
143
            or running in a context manager).
144
145
            Implementation note: Renewal will happen using a daemon thread with
146
            an interval of expire*2/3. If wishing to use a different renewal
147
            time, subclass Lock, call super().__init__() then set
148
            self._lock_renewal_interval to your desired interval.
149
        """
150
        assert isinstance(redis_client, StrictRedis)
151
        if auto_renewal and expire is None:
152
            raise ValueError("Expire may not be None when auto_renewal is set")
153
154
        self._client = redis_client
155
        self._expire = expire if expire is None else int(expire)
156
        self._id = urandom(16) if id is None else id
157
        self._held = id is not None
158
        self._name = 'lock:'+name
159
        self._signal = 'lock-signal:'+name
160
        self._lock_renewal_interval = (float(expire)*2/3
161
                                       if auto_renewal
162
                                       else None)
163
        self._lock_renewal_thread = None
164
165
    def reset(self):
166
        """
167
        Forcibly deletes the lock. Use this with care.
168
        """
169
        _eval_script(self._client, RESET, self._name, self._signal)
170
171
    @property
172
    def id(self):
173
        return self._id
174
175
    def get_owner_id(self):
176
        return self._client.get(self._name)
177
178
    def acquire(self, blocking=True, timeout=None):
179
        """
180
        :param blocking:
181
            Boolean value specifying whether lock should be blocking or not.
182
        :param timeout:
183
            An integer value specifying the maximum number of seconds to block.
184
        """
185
        logger.debug("Getting %r ...", self._name)
186
187
        if self._held:
188
            raise AlreadyAcquired("Already acquired from this Lock instance.")
189
190
        if not blocking and timeout is not None:
191
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
192
193
        timeout = timeout if timeout is None else int(timeout)
194
        if timeout is not None and timeout <= 0:
195
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
196
197
        if timeout and self._expire and timeout > self._expire:
198
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
199
200
        busy = True
201
        blpop_timeout = timeout or self._expire or 0
202
        timed_out = False
203
        while busy:
204
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
205
            if busy:
206
                if timed_out:
207
                    return False
208
                elif blocking:
209
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
210
                else:
211
                    logger.debug("Failed to get %r.", self._name)
212
                    return False
213
214
        logger.debug("Got lock for %r.", self._name)
215
        self._held = True
216
        if self._lock_renewal_interval is not None:
217
            self._start_lock_renewer()
218
        return True
219
220
    def extend(self, expire=None):
221
        """Extends expiration time of the lock.
222
223
        :param expire:
224
            New expiration time. If ``None`` - `expire` provided during
225
            lock initialization will be taken.
226
        """
227
        if expire is None:
228
            if self._expire is not None:
229
                expire = self._expire
230
            else:
231
                raise TypeError(
232
                    "To extend a lock 'expire' must be provided as an "
233
                    "argument to extend() method or at initialization time."
234
                )
235
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
236
        if error == 1:
237
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
238
        elif error == 2:
239
            raise NotExpirable("Lock %s has no assigned expiration time" %
240
                               self._name)
241
        elif error:
242
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
243
244
    @staticmethod
245
    def _lock_renewer(lockref, interval, stop):
246
        """
247
        Renew the lock key in redis every `interval` seconds for as long
248
        as `self._lock_renewal_thread.should_exit` is False.
249
        """
250
        log = getLogger("%s.lock_refresher" % __name__)
251
        while not stop.wait(timeout=interval):
252
            log.debug("Refreshing lock")
253
            lock = lockref()
254
            if lock is None:
255
                log.debug("The lock no longer exists, "
256
                          "stopping lock refreshing")
257
                break
258
            lock.extend(expire=lock._expire)
259
            del lock
260
        log.debug("Exit requested, stopping lock refreshing")
261
262
    def _start_lock_renewer(self):
263
        """
264
        Starts the lock refresher thread.
265
        """
266
        if self._lock_renewal_thread is not None:
267
            raise AlreadyStarted("Lock refresh thread already started")
268
269
        logger.debug(
270
            "Starting thread to refresh lock every %s seconds",
271
            self._lock_renewal_interval
272
        )
273
        self._lock_renewal_stop = threading.Event()
274
        self._lock_renewal_thread = threading.Thread(
275
            group=None,
276
            target=self._lock_renewer,
277
            kwargs={'lockref': weakref.ref(self),
278
                    'interval': self._lock_renewal_interval,
279
                    'stop': self._lock_renewal_stop}
280
        )
281
        self._lock_renewal_thread.setDaemon(True)
282
        self._lock_renewal_thread.start()
283
284
    def _stop_lock_renewer(self):
285
        """
286
        Stop the lock renewer.
287
288
        This signals the renewal thread and waits for its exit.
289
        """
290
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
291
            return
292
        logger.debug("Signalling the lock refresher to stop")
293
        self._lock_renewal_stop.set()
294
        self._lock_renewal_thread.join()
295
        self._lock_renewal_thread = None
296
        logger.debug("Lock refresher has stopped")
297
298
    def __enter__(self):
299
        acquired = self.acquire(blocking=True)
300
        assert acquired, "Lock wasn't acquired, but blocking=True"
301
        return self
302
303
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
304
        self.release()
305
306
    def release(self):
307
        """Releases the lock, that was acquired with the same object.
308
309
        .. note::
310
311
            If you want to release a lock that you acquired in a different place you have two choices:
312
313
            * Use ``Lock("name", id=id_from_other_place).release()``
314
            * Use ``Lock("name").reset()``
315
        """
316
        if not self._held:
317
            raise NotAcquired("This Lock instance didn't acquire the lock.")
318
        if self._lock_renewal_thread is not None:
319
            self._stop_lock_renewer()
320
        logger.debug("Releasing %r.", self._name)
321
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
322
        if error == 1:
323
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
324
        elif error:
325
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
326
        else:
327
            self._held = False
328
329
330
def reset_all(redis_client):
331
    """
332
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
333
    """
334
    _eval_script(redis_client, RESET_ALL)
335