Completed
Pull Request — master (#38)
by
unknown
47s
created

src.redis_lock.Lock._delete_signal()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 2
rs 10
1
import threading
2
from logging import getLogger
3
from os import urandom
4
from hashlib import sha1
5
6
from redis import StrictRedis
7
from redis.exceptions import NoScriptError
8
9
__version__ = "3.0.0"
10
11
logger = getLogger(__name__)
12
13
# Check if the id match. If not, return an error code.
14
UNLOCK_SCRIPT = b"""
15
    if redis.call("get", KEYS[1]) ~= ARGV[1] then
16
        return 1
17
    else
18
        redis.call("del", KEYS[2])
19
        redis.call("lpush", KEYS[2], 1)
20
        redis.call("del", KEYS[1])
21
        return 0
22
    end
23
"""
24
UNLOCK_SCRIPT_HASH = sha1(UNLOCK_SCRIPT).hexdigest()
25
26
# Covers both cases when key doesn't exist and doesn't equal to lock's id
27
EXTEND_SCRIPT = b"""
28
    if redis.call("get", KEYS[1]) ~= ARGV[2] then
29
        return 1
30
    elseif redis.call("ttl", KEYS[1]) < 0 then
31
        return 2
32
    else
33
        redis.call("expire", KEYS[1], ARGV[1])
34
        return 0
35
    end
36
"""
37
EXTEND_SCRIPT_HASH = sha1(EXTEND_SCRIPT).hexdigest()
38
39
RESET_SCRIPT = b"""
40
    redis.call('del', KEYS[2])
41
    redis.call('lpush', KEYS[2], 1)
42
    return redis.call('del', KEYS[1])
43
"""
44
45
RESET_SCRIPT_HASH = sha1(RESET_SCRIPT).hexdigest()
46
47
RESET_ALL_SCRIPT = b"""
48
    local locks = redis.call('keys', 'lock:*')
49
    local signal
50
    for _, lock in pairs(locks) do
51
        signal = 'lock-signal:' .. string.sub(lock, 6)
52
        redis.call('del', signal)
53
        redis.call('lpush', signal, 1)
54
        redis.call('del', lock)
55
    end
56
    return #locks
57
"""
58
59
RESET_ALL_SCRIPT_HASH = sha1(RESET_ALL_SCRIPT).hexdigest()
60
61
DELETE_ALL_SIGNAL_KEYS_SCRIPT = b"""
62
    local signals = redis.call('keys', 'lock-signal:*')
63
    for _, signal in pairs(signals) do
64
        redis.call('del', signal)
65
    end
66
    return #signals
67
"""
68
69
DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH = (sha1(DELETE_ALL_SIGNAL_KEYS_SCRIPT)
70
                                          .hexdigest())
71
72
73
class AlreadyAcquired(RuntimeError):
74
    pass
75
76
77
class NotAcquired(RuntimeError):
78
    pass
79
80
81
class AlreadyStarted(RuntimeError):
82
    pass
83
84
85
class TimeoutNotUsable(RuntimeError):
86
    pass
87
88
89
class InvalidTimeout(RuntimeError):
90
    pass
91
92
93
class TimeoutTooLarge(RuntimeError):
94
    pass
95
96
97
class NotExpirable(RuntimeError):
98
    pass
99
100
101
((UNLOCK, _, _,   # noqa
102
  EXTEND, _, _,
103
  RESET, _, _,
104
  RESET_ALL, _, _,
105
  DELETE_ALL_SIGNAL_KEYS, _, _),
106
 SCRIPTS) = zip(*enumerate([
107
    UNLOCK_SCRIPT_HASH, UNLOCK_SCRIPT, 'UNLOCK_SCRIPT',
108
    EXTEND_SCRIPT_HASH, EXTEND_SCRIPT, 'EXTEND_SCRIPT',
109
    RESET_SCRIPT_HASH, RESET_SCRIPT, 'RESET_SCRIPT',
110
    RESET_ALL_SCRIPT_HASH, RESET_ALL_SCRIPT, 'RESET_ALL_SCRIPT',
111
    DELETE_ALL_SIGNAL_KEYS_SCRIPT_HASH, DELETE_ALL_SIGNAL_KEYS_SCRIPT,
112
    'DELETE_ALL_SIGNAL_KEYS_SCRIPT'
113
]))
114
115
116
def _eval_script(redis, script_id, *keys, **kwargs):
117
    """Tries to call ``EVALSHA`` with the `hash` and then, if it fails, calls
118
    regular ``EVAL`` with the `script`.
119
    """
120
    args = kwargs.pop('args', ())
121
    if kwargs:
122
        raise TypeError("Unexpected keyword arguments %s" % kwargs.keys())
123
    try:
124
        return redis.evalsha(SCRIPTS[script_id], len(keys), *keys + args)
125
    except NoScriptError:
126
        logger.warn("%s not cached.", SCRIPTS[script_id + 2])
127
        return redis.eval(SCRIPTS[script_id + 1], len(keys), *keys + args)
128
129
130
class Lock(object):
131
    """
132
    A Lock context manager implemented via redis SETNX/BLPOP.
133
    """
134
135
    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False):
