Completed
Pull Request — master (#858)
by Eddie
02:03
created

zipline.utils.TradingDayOfWeekRule   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 46
Duplicated Lines 0 %
Metric Value
dl 0
loc 46
rs 10
wmc 9

4 Methods

Rating   Name   Duplication   Size   Complexity  
B should_trigger() 0 17 5
A date_func() 0 3 1
A __init__() 0 9 2
A calculate_start_and_end() 0 13 1
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from abc import ABCMeta, abstractmethod
16
from collections import namedtuple
17
import six
18
19
import datetime
20
import pandas as pd
21
import pytz
22
23
from .context_tricks import nop_context
24
25
26
__all__ = [
27
    'EventManager',
28
    'Event',
29
    'EventRule',
30
    'StatelessRule',
31
    'ComposedRule',
32
    'Always',
33
    'Never',
34
    'AfterOpen',
35
    'BeforeClose',
36
    'NotHalfDay',
37
    'NthTradingDayOfWeek',
38
    'NDaysBeforeLastTradingDayOfWeek',
39
    'NthTradingDayOfMonth',
40
    'NDaysBeforeLastTradingDayOfMonth',
41
    'StatefulRule',
42
    'OncePerDay',
43
44
    # Factory API
45
    'DateRuleFactory',
46
    'TimeRuleFactory',
47
    'date_rules',
48
    'time_rules',
49
    'make_eventrule',
50
]
51
52
53
MAX_MONTH_RANGE = 26
54
MAX_WEEK_RANGE = 5
55
56
57
def naive_to_utc(ts):
58
    """
59
    Converts a UTC tz-naive timestamp to a tz-aware timestamp.
60
    """
61
    # Drop the nanoseconds field. warn=False suppresses the warning
62
    # that we are losing the nanoseconds; however, this is intended.
63
    return pd.Timestamp(ts.to_pydatetime(warn=False), tz='UTC')
64
65
66
def ensure_utc(time, tz='UTC'):
67
    """
68
    Normalize a time. If the time is tz-naive, assume it is UTC.
69
    """
70
    if not time.tzinfo:
71
        time = time.replace(tzinfo=pytz.timezone(tz))
72
    return time.replace(tzinfo=pytz.utc)
73
74
75
def _coerce_datetime(maybe_dt):
76
    if isinstance(maybe_dt, datetime.datetime):
77
        return maybe_dt
78
    elif isinstance(maybe_dt, datetime.date):
79
        return datetime.datetime(
80
            year=maybe_dt.year,
81
            month=maybe_dt.month,
82
            day=maybe_dt.day,
83
            tzinfo=pytz.utc,
84
        )
85
    elif isinstance(maybe_dt, (tuple, list)) and len(maybe_dt) == 3:
86
        year, month, day = maybe_dt
87
        return datetime.datetime(
88
            year=year,
89
            month=month,
90
            day=day,
91
            tzinfo=pytz.utc,
92
        )
93
    else:
94
        raise TypeError('Cannot coerce %s into a datetime.datetime'
95
                        % type(maybe_dt).__name__)
96
97
98
def _out_of_range_error(a, b=None, var='offset'):
99
    start = 0
100
    if b is None:
101
        end = a - 1
102
    else:
103
        start = a
104
        end = b - 1
105
    return ValueError(
106
        '{var} must be in between {start} and {end} inclusive'.format(
107
            var=var,
108
            start=start,
109
            end=end,
110
        )
111
    )
112
113
114
def _td_check(td):
115
    seconds = td.total_seconds()
116
117
    # 23400 seconds is 6 hours and 30 minutes.
118
    if 60 <= seconds <= 23400:
119
        return td
120
    else:
121
        raise ValueError('offset must be in between 1 minute and 6 hours and'
122
                         ' 30 minutes inclusive')
123
124
125
def _build_offset(offset, kwargs, default):
126
    """
127
    Builds the offset argument for event rules.
128
    """
129
    if offset is None:
130
        if not kwargs:
131
            return default  # use the default.
132
        else:
133
            return _td_check(datetime.timedelta(**kwargs))
134
    elif kwargs:
135
        raise ValueError('Cannot pass kwargs and an offset')
136
    elif isinstance(offset, datetime.timedelta):
137
        return _td_check(offset)
138
    else:
139
        raise TypeError("Must pass 'hours' and/or 'minutes' as keywords")
140
141
142
def _build_date(date, kwargs):
143
    """
144
    Builds the date argument for event rules.
145
    """
146
    if date is None:
147
        if not kwargs:
148
            raise ValueError('Must pass a date or kwargs')
149
        else:
150
            return datetime.date(**kwargs)
151
152
    elif kwargs:
153
        raise ValueError('Cannot pass kwargs and a date')
154
    else:
155
        return date
156
157
158
def _build_time(time, kwargs):
159
    """
160
    Builds the time argument for event rules.
161
    """
162
    tz = kwargs.pop('tz', 'UTC')
163
    if time:
164
        if kwargs:
165
            raise ValueError('Cannot pass kwargs and a time')
166
        else:
167
            return ensure_utc(time, tz)
168
    elif not kwargs:
169
        raise ValueError('Must pass a time or kwargs')
170
    else:
171
        return datetime.time(**kwargs)
172
173
174
class EventManager(object):
175
    """Manages a list of Event objects.
176
    This manages the logic for checking the rules and dispatching to the
177
    handle_data function of the Events.
178
179
    Parameters
180
    ----------
181
    create_context : (BarData) -> context manager, optional
182
        An optional callback to produce a context manager to wrap the calls
183
        to handle_data. This will be passed the current BarData.
184
    """
185
    def __init__(self, create_context=None):
186
        self._events = []
187
        self._create_context = (
188
            create_context
189
            if create_context is not None else
190
            lambda *_: nop_context
191
        )
192
193
    def add_event(self, event, prepend=False):
194
        """
195
        Adds an event to the manager.
196
        """
197
        if prepend:
198
            self._events.insert(0, event)
199
        else:
200
            self._events.append(event)
201
202
    def handle_data(self, context, data, dt):
203
        with self._create_context(data):
204
            for event in self._events:
205
                event.handle_data(
206
                    context,
207
                    data,
208
                    dt,
209
                    context.trading_environment,
210
                )
211
212
213
class Event(namedtuple('Event', ['rule', 'callback'])):
214
    """
215
    An event is a pairing of an EventRule and a callable that will be invoked
216
    with the current algorithm context, data, and datetime only when the rule
217
    is triggered.
218
    """
219
    def __new__(cls, rule=None, callback=None):
220
        callback = callback or (lambda *args, **kwargs: None)
221
        return super(cls, cls).__new__(cls, rule=rule, callback=callback)
222
223
    def handle_data(self, context, data, dt, env):
224
        """
225
        Calls the callable only when the rule is triggered.
226
        """
227
        if self.rule.should_trigger(dt, env):
228
            self.callback(context, data)
229
230
231
class EventRule(six.with_metaclass(ABCMeta)):
232
    @abstractmethod
233
    def should_trigger(self, dt, env):
234
        """
235
        Checks if the rule should trigger with its current state.
236
        This method should be pure and NOT mutate any state on the object.
237
        """
238
        raise NotImplementedError('should_trigger')
239
240
241
class StatelessRule(EventRule):
242
    """
243
    A stateless rule has no state.
244
    This is reentrant and will always give the same result for the
245
    same datetime.
246
    Because these are pure, they can be composed to create new rules.
247
    """
248
    def and_(self, rule):
249
        """
250
        Logical and of two rules, triggers only when both rules trigger.
251
        This follows the short circuiting rules for normal and.
252
        """
253
        return ComposedRule(self, rule, ComposedRule.lazy_and)
254
    __and__ = and_
255
256
257
class ComposedRule(StatelessRule):
258
    """
259
    A rule that composes the results of two rules with some composing function.
260
    The composing function should be a binary function that accepts the results
261
    first(dt) and second(dt) as positional arguments.
262
    For example, operator.and_.
263
    If lazy=True, then the lazy composer is used instead. The lazy composer
264
    expects a function that takes the two should_trigger functions and the
265
    datetime. This is useful of you don't always want to call should_trigger
266
    for one of the rules. For example, this is used to implement the & and |
267
    operators so that they will have the same short circuit logic that is
268
    expected.
269
    """
270
    def __init__(self, first, second, composer):
271
        if not (isinstance(first, StatelessRule) and
272
                isinstance(second, StatelessRule)):
273
            raise ValueError('Only two StatelessRules can be composed')
274
275
        self.first = first
276
        self.second = second
277
        self.composer = composer
278
279
    def should_trigger(self, dt, env):
280
        """
281
        Composes the two rules with a lazy composer.
282
        """
283
        return self.composer(
284
            self.first.should_trigger,
285
            self.second.should_trigger,
286
            dt,
287
            env
288
        )
289
290
    @staticmethod
291
    def lazy_and(first_should_trigger, second_should_trigger, dt, env):
292
        """
293
        Lazily ands the two rules. This will NOT call the should_trigger of the
294
        second rule if the first one returns False.
295
        """
296
        return first_should_trigger(dt, env) and second_should_trigger(dt, env)
297
298
299
class Always(StatelessRule):
300
    """
301
    A rule that always triggers.
302
    """
303
    @staticmethod
304
    def always_trigger(dt, env):
305
        """
306
        A should_trigger implementation that will always trigger.
307
        """
308
        return True
309
    should_trigger = always_trigger
310
311
312
class Never(StatelessRule):
313
    """
314
    A rule that never triggers.
315
    """
316
    @staticmethod
317
    def never_trigger(dt, env):
318
        """
319
        A should_trigger implementation that will never trigger.
320
        """
321
        return False
322
    should_trigger = never_trigger
323
324
325
class AfterOpen(StatelessRule):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
326
    """
327
    A rule that triggers for some offset after the market opens.
328
    Example that triggers after 30 minutes of the market opening:
329
330
    >>> AfterOpen(minutes=30)
331
    """
332
    def __init__(self, offset=None, **kwargs):
333
        self.offset = _build_offset(
334
            offset,
335
            kwargs,
336
            datetime.timedelta(minutes=1),  # Defaults to the first minute.
337
        )
338
339
        self._next_period_start = None
340
        self._next_period_end = None
341
342
        self._one_minute = datetime.timedelta(minutes=1)
343
344
    def calculate_dates(self, dt, env):
345
        # given a dt, find that day's open and period end (open + offset)
346
        self._next_period_start = env.get_open_and_close(dt)[0]
347
        self._next_period_end = \
348
            self._next_period_start + self.offset - self._one_minute
349
350
    def should_trigger(self, dt, env):
351
        if self._next_period_start is None:
352
            self.calculate_dates(dt, env)
353
354
        if self._next_period_start <= dt < self._next_period_end:
355
            # haven't made it past the offset yet
356
            return False
357
        else:
358
            if dt >= self._next_period_end:
359
                self.calculate_dates(env.next_trading_day(dt), env)
360
361
            return True
362
363
364
class BeforeClose(StatelessRule):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
365
    """
366
    A rule that triggers for some offset time before the market closes.
367
    Example that triggers for the last 30 minutes every day:
368
369
    >>> BeforeClose(minutes=30)
370
    """
371
    def __init__(self, offset=None, **kwargs):
372
        self.offset = _build_offset(
373
            offset,
374
            kwargs,
375
            datetime.timedelta(minutes=1),  # Defaults to the last minute.
376
        )
377
378
        self._next_period_start = None
379
        self._next_period_end = None
380
381
        self._one_minute = datetime.timedelta(minutes=1)
382
383
    def calculate_dates(self, dt, env):
384
        # given a dt, find that day's close and period start (close - offset)
385
        self._next_period_end = env.get_open_and_close(dt)[1]
386
        self._next_period_start = \
387
            self._next_period_end - self.offset - self._one_minute
388
389
    def should_trigger(self, dt, env):
390
        if self._next_period_start is None:
391
            self.calculate_dates(dt, env)
392
393
        if dt <= self._next_period_start:
394
            return False
395
        else:
396
            if dt > self._next_period_end:
397
                self.calculate_dates(env.next_trading_day(dt), env)
398
399
            return True
400
401
402
class NotHalfDay(StatelessRule):
403
    """
404
    A rule that only triggers when it is not a half day.
405
    """
406
    def should_trigger(self, dt, env):
407
        return dt.date() not in env.early_closes
408
409
410
class TradingDayOfWeekRule(six.with_metaclass(ABCMeta, StatelessRule)):
411
    def __init__(self, n=0):
412
        if not 0 <= abs(n) < MAX_WEEK_RANGE:
413
            raise _out_of_range_error(MAX_WEEK_RANGE)
414
415
        self.td_delta = n
416
417
        self.next_date_start = None
418
        self.next_date_end = None
419
        self.next_midnight_timestamp = None
420
421
    @abstractmethod
422
    def date_func(self, dt, env):
423
        raise NotImplementedError
424
425
    def calculate_start_and_end(self, dt, env):
426
        next_trading_day = _coerce_datetime(
427
            env.add_trading_days(
428
                self.td_delta,
429
                self.date_func(dt, env),
430
            )
431
        )
432
433
        next_open, next_close = env.get_open_and_close(next_trading_day)
434
        self.next_date_start = next_open
435
        self.next_date_end = next_close
436
        self.next_midnight_timestamp = \
437
            pd.Timestamp(next_trading_day.date(), tz='UTC')
438
439
    def should_trigger(self, dt, env):
440
        if self.next_date_start is None:
441
            # first time this method has been called.  calculate the midnight,
442
            # open, and close of the next matching day.
443
            self.calculate_start_and_end(dt, env)
444
445
        # if the next matching day is in the past, calculate the next one.
446
        if dt > self.next_date_end:
447
            self.calculate_start_and_end(dt + datetime.timedelta(days=7),
448
                                         env)
449
450
        # if the given dt is within the next matching day, return true.
451
        if self.next_date_start <= dt <= self.next_date_end or \
452
                dt == self.next_midnight_timestamp:
453
            return True
454
455
        return False
456
457
458
class NthTradingDayOfWeek(TradingDayOfWeekRule):
459
    """
460
    A rule that triggers on the nth trading day of the week.
461
    This is zero-indexed, n=0 is the first trading day of the week.
462
    """
463
    @staticmethod
464
    def get_first_trading_day_of_week(dt, env):
465
        prev = dt
466
        dt = env.previous_trading_day(dt)
467
        while dt.date().weekday() < prev.date().weekday():
468
            prev = dt
469
            dt = env.previous_trading_day(dt)
470
        return prev.date()
471
472
    date_func = get_first_trading_day_of_week
473
474
475
class NDaysBeforeLastTradingDayOfWeek(TradingDayOfWeekRule):
476
    """
477
    A rule that triggers n days before the last trading day of the week.
478
    """
479
    def __init__(self, n):
480
        super(NDaysBeforeLastTradingDayOfWeek, self).__init__(-n)
481
482
    @staticmethod
483
    def get_last_trading_day_of_week(dt, env):
484
        prev = dt
485
        dt = env.next_trading_day(dt)
486
        # Traverse forward until we hit a week border, then jump back to the
487
        # previous trading day.
488
        while dt.date().weekday() > prev.date().weekday():
489
            prev = dt
490
            dt = env.next_trading_day(dt)
491
        return prev.date()
492
493
    date_func = get_last_trading_day_of_week
494
495
496
class NthTradingDayOfMonth(StatelessRule):
497
    """
498
    A rule that triggers on the nth trading day of the month.
499
    This is zero-indexed, n=0 is the first trading day of the month.
500
    """
501
    def __init__(self, n=0):
502
        if not 0 <= n < MAX_MONTH_RANGE:
503
            raise _out_of_range_error(MAX_MONTH_RANGE)
504
        self.td_delta = n
505
        self.month = None
506
        self.day = None
507
508
    def should_trigger(self, dt, env):
509
        return self.get_nth_trading_day_of_month(dt, env) == dt.date()
510
511
    def get_nth_trading_day_of_month(self, dt, env):
512
        if self.month == dt.month:
513
            # We already computed the day for this month.
514
            return self.day
515
516
        if not self.td_delta:
517
            self.day = self.get_first_trading_day_of_month(dt, env)
518
        else:
519
            self.day = env.add_trading_days(
520
                self.td_delta,
521
                self.get_first_trading_day_of_month(dt, env),
522
            ).date()
523
524
        return self.day
525
526
    def get_first_trading_day_of_month(self, dt, env):
527
        self.month = dt.month
528
529
        dt = dt.replace(day=1)
530
        self.first_day = (dt if env.is_trading_day(dt)
531
                          else env.next_trading_day(dt)).date()
532
        return self.first_day
533
534
535
class NDaysBeforeLastTradingDayOfMonth(StatelessRule):
536
    """
537
    A rule that triggers n days before the last trading day of the month.
538
    """
539
    def __init__(self, n=0):
540
        if not 0 <= n < MAX_MONTH_RANGE:
541
            raise _out_of_range_error(MAX_MONTH_RANGE)
542
        self.td_delta = -n
543
        self.month = None
544
        self.day = None
545
546
    def should_trigger(self, dt, env):
547
        return self.get_nth_to_last_trading_day_of_month(dt, env) == dt.date()
548
549
    def get_nth_to_last_trading_day_of_month(self, dt, env):
550
        if self.month == dt.month:
551
            # We already computed the last day for this month.
552
            return self.day
553
554
        if not self.td_delta:
555
            self.day = self.get_last_trading_day_of_month(dt, env)
556
        else:
557
            self.day = env.add_trading_days(
558
                self.td_delta,
559
                self.get_last_trading_day_of_month(dt, env),
560
            ).date()
561
562
        return self.day
563
564
    def get_last_trading_day_of_month(self, dt, env):
565
        self.month = dt.month
566
567
        if dt.month == 12:
568
            # Roll the year forward and start in January.
569
            year = dt.year + 1
570
            month = 1
571
        else:
572
            # Increment the month in the same year.
573
            year = dt.year
574
            month = dt.month + 1
575
576
        self.last_day = env.previous_trading_day(
577
            dt.replace(year=year, month=month, day=1)
578
        ).date()
579
        return self.last_day
580
581
582
# Stateful rules
583
584
585
class StatefulRule(EventRule):
586
    """
587
    A stateful rule has state.
588
    This rule will give different results for the same datetimes depending
589
    on the internal state that this holds.
590
    StatefulRules wrap other rules as state transformers.
591
    """
592
    def __init__(self, rule=None):
593
        self.rule = rule or Always()
594
595
    def new_should_trigger(self, callable_):
596
        """
597
        Replace the should trigger implementation for the current rule.
598
        """
599
        self.should_trigger = callable_
600
601
602
class OncePerDay(StatefulRule):
603
    def __init__(self, rule=None):
604
        self.triggered = False
605
606
        self.date = None
607
        self.next_date = None
608
609
        super(OncePerDay, self).__init__(rule)
610
611
    def should_trigger(self, dt, env):
612
        if self.date is None or dt >= self.next_date:
613
            # initialize or reset for new date
614
            self.triggered = False
615
            self.date = dt
616
617
            # record the timestamp for the next day, so that we can use it
618
            # to know if we've moved to the next day
619
            self.next_date = dt + pd.Timedelta(1, unit="d")
620
621
        if not self.triggered and self.rule.should_trigger(dt, env):
622
            self.triggered = True
623
            return True
624
625
626
# Factory API
627
628
class DateRuleFactory(object):
629
    every_day = Always
630
631
    @staticmethod
632
    def month_start(days_offset=0):
633
        return NthTradingDayOfMonth(n=days_offset)
634
635
    @staticmethod
636
    def month_end(days_offset=0):
637
        return NDaysBeforeLastTradingDayOfMonth(n=days_offset)
638
639
    @staticmethod
640
    def week_start(days_offset=0):
641
        return NthTradingDayOfWeek(n=days_offset)
642
643
    @staticmethod
644
    def week_end(days_offset=0):
645
        return NDaysBeforeLastTradingDayOfWeek(n=days_offset)
646
647
648
class TimeRuleFactory(object):
649
    market_open = AfterOpen
650
    market_close = BeforeClose
651
652
653
# Convenience aliases.
654
date_rules = DateRuleFactory
655
time_rules = TimeRuleFactory
656
657
658
def make_eventrule(date_rule, time_rule, half_days=True):
659
    """
660
    Constructs an event rule from the factory api.
661
    """
662
    if half_days:
663
        inner_rule = date_rule & time_rule
664
    else:
665
        inner_rule = date_rule & time_rule & NotHalfDay()
666
667
    return OncePerDay(rule=inner_rule)
668