Passed
Push — beta ( 72a57d...7d0ef0 )
by Dean
03:02
created

SyncTask.abort()   A

Complexity

Conditions 4

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 14.7187

Importance

Changes 0
Metric Value
dl 0
loc 13
ccs 1
cts 8
cp 0.125
rs 9.2
c 0
b 0
f 0
cc 4
crap 14.7187
1 1
from plugin.managers.exception import ExceptionManager
2 1
from plugin.models import *
3 1
from plugin.sync.core.enums import SyncData, SyncMode
4 1
from plugin.sync.core.exceptions import SyncAbort
5 1
from plugin.sync.core.task.artifacts import SyncArtifacts
6 1
from plugin.sync.core.task.configuration import SyncConfiguration
7 1
from plugin.sync.core.task.map import SyncMap
8 1
from plugin.sync.core.task.progress import SyncProgress
9 1
from plugin.sync.core.task.profiler import SyncProfiler
10 1
from plugin.sync.core.task.state import SyncState
11
12 1
from datetime import datetime
13 1
from peewee import JOIN_LEFT_OUTER
14 1
import logging
15 1
import time
16
17 1
log = logging.getLogger(__name__)
18
19
20 1
class SyncTask(object):
21 1
    def __init__(self, account, mode, data, media, result, status, **kwargs):
22 1
        self.account = account
23
24
        # Sync options
25 1
        self.mode = mode
26 1
        self.data = data
27 1
        self.media = media
28
29
        # Extra arguments
30 1
        self.kwargs = kwargs
31
32
        # Handlers/Modes for task
33 1
        self.handlers = None
34 1
        self.modes = None
35
36
        # State/Result management
37 1
        self.result = result
38 1
        self.status = status
39
40 1
        self.exceptions = []
41
42 1
        self.finished = False
43 1
        self.started = False
44 1
        self.success = None
45
46 1
        self._abort = False
47
48
        # Construct children
49 1
        self.artifacts = SyncArtifacts(self)
50 1
        self.configuration = SyncConfiguration(self)
51 1
        self.map = SyncMap(self)
52 1
        self.progress = SyncProgress(self)
53 1
        self.profiler = SyncProfiler(self)
54
55 1
        self.state = SyncState(self)
56
57 1
    @property
58
    def id(self):
59
        if self.result is None:
60
            return None
61
62
        return self.result.id
63
64 1
    @property
65
    def elapsed(self):
66
        if self.result is None:
67
            return None
68
69
        return (datetime.utcnow() - self.result.started_at).total_seconds()
70
71 1
    def construct(self, handlers, modes):
72
        log.debug('Constructing %d handlers...', len(handlers))
73
        self.handlers = dict(self._construct_modules(handlers, 'data'))
74
75
        log.debug('Constructing %d modes...', len(modes))
76
        self.modes = dict(self._construct_modules(modes, 'mode'))
77
78 1
    def load(self):
79
        # Load task configuration
80
        self.configuration.load(self.account)
81
82
        # Automatically determine enabled data types
83
        if self.data is None:
84
            self.data = self.get_enabled_data(self.configuration, self.mode)
85
86
        log.debug('Sync Data: %r', self.data)
87
        log.debug('Sync Media: %r', self.media)
88
89
        if self.data is None:
90
            raise ValueError('No collections enabled for sync')
91
92
        # Load children
93
        self.profiler.load()
94
        self.state.load()
95
96 1
    def abort(self, timeout=None):
97
        # Set `abort` flag, thread will abort on the next `checkpoint()`
98
        self._abort = True
99
100
        if timeout is None:
101
            return
102
103
        # Wait `timeout` seconds for task to finish
104
        for x in xrange(timeout):
105
            if self.finished:
106
                return
107
108
            time.sleep(1)
109
110 1
    def checkpoint(self):
111
        # Check if an abort has been requested
112
        if not self._abort:
113
            return
114
115
        raise SyncAbort()
116
117 1
    def finish(self):
118
        # Update result in database
119
        self.result.ended_at = datetime.utcnow()
120
        self.result.success = self.success
121
        self.result.save()
122
123
        # Store exceptions in database
124
        for exc_info in self.exceptions:
125
            try:
126
                self.store_exception(self.result, exc_info)
127
            except Exception, ex:
128
                log.warn('Unable to store exception: %s', str(ex), exc_info=True)
129
130
        # Flush caches to archives
131
        self.state.flush()
132
133
        # Display profiler report
134
        self.profiler.log_report()
135
136
        # Mark finished
137
        self.finished = True
138
139 1
    @staticmethod
140
    def store_exception(result, exc_info):
141
        exception, error = ExceptionManager.create.from_exc_info(exc_info)
142
143
        # Link error to result
144
        SyncResultError.create(
145
            result=result,
146
            error=error
147
        )
148
149
        # Link exception to result
150
        SyncResultException.create(
151
            result=result,
152
            exception=exception
153
        )
