Completed
Pull Request — master (#858)
by Eddie
01:43
created

tests.calculate_results()   C

Complexity

Conditions 7

Size

Total Lines 73

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 7
dl 0
loc 73
rs 5.5062

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#
2
# Copyright 2013 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import division
17
18
from datetime import (
19
    datetime,
20
    timedelta,
21
)
22
import logging
23
24
from testfixtures import TempDirectory
25
import unittest
26
import nose.tools as nt
27
import pytz
28
29
import pandas as pd
30
import numpy as np
31
from six.moves import range, zip
32
33
from zipline.data.us_equity_pricing import (
34
    SQLiteAdjustmentWriter,
35
    SQLiteAdjustmentReader,
36
)
37
import zipline.utils.factory as factory
38
import zipline.finance.performance as perf
39
from zipline.finance.transaction import create_transaction
40
import zipline.utils.math_utils as zp_math
41
42
from zipline.finance.blotter import Order
43
from zipline.finance.commission import PerShare, PerTrade, PerDollar
44
from zipline.finance.trading import TradingEnvironment
45
from zipline.pipeline.loaders.synthetic import NullAdjustmentReader
46
from zipline.utils.factory import create_simulation_parameters
47
from zipline.utils.serialization_utils import (
48
    loads_with_persistent_ids, dumps_with_persistent_ids
49
)
50
import zipline.protocol as zp
51
from zipline.protocol import Event
52
from zipline.utils.test_utils import create_data_portal_from_trade_history
53
54
logger = logging.getLogger('Test Perf Tracking')
55
56
onesec = timedelta(seconds=1)
57
oneday = timedelta(days=1)
58
tradingday = timedelta(hours=6, minutes=30)
59
60
# nose.tools changed name in python 3
61
if not hasattr(nt, 'assert_count_equal'):
62
    nt.assert_count_equal = nt.assert_items_equal
63
64
65
def check_perf_period(pp,
66
                      pt,
67
                      gross_leverage,
68
                      net_leverage,
69
                      long_exposure,
70
                      longs_count,
71
                      short_exposure,
72
                      shorts_count,
73
                      data_portal):
74
75
    pos_stats = pt.stats()
76
    pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
77
    perf_data = pp.to_dict(pos_stats, pp_stats, pt)
78
    np.testing.assert_allclose(
79
        gross_leverage, perf_data['gross_leverage'], rtol=1e-3)
80
    np.testing.assert_allclose(
81
        net_leverage, perf_data['net_leverage'], rtol=1e-3)
82
    np.testing.assert_allclose(
83
        long_exposure, perf_data['long_exposure'], rtol=1e-3)
84
    np.testing.assert_allclose(
85
        longs_count, perf_data['longs_count'], rtol=1e-3)
86
    np.testing.assert_allclose(
87
        short_exposure, perf_data['short_exposure'], rtol=1e-3)
88
    np.testing.assert_allclose(
89
        shorts_count, perf_data['shorts_count'], rtol=1e-3)
90
91
92
def check_account(account,
93
                  settled_cash,
94
                  equity_with_loan,
95
                  total_positions_value,
96
                  regt_equity,
97
                  available_funds,
98
                  excess_liquidity,
99
                  cushion,
100
                  leverage,
101
                  net_leverage,
102
                  net_liquidation):
103
    # this is a long only portfolio that is only partially invested
104
    # so net and gross leverage are equal.
105
106
    np.testing.assert_allclose(settled_cash,
107
                               account['settled_cash'], rtol=1e-3)
108
    np.testing.assert_allclose(equity_with_loan,
109
                               account['equity_with_loan'], rtol=1e-3)
110
    np.testing.assert_allclose(total_positions_value,
111
                               account['total_positions_value'], rtol=1e-3)
112
    np.testing.assert_allclose(regt_equity,
113
                               account['regt_equity'], rtol=1e-3)
114
    np.testing.assert_allclose(available_funds,
115
                               account['available_funds'], rtol=1e-3)
116
    np.testing.assert_allclose(excess_liquidity,
117
                               account['excess_liquidity'], rtol=1e-3)
118
    np.testing.assert_allclose(cushion,
119
                               account['cushion'], rtol=1e-3)
120
    np.testing.assert_allclose(leverage, account['leverage'], rtol=1e-3)
121
    np.testing.assert_allclose(net_leverage,
122
                               account['net_leverage'], rtol=1e-3)
123
    np.testing.assert_allclose(net_liquidation,
124
                               account['net_liquidation'], rtol=1e-3)
125
126
127
def create_txn(sid, dt, price, amount):
128
    """
129
    Create a fake transaction to be filled and processed prior to the execution
130
    of a given trade event.
131
    """
132
    mock_order = Order(dt, sid, amount, id=None)
133
    return create_transaction(sid, dt, mock_order, price, amount)
134
135
136
def benchmark_events_in_range(sim_params, env):
137
    return [
138
        Event({'dt': dt,
139
               'returns': ret,
140
               'type': zp.DATASOURCE_TYPE.BENCHMARK,
141
               # We explicitly rely on the behavior that benchmarks sort before
142
               # any other events.
143
               'source_id': '1Abenchmarks'})
144
        for dt, ret in env.benchmark_returns.iteritems()
145
        if dt.date() >= sim_params.period_start.date() and
146
        dt.date() <= sim_params.period_end.date()
147
    ]
148
149
150
def calculate_results(sim_params,
151
                      env,
152
                      tempdir,
153
                      benchmark_events,
154
                      trade_events,
155
                      adjustment_reader,
156
                      splits=None,
157
                      txns=None,
158
                      commissions=None):
159
    """
160
    Run the given events through a stripped down version of the loop in
161
    AlgorithmSimulator.transform.
162
163
    IMPORTANT NOTE FOR TEST WRITERS/READERS:
164
165
    This loop has some wonky logic for the order of event processing for
166
    datasource types.  This exists mostly to accomodate legacy tests accomodate
167
    existing tests that were making assumptions about how events would be
168
    sorted.
169
170
    In particular:
171
172
        - Dividends passed for a given date are processed PRIOR to any events
173
          for that date.
174
        - Splits passed for a given date are process AFTER any events for that
175
          date.
176
177
    Tests that use this helper should not be considered useful guarantees of
178
    the behavior of AlgorithmSimulator on a stream containing the same events
179
    unless the subgroups have been explicitly re-sorted in this way.
180
    """
181
182
    txns = txns or []
183
    splits = splits or {}
184
    commissions = commissions or {}
185
186
    adjustment_reader = adjustment_reader or NullAdjustmentReader()
187
188
    data_portal = create_data_portal_from_trade_history(
189
        env,
190
        tempdir,
191
        sim_params,
192
        trade_events,
193
    )
194
    data_portal._adjustment_reader = adjustment_reader
195
196
    perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal)
197
198
    results = []
199
200
    for date in sim_params.trading_days:
201
202
        for txn in filter(lambda txn: txn.dt == date, txns):
203
            # Process txns for this date.
204
            perf_tracker.process_transaction(txn)
205
206
        try:
207
            commissions_for_date = commissions[date]
208
            for comm in commissions_for_date:
209
                perf_tracker.process_commission(comm)
210
        except KeyError:
211
            pass
212
213
        try:
214
            splits_for_date = splits[date]
215
            perf_tracker.handle_splits(splits_for_date)
216
        except KeyError:
217
            pass
218
219
        msg = perf_tracker.handle_market_close_daily(date)
220
        msg['account'] = perf_tracker.get_account(date)
221
        results.append(msg)
222
    return results
223
224
225
def check_perf_tracker_serialization(perf_tracker):
226
    scalar_keys = [
227
        'emission_rate',
228
        'txn_count',
229
        'market_open',
230
        'last_close',
231
        'period_start',
232
        'day_count',
233
        'capital_base',
234
        'market_close',
235
        'saved_dt',
236
        'period_end',
237
        'total_days',
238
    ]
239
    p_string = dumps_with_persistent_ids(perf_tracker)
240
241
    test = loads_with_persistent_ids(p_string, env=perf_tracker.env)
242
243
    for k in scalar_keys:
244
        nt.assert_equal(getattr(test, k), getattr(perf_tracker, k), k)
245
246
247
def setup_env_data(env, sim_params, sids):
248
    data = {}
249
    for sid in sids:
250
        data[sid] = {
251
            "start_date": sim_params.trading_days[0],
252
            "end_date": sim_params.trading_days[-1]
253
        }
254
255
    env.write_data(equities_data=data)
256
257
258
class TestSplitPerformance(unittest.TestCase):
259
    @classmethod
260
    def setUpClass(cls):
261
        cls.env = TradingEnvironment()
262
        cls.sim_params = create_simulation_parameters(num_days=2,
263
                                                      capital_base=10e3)
264
265
        setup_env_data(cls.env, cls.sim_params, [1])
266
267
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
268
                                                         cls.env)
269
        cls.tempdir = TempDirectory()
270
271
    @classmethod
272
    def tearDownClass(cls):
273
        cls.tempdir.cleanup()
274
275
    def test_split_long_position(self):
276
        events = factory.create_trade_history(
277
            1,
278
            # TODO: Should we provide adjusted prices in the tests, or provide
279
            # raw prices and adjust via DataPortal?
280
            [20, 60],
281
            [100, 100],
282
            oneday,
283
            self.sim_params,
284
            env=self.env
285
        )
286
287
        # set up a long position in sid 1
288
        # 100 shares at $20 apiece = $2000 position
289
        txns = [create_txn(events[0].sid, events[0].dt, 20, 100)]
290
291
        # set up a split with ratio 3 occurring at the start of the second
292
        # day.
293
        splits = {
294
            events[1].dt: [(1, 3)]
295
        }
296
297
        results = calculate_results(self.sim_params, self.env,
298
                                    self.tempdir,
299
                                    self.benchmark_events,
300
                                    {1: events},
301
                                    NullAdjustmentReader(),
302
                                    txns=txns, splits=splits)
303
304
        # should have 33 shares (at $60 apiece) and $20 in cash
305
        self.assertEqual(2, len(results))
