Completed
Pull Request — master (#858)
by Eddie
02:09
created

tests.calculate_results()   C

Complexity

Conditions 7

Size

Total Lines 73

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 7
dl 0
loc 73
rs 5.5062

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#
2
# Copyright 2013 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import division
17
18
from datetime import (
19
    datetime,
20
    timedelta,
21
)
22
import logging
23
24
from testfixtures import TempDirectory
25
import unittest
26
import nose.tools as nt
27
import pytz
28
29
import pandas as pd
30
import numpy as np
31
from six.moves import range, zip
32
33
from zipline.data.us_equity_pricing import (
34
    SQLiteAdjustmentWriter,
35
    SQLiteAdjustmentReader,
36
)
37
import zipline.utils.factory as factory
38
import zipline.finance.performance as perf
39
from zipline.finance.transaction import create_transaction
40
import zipline.utils.math_utils as zp_math
41
42
from zipline.finance.blotter import Order
43
from zipline.finance.commission import PerShare, PerTrade, PerDollar
44
from zipline.finance.trading import TradingEnvironment
45
from zipline.pipeline.loaders.synthetic import NullAdjustmentReader
46
from zipline.utils.factory import create_simulation_parameters
47
from zipline.utils.serialization_utils import (
48
    loads_with_persistent_ids, dumps_with_persistent_ids
49
)
50
import zipline.protocol as zp
51
from zipline.protocol import Event
52
from zipline.utils.test_utils import create_data_portal_from_trade_history
53
54
logger = logging.getLogger('Test Perf Tracking')
55
56
onesec = timedelta(seconds=1)
57
oneday = timedelta(days=1)
58
tradingday = timedelta(hours=6, minutes=30)
59
60
# nose.tools changed name in python 3
61
if not hasattr(nt, 'assert_count_equal'):
62
    nt.assert_count_equal = nt.assert_items_equal
63
64
65
def check_perf_period(pp,
66
                      pt,
67
                      gross_leverage,
68
                      net_leverage,
69
                      long_exposure,
70
                      longs_count,
71
                      short_exposure,
72
                      shorts_count,
73
                      data_portal):
74
75
    pos_stats = pt.stats()
76
    pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
77
    perf_data = pp.to_dict(pos_stats, pp_stats, pt)
78
    np.testing.assert_allclose(
79
        gross_leverage, perf_data['gross_leverage'], rtol=1e-3)
80
    np.testing.assert_allclose(
81
        net_leverage, perf_data['net_leverage'], rtol=1e-3)
82
    np.testing.assert_allclose(
83
        long_exposure, perf_data['long_exposure'], rtol=1e-3)
84
    np.testing.assert_allclose(
85
        longs_count, perf_data['longs_count'], rtol=1e-3)
86
    np.testing.assert_allclose(
87
        short_exposure, perf_data['short_exposure'], rtol=1e-3)
88
    np.testing.assert_allclose(
89
        shorts_count, perf_data['shorts_count'], rtol=1e-3)
90
91
92
def check_account(account,
93
                  settled_cash,
94
                  equity_with_loan,
95
                  total_positions_value,
96
                  regt_equity,
97
                  available_funds,
98
                  excess_liquidity,
99
                  cushion,
100
                  leverage,
101
                  net_leverage,
102
                  net_liquidation):
103
    # this is a long only portfolio that is only partially invested
104
    # so net and gross leverage are equal.
105
106
    np.testing.assert_allclose(settled_cash,
107
                               account['settled_cash'], rtol=1e-3)
108
    np.testing.assert_allclose(equity_with_loan,
109
                               account['equity_with_loan'], rtol=1e-3)
110
    np.testing.assert_allclose(total_positions_value,
111
                               account['total_positions_value'], rtol=1e-3)
112
    np.testing.assert_allclose(regt_equity,
113
                               account['regt_equity'], rtol=1e-3)
114
    np.testing.assert_allclose(available_funds,
115
                               account['available_funds'], rtol=1e-3)
116
    np.testing.assert_allclose(excess_liquidity,
117
                               account['excess_liquidity'], rtol=1e-3)
118
    np.testing.assert_allclose(cushion,
119
                               account['cushion'], rtol=1e-3)
120
    np.testing.assert_allclose(leverage, account['leverage'], rtol=1e-3)
121
    np.testing.assert_allclose(net_leverage,
122
                               account['net_leverage'], rtol=1e-3)
123
    np.testing.assert_allclose(net_liquidation,
124
                               account['net_liquidation'], rtol=1e-3)
125
126
127
def create_txn(sid, dt, price, amount):
128
    """
129
    Create a fake transaction to be filled and processed prior to the execution
130
    of a given trade event.
131
    """
132
    mock_order = Order(dt, sid, amount, id=None)
133
    return create_transaction(sid, dt, mock_order, price, amount)
134
135
136
def benchmark_events_in_range(sim_params, env):
137
    return [
138
        Event({'dt': dt,
139
               'returns': ret,
140
               'type': zp.DATASOURCE_TYPE.BENCHMARK,
141
               # We explicitly rely on the behavior that benchmarks sort before
142
               # any other events.
143
               'source_id': '1Abenchmarks'})
144
        for dt, ret in env.benchmark_returns.iteritems()
145
        if dt.date() >= sim_params.period_start.date() and
146
        dt.date() <= sim_params.period_end.date()
147
    ]
148
149
150
def calculate_results(sim_params,
151
                      env,
152
                      tempdir,
153
                      benchmark_events,
154
                      trade_events,
155
                      adjustment_reader,
156
                      splits=None,
157
                      txns=None,
158
                      commissions=None):
159
    """
160
    Run the given events through a stripped down version of the loop in
161
    AlgorithmSimulator.transform.
162
163
    IMPORTANT NOTE FOR TEST WRITERS/READERS:
164
165
    This loop has some wonky logic for the order of event processing for
166
    datasource types.  This exists mostly to accomodate legacy tests accomodate
167
    existing tests that were making assumptions about how events would be
168
    sorted.
169
170
    In particular:
171
172
        - Dividends passed for a given date are processed PRIOR to any events
173
          for that date.
174
        - Splits passed for a given date are process AFTER any events for that
175
          date.
176
177
    Tests that use this helper should not be considered useful guarantees of
178
    the behavior of AlgorithmSimulator on a stream containing the same events
179
    unless the subgroups have been explicitly re-sorted in this way.
180
    """
181
182
    txns = txns or []
183
    splits = splits or {}
184
    commissions = commissions or {}
185
186
    adjustment_reader = adjustment_reader or NullAdjustmentReader()
187
188
    data_portal = create_data_portal_from_trade_history(
189
        env,
190
        tempdir,
191
        sim_params,
192
        trade_events,
193
    )
194
    data_portal._adjustment_reader = adjustment_reader
195
196
    perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal)
197
198
    results = []
199
200
    for date in sim_params.trading_days:
201
202
        for txn in filter(lambda txn: txn.dt == date, txns):
203
            # Process txns for this date.
204
            perf_tracker.process_transaction(txn)
205
206
        try:
207
            commissions_for_date = commissions[date]
208
            for comm in commissions_for_date:
209
                perf_tracker.process_commission(comm)
210
        except KeyError:
211
            pass
212
213
        try:
214
            splits_for_date = splits[date]
215
            perf_tracker.handle_splits(splits_for_date)
216
        except KeyError:
217
            pass
218
219
        msg = perf_tracker.handle_market_close_daily(date)
220
        msg['account'] = perf_tracker.get_account(date)
221
        results.append(msg)
222
    return results
223
224
225
def check_perf_tracker_serialization(perf_tracker):
226
    scalar_keys = [
227
        'emission_rate',
228
        'txn_count',
229
        'market_open',
230
        'last_close',
231
        'period_start',
232
        'day_count',
233
        'capital_base',
234
        'market_close',
235
        'saved_dt',
236
        'period_end',
237
        'total_days',
238
    ]
239
    p_string = dumps_with_persistent_ids(perf_tracker)
240
241
    test = loads_with_persistent_ids(p_string, env=perf_tracker.env)
242
243
    for k in scalar_keys:
244
        nt.assert_equal(getattr(test, k), getattr(perf_tracker, k), k)
245
246
247
def setup_env_data(env, sim_params, sids):
248
    data = {}
249
    for sid in sids:
250
        data[sid] = {
251
            "start_date": sim_params.trading_days[0],
252
            "end_date": sim_params.trading_days[-1]
253
        }
254
255
    env.write_data(equities_data=data)
256
257
258
class TestSplitPerformance(unittest.TestCase):
259
    @classmethod
260
    def setUpClass(cls):
261
        cls.env = TradingEnvironment()
262
        cls.sim_params = create_simulation_parameters(num_days=2,
263
                                                      capital_base=10e3)
264
265
        setup_env_data(cls.env, cls.sim_params, [1])
266
267
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
268
                                                         cls.env)
269
        cls.tempdir = TempDirectory()
270
271
    @classmethod
272
    def tearDownClass(cls):
273
        cls.tempdir.cleanup()
274
275
    def test_split_long_position(self):
276
        events = factory.create_trade_history(
277
            1,
278
            # TODO: Should we provide adjusted prices in the tests, or provide
279
            # raw prices and adjust via DataPortal?
280
            [20, 60],
281
            [100, 100],
282
            oneday,
283
            self.sim_params,
284
            env=self.env
285
        )
286
287
        # set up a long position in sid 1
288
        # 100 shares at $20 apiece = $2000 position
289
        txns = [create_txn(events[0].sid, events[0].dt, 20, 100)]
290
291
        # set up a split with ratio 3 occurring at the start of the second
292
        # day.
293
        splits = {
294
            events[1].dt: [(1, 3)]
295
        }
296
297
        results = calculate_results(self.sim_params, self.env,
298
                                    self.tempdir,
299
                                    self.benchmark_events,
300
                                    {1: events},
301
                                    NullAdjustmentReader(),
302
                                    txns=txns, splits=splits)
303
304
        # should have 33 shares (at $60 apiece) and $20 in cash
305
        self.assertEqual(2, len(results))
306
307
        latest_positions = results[1]['daily_perf']['positions']