154
155 1
    @classmethod
156
    def create(cls, account, mode, data, media, trigger, **kwargs):
157
        # Get account
158
        if type(account) is int:
159
            account = cls.get_account(account)
160
        elif type(account) is not Account:
161
            raise ValueError('Unexpected value provided for the "account" parameter')
162
163
        # Get/Create sync status
164
        status, created = SyncStatus.get_or_create(
165
            account=account,
166
            mode=mode,
167
            section=kwargs.get('section', None)
168
        )
169
170
        # Create sync result
171
        result = SyncResult.create(
172
            status=status,
173
            trigger=trigger,
174
175
            started_at=datetime.utcnow()
176
        )
177
178
        # Create sync task
179
        task = SyncTask(
180
            account, mode,
181
            data, media,
182
            result, status,
183
            **kwargs
184
        )
185
186
        # Load sync configuration/state
187
        task.load()
188
189
        return task
190
191 1
    @classmethod
192
    def get_account(cls, account_id):
193
        # TODO Move account retrieval/join to `Account` class
194
        return (
195
            Account.select(
196
                Account.id,
197
                Account.name,
198
199
                PlexAccount.id,
200
                PlexAccount.key,
201
                PlexAccount.username,
202
                PlexBasicCredential.token_plex,
203
                PlexBasicCredential.token_server,
204
205
                TraktAccount.username,
206
                TraktBasicCredential.token,
207
208
                TraktOAuthCredential.access_token,
209
                TraktOAuthCredential.refresh_token,
210
                TraktOAuthCredential.created_at,
211
                TraktOAuthCredential.expires_in
212
            )
213
            # Plex
214
            .join(
215
                PlexAccount, JOIN_LEFT_OUTER, on=(
216
                    PlexAccount.account == Account.id
217
                ).alias('plex')
218
            )
219
            .join(
220
                PlexBasicCredential, JOIN_LEFT_OUTER, on=(
221
                    PlexBasicCredential.account == PlexAccount.id
222
                ).alias('basic')
223
            )
224
            # Trakt
225
            .switch(Account)
226
            .join(
227
                TraktAccount, JOIN_LEFT_OUTER, on=(
228
                    TraktAccount.account == Account.id
229
                ).alias('trakt')
230
            )
231
            .join(
232
                TraktBasicCredential, JOIN_LEFT_OUTER, on=(
233
                    TraktBasicCredential.account == TraktAccount.id
234
                ).alias('basic')
235
            )
236
            .switch(TraktAccount)
237
            .join(
238
                TraktOAuthCredential, JOIN_LEFT_OUTER, on=(
239
                    TraktOAuthCredential.account == TraktAccount.id
240
                ).alias('oauth')
241
            )
242
            .where(Account.id == account_id)
243
            .get()
244
        )
245
246 1
    @classmethod
247
    def get_enabled_data(cls, config, mode):
248
        # Determine accepted modes
249
        modes = [SyncMode.Full]
250
251
        if mode == SyncMode.Full:
252
            modes.extend([
253
                SyncMode.FastPull,
254
                SyncMode.Pull,
255
                SyncMode.Push
256
            ])
257
        elif mode == SyncMode.FastPull:
258
            modes.extend([
259
                mode,
260
                SyncMode.Pull
261
            ])
262
        else:
263
            modes.append(mode)
264
265
        # Retrieve enabled data
266
        enabled = []
267
268
        if config['sync.watched.mode'] in modes:
269
            enabled.append(SyncData.Watched)
270
271
        if config['sync.ratings.mode'] in modes:
272
            enabled.append(SyncData.Ratings)
273
274
        if config['sync.playback.mode'] in modes:
275
            enabled.append(SyncData.Playback)
276
277
        if config['sync.collection.mode'] in modes:
278
            enabled.append(SyncData.Collection)
279
280
        # Lists
281
        if config['sync.lists.watchlist.mode'] in modes:
282
            enabled.append(SyncData.Watchlist)
283
284
        if config['sync.lists.liked.mode'] in modes:
285
            enabled.append(SyncData.Liked)
286
287
        if config['sync.lists.personal.mode'] in modes:
288
            enabled.append(SyncData.Personal)
289
290
        # Convert to enum value
291
        result = None
292
293
        for data in enabled:
294
            if result is None:
295
                result = data
296
                continue
297
298
            result |= data
299
300
        return result
301
302 1
    def _construct_modules(self, modules, attribute):
303
        for cls in modules:
304
            keys = getattr(cls, attribute, None)
305
306
            if keys is None:
307
                log.warn('Module %r is missing a valid %r attribute', cls, attribute)
308
                continue
309
310
            # Convert `keys` to list
311
            if type(keys) is not list:
312
                keys = [keys]
313
314
            # Construct module
315
            obj = cls(self)
316
317
            # Return module with keys
318
            for key in keys:
319
                yield key, obj
320