306
307
        latest_positions = results[1]['daily_perf']['positions']
308
        self.assertEqual(1, len(latest_positions))
309
310
        # check the last position to make sure it's been updated
311
        position = latest_positions[0]
312
313
        self.assertEqual(1, position['sid'])
314
        self.assertEqual(33, position['amount'])
315
        self.assertEqual(60, position['cost_basis'])
316
        self.assertEqual(60, position['last_sale_price'])
317
318
        # since we started with $10000, and we spent $2000 on the
319
        # position, but then got $20 back, we should have $8020
320
        # (or close to it) in cash.
321
322
        # we won't get exactly 8020 because sometimes a split is
323
        # denoted as a ratio like 0.3333, and we lose some digits
324
        # of precision.  thus, make sure we're pretty close.
325
        daily_perf = results[1]['daily_perf']
326
327
        self.assertTrue(
328
            zp_math.tolerant_equals(8020,
329
                                    daily_perf['ending_cash'], 1),
330
            "ending_cash was {0}".format(daily_perf['ending_cash']))
331
332
        # Validate that the account attributes were updated.
333
        account = results[1]['account']
334
        self.assertEqual(float('inf'), account['day_trades_remaining'])
335
        # this is a long only portfolio that is only partially invested
336
        # so net and gross leverage are equal.
337
        np.testing.assert_allclose(0.198, account['leverage'], rtol=1e-3)
338
        np.testing.assert_allclose(0.198, account['net_leverage'], rtol=1e-3)
339
        np.testing.assert_allclose(8020, account['regt_equity'], rtol=1e-3)
340
        self.assertEqual(float('inf'), account['regt_margin'])
341
        np.testing.assert_allclose(8020, account['available_funds'], rtol=1e-3)
342
        self.assertEqual(0, account['maintenance_margin_requirement'])
343
        np.testing.assert_allclose(10000,
344
                                   account['equity_with_loan'], rtol=1e-3)
345
        self.assertEqual(float('inf'), account['buying_power'])
346
        self.assertEqual(0, account['initial_margin_requirement'])
347
        np.testing.assert_allclose(8020, account['excess_liquidity'],
348
                                   rtol=1e-3)
349
        np.testing.assert_allclose(8020, account['settled_cash'], rtol=1e-3)
350
        np.testing.assert_allclose(10000, account['net_liquidation'],
351
                                   rtol=1e-3)
352
        np.testing.assert_allclose(0.802, account['cushion'], rtol=1e-3)
353
        np.testing.assert_allclose(1980, account['total_positions_value'],
354
                                   rtol=1e-3)
355
        self.assertEqual(0, account['accrued_interest'])
356
357
        for i, result in enumerate(results):
358
            for perf_kind in ('daily_perf', 'cumulative_perf'):
359
                perf_result = result[perf_kind]
360
                # prices aren't changing, so pnl and returns should be 0.0
361
                self.assertEqual(0.0, perf_result['pnl'],
362
                                 "day %s %s pnl %s instead of 0.0" %
363
                                 (i, perf_kind, perf_result['pnl']))
364
                self.assertEqual(0.0, perf_result['returns'],
365
                                 "day %s %s returns %s instead of 0.0" %
366
                                 (i, perf_kind, perf_result['returns']))
367
368
369
class TestCommissionEvents(unittest.TestCase):
370
    @classmethod
371
    def setUpClass(cls):
372
        cls.env = TradingEnvironment()
373
        cls.sim_params = create_simulation_parameters(num_days=5,
374
                                                      capital_base=10e3)
375
        setup_env_data(cls.env, cls.sim_params, [0, 1, 133])
376
377
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
378
                                                         cls.env)
379
        cls.tempdir = TempDirectory()
380
381
    @classmethod
382
    def tearDownClass(cls):
383
        cls.tempdir.cleanup()
384
385
    def test_commission_event(self):
386
        trade_events = factory.create_trade_history(
387
            1,
388
            [10, 10, 10, 10, 10],
389
            [100, 100, 100, 100, 100],
390
            oneday,
391
            self.sim_params,
392
            env=self.env
393
        )
394
395
        # Test commission models and validate result
396
        # Expected commission amounts:
397
        # PerShare commission:  1.00, 1.00, 1.50 = $3.50
398
        # PerTrade commission:  5.00, 5.00, 5.00 = $15.00
399
        # PerDollar commission: 1.50, 3.00, 4.50 = $9.00
400
        # Total commission = $3.50 + $15.00 + $9.00 = $27.50
401
402
        # Create 3 transactions:  50, 100, 150 shares traded @ $20
403
        first_trade = trade_events[0]
404
        transactions = [create_txn(first_trade.sid, first_trade.dt, 20, i)
405
                        for i in [50, 100, 150]]
406
407
        # Create commission models and validate that produce expected
408
        # commissions.
409
        models = [PerShare(cost=0.01, min_trade_cost=1.00),
410
                  PerTrade(cost=5.00),
411
                  PerDollar(cost=0.0015)]
412
        expected_results = [3.50, 15.0, 9.0]
413
414
        for model, expected in zip(models, expected_results):
415
            total_commission = 0
416
            for trade in transactions:
417
                total_commission += model.calculate(trade)[1]
418
            self.assertEqual(total_commission, expected)
419
420
        # Verify that commission events are handled correctly by
421
        # PerformanceTracker.
422
        commissions = {}
423
        cash_adj_dt = trade_events[0].dt
424
        cash_adjustment = factory.create_commission(1, 300.0, cash_adj_dt)
425
        commissions[cash_adj_dt] = [cash_adjustment]
426
427
        # Insert a purchase order.
428
        txns = [create_txn(1, cash_adj_dt, 20, 1)]
429
        results = calculate_results(self.sim_params,
430
                                    self.env,
431
                                    self.tempdir,
432
                                    self.benchmark_events,
433
                                    {1: trade_events},
434
                                    NullAdjustmentReader(),
435
                                    txns=txns,
436
                                    commissions=commissions)
437
438
        # Validate that we lost 320 dollars from our cash pool.
439
        self.assertEqual(results[-1]['cumulative_perf']['ending_cash'],
440
                         9680, "Should have lost 320 from cash pool.")
441
        # Validate that the cost basis of our position changed.
442
        self.assertEqual(results[-1]['daily_perf']['positions']
443
                         [0]['cost_basis'], 320.0)
444
        # Validate that the account attributes were updated.
445
        account = results[1]['account']
446
        self.assertEqual(float('inf'), account['day_trades_remaining'])
447
        np.testing.assert_allclose(0.001, account['leverage'], rtol=1e-3,
448
                                   atol=1e-4)
449
        np.testing.assert_allclose(9680, account['regt_equity'], rtol=1e-3)
450
        self.assertEqual(float('inf'), account['regt_margin'])
451
        np.testing.assert_allclose(9680, account['available_funds'],
452
                                   rtol=1e-3)
453
        self.assertEqual(0, account['maintenance_margin_requirement'])
454
        np.testing.assert_allclose(9690,
455
                                   account['equity_with_loan'], rtol=1e-3)
456
        self.assertEqual(float('inf'), account['buying_power'])
457
        self.assertEqual(0, account['initial_margin_requirement'])
458
        np.testing.assert_allclose(9680, account['excess_liquidity'],
459
                                   rtol=1e-3)
460
        np.testing.assert_allclose(9680, account['settled_cash'],
461
                                   rtol=1e-3)
462
        np.testing.assert_allclose(9690, account['net_liquidation'],
463
                                   rtol=1e-3)
464
        np.testing.assert_allclose(0.999, account['cushion'], rtol=1e-3)
465
        np.testing.assert_allclose(10, account['total_positions_value'],
466
                                   rtol=1e-3)
467
        self.assertEqual(0, account['accrued_interest'])
468
469
    def test_commission_zero_position(self):
470
        """
471
        Ensure no div-by-zero errors.
472
        """
473
        events = factory.create_trade_history(
474
            1,
475
            [10, 10, 10, 10, 10],
476
            [100, 100, 100, 100, 100],
477
            oneday,
478
            self.sim_params,
479
            env=self.env
480
        )
481
482
        # Buy and sell the same sid so that we have a zero position by the
483
        # time of events[3].
484
        txns = [
485
            create_txn(events[0].sid, events[0].dt, 20, 1),
486
            create_txn(events[1].sid, events[1].dt, 20, -1),
487
        ]
488
489
        # Add a cash adjustment at the time of event[3].
490
        cash_adj_dt = events[3].dt
491
        commissions = {}
492
        cash_adjustment = factory.create_commission(1, 300.0, cash_adj_dt)
493
        commissions[cash_adj_dt] = [cash_adjustment]
494
495
        results = calculate_results(self.sim_params,
496
                                    self.env,
497
                                    self.tempdir,
498
                                    self.benchmark_events,
499
                                    {1: events},
500
                                    NullAdjustmentReader(),
501
                                    txns=txns,
502
                                    commissions=commissions)
503
        # Validate that we lost 300 dollars from our cash pool.
504
        self.assertEqual(results[-1]['cumulative_perf']['ending_cash'],
505
                         9700)
506
507
    def test_commission_no_position(self):
508
        """
509
        Ensure no position-not-found or sid-not-found errors.
510
        """
511
        events = factory.create_trade_history(
512
            1,
513
            [10, 10, 10, 10, 10],
514
            [100, 100, 100, 100, 100],
515
            oneday,
516
            self.sim_params,
517
            env=self.env
518
        )
519
520
        # Add a cash adjustment at the time of event[3].
521
        cash_adj_dt = events[3].dt
522
        commissions = {}
523
        cash_adjustment = factory.create_commission(1, 300.0, cash_adj_dt)
524
        commissions[cash_adj_dt] = [cash_adjustment]
525
526
        results = calculate_results(self.sim_params,
527
                                    self.env,
528
                                    self.tempdir,
529
                                    self.benchmark_events,
530
                                    {1: events},
531
                                    NullAdjustmentReader(),
532
                                    commissions=commissions)
533
        # Validate that we lost 300 dollars from our cash pool.