308
        self.assertEqual(1, len(latest_positions))
309
310
        # check the last position to make sure it's been updated
311
        position = latest_positions[0]
312
313
        self.assertEqual(1, position['sid'])
314
        self.assertEqual(33, position['amount'])
315
        self.assertEqual(60, position['cost_basis'])
316
        self.assertEqual(60, position['last_sale_price'])
317
318
        # since we started with $10000, and we spent $2000 on the
319
        # position, but then got $20 back, we should have $8020
320
        # (or close to it) in cash.
321
322
        # we won't get exactly 8020 because sometimes a split is
323
        # denoted as a ratio like 0.3333, and we lose some digits
324
        # of precision.  thus, make sure we're pretty close.
325
        daily_perf = results[1]['daily_perf']
326
327
        self.assertTrue(
328
            zp_math.tolerant_equals(8020,
329
                                    daily_perf['ending_cash'], 1),
330
            "ending_cash was {0}".format(daily_perf['ending_cash']))
331
332
        # Validate that the account attributes were updated.
333
        account = results[1]['account']
334
        self.assertEqual(float('inf'), account['day_trades_remaining'])
335
        # this is a long only portfolio that is only partially invested
336
        # so net and gross leverage are equal.
337
        np.testing.assert_allclose(0.198, account['leverage'], rtol=1e-3)
338
        np.testing.assert_allclose(0.198, account['net_leverage'], rtol=1e-3)
339
        np.testing.assert_allclose(8020, account['regt_equity'], rtol=1e-3)
340
        self.assertEqual(float('inf'), account['regt_margin'])
341
        np.testing.assert_allclose(8020, account['available_funds'], rtol=1e-3)
342
        self.assertEqual(0, account['maintenance_margin_requirement'])
343
        np.testing.assert_allclose(10000,
344
                                   account['equity_with_loan'], rtol=1e-3)
345
        self.assertEqual(float('inf'), account['buying_power'])
346
        self.assertEqual(0, account['initial_margin_requirement'])
347
        np.testing.assert_allclose(8020, account['excess_liquidity'],
348
                                   rtol=1e-3)
349
        np.testing.assert_allclose(8020, account['settled_cash'], rtol=1e-3)
350
        np.testing.assert_allclose(10000, account['net_liquidation'],
351
                                   rtol=1e-3)
352
        np.testing.assert_allclose(0.802, account['cushion'], rtol=1e-3)
353
        np.testing.assert_allclose(1980, account['total_positions_value'],
354
                                   rtol=1e-3)
355
        self.assertEqual(0, account['accrued_interest'])
356
357
        for i, result in enumerate(results):
358
            for perf_kind in ('daily_perf', 'cumulative_perf'):
359
                perf_result = result[perf_kind]
360
                # prices aren't changing, so pnl and returns should be 0.0
361
                self.assertEqual(0.0, perf_result['pnl'],
362
                                 "day %s %s pnl %s instead of 0.0" %
363
                                 (i, perf_kind, perf_result['pnl']))
364
                self.assertEqual(0.0, perf_result['returns'],
365
                                 "day %s %s returns %s instead of 0.0" %
366
                                 (i, perf_kind, perf_result['returns']))
367
368
369
class TestCommissionEvents(unittest.TestCase):
370
    @classmethod
371
    def setUpClass(cls):
372
        cls.env = TradingEnvironment()
373
        cls.sim_params = create_simulation_parameters(num_days=5,
374
                                                      capital_base=10e3)
375
        setup_env_data(cls.env, cls.sim_params, [0, 1, 133])
376
377
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
378
                                                         cls.env)
379
        cls.tempdir = TempDirectory()
380
381
    @classmethod
382
    def tearDownClass(cls):
383
        cls.tempdir.cleanup()
384
385
    def test_commission_event(self):
386
        trade_events = factory.create_trade_history(
387
            1,
388
            [10, 10, 10, 10, 10],
389
            [100, 100, 100, 100, 100],
390
            oneday,
391
            self.sim_params,
392
            env=self.env
393
        )
394
395
        # Test commission models and validate result
396
        # Expected commission amounts:
397
        # PerShare commission:  1.00, 1.00, 1.50 = $3.50
398
        # PerTrade commission:  5.00, 5.00, 5.00 = $15.00
399
        # PerDollar commission: 1.50, 3.00, 4.50 = $9.00
400
        # Total commission = $3.50 + $15.00 + $9.00 = $27.50
401
402
        # Create 3 transactions:  50, 100, 150 shares traded @ $20
403
        first_trade = trade_events[0]
404
        transactions = [create_txn(first_trade.sid, first_trade.dt, 20, i)
405
                        for i in [50, 100, 150]]
406
407
        # Create commission models and validate that produce expected
408
        # commissions.
409
        models = [PerShare(cost=0.01, min_trade_cost=1.00),
410
                  PerTrade(cost=5.00),
411
                  PerDollar(cost=0.0015)]
412
        expected_results = [3.50, 15.0, 9.0]
413
414
        for model, expected in zip(models, expected_results):
415
            total_commission = 0
416
            for trade in transactions:
417
                total_commission += model.calculate(trade)[1]
418
            self.assertEqual(total_commission, expected)
419
420
        # Verify that commission events are handled correctly by
421
        # PerformanceTracker.
422
        commissions = {}
423
        cash_adj_dt = trade_events[0].dt
424
        cash_adjustment = factory.create_commission(1, 300.0, cash_adj_dt)
425
        commissions[cash_adj_dt] = [cash_adjustment]
426
427
        # Insert a purchase order.
428
        txns = [create_txn(1, cash_adj_dt, 20, 1)]
429
        results = calculate_results(self.sim_params,
430
                                    self.env,
431
                                    self.tempdir,
432
                                    self.benchmark_events,
433
                                    {1: trade_events},
434
                                    NullAdjustmentReader(),
435
                                    txns=txns,
436
                                    commissions=commissions)
437
438
        # Validate that we lost 320 dollars from our cash pool.
439
        self.assertEqual(results[-1]['cumulative_perf']['ending_cash'],
440
                         9680, "Should have lost 320 from cash pool.")
441
        # Validate that the cost basis of our position changed.
442
        self.assertEqual(results[-1]['daily_perf']['positions']
443
                         [0]['cost_basis'], 320.0)
444
        # Validate that the account attributes were updated.
445
        account = results[1]['account']
446
        self.assertEqual(float('inf'), account['day_trades_remaining'])
447
        np.testing.assert_allclose(0.001, account['leverage'], rtol=1e-3,
448
                                   atol=1e-4)
449
        np.testing.assert_allclose(9680, account['regt_equity'], rtol=1e-3)
450
        self.assertEqual(float('inf'), account['regt_margin'])
451
        np.testing.assert_allclose(9680, account['available_funds'],
452
                                   rtol=1e-3)
453
        self.assertEqual(0, account['maintenance_margin_requirement'])
454
        np.testing.assert_allclose(9690,
455
                                   account['equity_with_loan'], rtol=1e-3)
456
        self.assertEqual(float('inf'), account['buying_power'])
457
        self.assertEqual(0, account['initial_margin_requirement'])
458
        np.testing.assert_allclose(9680, account['excess_liquidity'],
459
                                   rtol=1e-3)
460
        np.testing.assert_allclose(9680, account['settled_cash'],
461
                                   rtol=1e-3)
462
        np.testing.assert_allclose(9690, account['net_liquidation'],
463
                                   rtol=1e-3)
464
        np.testing.assert_allclose(0.999, account['cushion'], rtol=1e-3)
465
        np.testing.assert_allclose(10, account['total_positions_value'],
466
                                   rtol=1e-3)
467
        self.assertEqual(0, account['accrued_interest'])
468
469
    def test_commission_zero_position(self):
470
        """
471
        Ensure no div-by-zero errors.
472
        """
473
        events = factory.create_trade_history(
474
            1,
475
            [10, 10, 10, 10, 10],
476
            [100, 100, 100, 100, 100],
477
            oneday,
478
            self.sim_params,
479
            env=self.env
480
        )
481
482
        # Buy and sell the same sid so that we have a zero position by the
483
        # time of events[3].
484
        txns = [
485
            create_txn(events[0].sid, events[0].dt, 20, 1),
486
            create_txn(events[1].sid, events[1].dt, 20, -1),
487
        ]
488
489
        # Add a cash adjustment at the time of event[3].
490
        cash_adj_dt = events[3].dt
491
        commissions = {}
492
        cash_adjustment = factory.create_commission(1, 300.0, cash_adj_dt)
493
        commissions[cash_adj_dt] = [cash_adjustment]
494
495
        results = calculate_results(self.sim_params,
496
                                    self.env,
497
                                    self.tempdir,
498
                                    self.benchmark_events,
499
                                    {1: events},
500
                                    NullAdjustmentReader(),
501
                                    txns=txns,
502
                                    commissions=commissions)
503
        # Validate that we lost 300 dollars from our cash pool.
504
        self.assertEqual(results[-1]['cumulative_perf']['ending_cash'],
505
                         9700)
506
507
    def test_commission_no_position(self):
508
        """
509
        Ensure no position-not-found or sid-not-found errors.
510
        """
511
        events = factory.create_trade_history(
512
            1,
513
            [10, 10, 10, 10, 10],
514
            [100, 100, 100, 100, 100],
515
            oneday,
516
            self.sim_params,
517
            env=self.env
518
        )
519
520
        # Add a cash adjustment at the time of event[3].
521
        cash_adj_dt = events[3].dt
522
        commissions = {}
523
        cash_adjustment = factory.create_commission(1, 300.0, cash_adj_dt)
524
        commissions[cash_adj_dt] = [cash_adjustment]
525
526
        results = calculate_results(self.sim_params,
527
                                    self.env,
528
                                    self.tempdir,
529
                                    self.benchmark_events,
530
                                    {1: events},
531
                                    NullAdjustmentReader(),
532
                                    commissions=commissions)
533
        # Validate that we lost 300 dollars from our cash pool.
534
        self.assertEqual(results[-1]['cumulative_perf']['ending_cash'],
535
                         9700)