136
        """
137
        :param redis_client:
138
            An instance of :class:`~StrictRedis`.
139
        :param name:
140
            The name (redis key) the lock should have.
141
        :param expire:
142
            The lock expiry time in seconds. If left at the default (None)
143
            the lock will not expire.
144
        :param id:
145
            The ID (redis value) the lock should have. A random value is
146
            generated when left at the default.
147
148
            Note that if you specify this then the lock is marked as "held". Acquires
149
            won't be possible.
150
        :param auto_renewal:
151
            If set to True, Lock will automatically renew the lock so that it
152
            doesn't expire for as long as the lock is held (acquire() called
153
            or running in a context manager).
154
155
            Implementation note: Renewal will happen using a daemon thread with
156
            an interval of expire*2/3. If wishing to use a different renewal
157
            time, subclass Lock, call super().__init__() then set
158
            self._lock_renewal_interval to your desired interval.
159
        """
160
        assert isinstance(redis_client, StrictRedis)
161
        if auto_renewal and expire is None:
162
            raise ValueError("Expire may not be None when auto_renewal is set")
163
164
        self._client = redis_client
165
        self._expire = expire if expire is None else int(expire)
166
        self._id = urandom(16) if id is None else id
167
        self._held = id is not None
168
        self._name = 'lock:'+name
169
        self._signal = 'lock-signal:'+name
170
        self._lock_renewal_interval = expire*2/3 if auto_renewal else None
171
        self._lock_renewal_thread = None
172
173
    def reset(self):
174
        """
175
        Forcibly deletes the lock. Use this with care.
176
        """
177
        _eval_script(self._client, RESET, self._name, self._signal)
178
        self._delete_signal()
179
180
    @property
181
    def id(self):
182
        return self._id
183
184
    def get_owner_id(self):
185
        return self._client.get(self._name)
186
187
    def acquire(self, blocking=True, timeout=None):
188
        """
189
        :param blocking:
190
            Boolean value specifying whether lock should be blocking or not.
191
        :param timeout:
192
            An integer value specifying the maximum number of seconds to block.
193
        """
194
        logger.debug("Getting %r ...", self._name)
195
196
        if self._held:
197
            raise AlreadyAcquired("Already acquired from this Lock instance.")
198
199
        if not blocking and timeout is not None:
200
            raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
201
202
        timeout = timeout if timeout is None else int(timeout)
203
        if timeout is not None and timeout <= 0:
204
            raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
205
206
        if timeout and self._expire and timeout > self._expire:
207
            raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
208
209
        busy = True
210
        blpop_timeout = timeout or self._expire or 0
211
        timed_out = False
212
        while busy:
213
            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
214
            if busy:
215
                if timed_out:
216
                    return False
217
                elif blocking:
218
                    timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
219
                else:
220
                    logger.debug("Failed to get %r.", self._name)
221
                    return False
222
223
        logger.debug("Got lock for %r.", self._name)
224
        self._held = True
225
        if self._lock_renewal_interval is not None:
226
            self._start_lock_renewer()
227
        return True
228
229
    def extend(self, expire=None):
230
        """Extends expiration time of the lock.
231
232
        :param expire:
233
            New expiration time. If ``None`` - `expire` provided during
234
            lock initialization will be taken.
235
        """
236
        if expire is None:
237
            if self._expire is not None:
238
                expire = self._expire
239
            else:
240
                raise TypeError(
241
                    "To extend a lock 'expire' must be provided as an "
242
                    "argument to extend() method or at initialization time."
243
                )
244
        error = _eval_script(self._client, EXTEND, self._name, args=(expire, self._id))