534
        self.assertEqual(results[-1]['cumulative_perf']['ending_cash'],
535
                         9700)
536
537
538
class MockDailyBarSpotReader(object):
539
540
    def spot_price(self, sid, day, colname):
541
        return 100.0
542
543
544
class TestDividendPerformance(unittest.TestCase):
545
546
    @classmethod
547
    def setUpClass(cls):
548
        cls.env = TradingEnvironment()
549
        cls.sim_params = create_simulation_parameters(num_days=6,
550
                                                      capital_base=10e3)
551
552
        setup_env_data(cls.env, cls.sim_params, [1, 2])
553
554
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
555
                                                         cls.env)
556
557
    @classmethod
558
    def tearDownClass(cls):
559
        del cls.env
560
561
    def setUp(self):
562
        self.tempdir = TempDirectory()
563
564
    def tearDown(self):
565
        self.tempdir.cleanup()
566
567
    def test_market_hours_calculations(self):
568
        # DST in US/Eastern began on Sunday March 14, 2010
569
        before = datetime(2010, 3, 12, 14, 31, tzinfo=pytz.utc)
570
        after = factory.get_next_trading_dt(
571
            before,
572
            timedelta(days=1),
573
            self.env,
574
        )
575
        self.assertEqual(after.hour, 13)
576
577
    def test_long_position_receives_dividend(self):
578
        # post some trades in the market
579
        events = factory.create_trade_history(
580
            1,
581
            [10, 10, 10, 10, 10, 10],
582
            [100, 100, 100, 100, 100, 100],
583
            oneday,
584
            self.sim_params,
585
            env=self.env
586
        )
587
588
        dbpath = self.tempdir.getpath('adjustments.sqlite')
589
590
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
591
                                        MockDailyBarSpotReader())
592
        splits = mergers = pd.DataFrame(
593
            {
594
                # Hackery to make the dtypes correct on an empty frame.
595
                'effective_date': np.array([], dtype=int),
596
                'ratio': np.array([], dtype=float),
597
                'sid': np.array([], dtype=int),
598
            },
599
            index=pd.DatetimeIndex([], tz='UTC'),
600
            columns=['effective_date', 'ratio', 'sid'],
601
        )
602
        dividends = pd.DataFrame({
603
            'sid': np.array([1], dtype=np.uint32),
604
            'amount': np.array([10.00], dtype=np.float64),
605
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
606
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
607
            'record_date': np.array([events[1].dt], dtype='datetime64[ns]'),
608
            'pay_date': np.array([events[2].dt], dtype='datetime64[ns]'),
609
        })
610
        writer.write(splits, mergers, dividends)
611
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
612
613
        # Simulate a transaction being filled prior to the ex_date.
614
        txns = [create_txn(events[0].sid, events[0].dt, 10.0, 100)]
615
        results = calculate_results(
616
            self.sim_params,
617
            self.env,
618
            self.tempdir,
619
            self.benchmark_events,
620
            {1: events},
621
            adjustment_reader,
622
            txns=txns,
623
        )
624
625
        self.assertEqual(len(results), 6)
626
        cumulative_returns = \
627
            [event['cumulative_perf']['returns'] for event in results]
628
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.1, 0.1, 0.1, 0.1])
629
        daily_returns = [event['daily_perf']['returns']
630
                         for event in results]
631
        self.assertEqual(daily_returns, [0.0, 0.0, 0.10, 0.0, 0.0, 0.0])
632
        cash_flows = [event['daily_perf']['capital_used']
633
                      for event in results]
634
        self.assertEqual(cash_flows, [-1000, 0, 1000, 0, 0, 0])
635
        cumulative_cash_flows = \
636
            [event['cumulative_perf']['capital_used'] for event in results]
637
        self.assertEqual(cumulative_cash_flows, [-1000, -1000, 0, 0, 0, 0])
638
        cash_pos = \
639
            [event['cumulative_perf']['ending_cash'] for event in results]
640
        self.assertEqual(cash_pos, [9000, 9000, 10000, 10000, 10000, 10000])
641
642
    def test_long_position_receives_stock_dividend(self):
643
        # post some trades in the market
644
        events = {}
645
        for sid in (1, 2):
646
            events[sid] = factory.create_trade_history(
647
                sid,
648
                [10, 10, 10, 10, 10, 10],
649
                [100, 100, 100, 100, 100, 100],
650
                oneday,
651
                self.sim_params,
652
                env=self.env
653
            )
654
655
        dbpath = self.tempdir.getpath('adjustments.sqlite')
656
657
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
658
                                        MockDailyBarSpotReader())
659
        splits = mergers = pd.DataFrame(
660
            {
661
                # Hackery to make the dtypes correct on an empty frame.
662
                'effective_date': np.array([], dtype=int),
663
                'ratio': np.array([], dtype=float),
664
                'sid': np.array([], dtype=int),
665
            },
666
            index=pd.DatetimeIndex([], tz='UTC'),
667
            columns=['effective_date', 'ratio', 'sid'],
668
        )
669
        dividends = pd.DataFrame({
670
            'sid': np.array([], dtype=np.uint32),
671
            'amount': np.array([], dtype=np.float64),
672
            'declared_date': np.array([], dtype='datetime64[ns]'),
673
            'ex_date': np.array([], dtype='datetime64[ns]'),
674
            'pay_date': np.array([], dtype='datetime64[ns]'),
675
            'record_date': np.array([], dtype='datetime64[ns]'),
676
        })
677
        sid_1 = events[1]
678
        stock_dividends = pd.DataFrame({
679
            'sid': np.array([1], dtype=np.uint32),
680
            'payment_sid': np.array([2], dtype=np.uint32),
681
            'ratio': np.array([2], dtype=np.float64),
682
            'declared_date': np.array([sid_1[0].dt], dtype='datetime64[ns]'),
683
            'ex_date': np.array([sid_1[1].dt], dtype='datetime64[ns]'),
684
            'record_date': np.array([sid_1[1].dt], dtype='datetime64[ns]'),
685
            'pay_date': np.array([sid_1[2].dt], dtype='datetime64[ns]'),
686
        })
687
        writer.write(splits, mergers, dividends, stock_dividends)
688
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
689
690
        txns = [create_txn(events[1][0].sid, events[1][0].dt, 10.0, 100)]
691
692
        results = calculate_results(
693
            self.sim_params,
694
            self.env,
695
            self.tempdir,
696
            self.benchmark_events,
697
            events,
698
            adjustment_reader,
699
            txns=txns,
700
        )
701
702
        self.assertEqual(len(results), 6)
703
        cumulative_returns = \
704
            [event['cumulative_perf']['returns'] for event in results]
705
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.2, 0.2, 0.2, 0.2])
706
        daily_returns = [event['daily_perf']['returns']
707
                         for event in results]
708
        self.assertEqual(daily_returns, [0.0, 0.0, 0.2, 0.0, 0.0, 0.0])
709
        cash_flows = [event['daily_perf']['capital_used']
710
                      for event in results]
711
        self.assertEqual(cash_flows, [-1000, 0, 0, 0, 0, 0])
712
        cumulative_cash_flows = \
713
            [event['cumulative_perf']['capital_used'] for event in results]
714
        self.assertEqual(cumulative_cash_flows, [-1000] * 6)
715
        cash_pos = \
716
            [event['cumulative_perf']['ending_cash'] for event in results]
717
        self.assertEqual(cash_pos, [9000] * 6)
718
719
    def test_long_position_purchased_on_ex_date_receives_no_dividend(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
720
        # post some trades in the market
721
        events = factory.create_trade_history(
722
            1,
723
            [10, 10, 10, 10, 10, 10],
724
            [100, 100, 100, 100, 100, 100],
725
            oneday,
726
            self.sim_params,
727
            env=self.env
728
        )
729
730
        dbpath = self.tempdir.getpath('adjustments.sqlite')
731
732
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
733
                                        MockDailyBarSpotReader())
734
        splits = mergers = pd.DataFrame(
735
            {
736
                # Hackery to make the dtypes correct on an empty frame.
737
                'effective_date': np.array([], dtype=int),
738
                'ratio': np.array([], dtype=float),
739
                'sid': np.array([], dtype=int),
740
            },
741
            index=pd.DatetimeIndex([], tz='UTC'),
742
            columns=['effective_date', 'ratio', 'sid'],
743
        )
744
        dividends = pd.DataFrame({
745
            'sid': np.array([1], dtype=np.uint32),
746
            'amount': np.array([10.00], dtype=np.float64),
747
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
748
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
749
            'record_date': np.array([events[1].dt], dtype='datetime64[ns]'),
750
            'pay_date': np.array([events[2].dt], dtype='datetime64[ns]'),
751
        })
752
        writer.write(splits, mergers, dividends)
753
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
754
755
        # Simulate a transaction being filled on the ex_date.
756
        txns = [create_txn(events[1].sid, events[1].dt, 10.0, 100)]
757
758
        results = calculate_results(
759
            self.sim_params,
760
            self.env,
761
            self.tempdir,
762
            self.benchmark_events,
763
            {1: events},
764
            adjustment_reader,
765
            txns=txns,
766
        )
767
768
        self.assertEqual(len(results), 6)
769
        cumulative_returns = \
770
            [event['cumulative_perf']['returns'] for event in results]
771
        self.assertEqual(cumulative_returns, [0, 0, 0, 0, 0, 0])
772
        daily_returns = [event['daily_perf']['returns'] for event in results]
773
        self.assertEqual(daily_returns, [0, 0, 0, 0, 0, 0])
774
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
775
        self.assertEqual(cash_flows, [0, -1000, 0, 0, 0, 0])
776
        cumulative_cash_flows = \
777
            [event['cumulative_perf']['capital_used'] for event in results]
778
        self.assertEqual(cumulative_cash_flows,
779
                         [0, -1000, -1000, -1000, -1000, -1000])
780
781
    def test_selling_before_dividend_payment_still_gets_paid(self):
782
        # post some trades in the market