536
537
538
class MockDailyBarSpotReader(object):
539
540
    def spot_price(self, sid, day, colname):
541
        return 100.0
542
543
544
class TestDividendPerformance(unittest.TestCase):
545
546
    @classmethod
547
    def setUpClass(cls):
548
        cls.env = TradingEnvironment()
549
        cls.sim_params = create_simulation_parameters(num_days=6,
550
                                                      capital_base=10e3)
551
552
        setup_env_data(cls.env, cls.sim_params, [1, 2])
553
554
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
555
                                                         cls.env)
556
557
    @classmethod
558
    def tearDownClass(cls):
559
        del cls.env
560
561
    def setUp(self):
562
        self.tempdir = TempDirectory()
563
564
    def tearDown(self):
565
        self.tempdir.cleanup()
566
567
    def test_market_hours_calculations(self):
568
        # DST in US/Eastern began on Sunday March 14, 2010
569
        before = datetime(2010, 3, 12, 14, 31, tzinfo=pytz.utc)
570
        after = factory.get_next_trading_dt(
571
            before,
572
            timedelta(days=1),
573
            self.env,
574
        )
575
        self.assertEqual(after.hour, 13)
576
577
    def test_long_position_receives_dividend(self):
578
        # post some trades in the market
579
        events = factory.create_trade_history(
580
            1,
581
            [10, 10, 10, 10, 10, 10],
582
            [100, 100, 100, 100, 100, 100],
583
            oneday,
584
            self.sim_params,
585
            env=self.env
586
        )
587
588
        dbpath = self.tempdir.getpath('adjustments.sqlite')
589
590
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
591
                                        MockDailyBarSpotReader())
592
        splits = mergers = pd.DataFrame(
593
            {
594
                # Hackery to make the dtypes correct on an empty frame.
595
                'effective_date': np.array([], dtype=int),
596
                'ratio': np.array([], dtype=float),
597
                'sid': np.array([], dtype=int),
598
            },
599
            index=pd.DatetimeIndex([], tz='UTC'),
600
            columns=['effective_date', 'ratio', 'sid'],
601
        )
602
        dividends = pd.DataFrame({
603
            'sid': np.array([1], dtype=np.uint32),
604
            'amount': np.array([10.00], dtype=np.float64),
605
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
606
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
607
            'record_date': np.array([events[1].dt], dtype='datetime64[ns]'),
608
            'pay_date': np.array([events[2].dt], dtype='datetime64[ns]'),
609
        })
610
        writer.write(splits, mergers, dividends)
611
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
612
613
        # Simulate a transaction being filled prior to the ex_date.
614
        txns = [create_txn(events[0].sid, events[0].dt, 10.0, 100)]
615
        results = calculate_results(
616
            self.sim_params,
617
            self.env,
618
            self.tempdir,
619
            self.benchmark_events,
620
            {1: events},
621
            adjustment_reader,
622
            txns=txns,
623
        )
624
625
        self.assertEqual(len(results), 6)
626
        cumulative_returns = \
627
            [event['cumulative_perf']['returns'] for event in results]
628
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.1, 0.1, 0.1, 0.1])
629
        daily_returns = [event['daily_perf']['returns']
630
                         for event in results]
631
        self.assertEqual(daily_returns, [0.0, 0.0, 0.10, 0.0, 0.0, 0.0])
632
        cash_flows = [event['daily_perf']['capital_used']
633
                      for event in results]
634
        self.assertEqual(cash_flows, [-1000, 0, 1000, 0, 0, 0])
635
        cumulative_cash_flows = \
636
            [event['cumulative_perf']['capital_used'] for event in results]
637
        self.assertEqual(cumulative_cash_flows, [-1000, -1000, 0, 0, 0, 0])
638
        cash_pos = \
639
            [event['cumulative_perf']['ending_cash'] for event in results]
640
        self.assertEqual(cash_pos, [9000, 9000, 10000, 10000, 10000, 10000])
641
642
    def test_long_position_receives_stock_dividend(self):
643
        # post some trades in the market
644
        events = {}
645
        for sid in (1, 2):
646
            events[sid] = factory.create_trade_history(
647
                sid,
648
                [10, 10, 10, 10, 10, 10],
649
                [100, 100, 100, 100, 100, 100],
650
                oneday,
651
                self.sim_params,
652
                env=self.env
653
            )
654
655
        dbpath = self.tempdir.getpath('adjustments.sqlite')
656
657
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
658
                                        MockDailyBarSpotReader())
659
        splits = mergers = pd.DataFrame(
660
            {
661
                # Hackery to make the dtypes correct on an empty frame.
662
                'effective_date': np.array([], dtype=int),
663
                'ratio': np.array([], dtype=float),
664
                'sid': np.array([], dtype=int),
665
            },
666
            index=pd.DatetimeIndex([], tz='UTC'),
667
            columns=['effective_date', 'ratio', 'sid'],
668
        )
669
        dividends = pd.DataFrame({
670
            'sid': np.array([], dtype=np.uint32),
671
            'amount': np.array([], dtype=np.float64),
672
            'declared_date': np.array([], dtype='datetime64[ns]'),
673
            'ex_date': np.array([], dtype='datetime64[ns]'),
674
            'pay_date': np.array([], dtype='datetime64[ns]'),
675
            'record_date': np.array([], dtype='datetime64[ns]'),
676
        })
677
        sid_1 = events[1]
678
        stock_dividends = pd.DataFrame({
679
            'sid': np.array([1], dtype=np.uint32),
680
            'payment_sid': np.array([2], dtype=np.uint32),
681
            'ratio': np.array([2], dtype=np.float64),
682
            'declared_date': np.array([sid_1[0].dt], dtype='datetime64[ns]'),
683
            'ex_date': np.array([sid_1[1].dt], dtype='datetime64[ns]'),
684
            'record_date': np.array([sid_1[1].dt], dtype='datetime64[ns]'),
685
            'pay_date': np.array([sid_1[2].dt], dtype='datetime64[ns]'),
686
        })
687
        writer.write(splits, mergers, dividends, stock_dividends)
688
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
689
690
        txns = [create_txn(events[1][0].sid, events[1][0].dt, 10.0, 100)]
691
692
        results = calculate_results(
693
            self.sim_params,
694
            self.env,
695
            self.tempdir,
696
            self.benchmark_events,
697
            events,
698
            adjustment_reader,
699
            txns=txns,
700
        )
701
702
        self.assertEqual(len(results), 6)
703
        cumulative_returns = \
704
            [event['cumulative_perf']['returns'] for event in results]
705
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.2, 0.2, 0.2, 0.2])
706
        daily_returns = [event['daily_perf']['returns']
707
                         for event in results]
708
        self.assertEqual(daily_returns, [0.0, 0.0, 0.2, 0.0, 0.0, 0.0])
709
        cash_flows = [event['daily_perf']['capital_used']
710
                      for event in results]
711
        self.assertEqual(cash_flows, [-1000, 0, 0, 0, 0, 0])
712
        cumulative_cash_flows = \
713
            [event['cumulative_perf']['capital_used'] for event in results]
714
        self.assertEqual(cumulative_cash_flows, [-1000] * 6)
715
        cash_pos = \
716
            [event['cumulative_perf']['ending_cash'] for event in results]
717
        self.assertEqual(cash_pos, [9000] * 6)
718
719
    def test_long_position_purchased_on_ex_date_receives_no_dividend(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
720
        # post some trades in the market
721
        events = factory.create_trade_history(
722
            1,
723
            [10, 10, 10, 10, 10, 10],
724
            [100, 100, 100, 100, 100, 100],
725
            oneday,
726
            self.sim_params,
727
            env=self.env
728
        )
729
730
        dbpath = self.tempdir.getpath('adjustments.sqlite')
731
732
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
733
                                        MockDailyBarSpotReader())
734
        splits = mergers = pd.DataFrame(
735
            {
736
                # Hackery to make the dtypes correct on an empty frame.
737
                'effective_date': np.array([], dtype=int),
738
                'ratio': np.array([], dtype=float),
739
                'sid': np.array([], dtype=int),
740
            },
741
            index=pd.DatetimeIndex([], tz='UTC'),
742
            columns=['effective_date', 'ratio', 'sid'],
743
        )
744
        dividends = pd.DataFrame({
745
            'sid': np.array([1], dtype=np.uint32),
746
            'amount': np.array([10.00], dtype=np.float64),
747
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
748
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
749
            'record_date': np.array([events[1].dt], dtype='datetime64[ns]'),
750
            'pay_date': np.array([events[2].dt], dtype='datetime64[ns]'),
751
        })
752
        writer.write(splits, mergers, dividends)
753
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
754
755
        # Simulate a transaction being filled on the ex_date.
756
        txns = [create_txn(events[1].sid, events[1].dt, 10.0, 100)]
757
758
        results = calculate_results(
759
            self.sim_params,
760
            self.env,
761
            self.tempdir,
762
            self.benchmark_events,
763
            {1: events},
764
            adjustment_reader,
765
            txns=txns,
766
        )
767
768
        self.assertEqual(len(results), 6)
769
        cumulative_returns = \
770
            [event['cumulative_perf']['returns'] for event in results]
771
        self.assertEqual(cumulative_returns, [0, 0, 0, 0, 0, 0])
772
        daily_returns = [event['daily_perf']['returns'] for event in results]
773
        self.assertEqual(daily_returns, [0, 0, 0, 0, 0, 0])
774
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
775
        self.assertEqual(cash_flows, [0, -1000, 0, 0, 0, 0])
776
        cumulative_cash_flows = \
777
            [event['cumulative_perf']['capital_used'] for event in results]
778
        self.assertEqual(cumulative_cash_flows,
779
                         [0, -1000, -1000, -1000, -1000, -1000])
780
781
    def test_selling_before_dividend_payment_still_gets_paid(self):
782
        # post some trades in the market
783
        events = factory.create_trade_history(
784
            1,
785
            [10, 10, 10, 10, 10, 10],
786
            [100, 100, 100, 100, 100, 100],
787
            oneday,
788
            self.sim_params,
789
            env=self.env
790
        )