245
        if error == 1:
246
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
247
        elif error == 2:
248
            raise NotExpirable("Lock %s has no assigned expiration time" %
249
                               self._name)
250
        elif error:
251
            raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
252
253
    def _lock_renewer(self, interval):
254
        """
255
        Renew the lock key in redis every `interval` seconds for as long
256
        as `self._lock_renewal_thread.should_exit` is False.
257
        """
258
        log = getLogger("%s.lock_refresher" % __name__)
259
        while not self._lock_renewal_thread.wait_for_exit_request(timeout=interval):
260
            log.debug("Refreshing lock")
261
            self.extend(expire=self._expire)
262
        log.debug("Exit requested, stopping lock refreshing")
263
264
    def _start_lock_renewer(self):
265
        """
266
        Starts the lock refresher thread.
267
        """
268
        if self._lock_renewal_thread is not None:
269
            raise AlreadyStarted("Lock refresh thread already started")
270
271
        logger.debug(
272
            "Starting thread to refresh lock every %s seconds",
273
            self._lock_renewal_interval
274
        )
275
        self._lock_renewal_thread = InterruptableThread(
276
            group=None,
277
            target=self._lock_renewer,
278
            kwargs={'interval': self._lock_renewal_interval}
279
        )
280
        self._lock_renewal_thread.setDaemon(True)
281
        self._lock_renewal_thread.start()
282
283
    def _stop_lock_renewer(self):
284
        """
285
        Stop the lock renewer.
286
287
        This signals the renewal thread and waits for its exit.
288
        """
289
        if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
290
            return
291
        logger.debug("Signalling the lock refresher to stop")
292
        self._lock_renewal_thread.request_exit()
293
        self._lock_renewal_thread.join()
294
        self._lock_renewal_thread = None
295
        logger.debug("Lock refresher has stopped")
296
297
    def __enter__(self):
298
        acquired = self.acquire(blocking=True)
299
        assert acquired, "Lock wasn't acquired, but blocking=True"
300
        return self
301
302
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
303
        self.release()
304
305
    def release(self):
306
        """Releases the lock, that was acquired with the same object.
307
308
        .. note::
309
310
            If you want to release a lock that you acquired in a different place you have two choices:
311
312
            * Use ``Lock("name", id=id_from_other_place).release()``
313
            * Use ``Lock("name").reset()``
314
        """
315
        if not self._held:
316
            raise NotAcquired("This Lock instance didn't acquire the lock.")
317
        if self._lock_renewal_thread is not None:
318
            self._stop_lock_renewer()
319
        logger.debug("Releasing %r.", self._name)
320
        error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
321
        if error == 1:
322
            raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
323
        elif error:
324
            raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
325
        else:
326
            self._held = False
327
            self._delete_signal()
328
329
    def _delete_signal(self):
330
        self._client.delete(self._signal)
331
332
333
class InterruptableThread(threading.Thread):
334
    """
335
    A Python thread that can be requested to stop by calling request_exit()
336
    on it.
337
338
    Code running inside this thread should periodically check the
339
    `should_exit` property (or use wait_for_exit_request) on the thread
340
    object and stop further processing once it returns True.
341
    """
342
    def __init__(self, *args, **kwargs):
343
        self._should_exit = threading.Event()
344
        super(InterruptableThread, self).__init__(*args, **kwargs)
345
346
    def request_exit(self):
347
        """
348
        Signal the thread that it should stop performing more work and exit.
349
        """
350
        self._should_exit.set()
351
352
    @property
353
    def should_exit(self):
354
        return self._should_exit.isSet()
355
356
    def wait_for_exit_request(self, timeout=None):
357
        """
358
        Wait until the thread has been signalled to exit.
359
360
        If timeout is specified (as a float of seconds to wait) then wait
361
        up to this many seconds before returning the value of `should_exit`.
362
        """
363
        should_exit = self._should_exit.wait(timeout)
364
        if should_exit is None:
365
            # Python 2.6 compatibility which doesn't return self.__flag when
366
            # calling Event.wait()
367
            should_exit = self.should_exit
368
        return should_exit
369
370
371
def reset_all(redis_client):
372
    """
373
    Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
374
    """
375
    _eval_script(redis_client, RESET_ALL)
376
    _eval_script(redis_client, DELETE_ALL_SIGNAL_KEYS)
377