783
        events = factory.create_trade_history(
784
            1,
785
            [10, 10, 10, 10, 10, 10],
786
            [100, 100, 100, 100, 100, 100],
787
            oneday,
788
            self.sim_params,
789
            env=self.env
790
        )
791
792
        dbpath = self.tempdir.getpath('adjustments.sqlite')
793
794
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
795
                                        MockDailyBarSpotReader())
796
        splits = mergers = pd.DataFrame(
797
            {
798
                # Hackery to make the dtypes correct on an empty frame.
799
                'effective_date': np.array([], dtype=int),
800
                'ratio': np.array([], dtype=float),
801
                'sid': np.array([], dtype=int),
802
            },
803
            index=pd.DatetimeIndex([], tz='UTC'),
804
            columns=['effective_date', 'ratio', 'sid'],
805
        )
806
        dividends = pd.DataFrame({
807
            'sid': np.array([1], dtype=np.uint32),
808
            'amount': np.array([10.00], dtype=np.float64),
809
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
810
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
811
            'record_date': np.array([events[1].dt], dtype='datetime64[ns]'),
812
            'pay_date': np.array([events[3].dt], dtype='datetime64[ns]'),
813
        })
814
        writer.write(splits, mergers, dividends)
815
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
816
817
        buy_txn = create_txn(events[0].sid, events[0].dt, 10.0, 100)
818
        sell_txn = create_txn(events[2].sid, events[2].dt, 10.0, -100)
819
        txns = [buy_txn, sell_txn]
820
821
        results = calculate_results(
822
            self.sim_params,
823
            self.env,
824
            self.tempdir,
825
            self.benchmark_events,
826
            {1: events},
827
            adjustment_reader,
828
            txns=txns,
829
        )
830
831
        self.assertEqual(len(results), 6)
832
        cumulative_returns = \
833
            [event['cumulative_perf']['returns'] for event in results]
834
        self.assertEqual(cumulative_returns, [0, 0, 0, 0.1, 0.1, 0.1])
835
        daily_returns = [event['daily_perf']['returns'] for event in results]
836
        self.assertEqual(daily_returns, [0, 0, 0, 0.1, 0, 0])
837
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
838
        self.assertEqual(cash_flows, [-1000, 0, 1000, 1000, 0, 0])
839
        cumulative_cash_flows = \
840
            [event['cumulative_perf']['capital_used'] for event in results]
841
        self.assertEqual(cumulative_cash_flows,
842
                         [-1000, -1000, 0, 1000, 1000, 1000])
843
844
    def test_buy_and_sell_before_ex(self):
845
        # post some trades in the market
846
        events = factory.create_trade_history(
847
            1,
848
            [10, 10, 10, 10, 10, 10],
849
            [100, 100, 100, 100, 100, 100],
850
            oneday,
851
            self.sim_params,
852
            env=self.env
853
        )
854
        dbpath = self.tempdir.getpath('adjustments.sqlite')
855
856
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
857
                                        MockDailyBarSpotReader())
858
        splits = mergers = pd.DataFrame(
859
            {
860
                # Hackery to make the dtypes correct on an empty frame.
861
                'effective_date': np.array([], dtype=int),
862
                'ratio': np.array([], dtype=float),
863
                'sid': np.array([], dtype=int),
864
            },
865
            index=pd.DatetimeIndex([], tz='UTC'),
866
            columns=['effective_date', 'ratio', 'sid'],
867
        )
868
869
        dividends = pd.DataFrame({
870
            'sid': np.array([1], dtype=np.uint32),
871
            'amount': np.array([10.0], dtype=np.float64),
872
            'declared_date': np.array([events[3].dt], dtype='datetime64[ns]'),
873
            'ex_date': np.array([events[4].dt], dtype='datetime64[ns]'),
874
            'pay_date': np.array([events[5].dt], dtype='datetime64[ns]'),
875
            'record_date': np.array([events[4].dt], dtype='datetime64[ns]'),
876
        })
877
        writer.write(splits, mergers, dividends)
878
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
879
880
        buy_txn = create_txn(events[1].sid, events[1].dt, 10.0, 100)
881
        sell_txn = create_txn(events[2].sid, events[2].dt, 10.0, -100)
882
        txns = [buy_txn, sell_txn]
883
884
        results = calculate_results(
885
            self.sim_params,
886
            self.env,
887
            self.tempdir,
888
            self.benchmark_events,
889
            {1: events},
890
            txns=txns,
891
            adjustment_reader=adjustment_reader,
892
        )
893
894
        self.assertEqual(len(results), 6)
895
        cumulative_returns = \
896
            [event['cumulative_perf']['returns'] for event in results]
897
        self.assertEqual(cumulative_returns, [0, 0, 0, 0, 0, 0])
898
        daily_returns = [event['daily_perf']['returns'] for event in results]
899
        self.assertEqual(daily_returns, [0, 0, 0, 0, 0, 0])
900
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
901
        self.assertEqual(cash_flows, [0, -1000, 1000, 0, 0, 0])
902
        cumulative_cash_flows = \
903
            [event['cumulative_perf']['capital_used'] for event in results]
904
        self.assertEqual(cumulative_cash_flows, [0, -1000, 0, 0, 0, 0])
905
906
    def test_ending_before_pay_date(self):
907
        # post some trades in the market
908
        events = factory.create_trade_history(
909
            1,
910
            [10, 10, 10, 10, 10, 10],
911
            [100, 100, 100, 100, 100, 100],
912
            oneday,
913
            self.sim_params,
914
            env=self.env
915
        )
916
917
        pay_date = self.sim_params.first_open
918
        # find pay date that is much later.
919
        for i in range(30):
920
            pay_date = factory.get_next_trading_dt(pay_date, oneday, self.env)
921
922
        dbpath = self.tempdir.getpath('adjustments.sqlite')
923
924
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
925
                                        MockDailyBarSpotReader())
926
        splits = mergers = pd.DataFrame(
927
            {
928
                # Hackery to make the dtypes correct on an empty frame.
929
                'effective_date': np.array([], dtype=int),
930
                'ratio': np.array([], dtype=float),
931
                'sid': np.array([], dtype=int),
932
            },
933
            index=pd.DatetimeIndex([], tz='UTC'),
934
            columns=['effective_date', 'ratio', 'sid'],
935
        )
936
        dividends = pd.DataFrame({
937
            'sid': np.array([1], dtype=np.uint32),
938
            'amount': np.array([10.00], dtype=np.float64),
939
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
940
            'ex_date': np.array([events[0].dt], dtype='datetime64[ns]'),
941
            'record_date': np.array([events[0].dt], dtype='datetime64[ns]'),
942
            'pay_date': np.array([pay_date], dtype='datetime64[ns]'),
943
        })
944
        writer.write(splits, mergers, dividends)
945
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
946
947
        txns = [create_txn(events[1].sid, events[1].dt, 10.0, 100)]
948
949
        results = calculate_results(
950
            self.sim_params,
951
            self.env,
952
            self.tempdir,
953
            self.benchmark_events,
954
            {1: events},
955
            txns=txns,
956
            adjustment_reader=adjustment_reader,
957
        )
958
959
        self.assertEqual(len(results), 6)
960
        cumulative_returns = \
961
            [event['cumulative_perf']['returns'] for event in results]
962
        self.assertEqual(cumulative_returns, [0, 0, 0, 0.0, 0.0, 0.0])
963
        daily_returns = [event['daily_perf']['returns'] for event in results]
964
        self.assertEqual(daily_returns, [0, 0, 0, 0, 0, 0])
965
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
966
        self.assertEqual(cash_flows, [0, -1000, 0, 0, 0, 0])
967
        cumulative_cash_flows = \
968
            [event['cumulative_perf']['capital_used'] for event in results]
969
        self.assertEqual(
970
            cumulative_cash_flows,
971
            [0, -1000, -1000, -1000, -1000, -1000]
972
        )
973
974
    def test_short_position_pays_dividend(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
975
        # post some trades in the market
976
        events = factory.create_trade_history(
977
            1,
978
            [10, 10, 10, 10, 10, 10],
979
            [100, 100, 100, 100, 100, 100],
980
            oneday,
981
            self.sim_params,
982
            env=self.env
983
        )
984
985
        dbpath = self.tempdir.getpath('adjustments.sqlite')
986
987
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
988
                                        MockDailyBarSpotReader())
989
        splits = mergers = pd.DataFrame(
990
            {
991
                # Hackery to make the dtypes correct on an empty frame.
992
                'effective_date': np.array([], dtype=int),
993
                'ratio': np.array([], dtype=float),
994
                'sid': np.array([], dtype=int),
995
            },
996
            index=pd.DatetimeIndex([], tz='UTC'),
997
            columns=['effective_date', 'ratio', 'sid'],
998
        )
999
        dividends = pd.DataFrame({
1000
            'sid': np.array([1], dtype=np.uint32),
1001
            'amount': np.array([10.00], dtype=np.float64),
1002
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
1003
            'ex_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1004
            'record_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1005
            'pay_date': np.array([events[3].dt], dtype='datetime64[ns]'),
1006
        })
1007
        writer.write(splits, mergers, dividends)
1008
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
1009
1010
        txns = [create_txn(events[1].sid, events[1].dt, 10.0, -100)]
1011
1012
        results = calculate_results(
1013
            self.sim_params,
1014
            self.env,
1015
            self.tempdir,
1016
            self.benchmark_events,
1017
            {1: events},
1018
            adjustment_reader,
1019
            txns=txns,
1020
        )
1021
1022
        self.assertEqual(len(results), 6)
1023
        cumulative_returns = \
1024
            [event['cumulative_perf']['returns'] for event in results]
1025
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, -0.1, -0.1, -0.1])
1026
        daily_returns = [event['daily_perf']['returns'] for event in results]
1027
        self.assertEqual(daily_returns, [0.0, 0.0, 0.0, -0.1, 0.0, 0.0])
1028
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
1029
        self.assertEqual(cash_flows, [0, 1000, 0, -1000, 0, 0])
1030
        cumulative_cash_flows = \
1031
            [event['cumulative_perf']['capital_used'] for event in results]