791
792
        dbpath = self.tempdir.getpath('adjustments.sqlite')
793
794
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
795
                                        MockDailyBarSpotReader())
796
        splits = mergers = pd.DataFrame(
797
            {
798
                # Hackery to make the dtypes correct on an empty frame.
799
                'effective_date': np.array([], dtype=int),
800
                'ratio': np.array([], dtype=float),
801
                'sid': np.array([], dtype=int),
802
            },
803
            index=pd.DatetimeIndex([], tz='UTC'),
804
            columns=['effective_date', 'ratio', 'sid'],
805
        )
806
        dividends = pd.DataFrame({
807
            'sid': np.array([1], dtype=np.uint32),
808
            'amount': np.array([10.00], dtype=np.float64),
809
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
810
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
811
            'record_date': np.array([events[1].dt], dtype='datetime64[ns]'),
812
            'pay_date': np.array([events[3].dt], dtype='datetime64[ns]'),
813
        })
814
        writer.write(splits, mergers, dividends)
815
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
816
817
        buy_txn = create_txn(events[0].sid, events[0].dt, 10.0, 100)
818
        sell_txn = create_txn(events[2].sid, events[2].dt, 10.0, -100)
819
        txns = [buy_txn, sell_txn]
820
821
        results = calculate_results(
822
            self.sim_params,
823
            self.env,
824
            self.tempdir,
825
            self.benchmark_events,
826
            {1: events},
827
            adjustment_reader,
828
            txns=txns,
829
        )
830
831
        self.assertEqual(len(results), 6)
832
        cumulative_returns = \
833
            [event['cumulative_perf']['returns'] for event in results]
834
        self.assertEqual(cumulative_returns, [0, 0, 0, 0.1, 0.1, 0.1])
835
        daily_returns = [event['daily_perf']['returns'] for event in results]
836
        self.assertEqual(daily_returns, [0, 0, 0, 0.1, 0, 0])
837
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
838
        self.assertEqual(cash_flows, [-1000, 0, 1000, 1000, 0, 0])
839
        cumulative_cash_flows = \
840
            [event['cumulative_perf']['capital_used'] for event in results]
841
        self.assertEqual(cumulative_cash_flows,
842
                         [-1000, -1000, 0, 1000, 1000, 1000])
843
844
    def test_buy_and_sell_before_ex(self):
845
        # post some trades in the market
846
        events = factory.create_trade_history(
847
            1,
848
            [10, 10, 10, 10, 10, 10],
849
            [100, 100, 100, 100, 100, 100],
850
            oneday,
851
            self.sim_params,
852
            env=self.env
853
        )
854
        dbpath = self.tempdir.getpath('adjustments.sqlite')
855
856
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
857
                                        MockDailyBarSpotReader())
858
        splits = mergers = pd.DataFrame(
859
            {
860
                # Hackery to make the dtypes correct on an empty frame.
861
                'effective_date': np.array([], dtype=int),
862
                'ratio': np.array([], dtype=float),
863
                'sid': np.array([], dtype=int),
864
            },
865
            index=pd.DatetimeIndex([], tz='UTC'),
866
            columns=['effective_date', 'ratio', 'sid'],
867
        )
868
869
        dividends = pd.DataFrame({
870
            'sid': np.array([1], dtype=np.uint32),
871
            'amount': np.array([10.0], dtype=np.float64),
872
            'declared_date': np.array([events[3].dt], dtype='datetime64[ns]'),
873
            'ex_date': np.array([events[4].dt], dtype='datetime64[ns]'),
874
            'pay_date': np.array([events[5].dt], dtype='datetime64[ns]'),
875
            'record_date': np.array([events[4].dt], dtype='datetime64[ns]'),
876
        })
877
        writer.write(splits, mergers, dividends)
878
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
879
880
        buy_txn = create_txn(events[1].sid, events[1].dt, 10.0, 100)
881
        sell_txn = create_txn(events[2].sid, events[2].dt, 10.0, -100)
882
        txns = [buy_txn, sell_txn]
883
884
        results = calculate_results(
885
            self.sim_params,
886
            self.env,
887
            self.tempdir,
888
            self.benchmark_events,
889
            {1: events},
890
            txns=txns,
891
            adjustment_reader=adjustment_reader,
892
        )
893
894
        self.assertEqual(len(results), 6)
895
        cumulative_returns = \
896
            [event['cumulative_perf']['returns'] for event in results]
897
        self.assertEqual(cumulative_returns, [0, 0, 0, 0, 0, 0])
898
        daily_returns = [event['daily_perf']['returns'] for event in results]
899
        self.assertEqual(daily_returns, [0, 0, 0, 0, 0, 0])
900
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
901
        self.assertEqual(cash_flows, [0, -1000, 1000, 0, 0, 0])
902
        cumulative_cash_flows = \
903
            [event['cumulative_perf']['capital_used'] for event in results]
904
        self.assertEqual(cumulative_cash_flows, [0, -1000, 0, 0, 0, 0])
905
906
    def test_ending_before_pay_date(self):
907
        # post some trades in the market
908
        events = factory.create_trade_history(
909
            1,
910
            [10, 10, 10, 10, 10, 10],
911
            [100, 100, 100, 100, 100, 100],
912
            oneday,
913
            self.sim_params,
914
            env=self.env
915
        )
916
917
        pay_date = self.sim_params.first_open
918
        # find pay date that is much later.
919
        for i in range(30):
920
            pay_date = factory.get_next_trading_dt(pay_date, oneday, self.env)
921
922
        dbpath = self.tempdir.getpath('adjustments.sqlite')
923
924
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
925
                                        MockDailyBarSpotReader())
926
        splits = mergers = pd.DataFrame(
927
            {
928
                # Hackery to make the dtypes correct on an empty frame.
929
                'effective_date': np.array([], dtype=int),
930
                'ratio': np.array([], dtype=float),
931
                'sid': np.array([], dtype=int),
932
            },
933
            index=pd.DatetimeIndex([], tz='UTC'),
934
            columns=['effective_date', 'ratio', 'sid'],
935
        )
936
        dividends = pd.DataFrame({
937
            'sid': np.array([1], dtype=np.uint32),
938
            'amount': np.array([10.00], dtype=np.float64),
939
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
940
            'ex_date': np.array([events[0].dt], dtype='datetime64[ns]'),
941
            'record_date': np.array([events[0].dt], dtype='datetime64[ns]'),
942
            'pay_date': np.array([pay_date], dtype='datetime64[ns]'),
943
        })
944
        writer.write(splits, mergers, dividends)
945
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
946
947
        txns = [create_txn(events[1].sid, events[1].dt, 10.0, 100)]
948
949
        results = calculate_results(
950
            self.sim_params,
951
            self.env,
952
            self.tempdir,
953
            self.benchmark_events,
954
            {1: events},
955
            txns=txns,
956
            adjustment_reader=adjustment_reader,
957
        )
958
959
        self.assertEqual(len(results), 6)
960
        cumulative_returns = \
961
            [event['cumulative_perf']['returns'] for event in results]
962
        self.assertEqual(cumulative_returns, [0, 0, 0, 0.0, 0.0, 0.0])
963
        daily_returns = [event['daily_perf']['returns'] for event in results]
964
        self.assertEqual(daily_returns, [0, 0, 0, 0, 0, 0])
965
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
966
        self.assertEqual(cash_flows, [0, -1000, 0, 0, 0, 0])
967
        cumulative_cash_flows = \
968
            [event['cumulative_perf']['capital_used'] for event in results]
969
        self.assertEqual(
970
            cumulative_cash_flows,
971
            [0, -1000, -1000, -1000, -1000, -1000]
972
        )
973
974
    def test_short_position_pays_dividend(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
975
        # post some trades in the market
976
        events = factory.create_trade_history(
977
            1,
978
            [10, 10, 10, 10, 10, 10],
979
            [100, 100, 100, 100, 100, 100],
980
            oneday,
981
            self.sim_params,
982
            env=self.env
983
        )
984
985
        dbpath = self.tempdir.getpath('adjustments.sqlite')
986
987
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
988
                                        MockDailyBarSpotReader())
989
        splits = mergers = pd.DataFrame(
990
            {
991
                # Hackery to make the dtypes correct on an empty frame.
992
                'effective_date': np.array([], dtype=int),
993
                'ratio': np.array([], dtype=float),
994
                'sid': np.array([], dtype=int),
995
            },
996
            index=pd.DatetimeIndex([], tz='UTC'),
997
            columns=['effective_date', 'ratio', 'sid'],
998
        )
999
        dividends = pd.DataFrame({
1000
            'sid': np.array([1], dtype=np.uint32),
1001
            'amount': np.array([10.00], dtype=np.float64),
1002
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
1003
            'ex_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1004
            'record_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1005
            'pay_date': np.array([events[3].dt], dtype='datetime64[ns]'),
1006
        })
1007
        writer.write(splits, mergers, dividends)
1008
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
1009
1010
        txns = [create_txn(events[1].sid, events[1].dt, 10.0, -100)]
1011
1012
        results = calculate_results(
1013
            self.sim_params,
1014
            self.env,
1015
            self.tempdir,
1016
            self.benchmark_events,
1017
            {1: events},
1018
            adjustment_reader,
1019
            txns=txns,
1020
        )
1021
1022
        self.assertEqual(len(results), 6)
1023
        cumulative_returns = \
1024
            [event['cumulative_perf']['returns'] for event in results]
1025
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, -0.1, -0.1, -0.1])
1026
        daily_returns = [event['daily_perf']['returns'] for event in results]
1027
        self.assertEqual(daily_returns, [0.0, 0.0, 0.0, -0.1, 0.0, 0.0])
1028
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
1029
        self.assertEqual(cash_flows, [0, 1000, 0, -1000, 0, 0])
1030
        cumulative_cash_flows = \
1031
            [event['cumulative_perf']['capital_used'] for event in results]
1032
        self.assertEqual(cumulative_cash_flows, [0, 1000, 1000, 0, 0, 0])
1033
1034
    def test_no_position_receives_no_dividend(self):
1035
        # post some trades in the market