1032
        self.assertEqual(cumulative_cash_flows, [0, 1000, 1000, 0, 0, 0])
1033
1034
    def test_no_position_receives_no_dividend(self):
1035
        # post some trades in the market
1036
        events = factory.create_trade_history(
1037
            1,
1038
            [10, 10, 10, 10, 10, 10],
1039
            [100, 100, 100, 100, 100, 100],
1040
            oneday,
1041
            self.sim_params,
1042
            env=self.env
1043
        )
1044
1045
        dbpath = self.tempdir.getpath('adjustments.sqlite')
1046
1047
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
1048
                                        MockDailyBarSpotReader())
1049
        splits = mergers = pd.DataFrame(
1050
            {
1051
                # Hackery to make the dtypes correct on an empty frame.
1052
                'effective_date': np.array([], dtype=int),
1053
                'ratio': np.array([], dtype=float),
1054
                'sid': np.array([], dtype=int),
1055
            },
1056
            index=pd.DatetimeIndex([], tz='UTC'),
1057
            columns=['effective_date', 'ratio', 'sid'],
1058
        )
1059
        dividends = pd.DataFrame({
1060
            'sid': np.array([1], dtype=np.uint32),
1061
            'amount': np.array([10.00], dtype=np.float64),
1062
            'declared_date': np.array([events[0].dt], dtype='datetime64[ns]'),
1063
            'ex_date': np.array([events[1].dt], dtype='datetime64[ns]'),
1064
            'pay_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1065
            'record_date': np.array([events[2].dt], dtype='datetime64[ns]'),
1066
        })
1067
        writer.write(splits, mergers, dividends)
1068
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
1069
1070
        results = calculate_results(
1071
            self.sim_params,
1072
            self.env,
1073
            self.tempdir,
1074
            self.benchmark_events,
1075
            {1: events},
1076
            adjustment_reader,
1077
        )
1078
1079
        self.assertEqual(len(results), 6)
1080
        cumulative_returns = \
1081
            [event['cumulative_perf']['returns'] for event in results]
1082
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
1083
        daily_returns = [event['daily_perf']['returns'] for event in results]
1084
        self.assertEqual(daily_returns, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
1085
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
1086
        self.assertEqual(cash_flows, [0, 0, 0, 0, 0, 0])
1087
        cumulative_cash_flows = \
1088
            [event['cumulative_perf']['capital_used'] for event in results]
1089
        self.assertEqual(cumulative_cash_flows, [0, 0, 0, 0, 0, 0])
1090
1091
    def test_no_dividend_at_simulation_end(self):
1092
        # post some trades in the market
1093
        events = factory.create_trade_history(
1094
            1,
1095
            [10, 10, 10, 10, 10],
1096
            [100, 100, 100, 100, 100],
1097
            oneday,
1098
            self.sim_params,
1099
            env=self.env
1100
        )
1101
1102
        dbpath = self.tempdir.getpath('adjustments.sqlite')
1103
1104
        writer = SQLiteAdjustmentWriter(dbpath, self.env.trading_days,
1105
                                        MockDailyBarSpotReader())
1106
        splits = mergers = pd.DataFrame(
1107
            {
1108
                # Hackery to make the dtypes correct on an empty frame.
1109
                'effective_date': np.array([], dtype=int),
1110
                'ratio': np.array([], dtype=float),
1111
                'sid': np.array([], dtype=int),
1112
            },
1113
            index=pd.DatetimeIndex([], tz='UTC'),
1114
            columns=['effective_date', 'ratio', 'sid'],
1115
        )
1116
        dividends = pd.DataFrame({
1117
            'sid': np.array([1], dtype=np.uint32),
1118
            'amount': np.array([10.00], dtype=np.float64),
1119
            'declared_date': np.array([events[-3].dt], dtype='datetime64[ns]'),
1120
            'ex_date': np.array([events[-2].dt], dtype='datetime64[ns]'),
1121
            'record_date': np.array([events[0].dt], dtype='datetime64[ns]'),
1122
            'pay_date': np.array([self.env.next_trading_day(events[-1].dt)],
1123
                                 dtype='datetime64[ns]'),
1124
        })
1125
        writer.write(splits, mergers, dividends)
1126
        adjustment_reader = SQLiteAdjustmentReader(dbpath)
1127
1128
        # Set the last day to be the last event
1129
        sim_params = create_simulation_parameters(
1130
            num_days=6,
1131
            capital_base=10e3,
1132
            start=self.sim_params.period_start,
1133
            end=self.sim_params.period_end
1134
        )
1135
1136
        sim_params.period_end = events[-1].dt
1137
        sim_params.update_internal_from_env(self.env)
1138
1139
        # Simulate a transaction being filled prior to the ex_date.
1140
        txns = [create_txn(events[0].sid, events[0].dt, 10.0, 100)]
1141
        results = calculate_results(
1142
            sim_params,
1143
            self.env,
1144
            self.tempdir,
1145
            self.benchmark_events,
1146
            {1: events},
1147
            adjustment_reader=adjustment_reader,
1148
            txns=txns,
1149
        )
1150
1151
        self.assertEqual(len(results), 5)
1152
        cumulative_returns = \
1153
            [event['cumulative_perf']['returns'] for event in results]
1154
        self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, 0.0, 0.0])
1155
        daily_returns = [event['daily_perf']['returns'] for event in results]
1156
        self.assertEqual(daily_returns, [0.0, 0.0, 0.0, 0.0, 0.0])
1157
        cash_flows = [event['daily_perf']['capital_used'] for event in results]
1158
        self.assertEqual(cash_flows, [-1000, 0, 0, 0, 0])
1159
        cumulative_cash_flows = \
1160
            [event['cumulative_perf']['capital_used'] for event in results]
1161
        self.assertEqual(cumulative_cash_flows,
1162
                         [-1000, -1000, -1000, -1000, -1000])
1163
1164
1165
class TestDividendPerformanceHolidayStyle(TestDividendPerformance):
1166
1167
    # The holiday tests begins the simulation on the day
1168
    # before Thanksgiving, so that the next trading day is
1169
    # two days ahead. Any tests that hard code events
1170
    # to be start + oneday will fail, since those events will
1171
    # be skipped by the simulation.
1172
1173
    @classmethod
1174
    def setUpClass(cls):
1175
        cls.env = TradingEnvironment()
1176
        cls.sim_params = create_simulation_parameters(
1177
            num_days=6,
1178
            capital_base=10e3,
1179
            start=pd.Timestamp("2003-11-30", tz='UTC'),
1180
            end=pd.Timestamp("2003-12-08", tz='UTC')
1181
        )
1182
1183
        setup_env_data(cls.env, cls.sim_params, [1, 2])
1184
1185
        cls.benchmark_events = benchmark_events_in_range(cls.sim_params,
1186
                                                         cls.env)
1187
1188
1189
class TestPositionPerformance(unittest.TestCase):
1190
1191
    def setUp(self):
1192
        self.tempdir = TempDirectory()
1193
1194
    def create_environment_stuff(self, num_days=4, sids=[1, 2]):
1195
        self.env = TradingEnvironment()
1196
        self.sim_params = create_simulation_parameters(num_days=num_days)
1197
1198
        setup_env_data(self.env, self.sim_params, [1, 2])
1199
1200
        self.finder = self.env.asset_finder
1201
1202
        self.benchmark_events = benchmark_events_in_range(self.sim_params,
1203
                                                          self.env)
1204
1205
    def tearDown(self):
1206
        self.tempdir.cleanup()
1207
        del self.env
1208
1209
    def test_long_short_positions(self):
1210
        """
1211
        start with $1000
1212
        buy 100 stock1 shares at $10
1213
        sell short 100 stock2 shares at $10
1214
        stock1 then goes down to $9
1215
        stock2 goes to $11
1216
        """
1217
        self.create_environment_stuff()
1218
1219
        trades_1 = factory.create_trade_history(
1220
            1,
1221
            [10, 10, 10, 9],
1222
            [100, 100, 100, 100],
1223
            oneday,
1224
            self.sim_params,
1225
            env=self.env
1226
        )
1227
1228
        trades_2 = factory.create_trade_history(
1229
            2,
1230
            [10, 10, 10, 11],
1231
            [100, 100, 100, 100],
1232
            onesec,
1233
            self.sim_params,
1234
            env=self.env
1235
        )
1236
1237
        txn1 = create_txn(trades_1[1].sid, trades_1[1].dt, 10.0, 100)
1238
        txn2 = create_txn(trades_2[1].sid, trades_1[1].dt, 10.0, -100)
1239
1240
        data_portal = create_data_portal_from_trade_history(
1241
            self.env,
1242
            self.tempdir,
1243
            self.sim_params,
1244
            {1: trades_1, 2: trades_2}
1245
        )
1246
1247
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
1248
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal)
1249
        pt.execute_transaction(txn1)
1250
        pp.handle_execution(txn1)
1251
        pt.execute_transaction(txn2)
1252
        pp.handle_execution(txn2)
1253
1254
        check_perf_period(
1255
            pp,
1256
            pt,
1257
            gross_leverage=2.0,
1258
            net_leverage=0.0,
1259
            long_exposure=1000.0,
1260
            longs_count=1,
1261
            short_exposure=-1000.0,
1262
            shorts_count=1,
1263
            data_portal=data_portal
1264
        )
1265
1266
        dt = trades_1[-2].dt
1267
        pt.sync_last_sale_prices(dt)
1268
1269
        pos_stats = pt.stats()
1270
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1271
        # Validate that the account attributes were updated.
1272
        account = pp.as_account(pos_stats, pp_stats)
1273
        check_account(account,
1274
                      settled_cash=1000.0,
1275
                      equity_with_loan=1000.0,
1276
                      total_positions_value=0.0,
1277
                      regt_equity=1000.0,
1278
                      available_funds=1000.0,
1279
                      excess_liquidity=1000.0,
1280
                      cushion=1.0,
1281
                      leverage=2.0,
1282
                      net_leverage=0.0,
1283
                      net_liquidation=1000.0)
1284
1285
        # Validate that the account attributes were updated.
1286
        dt = trades_1[-1].dt
1287
        pt.sync_last_sale_prices(dt)
1288
        pos_stats = pt.stats()
1289
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1290
        account = pp.as_account(pos_stats, pp_stats)
1291
1292
        check_perf_period(
1293
            pp,
1294
            pt,
1295
            gross_leverage=2.5,
1296
            net_leverage=-0.25,
1297
            long_exposure=900.0,
1298
            longs_count=1,
1299
            short_exposure=-1100.0,
1300
            shorts_count=1,
1301
            data_portal=data_portal
1302
        )
1303
1304
        check_account(account,
1305
                      settled_cash=1000.0,
1306
                      equity_with_loan=800.0,
1307
                      total_positions_value=-200.0,
1308
                      regt_equity=1000.0,
1309
                      available_funds=1000.0,
1310
                      excess_liquidity=1000.0,
1311
                      cushion=1.25,
1312
                      leverage=2.5,
1313
                      net_leverage=-0.25,
1314
                      net_liquidation=800.0)
1315
1316
    def test_levered_long_position(self):
1317
        """
1318
            start with $1,000, then buy 1000 shares at $10.
1319
            price goes to $11
1320
        """
1321
        # post some trades in the market
1322
1323
        self.create_environment_stuff()
1324
1325
        trades = factory.create_trade_history(
1326
            1,
1327
            [10, 10, 10, 11],
1328
            [100, 100, 100, 100],
1329
            oneday,
1330
            self.sim_params,
1331
            env=self.env
1332
        )
1333
1334
        data_portal = create_data_portal_from_trade_history(
1335
            self.env,
1336
            self.tempdir,
1337
            self.sim_params,
1338
            {1: trades})
1339
1340
        txn = create_txn(trades[1].sid, trades[1].dt, 10.0, 1000)
1341
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
1342
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal)
1343
1344
        pt.execute_transaction(txn)
1345
        pp.handle_execution(txn)
1346
1347
        check_perf_period(
1348
            pp,
1349
            pt,
1350
            gross_leverage=10.0,
1351
            net_leverage=10.0,
1352
            long_exposure=10000.0,
1353
            longs_count=1,
1354
            short_exposure=0.0,
1355
            shorts_count=0,
1356
            data_portal=data_portal
1357
        )
1358
1359
        # Validate that the account attributes were updated.
1360
        pt.sync_last_sale_prices(trades[-2].dt)
1361
        pos_stats = pt.stats()
1362
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1363
        account = pp.as_account(pos_stats, pp_stats)
1364
        check_account(account,
1365
                      settled_cash=-9000.0,
1366
                      equity_with_loan=1000.0,
1367
                      total_positions_value=10000.0,
1368
                      regt_equity=-9000.0,
1369
                      available_funds=-9000.0,
1370
                      excess_liquidity=-9000.0,
1371
                      cushion=-9.0,
1372
                      leverage=10.0,
1373
                      net_leverage=10.0,
1374
                      net_liquidation=1000.0)
1375
1376
        # now simulate a price jump to $11
1377
        pt.sync_last_sale_prices(trades[-1].dt)
1378
1379
        check_perf_period(
1380
            pp,
1381
            pt,
1382
            gross_leverage=5.5,
1383
            net_leverage=5.5,
1384
            long_exposure=11000.0,
1385
            longs_count=1,
1386
            short_exposure=0.0,
1387
            shorts_count=0,
1388
            data_portal=data_portal
1389
        )
1390
1391
        # Validate that the account attributes were updated.
1392
        pos_stats = pt.stats()
1393
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1394
        account = pp.as_account(pos_stats, pp_stats)
1395
1396
        check_account(account,
1397
                      settled_cash=-9000.0,
1398
                      equity_with_loan=2000.0,
1399
                      total_positions_value=11000.0,
1400
                      regt_equity=-9000.0,
1401
                      available_funds=-9000.0,
1402
                      excess_liquidity=-9000.0,
1403
                      cushion=-4.5,
1404
                      leverage=5.5,
1405
                      net_leverage=5.5,
1406
                      net_liquidation=2000.0)
1407
1408
    def test_long_position(self):
1409
        """
1410
            verify that the performance period calculates properly for a
1411
            single buy transaction
1412
        """
1413
        self.create_environment_stuff()
1414
1415
        # post some trades in the market
1416
        trades = factory.create_trade_history(
1417
            1,
1418
            [10, 10, 10, 11],
1419
            [100, 100, 100, 100],
1420
            oneday,
1421
            self.sim_params,
1422
            env=self.env
1423
        )
1424
1425
        data_portal = create_data_portal_from_trade_history(
1426
            self.env,
1427
            self.tempdir,
1428
            self.sim_params,
1429
            {1: trades})
1430
1431
        txn = create_txn(trades[1].sid, trades[1].dt, 10.0, 100)
1432
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
1433
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1434
                                    period_open=self.sim_params.period_start,
1435
                                    period_close=self.sim_params.period_end)
1436
        pt.execute_transaction(txn)
1437
        pp.handle_execution(txn)
1438
1439
        # This verifies that the last sale price is being correctly
1440
        # set in the positions. If this is not the case then returns can
1441
        # incorrectly show as sharply dipping if a transaction arrives
1442
        # before a trade. This is caused by returns being based on holding
1443
        # stocks with a last sale price of 0.
1444
        self.assertEqual(pt.positions[1].last_sale_price, 10.0)
1445
1446
        pt.sync_last_sale_prices(trades[-1].dt)
1447
1448
        self.assertEqual(
1449
            pp.period_cash_flow,
1450
            -1 * txn.price * txn.amount,
1451
            "capital used should be equal to the opposite of the transaction \
1452
            cost of sole txn in test"
1453
        )
1454
1455
        self.assertEqual(len(pt.positions), 1, "should be just one position")
1456
1457
        self.assertEqual(
1458
            pt.positions[1].sid,
1459
            txn.sid,
1460
            "position should be in security with id 1")
1461
1462
        self.assertEqual(
1463
            pt.positions[1].amount,
1464
            txn.amount,
1465
            "should have a position of {sharecount} shares".format(
1466
                sharecount=txn.amount
1467
            )
1468
        )
1469
1470
        self.assertEqual(
1471
            pt.positions[1].cost_basis,
1472
            txn.price,
1473
            "should have a cost basis of 10"
1474
        )
1475
1476
        self.assertEqual(
1477
            pt.positions[1].last_sale_price,
1478
            trades[-1]['price'],
1479
            "last sale should be same as last trade. \
1480
            expected {exp} actual {act}".format(
1481
                exp=trades[-1]['price'],
1482
                act=pt.positions[1].last_sale_price)
1483
        )
1484
1485
        pos_stats = pt.stats()
1486
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1487
1488
        self.assertEqual(
1489
            pos_stats.net_value,
1490
            1100,
1491
            "ending value should be price of last trade times number of \
1492
            shares in position"
1493
        )
1494
1495
        self.assertEqual(pp_stats.pnl, 100,
1496
                         "gain of 1 on 100 shares should be 100")
1497
1498
        check_perf_period(
1499
            pp,
1500
            pt,
1501
            gross_leverage=1.0,
1502
            net_leverage=1.0,
1503
            long_exposure=1100.0,
1504
            longs_count=1,
1505
            short_exposure=0.0,
1506
            shorts_count=0,
1507
            data_portal=data_portal
1508
        )
1509
1510
        # Validate that the account attributes were updated.
1511
        account = pp.as_account(pos_stats, pp_stats)
1512
        check_account(account,
1513
                      settled_cash=0.0,
1514
                      equity_with_loan=1100.0,
1515
                      total_positions_value=1100.0,
1516
                      regt_equity=0.0,
1517
                      available_funds=0.0,
1518
                      excess_liquidity=0.0,
1519
                      cushion=0.0,
1520
                      leverage=1.0,
1521
                      net_leverage=1.0,
1522
                      net_liquidation=1100.0)
1523
1524
    def test_short_position(self):
1525
        """verify that the performance period calculates properly for a \
1526
single short-sale transaction"""
1527
        self.create_environment_stuff(num_days=6)
1528
1529
        trades = factory.create_trade_history(
1530
            1,
1531
            [10, 10, 10, 11, 10, 9],
1532
            [100, 100, 100, 100, 100, 100],
1533
            oneday,
1534
            self.sim_params,
1535
            env=self.env
1536
        )
1537
1538
        trades_1 = trades[:-2]
1539
1540
        data_portal = create_data_portal_from_trade_history(
1541
            self.env,
1542
            self.tempdir,
1543
            self.sim_params,
1544
            {1: trades})
1545
1546
        txn = create_txn(trades[1].sid, trades[1].dt, 10.0, -100)
1547
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
1548
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal)
1549
1550
        pt.execute_transaction(txn)
1551
        pp.handle_execution(txn)
1552
1553
        pt.sync_last_sale_prices(trades_1[-1].dt)
1554
1555
        self.assertEqual(
1556
            pp.period_cash_flow,
1557
            -1 * txn.price * txn.amount,
1558
            "capital used should be equal to the opposite of the transaction\
1559
             cost of sole txn in test"
1560
        )
1561
1562
        self.assertEqual(
1563
            len(pt.positions),
1564
            1,
1565
            "should be just one position")
1566
1567
        self.assertEqual(
1568
            pt.positions[1].sid,
1569
            txn.sid,
1570
            "position should be in security from the transaction"
1571
        )
1572
1573
        self.assertEqual(
1574
            pt.positions[1].amount,
1575
            -100,
1576
            "should have a position of -100 shares"
1577
        )
1578
1579
        self.assertEqual(
1580
            pt.positions[1].cost_basis,
1581
            txn.price,
1582
            "should have a cost basis of 10"
1583
        )
1584
1585
        self.assertEqual(
1586
            pt.positions[1].last_sale_price,
1587
            trades_1[-1]['price'],
1588
            "last sale should be price of last trade"
1589
        )
1590
1591
        pos_stats = pt.stats()