1036
        events = factory.create_trade_history(
1037
            1,
1038
            [10, 10, 10, 10, 10, 10],
1039
            [100, 100, 100, 100, 100, 100],
1040
            oneday,
1041
            self.sim_params,
1042
            env=self.env
1043
        )
1044
1045
        dbpath = self.tempdir.getpath('adjustments.sqlite')
1046
1047
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
1048
                                        MockDailyBarSpotReader())
1049
        splits = mergers = pd.DataFrame(
1050
            {
1051
                # Hackery to make the dtypes correct on an empty frame.
1052
                'effective_date': np.array([], dtype=int),
1053
                'ratio': np.array([], dtype=float),
1054
                'sid': np.array([], dtype=int),
1055
            },
1056
            index=pd.DatetimeIndex([], tz='UTC'),
1057
            columns=['effective_date', 'ratio', 'sid'],
1058
        )
1059
        dividends = pd.DataFrame({
1060
            'sid': np.array([1], dtype=np.uint32),
1061
            'amount': np.array([10.00], dtype=np.float64),
1062
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
1063
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
1064
            'pay_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1065
            'record_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1066
        })
1067
        writer.write(splits, mergers, dividends)
1068
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
1069
1070
        results = calculate_results(
1071
            self.sim_params,
1072
            self.env,
1073
            self.tempdir,
1074
            self.benchmark_events,
1075
            {1: events},
1076
            adjustment_reader,
1077
        )
1078
1079
        self.assertEqual(len(results), 6)
1080
        cumulative_returns = \
1081
            [event['cumulative_perf']['returns'] for event in results]
1082
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
1083
        daily_returns = [event['daily_perf']['returns'] for event in results]
1084
        self.assertEqual(daily_returns, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
1085
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
1086
        self.assertEqual(cash_flows, [0, 0, 0, 0, 0, 0])
1087
        cumulative_cash_flows = \
1088
            [event['cumulative_perf']['capital_used'] for event in results]
1089
        self.assertEqual(cumulative_cash_flows, [0, 0, 0, 0, 0, 0])
1090
1091
    def test_no_dividend_at_simulation_end(self):
1092
        # post some trades in the market
1093
        events = factory.create_trade_history(
1094
            1,
1095
            [10, 10, 10, 10, 10],
1096
            [100, 100, 100, 100, 100],
1097
            oneday,
1098
            self.sim_params,
1099
            env=self.env
1100
        )
1101
1102
        dbpath = self.tempdir.getpath('adjustments.sqlite')
1103
1104
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
1105
                                        MockDailyBarSpotReader())
1106
        splits = mergers = pd.DataFrame(
1107
            {
1108
                # Hackery to make the dtypes correct on an empty frame.
1109
                'effective_date': np.array([], dtype=int),
1110
                'ratio': np.array([], dtype=float),
1111
                'sid': np.array([], dtype=int),
1112
            },
1113
            index=pd.DatetimeIndex([], tz='UTC'),
1114
            columns=['effective_date', 'ratio', 'sid'],
1115
        )
1116
        dividends = pd.DataFrame({
1117
            'sid': np.array([1], dtype=np.uint32),
1118
            'amount': np.array([10.00], dtype=np.float64),
1119
            'declared_date': np.array([events[-3].dt], dtype='datetime64[ns]'),
1120
            'ex_date': np.array([events[-2].dt], dtype='datetime64[ns]'),
1121
            'record_date': np.array([events[0].dt], dtype='datetime64[ns]'),
1122
            'pay_date': np.array([self.env.next_trading_day(events[-1].dt)],
1123
                                 dtype='datetime64[ns]'),
1124
        })
1125
        writer.write(splits, mergers, dividends)
1126
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
1127
1128
        # Set the last day to be the last event
1129
        sim_params = create_simulation_parameters(
1130
            num_days=6,
1131
            capital_base=10e3,
1132
            start=self.sim_params.period_start,
1133
            end=self.sim_params.period_end
1134
        )
1135
1136
        sim_params.period_end = events[-1].dt
1137
        sim_params.update_internal_from_env(self.env)
1138
1139
        # Simulate a transaction being filled prior to the ex_date.
1140
        txns = [create_txn(events[0].sid, events[0].dt, 10.0, 100)]
1141
        results = calculate_results(
1142
            sim_params,
1143
            self.env,
1144
            self.tempdir,
1145
            self.benchmark_events,
1146
            {1: events},
1147
            adjustment_reader=adjustment_reader,
1148
            txns=txns,
1149
        )
1150
1151
        self.assertEqual(len(results), 5)
1152
        cumulative_returns = \
1153
            [event['cumulative_perf']['returns'] for event in results]
1154
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, 0.0, 0.0])
1155
        daily_returns = [event['daily_perf']['returns'] for event in results]
1156
        self.assertEqual(daily_returns, [0.0, 0.0, 0.0, 0.0, 0.0])
1157
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
1158
        self.assertEqual(cash_flows, [-1000, 0, 0, 0, 0])
1159
        cumulative_cash_flows = \
1160
            [event['cumulative_perf']['capital_used'] for event in results]
1161
        self.assertEqual(cumulative_cash_flows,
1162
                         [-1000, -1000, -1000, -1000, -1000])
1163
1164
1165
class TestDividendPerformanceHolidayStyle(TestDividendPerformance):
1166
1167
    # The holiday tests begins the simulation on the day
1168
    # before Thanksgiving, so that the next trading day is
1169
    # two days ahead. Any tests that hard code events
1170
    # to be start + oneday will fail, since those events will
1171
    # be skipped by the simulation.
1172
1173
    @classmethod
1174
    def setUpClass(cls):
1175
        cls.env = TradingEnvironment()
1176
        cls.sim_params = create_simulation_parameters(
1177
            num_days=6,
1178
            capital_base=10e3,
1179
            start=pd.Timestamp("2003-11-30", tz='UTC'),
1180
            end=pd.Timestamp("2003-12-08", tz='UTC')
1181
        )
1182
1183
        setup_env_data(cls.env, cls.sim_params, [1, 2])
1184
1185
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
1186
                                                         cls.env)
1187
1188
1189
class TestPositionPerformance(unittest.TestCase):
1190
1191
    def setUp(self):
1192
        self.tempdir = TempDirectory()
1193
1194
    def create_environment_stuff(self, num_days=4, sids=[1, 2]):
1195
        self.env = TradingEnvironment()
1196
        self.sim_params = create_simulation_parameters(num_days=num_days)
1197
1198
        setup_env_data(self.env, self.sim_params, [1, 2])
1199
1200
        self.finder = self.env.asset_finder
1201
1202
        self.benchmark_events = benchmark_events_in_range(self.sim_params,
1203
                                                          self.env)
1204
1205
    def tearDown(self):
1206
        self.tempdir.cleanup()
1207
        del self.env
1208
1209
    def test_long_short_positions(self):
1210
        """
1211
        start with $1000
1212
        buy 100 stock1 shares at $10
1213
        sell short 100 stock2 shares at $10
1214
        stock1 then goes down to $9
1215
        stock2 goes to $11
1216
        """
1217
        self.create_environment_stuff()
1218
1219
        trades_1 = factory.create_trade_history(
1220
            1,
1221
            [10, 10, 10, 9],
1222
            [100, 100, 100, 100],
1223
            oneday,
1224
            self.sim_params,
1225
            env=self.env
1226
        )
1227
1228
        trades_2 = factory.create_trade_history(
1229
            2,
1230
            [10, 10, 10, 11],
1231
            [100, 100, 100, 100],
1232
            onesec,
1233
            self.sim_params,
1234
            env=self.env
1235
        )
1236
1237
        txn1 = create_txn(trades_1[1].sid, trades_1[1].dt, 10.0, 100)
1238
        txn2 = create_txn(trades_2[1].sid, trades_1[1].dt, 10.0, -100)
1239
1240
        data_portal = create_data_portal_from_trade_history(
1241
            self.env,
1242
            self.tempdir,
1243
            self.sim_params,
1244
            {1: trades_1, 2: trades_2}
1245
        )
1246
1247
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
1248
                                  self.sim_params.data_frequency)
1249
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1250
                                    self.sim_params.data_frequency,
1251
                                    data_portal)
1252
        pt.execute_transaction(txn1)
1253
        pp.handle_execution(txn1)
1254
        pt.execute_transaction(txn2)
1255
        pp.handle_execution(txn2)
1256
1257
        check_perf_period(
1258
            pp,
1259
            pt,
1260
            gross_leverage=2.0,
1261
            net_leverage=0.0,
1262
            long_exposure=1000.0,
1263
            longs_count=1,
1264
            short_exposure=-1000.0,
1265
            shorts_count=1,
1266
            data_portal=data_portal
1267
        )
1268
1269
        dt = trades_1[-2].dt
1270
        pt.sync_last_sale_prices(dt)
1271
1272
        pos_stats = pt.stats()
1273
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1274
        # Validate that the account attributes were updated.
1275
        account = pp.as_account(pos_stats, pp_stats)
1276
        check_account(account,
1277
                      settled_cash=1000.0,
1278
                      equity_with_loan=1000.0,
1279
                      total_positions_value=0.0,
1280
                      regt_equity=1000.0,
1281
                      available_funds=1000.0,
1282
                      excess_liquidity=1000.0,
1283
                      cushion=1.0,
1284
                      leverage=2.0,
1285
                      net_leverage=0.0,
1286
                      net_liquidation=1000.0)
1287
1288
        # Validate that the account attributes were updated.
1289
        dt = trades_1[-1].dt
1290
        pt.sync_last_sale_prices(dt)
1291
        pos_stats = pt.stats()
1292
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1293
        account = pp.as_account(pos_stats, pp_stats)
1294
1295
        check_perf_period(
1296
            pp,
1297
            pt,
1298
            gross_leverage=2.5,
1299
            net_leverage=-0.25,
1300
            long_exposure=900.0,
1301
            longs_count=1,
1302
            short_exposure=-1100.0,
1303
            shorts_count=1,
1304
            data_portal=data_portal
1305
        )
1306
1307
        check_account(account,
1308
                      settled_cash=1000.0,
1309
                      equity_with_loan=800.0,
1310
                      total_positions_value=-200.0,
1311
                      regt_equity=1000.0,
1312
                      available_funds=1000.0,
1313
                      excess_liquidity=1000.0,
1314
                      cushion=1.25,
1315
                      leverage=2.5,
1316
                      net_leverage=-0.25,
1317
                      net_liquidation=800.0)
1318
1319
    def test_levered_long_position(self):
1320
        """
1321
            start with $1,000, then buy 1000 shares at $10.
1322
            price goes to $11
1323
        """
1324
        # post some trades in the market
1325
1326
        self.create_environment_stuff()
1327
1328
        trades = factory.create_trade_history(
1329
            1,
1330
            [10, 10, 10, 11],
1331
            [100, 100, 100, 100],
1332
            oneday,
1333
            self.sim_params,
1334
            env=self.env
1335
        )
1336
1337
        data_portal = create_data_portal_from_trade_history(
1338
            self.env,
1339
            self.tempdir,
1340
            self.sim_params,
1341
            {1: trades})
1342
1343
        txn = create_txn(trades[1].sid, trades[1].dt, 10.0, 1000)
1344
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
1345
                                  self.sim_params.data_frequency)
1346
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1347
                                    self.sim_params.data_frequency,
1348
                                    data_portal)
1349
1350
        pt.execute_transaction(txn)
1351
        pp.handle_execution(txn)
1352
1353
        check_perf_period(
1354
            pp,
1355
            pt,
1356
            gross_leverage=10.0,
1357
            net_leverage=10.0,
1358
            long_exposure=10000.0,
1359
            longs_count=1,
1360
            short_exposure=0.0,
1361
            shorts_count=0,
1362
            data_portal=data_portal
1363
        )
1364
1365
        # Validate that the account attributes were updated.
1366
        pt.sync_last_sale_prices(trades[-2].dt)
1367
        pos_stats = pt.stats()
1368
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1369
        account = pp.as_account(pos_stats, pp_stats)
1370
        check_account(account,
1371
                      settled_cash=-9000.0,
1372
                      equity_with_loan=1000.0,
1373
                      total_positions_value=10000.0,
1374
                      regt_equity=-9000.0,
1375
                      available_funds=-9000.0,
1376
                      excess_liquidity=-9000.0,
1377
                      cushion=-9.0,
1378
                      leverage=10.0,
1379
                      net_leverage=10.0,
1380
                      net_liquidation=1000.0)
1381
1382
        # now simulate a price jump to $11
1383
        pt.sync_last_sale_prices(trades[-1].dt)
1384
1385
        check_perf_period(
1386
            pp,
1387
            pt,
1388
            gross_leverage=5.5,
1389
            net_leverage=5.5,
1390
            long_exposure=11000.0,
1391
            longs_count=1,
1392
            short_exposure=0.0,
1393
            shorts_count=0,
1394
            data_portal=data_portal
1395
        )
1396
1397
        # Validate that the account attributes were updated.
1398
        pos_stats = pt.stats()
1399
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1400
        account = pp.as_account(pos_stats, pp_stats)
1401
1402
        check_account(account,
1403
                      settled_cash=-9000.0,
1404
                      equity_with_loan=2000.0,
1405
                      total_positions_value=11000.0,
1406
                      regt_equity=-9000.0,
1407
                      available_funds=-9000.0,
1408
                      excess_liquidity=-9000.0,
1409
                      cushion=-4.5,
1410
                      leverage=5.5,
1411
                      net_leverage=5.5,
1412
                      net_liquidation=2000.0)
1413
1414
    def test_long_position(self):
1415
        """
1416
            verify that the performance period calculates properly for a
1417
            single buy transaction
1418
        """
1419
        self.create_environment_stuff()
1420
1421
        # post some trades in the market
1422
        trades = factory.create_trade_history(
1423
            1,
1424
            [10, 10, 10, 11],
1425
            [100, 100, 100, 100],
1426
            oneday,
1427
            self.sim_params,
1428
            env=self.env
1429
        )
1430
1431
        data_portal = create_data_portal_from_trade_history(
1432
            self.env,
1433
            self.tempdir,
1434
            self.sim_params,
1435
            {1: trades})
1436
1437
        txn = create_txn(trades[1].sid, trades[1].dt, 10.0, 100)
1438
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
1439
                                  self.sim_params.data_frequency)
1440
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1441
                                    self.sim_params.data_frequency,
1442
                                    period_open=self.sim_params.period_start,
1443
                                    period_close=self.sim_params.period_end)
1444
        pt.execute_transaction(txn)
1445
        pp.handle_execution(txn)
1446
1447
        # This verifies that the last sale price is being correctly
1448
        # set in the positions. If this is not the case then returns can
1449
        # incorrectly show as sharply dipping if a transaction arrives
1450
        # before a trade. This is caused by returns being based on holding
1451
        # stocks with a last sale price of 0.
1452
        self.assertEqual(pt.positions[1].last_sale_price, 10.0)
1453
1454
        pt.sync_last_sale_prices(trades[-1].dt)
1455
1456
        self.assertEqual(
1457
            pp.period_cash_flow,
1458
            -1 * txn.price * txn.amount,
1459
            "capital used should be equal to the opposite of the transaction \
1460
            cost of sole txn in test"
1461
        )
1462
1463
        self.assertEqual(len(pt.positions), 1, "should be just one position")
1464
1465
        self.assertEqual(
1466
            pt.positions[1].sid,
1467
            txn.sid,
1468
            "position should be in security with id 1")
1469
1470
        self.assertEqual(
1471
            pt.positions[1].amount,
1472
            txn.amount,
1473
            "should have a position of {sharecount} shares".format(
1474
                sharecount=txn.amount
1475
            )
1476
        )
1477
1478
        self.assertEqual(
1479
            pt.positions[1].cost_basis,
1480
            txn.price,
1481
            "should have a cost basis of 10"
1482
        )
1483
1484
        self.assertEqual(
1485
            pt.positions[1].last_sale_price,
1486
            trades[-1]['price'],
1487
            "last sale should be same as last trade. \
1488
            expected {exp} actual {act}".format(
1489
                exp=trades[-1]['price'],
1490
                act=pt.positions[1].last_sale_price)
1491
        )
1492
1493
        pos_stats = pt.stats()
1494
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1495
1496
        self.assertEqual(
1497
            pos_stats.net_value,
1498
            1100,
1499
            "ending value should be price of last trade times number of \
1500
            shares in position"
1501
        )
1502
1503
        self.assertEqual(pp_stats.pnl, 100,
1504
                         "gain of 1 on 100 shares should be 100")
1505
1506
        check_perf_period(
1507
            pp,
1508
            pt,
1509
            gross_leverage=1.0,
1510
            net_leverage=1.0,
1511
            long_exposure=1100.0,
1512
            longs_count=1,
1513
            short_exposure=0.0,
1514
            shorts_count=0,
1515
            data_portal=data_portal
1516
        )
1517
1518
        # Validate that the account attributes were updated.
1519
        account = pp.as_account(pos_stats, pp_stats)
1520
        check_account(account,
1521
                      settled_cash=0.0,
1522
                      equity_with_loan=1100.0,
1523
                      total_positions_value=1100.0,
1524
                      regt_equity=0.0,
1525
                      available_funds=0.0,
1526
                      excess_liquidity=0.0,
1527
                      cushion=0.0,
1528
                      leverage=1.0,
1529
                      net_leverage=1.0,
1530
                      net_liquidation=1100.0)
1531
1532
    def test_short_position(self):
1533
        """verify that the performance period calculates properly for a \
1534
single short-sale transaction"""
1535
        self.create_environment_stuff(num_days=6)
1536
1537
        trades = factory.create_trade_history(
1538
            1,
1539
            [10, 10, 10, 11, 10, 9],
1540
            [100, 100, 100, 100, 100, 100],
1541
            oneday,
1542
            self.sim_params,
1543
            env=self.env
1544
        )
1545
1546
        trades_1 = trades[:-2]
1547
1548
        data_portal = create_data_portal_from_trade_history(
1549
            self.env,
1550
            self.tempdir,
1551
            self.sim_params,
1552
            {1: trades})
1553
1554
        txn = create_txn(trades[1].sid, trades[1].dt, 10.0, -100)
1555
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
1556
                                  self.sim_params.data_frequency)
1557
        pp = perf.PerformancePeriod(
1558
            1000.0, self.env.asset_finder,
1559
            self.sim_params.data_frequency,
1560
            data_portal)
1561
1562
        pt.execute_transaction(txn)
1563
        pp.handle_execution(txn)
1564
1565
        pt.sync_last_sale_prices(trades_1[-1].dt)
1566
1567
        self.assertEqual(
1568
            pp.period_cash_flow,
1569
            -1 * txn.price * txn.amount,
1570
            "capital used should be equal to the opposite of the transaction\
1571
             cost of sole txn in test"
1572
        )
1573
1574
        self.assertEqual(
1575
            len(pt.positions),
1576
            1,
1577
            "should be just one position")
1578
1579
        self.assertEqual(
1580
            pt.positions[1].sid,
1581
            txn.sid,
1582
            "position should be in security from the transaction"
1583
        )
1584
1585
        self.assertEqual(
1586
            pt.positions[1].amount,
1587
            -100,
1588
            "should have a position of -100 shares"
1589
        )
1590
1591
        self.assertEqual(
1592
            pt.positions[1].cost_basis,
1593
            txn.price,
1594
            "should have a cost basis of 10"
1595
        )
1596
1597
        self.assertEqual(
1598
            pt.positions[1].last_sale_price,
1599
            trades_1[-1]['price'],
1600
            "last sale should be price of last trade"
1601
        )
1602
1603
        pos_stats = pt.stats()
1604
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1605
1606
        self.assertEqual(
1607
            pos_stats.net_value,
1608
            -1100,
1609
            "ending value should be price of last trade times number of \
1610
            shares in position"
1611
        )
1612
1613
        self.assertEqual(pp_stats.pnl, -100,
1614
                         "gain of 1 on 100 shares should be 100")
1615
1616
        # simulate additional trades, and ensure that the position value
1617
        # reflects the new price