1592
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1593
1594
        self.assertEqual(
1595
            pos_stats.net_value,
1596
            -1100,
1597
            "ending value should be price of last trade times number of \
1598
            shares in position"
1599
        )
1600
1601
        self.assertEqual(pp_stats.pnl, -100,
1602
                         "gain of 1 on 100 shares should be 100")
1603
1604
        # simulate additional trades, and ensure that the position value
1605
        # reflects the new price
1606
        trades_2 = trades[-2:]
1607
1608
        # simulate a rollover to a new period
1609
        pp.rollover(pos_stats, pp_stats)
1610
1611
        pt.sync_last_sale_prices(trades[-1].dt)
1612
1613
        self.assertEqual(
1614
            pp.period_cash_flow,
1615
            0,
1616
            "capital used should be zero, there were no transactions in \
1617
            performance period"
1618
        )
1619
1620
        self.assertEqual(
1621
            len(pt.positions),
1622
            1,
1623
            "should be just one position"
1624
        )
1625
1626
        self.assertEqual(
1627
            pt.positions[1].sid,
1628
            txn.sid,
1629
            "position should be in security from the transaction"
1630
        )
1631
1632
        self.assertEqual(
1633
            pt.positions[1].amount,
1634
            -100,
1635
            "should have a position of -100 shares"
1636
        )
1637
1638
        self.assertEqual(
1639
            pt.positions[1].cost_basis,
1640
            txn.price,
1641
            "should have a cost basis of 10"
1642
        )
1643
1644
        self.assertEqual(
1645
            pt.positions[1].last_sale_price,
1646
            trades_2[-1].price,
1647
            "last sale should be price of last trade"
1648
        )
1649
1650
        pos_stats = pt.stats()
1651
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1652
1653
        self.assertEqual(
1654
            pos_stats.net_value,
1655
            -900,
1656
            "ending value should be price of last trade times number of \
1657
            shares in position")
1658
1659
        self.assertEqual(
1660
            pp_stats.pnl,
1661
            200,
1662
            "drop of 2 on -100 shares should be 200"
1663
        )
1664
1665
        # now run a performance period encompassing the entire trade sample.
1666
        ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal)
1667
        ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1668
                                         data_portal)
1669
1670
        ptTotal.execute_transaction(txn)
1671
        ppTotal.handle_execution(txn)
1672
1673
        ptTotal.sync_last_sale_prices(trades[-1].dt)
1674
1675
        self.assertEqual(
1676
            ppTotal.period_cash_flow,
1677
            -1 * txn.price * txn.amount,
1678
            "capital used should be equal to the opposite of the transaction \
1679
cost of sole txn in test"
1680
        )
1681
1682
        self.assertEqual(
1683
            len(ptTotal.positions),
1684
            1,
1685
            "should be just one position"
1686
        )
1687
        self.assertEqual(
1688
            ptTotal.positions[1].sid,
1689
            txn.sid,
1690
            "position should be in security from the transaction"
1691
        )
1692
1693
        self.assertEqual(
1694
            ptTotal.positions[1].amount,
1695
            -100,
1696
            "should have a position of -100 shares"
1697
        )
1698
1699
        self.assertEqual(
1700
            ptTotal.positions[1].cost_basis,
1701
            txn.price,
1702
            "should have a cost basis of 10"
1703
        )
1704
1705
        self.assertEqual(
1706
            ptTotal.positions[1].last_sale_price,
1707
            trades_2[-1].price,
1708
            "last sale should be price of last trade"
1709
        )
1710
1711
        pos_total_stats = ptTotal.stats()
1712
        pp_total_stats = ppTotal.stats(ptTotal.positions, pos_total_stats,
1713
                                       data_portal)
1714
1715
        self.assertEqual(
1716
            pos_total_stats.net_value,
1717
            -900,
1718
            "ending value should be price of last trade times number of \
1719
            shares in position")
1720
1721
        self.assertEqual(
1722
            pp_total_stats.pnl,
1723
            100,
1724
            "drop of 1 on -100 shares should be 100"
1725
        )
1726
1727
        check_perf_period(
1728
            pp,
1729
            pt,
1730
            gross_leverage=0.8181,
1731
            net_leverage=-0.8181,
1732
            long_exposure=0.0,
1733
            longs_count=0,
1734
            short_exposure=-900.0,
1735
            shorts_count=1,
1736
            data_portal=data_portal
1737
        )
1738
1739
        # Validate that the account attributes.
1740
        account = ppTotal.as_account(pos_stats, pp_stats)
1741
        check_account(account,
1742
                      settled_cash=2000.0,
1743
                      equity_with_loan=1100.0,
1744
                      total_positions_value=-900.0,
1745
                      regt_equity=2000.0,
1746
                      available_funds=2000.0,
1747
                      excess_liquidity=2000.0,
1748
                      cushion=1.8181,
1749
                      leverage=0.8181,
1750
                      net_leverage=-0.8181,
1751
                      net_liquidation=1100.0)
1752
1753
    def test_covering_short(self):
1754
        """verify performance where short is bought and covered, and shares \
1755
trade after cover"""
1756
        self.create_environment_stuff(num_days=10)
1757
1758
        trades = factory.create_trade_history(
1759
            1,
1760
            [10, 10, 10, 11, 9, 8, 7, 8, 9, 10],
1761
            [100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
1762
            onesec,
1763
            self.sim_params,
1764
            env=self.env
1765
        )
1766
1767
        data_portal = create_data_portal_from_trade_history(
1768
            self.env,
1769
            self.tempdir,
1770
            self.sim_params,
1771
            {1: trades})
1772
1773
        short_txn = create_txn(
1774
            trades[1].sid,
1775
            trades[1].dt,
1776
            10.0,
1777
            -100,
1778
        )
1779
        cover_txn = create_txn(trades[6].sid, trades[6].dt, 7.0, 100)
1780
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
1781
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1782
                                    data_portal)
1783
1784
        pt.execute_transaction(short_txn)
1785
        pp.handle_execution(short_txn)
1786
        pt.execute_transaction(cover_txn)
1787
        pp.handle_execution(cover_txn)
1788
1789
        pt.sync_last_sale_prices(trades[-1].dt)
1790
1791
        short_txn_cost = short_txn.price * short_txn.amount
1792
        cover_txn_cost = cover_txn.price * cover_txn.amount
1793
1794
        self.assertEqual(
1795
            pp.period_cash_flow,
1796
            -1 * short_txn_cost - cover_txn_cost,
1797
            "capital used should be equal to the net transaction costs"
1798
        )
1799
1800
        self.assertEqual(
1801
            len(pt.positions),
1802
            1,
1803
            "should be just one position"
1804
        )
1805
1806
        self.assertEqual(
1807
            pt.positions[1].sid,
1808
            short_txn.sid,
1809
            "position should be in security from the transaction"
1810
        )
1811
1812
        self.assertEqual(
1813
            pt.positions[1].amount,
1814
            0,
1815
            "should have a position of -100 shares"
1816
        )
1817
1818
        self.assertEqual(
1819
            pt.positions[1].cost_basis,
1820
            0,
1821
            "a covered position should have a cost basis of 0"
1822
        )
1823
1824
        self.assertEqual(
1825
            pt.positions[1].last_sale_price,
1826
            trades[-1].price,
1827
            "last sale should be price of last trade"
1828
        )
1829
1830
        pos_stats = pt.stats()
1831
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1832
1833
        self.assertEqual(
1834
            pos_stats.net_value,
1835
            0,
1836
            "ending value should be price of last trade times number of \
1837
shares in position"
1838
        )
1839
1840
        self.assertEqual(
1841
            pp_stats.pnl,
1842
            300,
1843
            "gain of 1 on 100 shares should be 300"
1844
        )
1845
1846
        check_perf_period(
1847
            pp,
1848
            pt,
1849
            gross_leverage=0.0,
1850
            net_leverage=0.0,
1851
            long_exposure=0.0,
1852
            longs_count=0,
1853
            short_exposure=0.0,
1854
            shorts_count=0,
1855
            data_portal=data_portal
1856
        )
1857
1858
        account = pp.as_account(pos_stats, pp_stats)
1859
        check_account(account,
1860
                      settled_cash=1300.0,
1861
                      equity_with_loan=1300.0,
1862
                      total_positions_value=0.0,
1863
                      regt_equity=1300.0,
1864
                      available_funds=1300.0,
1865
                      excess_liquidity=1300.0,
1866
                      cushion=1.0,
1867
                      leverage=0.0,
1868
                      net_leverage=0.0,
1869
                      net_liquidation=1300.0)
1870
1871
    def test_cost_basis_calc(self):
1872
        self.create_environment_stuff(num_days=5)
1873
1874
        history_args = (
1875
            1,
1876
            [10, 11, 11, 12, 10],
1877
            [100, 100, 100, 100, 100],
1878
            oneday,
1879
            self.sim_params,
1880
            self.env
1881
        )
1882
        trades = factory.create_trade_history(*history_args)
1883
        transactions = factory.create_txn_history(*history_args)[:4]
1884
1885
        data_portal = create_data_portal_from_trade_history(
1886
            self.env,
1887
            self.tempdir,
1888
            self.sim_params,
1889
            {1: trades})
1890
1891
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
1892
        pp = perf.PerformancePeriod(
1893
            1000.0,
1894
            self.env.asset_finder,
1895
            period_open=self.sim_params.period_start,
1896
            period_close=self.sim_params.trading_days[-1]
1897
        )
1898
        average_cost = 0
1899
        for i, txn in enumerate(transactions):
1900
            pt.execute_transaction(txn)
1901
            pp.handle_execution(txn)
1902
            average_cost = (average_cost * i + txn.price) / (i + 1)
1903
            self.assertEqual(pt.positions[1].cost_basis, average_cost)
1904
1905
        dt = trades[-2].dt
1906
        self.assertEqual(
1907
            pt.positions[1].last_sale_price,
1908
            trades[-2].price,
1909
            "should have a last sale of 12, got {val}".format(
1910
                val=pt.positions[1].last_sale_price)
1911
        )
1912
1913
        self.assertEqual(
1914
            pt.positions[1].cost_basis,
1915
            11,
1916
            "should have a cost basis of 11"
1917
        )
1918
1919
        pt.sync_last_sale_prices(dt)
1920
1921
        pos_stats = pt.stats()
1922
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1923
1924
        self.assertEqual(
1925
            pp_stats.pnl,
1926
            400
1927
        )
1928
1929
        down_tick = trades[-1]
1930
1931
        sale_txn = create_txn(
1932
            down_tick.sid,
1933
            down_tick.dt,
1934
            10.0,
1935
            -100)
1936
1937
        pp.rollover(pos_stats, pp_stats)
1938
1939
        pt.execute_transaction(sale_txn)
1940
        pp.handle_execution(sale_txn)
1941
1942
        dt = down_tick.dt
1943
        pt.sync_last_sale_prices(dt)
1944
1945
        self.assertEqual(
1946
            pt.positions[1].last_sale_price,
1947
            10,
1948
            "should have a last sale of 10, was {val}".format(
1949
                val=pt.positions[1].last_sale_price)
1950
        )
1951
1952
        self.assertEqual(
1953
            pt.positions[1].cost_basis,
1954
            11,
1955
            "should have a cost basis of 11"
1956
        )
1957
1958
        pos_stats = pt.stats()
1959
        pp_stats = pp.stats(pt.positions, pos_stats, data_portal)
1960
1961
        self.assertEqual(pp_stats.pnl, -800,
1962
                         "this period goes from +400 to -400")
1963
1964
        pt3 = perf.PositionTracker(self.env.asset_finder, data_portal)
1965
        pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder,
1966
                                     data_portal)