1618
        trades_2 = trades[-2:]
1619
1620
        # simulate a rollover to a new period
1621
        pp.rollover(pos_stats, pp_stats)
1622
1623
        pt.sync_last_sale_prices(trades[-1].dt)
1624
1625
        self.assertEqual(
1626
            pp.period_cash_flow,
1627
            0,
1628
            "capital used should be zero, there were no transactions in \
1629
            performance period"
1630
        )
1631
1632
        self.assertEqual(
1633
            len(pt.positions),
1634
            1,
1635
            "should be just one position"
1636
        )
1637
1638
        self.assertEqual(
1639
            pt.positions[1].sid,
1640
            txn.sid,
1641
            "position should be in security from the transaction"
1642
        )
1643
1644
        self.assertEqual(
1645
            pt.positions[1].amount,
1646
            -100,
1647
            "should have a position of -100 shares"
1648
        )
1649
1650
        self.assertEqual(
1651
            pt.positions[1].cost_basis,
1652
            txn.price,
1653
            "should have a cost basis of 10"
1654
        )
1655
1656
        self.assertEqual(
1657
            pt.positions[1].last_sale_price,
1658
            trades_2[-1].price,
1659
            "last sale should be price of last trade"
1660
        )
1661
1662
        pos_stats = pt.stats()
1663
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1664
1665
        self.assertEqual(
1666
            pos_stats.net_value,
1667
            -900,
1668
            "ending value should be price of last trade times number of \
1669
            shares in position")
1670
1671
        self.assertEqual(
1672
            pp_stats.pnl,
1673
            200,
1674
            "drop of 2 on -100 shares should be 200"
1675
        )
1676
1677
        # now run a performance period encompassing the entire trade sample.
1678
        ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
1679
                                       self.sim_params.data_frequency)
1680
        ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1681
                                         self.sim_params.data_frequency,
1682
                                         data_portal)
1683
1684
        ptTotal.execute_transaction(txn)
1685
        ppTotal.handle_execution(txn)
1686
1687
        ptTotal.sync_last_sale_prices(trades[-1].dt)
1688
1689
        self.assertEqual(
1690
            ppTotal.period_cash_flow,
1691
            -1 * txn.price * txn.amount,
1692
            "capital used should be equal to the opposite of the transaction \
1693
cost of sole txn in test"
1694
        )
1695
1696
        self.assertEqual(
1697
            len(ptTotal.positions),
1698
            1,
1699
            "should be just one position"
1700
        )
1701
        self.assertEqual(
1702
            ptTotal.positions[1].sid,
1703
            txn.sid,
1704
            "position should be in security from the transaction"
1705
        )
1706
1707
        self.assertEqual(
1708
            ptTotal.positions[1].amount,
1709
            -100,
1710
            "should have a position of -100 shares"
1711
        )
1712
1713
        self.assertEqual(
1714
            ptTotal.positions[1].cost_basis,
1715
            txn.price,
1716
            "should have a cost basis of 10"
1717
        )
1718
1719
        self.assertEqual(
1720
            ptTotal.positions[1].last_sale_price,
1721
            trades_2[-1].price,
1722
            "last sale should be price of last trade"
1723
        )
1724
1725
        pos_total_stats = ptTotal.stats()
1726
        pp_total_stats = ppTotal.stats(ptTotal.positions, pos_total_stats,
1727
                                       data_portal)
1728
1729
        self.assertEqual(
1730
            pos_total_stats.net_value,
1731
            -900,
1732
            "ending value should be price of last trade times number of \
1733
            shares in position")
1734
1735
        self.assertEqual(
1736
            pp_total_stats.pnl,
1737
            100,
1738
            "drop of 1 on -100 shares should be 100"
1739
        )
1740
1741
        check_perf_period(
1742
            pp,
1743
            pt,
1744
            gross_leverage=0.8181,
1745
            net_leverage=-0.8181,
1746
            long_exposure=0.0,
1747
            longs_count=0,
1748
            short_exposure=-900.0,
1749
            shorts_count=1,
1750
            data_portal=data_portal
1751
        )
1752
1753
        # Validate that the account attributes.
1754
        account = ppTotal.as_account(pos_stats, pp_stats)
1755
        check_account(account,
1756
                      settled_cash=2000.0,
1757
                      equity_with_loan=1100.0,
1758
                      total_positions_value=-900.0,
1759
                      regt_equity=2000.0,
1760
                      available_funds=2000.0,
1761
                      excess_liquidity=2000.0,
1762
                      cushion=1.8181,
1763
                      leverage=0.8181,
1764
                      net_leverage=-0.8181,
1765
                      net_liquidation=1100.0)
1766
1767
    def test_covering_short(self):
1768
        """verify performance where short is bought and covered, and shares \
1769
trade after cover"""
1770
        self.create_environment_stuff(num_days=10)
1771
1772
        trades = factory.create_trade_history(
1773
            1,
1774
            [10, 10, 10, 11, 9, 8, 7, 8, 9, 10],
1775
            [100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
1776
            onesec,
1777
            self.sim_params,
1778
            env=self.env
1779
        )
1780
1781
        data_portal = create_data_portal_from_trade_history(
1782
            self.env,
1783
            self.tempdir,
1784
            self.sim_params,
1785
            {1: trades})
1786
1787
        short_txn = create_txn(
1788
            trades[1].sid,
1789
            trades[1].dt,
1790
            10.0,
1791
            -100,
1792
        )
1793
        cover_txn = create_txn(trades[6].sid, trades[6].dt, 7.0, 100)
1794
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
1795
                                  self.sim_params.data_frequency)
1796
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1797
                                    self.sim_params.data_frequency,
1798
                                    data_portal)
1799
1800
        pt.execute_transaction(short_txn)
1801
        pp.handle_execution(short_txn)
1802
        pt.execute_transaction(cover_txn)
1803
        pp.handle_execution(cover_txn)
1804
1805
        pt.sync_last_sale_prices(trades[-1].dt)
1806
1807
        short_txn_cost = short_txn.price * short_txn.amount
1808
        cover_txn_cost = cover_txn.price * cover_txn.amount
1809
1810
        self.assertEqual(
1811
            pp.period_cash_flow,
1812
            -1 * short_txn_cost - cover_txn_cost,
1813
            "capital used should be equal to the net transaction costs"
1814
        )
1815
1816
        self.assertEqual(
1817
            len(pt.positions),
1818
            1,
1819
            "should be just one position"
1820
        )
1821
1822
        self.assertEqual(
1823
            pt.positions[1].sid,
1824
            short_txn.sid,
1825
            "position should be in security from the transaction"
1826
        )
1827
1828
        self.assertEqual(
1829
            pt.positions[1].amount,
1830
            0,
1831
            "should have a position of -100 shares"
1832
        )
1833
1834
        self.assertEqual(
1835
            pt.positions[1].cost_basis,
1836
            0,
1837
            "a covered position should have a cost basis of 0"
1838
        )
1839
1840
        self.assertEqual(
1841
            pt.positions[1].last_sale_price,
1842
            trades[-1].price,
1843
            "last sale should be price of last trade"
1844
        )
1845
1846
        pos_stats = pt.stats()
1847
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1848
1849
        self.assertEqual(
1850
            pos_stats.net_value,
1851
            0,
1852
            "ending value should be price of last trade times number of \
1853
shares in position"
1854
        )
1855
1856
        self.assertEqual(
1857
            pp_stats.pnl,
1858
            300,
1859
            "gain of 1 on 100 shares should be 300"
1860
        )
1861
1862
        check_perf_period(
1863
            pp,
1864
            pt,
1865
            gross_leverage=0.0,
1866
            net_leverage=0.0,
1867
            long_exposure=0.0,
1868
            longs_count=0,
1869
            short_exposure=0.0,
1870
            shorts_count=0,
1871
            data_portal=data_portal
1872
        )
1873
1874
        account = pp.as_account(pos_stats, pp_stats)
1875
        check_account(account,
1876
                      settled_cash=1300.0,
1877
                      equity_with_loan=1300.0,
1878
                      total_positions_value=0.0,
1879
                      regt_equity=1300.0,
1880
                      available_funds=1300.0,
1881
                      excess_liquidity=1300.0,
1882
                      cushion=1.0,
1883
                      leverage=0.0,
1884
                      net_leverage=0.0,
1885
                      net_liquidation=1300.0)
1886
1887
    def test_cost_basis_calc(self):
1888
        self.create_environment_stuff(num_days=5)
1889
1890
        history_args = (
1891
            1,
1892
            [10, 11, 11, 12, 10],
1893
            [100, 100, 100, 100, 100],
1894
            oneday,
1895
            self.sim_params,
1896
            self.env
1897
        )
1898
        trades = factory.create_trade_history(*history_args)
1899
        transactions = factory.create_txn_history(*history_args)[:4]
1900
1901
        data_portal = create_data_portal_from_trade_history(
1902
            self.env,
1903
            self.tempdir,
1904
            self.sim_params,
1905
            {1: trades})
1906
1907
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
1908
                                  self.sim_params.data_frequency)
1909
        pp = perf.PerformancePeriod(
1910
            1000.0,
1911
            self.env.asset_finder,
1912
            self.sim_params.data_frequency,
1913
            period_open=self.sim_params.period_start,
1914
            period_close=self.sim_params.trading_days[-1]
1915
        )
1916
        average_cost = 0
1917
        for i, txn in enumerate(transactions):
1918
            pt.execute_transaction(txn)
1919
            pp.handle_execution(txn)
1920
            average_cost = (average_cost * i + txn.price) / (i + 1)
1921
            self.assertEqual(pt.positions[1].cost_basis, average_cost)
1922
1923
        dt = trades[-2].dt
1924
        self.assertEqual(
1925
            pt.positions[1].last_sale_price,
1926
            trades[-2].price,
1927
            "should have a last sale of 12, got {val}".format(
1928
                val=pt.positions[1].last_sale_price)
1929
        )
1930
1931
        self.assertEqual(
1932
            pt.positions[1].cost_basis,
1933
            11,
1934
            "should have a cost basis of 11"
1935
        )
1936
1937
        pt.sync_last_sale_prices(dt)
1938
1939
        pos_stats = pt.stats()
1940
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1941
1942
        self.assertEqual(
1943
            pp_stats.pnl,
1944
            400
1945
        )
1946
1947
        down_tick = trades[-1]
1948
1949
        sale_txn = create_txn(
1950
            down_tick.sid,
1951
            down_tick.dt,
1952
            10.0,
1953
            -100)
1954
1955
        pp.rollover(pos_stats, pp_stats)
1956
1957
        pt.execute_transaction(sale_txn)
1958
        pp.handle_execution(sale_txn)
1959
1960
        dt = down_tick.dt
1961
        pt.sync_last_sale_prices(dt)
1962
1963
        self.assertEqual(
1964
            pt.positions[1].last_sale_price,
1965
            10,
1966
            "should have a last sale of 10, was {val}".format(
1967
                val=pt.positions[1].last_sale_price)
1968
        )
1969
1970
        self.assertEqual(
1971
            pt.positions[1].cost_basis,
1972
            11,
1973
            "should have a cost basis of 11"
1974
        )
1975
1976
        pos_stats = pt.stats()
1977
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1978
1979
        self.assertEqual(pp_stats.pnl, -800,
1980
                         "this period goes from +400 to -400")
1981
1982
        pt3 = perf.PositionTracker(self.env.asset_finder, data_portal,
1983
                                   self.sim_params.data_frequency)
1984
        pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1985
                                     self.sim_params.data_frequency)
1986
1987
        average_cost = 0
1988
        for i, txn in enumerate(transactions):
1989
            pt3.execute_transaction(txn)
1990
            pp3.handle_execution(txn)
1991
            average_cost = (average_cost * i + txn.price) / (i + 1)
1992
            self.assertEqual(pt3.positions[1].cost_basis, average_cost)
1993
1994
        pt3.execute_transaction(sale_txn)
1995
        pp3.handle_execution(sale_txn)
1996
1997
        trades.append(down_tick)
1998
        pt3.sync_last_sale_prices(trades[-1].dt)
1999
2000
        self.assertEqual(
2001
            pt3.positions[1].last_sale_price,
2002
            10,
2003
            "should have a last sale of 10"
2004
        )
2005
2006
        self.assertEqual(
2007
            pt3.positions[1].cost_basis,
2008
            11,
2009
            "should have a cost basis of 11"
2010
        )
2011
2012
        pt3.sync_last_sale_prices(dt)
2013
        pt3_stats = pt3.stats()
2014
        pp3_stats = pp3.stats(pt3.positions, pt3_stats, data_portal)
2015
2016
        self.assertEqual(
2017
            pp3_stats.pnl,
2018
            -400,
2019
            "should be -400 for all trades and transactions in period"
2020
        )
2021
2022
    def test_cost_basis_calc_close_pos(self):
2023
        self.create_environment_stuff(num_days=8)
2024
2025
        history_args = (
2026
            1,
2027
            [10, 9, 11, 8, 9, 12, 13, 14],
2028
            [200, -100, -100, 100, -300, 100, 500, 400],
2029
            onesec,
2030
            self.sim_params,
2031
            self.env
2032
        )
2033
        cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
2034
2035
        trades = factory.create_trade_history(*history_args)
2036
        transactions = factory.create_txn_history(*history_args)
2037
2038
        data_portal = create_data_portal_from_trade_history(
2039
            self.env,
2040
            self.tempdir,
2041
            self.sim_params,
2042
            {1: trades})
2043
2044
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
2045
                                  self.sim_params.data_frequency)
2046
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal,
2047
                                    self.sim_params.data_frequency)
2048
2049
        for txn, cb in zip(transactions, cost_bases):
2050
            pt.execute_transaction(txn)
2051
            pp.handle_execution(txn)
2052
            self.assertEqual(pt.positions[1].cost_basis, cb)
2053
2054
        self.assertEqual(pt.positions[1].cost_basis, cost_bases[-1])
2055
2056
2057
class TestPosition(unittest.TestCase):
2058
    def setUp(self):
2059
        pass
2060
2061
    def test_serialization(self):
2062
        dt = pd.Timestamp("1984/03/06 3:00PM")
2063
        pos = perf.Position(10, amount=np.float64(120.0), last_sale_date=dt,
2064
                            last_sale_price=3.4)
2065
2066
        p_string = dumps_with_persistent_ids(pos)
2067
2068
        test = loads_with_persistent_ids(p_string, env=None)
2069
        nt.assert_dict_equal(test.__dict__, pos.__dict__)
2070
2071
2072
class TestPositionTracker(unittest.TestCase):
2073
2074
    @classmethod
2075
    def setUpClass(cls):
2076
        cls.env = TradingEnvironment()
2077
        futures_metadata = {3: {'contract_multiplier': 1000},
2078
                            4: {'contract_multiplier': 1000}}
2079
        cls.env.write_data(equities_identifiers=[1, 2],
2080
                           futures_data=futures_metadata)
2081
2082
    @classmethod
2083
    def tearDownClass(cls):
2084
        del cls.env
2085
2086
    def setUp(self):
2087
        self.tempdir = TempDirectory()
2088
2089
    def tearDown(self):
2090
        self.tempdir.cleanup()
2091
2092
    def test_empty_positions(self):
2093
        """
2094
        make sure all the empty position stats return a numeric 0
2095
2096
        Originally this bug was due to np.dot([], []) returning
2097
        np.bool_(False)
2098
        """
2099
        sim_params = factory.create_simulation_parameters(
2100
            num_days=4, env=self.env
2101
        )
2102
        trades = factory.create_trade_history(
2103
            1,
2104
            [10, 10, 10, 11],
2105
            [100, 100, 100, 100],
2106
            oneday,
2107
            sim_params,
2108
            env=self.env
2109
        )
2110
2111
        data_portal = create_data_portal_from_trade_history(
2112
            self.env,
2113
            self.tempdir,
2114
            sim_params,
2115
            {1: trades})
2116
2117
        pt = perf.PositionTracker(self.env.asset_finder, data_portal,
2118
                                  sim_params.data_frequency)
2119
        pos_stats = pt.stats()
2120
2121
        stats = [
2122
            'net_value',
2123
            'net_exposure',
2124
            'gross_value',
2125
            'gross_exposure',
2126
            'short_value',
2127
            'short_exposure',
2128
            'shorts_count',
2129
            'long_value',
2130
            'long_exposure',
2131
            'longs_count',
2132
        ]
2133
        for name in stats:
2134
            val = getattr(pos_stats, name)
2135
            self.assertEquals(val, 0)
2136
            self.assertNotIsInstance(val, (bool, np.bool_))
2137
2138
    def test_position_values_and_exposures(self):
2139
        pt = perf.PositionTracker(self.env.asset_finder, None, None)
2140
        dt = pd.Timestamp("1984/03/06 3:00PM")
2141
        pos1 = perf.Position(1, amount=np.float64(10.0),
2142
                             last_sale_date=dt, last_sale_price=10)
2143
        pos2 = perf.Position(2, amount=np.float64(-20.0),
2144
                             last_sale_date=dt, last_sale_price=10)
2145
        pos3 = perf.Position(3, amount=np.float64(30.0),
2146
                             last_sale_date=dt, last_sale_price=10)
2147
        pos4 = perf.Position(4, amount=np.float64(-40.0),
2148
                             last_sale_date=dt, last_sale_price=10)
2149
        pt.update_positions({1: pos1, 2: pos2, 3: pos3, 4: pos4})
2150
2151
        # Test long-only methods
2152
2153
        pos_stats = pt.stats()
2154
        self.assertEqual(100, pos_stats.long_value)
2155
        self.assertEqual(100 + 300000, pos_stats.long_exposure)
2156
        self.assertEqual(2, pos_stats.longs_count)
2157
2158
        # Test short-only methods
2159
        self.assertEqual(-200, pos_stats.short_value)
2160
        self.assertEqual(-200 - 400000, pos_stats.short_exposure)
2161
        self.assertEqual(2, pos_stats.shorts_count)
2162
2163
        # Test gross and net values
2164
        self.assertEqual(100 + 200, pos_stats.gross_value)
2165
        self.assertEqual(100 - 200, pos_stats.net_value)
2166
2167
        # Test gross and net exposures
2168
        self.assertEqual(100 + 200 + 300000 + 400000, pos_stats.gross_exposure)
2169
        self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure)
2170
2171
    def test_serialization(self):
2172
        pt = perf.PositionTracker(self.env.asset_finder, None, None)
2173
        dt = pd.Timestamp("1984/03/06 3:00PM")
2174
        pos1 = perf.Position(1, amount=np.float64(120.0),
2175
                             last_sale_date=dt, last_sale_price=3.4)
2176
        pos3 = perf.Position(3, amount=np.float64(100.0),
2177
                             last_sale_date=dt, last_sale_price=3.4)
2178
2179
        pt.update_positions({1: pos1, 3: pos3})
2180
        p_string = dumps_with_persistent_ids(pt)
2181
        test = loads_with_persistent_ids(p_string, env=self.env)
2182
        nt.assert_count_equal(test.positions.keys(), pt.positions.keys())
2183
        for sid in pt.positions:
2184
            nt.assert_dict_equal(test.positions[sid].__dict__,
2185
                                 pt.positions[sid].__dict__)
2186
2187
2188
class TestPerformancePeriod(unittest.TestCase):
2189
2190
    def test_serialization(self):
2191
        env = TradingEnvironment()
2192
        pp = perf.PerformancePeriod(100, env.asset_finder, 'minute')
2193
2194
        p_string = dumps_with_persistent_ids(pp)
2195
        test = loads_with_persistent_ids(p_string, env=env)
2196
2197
        correct = pp.__dict__.copy()
2198
2199
        nt.assert_count_equal(test.__dict__.keys(), correct.keys())
2200
2201
        equal_keys = list(correct.keys())
2202
        equal_keys.remove('_account_store')
2203
        equal_keys.remove('_portfolio_store')
2204
2205
        for k in equal_keys:
2206
            nt.assert_equal(test.__dict__[k], correct[k])
2207