1967
1968
        average_cost = 0
1969
        for i, txn in enumerate(transactions):
1970
            pt3.execute_transaction(txn)
1971
            pp3.handle_execution(txn)
1972
            average_cost = (average_cost * i + txn.price) / (i + 1)
1973
            self.assertEqual(pt3.positions[1].cost_basis, average_cost)
1974
1975
        pt3.execute_transaction(sale_txn)
1976
        pp3.handle_execution(sale_txn)
1977
1978
        trades.append(down_tick)
1979
        pt3.sync_last_sale_prices(trades[-1].dt)
1980
1981
        self.assertEqual(
1982
            pt3.positions[1].last_sale_price,
1983
            10,
1984
            "should have a last sale of 10"
1985
        )
1986
1987
        self.assertEqual(
1988
            pt3.positions[1].cost_basis,
1989
            11,
1990
            "should have a cost basis of 11"
1991
        )
1992
1993
        pt3.sync_last_sale_prices(dt)
1994
        pt3_stats = pt3.stats()
1995
        pp3_stats = pp3.stats(pt3.positions, pt3_stats, data_portal)
1996
1997
        self.assertEqual(
1998
            pp3_stats.pnl,
1999
            -400,
2000
            "should be -400 for all trades and transactions in period"
2001
        )
2002
2003
    def test_cost_basis_calc_close_pos(self):
2004
        self.create_environment_stuff(num_days=8)
2005
2006
        history_args = (
2007
            1,
2008
            [10, 9, 11, 8, 9, 12, 13, 14],
2009
            [200, -100, -100, 100, -300, 100, 500, 400],
2010
            onesec,
2011
            self.sim_params,
2012
            self.env
2013
        )
2014
        cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
2015
2016
        trades = factory.create_trade_history(*history_args)
2017
        transactions = factory.create_txn_history(*history_args)
2018
2019
        data_portal = create_data_portal_from_trade_history(
2020
            self.env,
2021
            self.tempdir,
2022
            self.sim_params,
2023
            {1: trades})
2024
2025
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
2026
        pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal)
2027
2028
        for txn, cb in zip(transactions, cost_bases):
2029
            pt.execute_transaction(txn)
2030
            pp.handle_execution(txn)
2031
            self.assertEqual(pt.positions[1].cost_basis, cb)
2032
2033
        self.assertEqual(pt.positions[1].cost_basis, cost_bases[-1])
2034
2035
2036
class TestPosition(unittest.TestCase):
2037
    def setUp(self):
2038
        pass
2039
2040
    def test_serialization(self):
2041
        dt = pd.Timestamp("1984/03/06 3:00PM")
2042
        pos = perf.Position(10, amount=np.float64(120.0), last_sale_date=dt,
2043
                            last_sale_price=3.4)
2044
2045
        p_string = dumps_with_persistent_ids(pos)
2046
2047
        test = loads_with_persistent_ids(p_string, env=None)
2048
        nt.assert_dict_equal(test.__dict__, pos.__dict__)
2049
2050
2051
class TestPositionTracker(unittest.TestCase):
2052
2053
    @classmethod
2054
    def setUpClass(cls):
2055
        cls.env = TradingEnvironment()
2056
        futures_metadata = {3: {'contract_multiplier': 1000},
2057
                            4: {'contract_multiplier': 1000}}
2058
        cls.env.write_data(equities_identifiers=[1, 2],
2059
                           futures_data=futures_metadata)
2060
2061
    @classmethod
2062
    def tearDownClass(cls):
2063
        del cls.env
2064
2065
    def setUp(self):
2066
        self.tempdir = TempDirectory()
2067
2068
    def tearDown(self):
2069
        self.tempdir.cleanup()
2070
2071
    def test_empty_positions(self):
2072
        """
2073
        make sure all the empty position stats return a numeric 0
2074
2075
        Originally this bug was due to np.dot([], []) returning
2076
        np.bool_(False)
2077
        """
2078
        sim_params = factory.create_simulation_parameters(
2079
            num_days=4, env=self.env
2080
        )
2081
        trades = factory.create_trade_history(
2082
            1,
2083
            [10, 10, 10, 11],
2084
            [100, 100, 100, 100],
2085
            oneday,
2086
            sim_params,
2087
            env=self.env
2088
        )
2089
2090
        data_portal = create_data_portal_from_trade_history(
2091
            self.env,
2092
            self.tempdir,
2093
            sim_params,
2094
            {1: trades})
2095
2096
        pt = perf.PositionTracker(self.env.asset_finder, data_portal)
2097
        pos_stats = pt.stats()
2098
2099
        stats = [
2100
            'net_value',
2101
            'net_exposure',
2102
            'gross_value',
2103
            'gross_exposure',
2104
            'short_value',
2105
            'short_exposure',
2106
            'shorts_count',
2107
            'long_value',
2108
            'long_exposure',
2109
            'longs_count',
2110
        ]
2111
        for name in stats:
2112
            val = getattr(pos_stats, name)
2113
            self.assertEquals(val, 0)
2114
            self.assertNotIsInstance(val, (bool, np.bool_))
2115
2116
    def test_position_values_and_exposures(self):
2117
        pt = perf.PositionTracker(self.env.asset_finder, None)
2118
        dt = pd.Timestamp("1984/03/06 3:00PM")
2119
        pos1 = perf.Position(1, amount=np.float64(10.0),
2120
                             last_sale_date=dt, last_sale_price=10)
2121
        pos2 = perf.Position(2, amount=np.float64(-20.0),
2122
                             last_sale_date=dt, last_sale_price=10)
2123
        pos3 = perf.Position(3, amount=np.float64(30.0),
2124
                             last_sale_date=dt, last_sale_price=10)
2125
        pos4 = perf.Position(4, amount=np.float64(-40.0),
2126
                             last_sale_date=dt, last_sale_price=10)
2127
        pt.update_positions({1: pos1, 2: pos2, 3: pos3, 4: pos4})
2128
2129
        # Test long-only methods
2130
2131
        pos_stats = pt.stats()
2132
        self.assertEqual(100, pos_stats.long_value)
2133
        self.assertEqual(100 + 300000, pos_stats.long_exposure)
2134
        self.assertEqual(2, pos_stats.longs_count)
2135
2136
        # Test short-only methods
2137
        self.assertEqual(-200, pos_stats.short_value)
2138
        self.assertEqual(-200 - 400000, pos_stats.short_exposure)
2139
        self.assertEqual(2, pos_stats.shorts_count)
2140
2141
        # Test gross and net values
2142
        self.assertEqual(100 + 200, pos_stats.gross_value)
2143
        self.assertEqual(100 - 200, pos_stats.net_value)
2144
2145
        # Test gross and net exposures
2146
        self.assertEqual(100 + 200 + 300000 + 400000, pos_stats.gross_exposure)
2147
        self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure)
2148
2149
    def test_serialization(self):
2150
        pt = perf.PositionTracker(self.env.asset_finder, None)
2151
        dt = pd.Timestamp("1984/03/06 3:00PM")
2152
        pos1 = perf.Position(1, amount=np.float64(120.0),
2153
                             last_sale_date=dt, last_sale_price=3.4)
2154
        pos3 = perf.Position(3, amount=np.float64(100.0),
2155
                             last_sale_date=dt, last_sale_price=3.4)
2156
2157
        pt.update_positions({1: pos1, 3: pos3})
2158
        p_string = dumps_with_persistent_ids(pt)
2159
        test = loads_with_persistent_ids(p_string, env=self.env)
2160
        nt.assert_count_equal(test.positions.keys(), pt.positions.keys())
2161
        for sid in pt.positions:
2162
            nt.assert_dict_equal(test.positions[sid].__dict__,
2163
                                 pt.positions[sid].__dict__)
2164
2165
2166
class TestPerformancePeriod(unittest.TestCase):
2167
2168
    def test_serialization(self):
2169
        env = TradingEnvironment()
2170
        pp = perf.PerformancePeriod(100, env.asset_finder)
2171
2172
        p_string = dumps_with_persistent_ids(pp)
2173
        test = loads_with_persistent_ids(p_string, env=env)
2174
2175
        correct = pp.__dict__.copy()
2176
2177
        nt.assert_count_equal(test.__dict__.keys(), correct.keys())
2178
2179
        equal_keys = list(correct.keys())
2180
        equal_keys.remove('_account_store')
2181
        equal_keys.remove('_portfolio_store')
2182
2183
        for k in equal_keys:
2184
            nt.assert_equal(test.__dict__[k], correct[k])
2185