Completed
Pull Request — develop (#28)
by Fabian
01:36
created

EcsTaskDefinitionDiff._get_environment_diffs()   A

Complexity

Conditions 4

Size

Total Lines 10

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 9
CRAP Score 4

Importance

Changes 0
Metric Value
c 0
b 0
f 0
dl 0
loc 10
ccs 9
cts 9
cp 1
rs 9.2
cc 4
crap 4
1 1
from datetime import datetime
2
3 1
from boto3.session import Session
4 1
from botocore.exceptions import ClientError, NoCredentialsError
5 1
from dateutil.tz.tz import tzlocal
6
7
8 1
class EcsClient(object):
9 1
    def __init__(self, access_key_id=None, secret_access_key=None,
10
                 region=None, profile=None):
11 1
        session = Session(aws_access_key_id=access_key_id,
12
                          aws_secret_access_key=secret_access_key,
13
                          region_name=region,
14
                          profile_name=profile)
15 1
        self.boto = session.client(u'ecs')
16
17 1
    def describe_services(self, cluster_name, service_name):
18 1
        return self.boto.describe_services(
19
            cluster=cluster_name,
20
            services=[service_name]
21
        )
22
23 1
    def describe_task_definition(self, task_definition_arn):
24 1
        try:
25 1
            return self.boto.describe_task_definition(
26
                taskDefinition=task_definition_arn
27
            )
28 1
        except ClientError:
29 1
            raise UnknownTaskDefinitionError(
30
                u'Unknown task definition arn: %s' % task_definition_arn
31
            )
32
33 1
    def list_tasks(self, cluster_name, service_name):
34 1
        return self.boto.list_tasks(
35
            cluster=cluster_name,
36
            serviceName=service_name
37
        )
38
39 1
    def describe_tasks(self, cluster_name, task_arns):
40 1
        return self.boto.describe_tasks(cluster=cluster_name, tasks=task_arns)
41
42 1
    def register_task_definition(self, family, containers, volumes, role_arn):
43 1
        return self.boto.register_task_definition(
44
            family=family,
45
            containerDefinitions=containers,
46
            volumes=volumes,
47
            taskRoleArn=role_arn or u''
48
        )
49
50 1
    def deregister_task_definition(self, task_definition_arn):
51 1
        return self.boto.deregister_task_definition(
52
            taskDefinition=task_definition_arn
53
        )
54
55 1
    def update_service(self, cluster, service, desired_count, task_definition):
56 1
        return self.boto.update_service(
57
            cluster=cluster,
58
            service=service,
59
            desiredCount=desired_count,
60
            taskDefinition=task_definition
61
        )
62
63 1
    def run_task(self, cluster, task_definition, count, started_by, overrides):
64 1
        return self.boto.run_task(
65
            cluster=cluster,
66
            taskDefinition=task_definition,
67
            count=count,
68
            startedBy=started_by,
69
            overrides=overrides
70
        )
71
72
73 1
class EcsService(dict):
74 1
    def __init__(self, cluster, service_definition=None, **kwargs):
75 1
        self._cluster = cluster
76 1
        super(EcsService, self).__init__(service_definition, **kwargs)
77
78 1
    def set_desired_count(self, desired_count):
79 1
        self[u'desiredCount'] = desired_count
80
81 1
    def set_task_definition(self, task_definition):
82 1
        self[u'taskDefinition'] = task_definition.arn
83
84 1
    @property
85
    def cluster(self):
86 1
        return self._cluster
87
88 1
    @property
89
    def name(self):
90 1
        return self.get(u'serviceName')
91
92 1
    @property
93
    def task_definition(self):
94 1
        return self.get(u'taskDefinition')
95
96 1
    @property
97
    def desired_count(self):
98 1
        return self.get(u'desiredCount')
99
100 1
    @property
101
    def deployment_created_at(self):
102 1
        for deployment in self.get(u'deployments'):
103 1
            if deployment.get(u'status') == u'PRIMARY':
104 1
                return deployment.get(u'createdAt')
105 1
        return datetime.now()
106
107 1
    @property
108
    def deployment_updated_at(self):
109 1
        for deployment in self.get(u'deployments'):
110 1
            if deployment.get(u'status') == u'PRIMARY':
111 1
                return deployment.get(u'updatedAt')
112 1
        return datetime.now()
113
114 1
    @property
115
    def errors(self):
116 1
        return self.get_warnings(
117
            since=self.deployment_updated_at
118
        )
119
120 1
    @property
121
    def older_errors(self):
122 1
        return self.get_warnings(
123
            since=self.deployment_created_at,
124
            until=self.deployment_updated_at
125
        )
126
127 1
    def get_warnings(self, since=None, until=None):
128 1
        since = since or self.deployment_created_at
129 1
        until = until or datetime.now(tz=tzlocal())
130 1
        errors = {}
131 1
        for event in self.get(u'events'):
132 1
            if u'unable' not in event[u'message']:
133
                continue
134 1
            if since < event[u'createdAt'] < until:
135 1
                errors[event[u'createdAt']] = event[u'message']
136 1
        return errors
137
138
139 1
class EcsTaskDefinition(dict):
140 1
    def __init__(self, task_definition=None, **kwargs):
141 1
        super(EcsTaskDefinition, self).__init__(task_definition, **kwargs)
142 1
        self._diff = []
143
144 1
    @property
145
    def containers(self):
146 1
        return self.get(u'containerDefinitions')
147
148 1
    @property
149
    def container_names(self):
150 1
        for container in self.get(u'containerDefinitions'):
151 1
            yield container[u'name']
152
153 1
    @property
154
    def volumes(self):
155 1
        return self.get(u'volumes')
156
157 1
    @property
158
    def arn(self):
159 1
        return self.get(u'taskDefinitionArn')
160
161 1
    @property
162
    def family(self):
163 1
        return self.get(u'family')
164
165 1
    @property
166
    def role_arn(self):
167 1
        return self.get(u'taskRoleArn')
168
169 1
    @property
170
    def revision(self):
171 1
        return self.get(u'revision')
172
173 1
    @property
174
    def family_revision(self):
175 1
        return '%s:%d' % (self.get(u'family'), self.get(u'revision'))
176
177 1
    @property
178
    def diff(self):
179 1
        return self._diff
180
181 1
    def get_overrides(self):
182 1
        override = dict()
183 1
        overrides = []
184 1
        for diff in self.diff:
185 1
            if override.get('name') != diff.container:
186 1
                override = dict(name=diff.container)
187 1
                overrides.append(override)
188 1
            if diff.field == 'command':
189 1
                override['command'] = self.get_overrides_command(diff.value)
190 1
            elif diff.field == 'environment':
191 1
                override['environment'] = self.get_overrides_env(diff.value)
192 1
        return overrides
193
194 1
    @staticmethod
195
    def get_overrides_command(command):
196 1
        return command.split(' ')
197
198 1
    @staticmethod
199
    def get_overrides_env(env):
200 1
        return [{"name": e, "value": env[e]} for e in env]
201
202 1
    def set_images(self, tag=None, **images):
203 1
        self.validate_container_options(**images)
204 1
        for container in self.containers:
205 1
            if container[u'name'] in images:
206 1
                new_image = images[container[u'name']]
207 1
                diff = EcsTaskDefinitionDiff(
208
                    container=container[u'name'],
209
                    field=u'image',
210
                    value=new_image,
211
                    old_value=container[u'image']
212
                )
213 1
                self._diff.append(diff)
214 1
                container[u'image'] = new_image
215 1
            elif tag:
216 1 View Code Duplication
                image_definition = container[u'image'].rsplit(u':', 1)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
217 1
                new_image = u'%s:%s' % (image_definition[0], tag.strip())
218 1
                diff = EcsTaskDefinitionDiff(
219
                    container=container[u'name'],
220
                    field=u'image',
221
                    value=new_image,
222
                    old_value=container[u'image']
223
                )
224 1
                self._diff.append(diff)
225 1
                container[u'image'] = new_image
226
227 1
    def set_commands(self, **commands):
228 1
        self.validate_container_options(**commands)
229 1
        for container in self.containers:
230 1
            if container[u'name'] in commands:
231 1 View Code Duplication
                new_command = commands[container[u'name']]
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
232 1
                diff = EcsTaskDefinitionDiff(
233
                    container=container[u'name'],
234
                    field=u'command',
235
                    value=new_command,
236
                    old_value=container.get(u'command')
237
                )
238 1
                self._diff.append(diff)
239 1
                container[u'command'] = [new_command]
240
241 1
    def set_environment(self, environment_list):
242 1
        environment = {}
243
244 1
        for env in environment_list:
245 1
            environment.setdefault(env[0], {})
246 1
            environment[env[0]][env[1]] = env[2]
247
248 1
        self.validate_container_options(**environment)
249 1
        for container in self.containers:
250 1
            if container[u'name'] in environment:
251 1
                self.apply_container_environment(
252
                    container=container,
253
                    new_environment=environment[container[u'name']]
254
                )
255
256 1
    def apply_container_environment(self, container, new_environment):
257 1
        environment = container.get('environment', {})
258 1
        old_environment = {env['name']: env['value'] for env in environment}
259 1
        merged = old_environment.copy()
260 1
        merged.update(new_environment)
261
262 1
        diff = EcsTaskDefinitionDiff(
263
            container=container[u'name'],
264
            field=u'environment',
265
            value=merged,
266
            old_value=old_environment
267
        )
268 1
        self._diff.append(diff)
269
270 1
        container[u'environment'] = [
271
            {"name": e, "value": merged[e]} for e in merged
272
        ]
273
274 1
    def validate_container_options(self, **container_options):
275 1
        for container_name in container_options:
276 1
            if container_name not in self.container_names:
277 1
                raise UnknownContainerError(
278
                    u'Unknown container: %s' % container_name
279
                )
280
281 1
    def set_role_arn(self, role_arn):
282 1
        if role_arn:
283 1
            diff = EcsTaskDefinitionDiff(
284
                container=None,
285
                field=u'role_arn',
286
                value=role_arn,
287
                old_value=self[u'taskRoleArn']
288
            )
289 1
            self[u'taskRoleArn'] = role_arn
290 1
            self._diff.append(diff)
291
292
293 1
class EcsTaskDefinitionDiff(object):
294 1
    def __init__(self, container, field, value, old_value):
295 1
        self.container = container
296 1
        self.field = field
297 1
        self.value = value
298 1
        self.old_value = old_value
299
300 1
    def __repr__(self):
301 1
        if self.field == u'environment':
302 1
            return '\n'.join(self._get_environment_diffs(
303
                self.container,
304
                self.value,
305
                self.old_value,
306
            ))
307 1
        elif self.container:
308 1
            return u'Changed %s of container "%s" to: "%s" (was: "%s")' % (
309
                self.field,
310
                self.container,
311
                self.value,
312
                self.old_value
313
            )
314
        else:
315 1
            return u'Changed %s to: "%s" (was: "%s")' % (
316
                self.field,
317
                self.value,
318
                self.old_value
319
            )
320
321 1
    @staticmethod
322
    def _get_environment_diffs(container, env, old_env):
323 1
        msg = u'Changed environment "%s" of container "%s" to: "%s"'
324 1
        diffs = []
325 1
        for name, value in env.items():
326 1
            old_value = old_env.get(name)
327 1
            if value != old_value or not old_value:
328 1
                message = msg % (name, container, value)
329 1
                diffs.append(message)
330 1
        return diffs
331
332
333 1
class EcsAction(object):
334 1
    def __init__(self, client, cluster_name, service_name):
335 1
        self._client = client
336 1
        self._cluster_name = cluster_name
337 1
        self._service_name = service_name
338
339 1
        try:
340 1
            if service_name:
341 1
                self._service = self.get_service()
342 1
        except IndexError:
343 1
            raise EcsConnectionError(
344
                u'An error occurred when calling the DescribeServices '
345
                u'operation: Service not found.'
346
            )
347 1
        except ClientError as e:
348 1
            raise EcsConnectionError(str(e))
349 1
        except NoCredentialsError:
350 1
            raise EcsConnectionError(
351
                u'Unable to locate credentials. Configure credentials '
352
                u'by running "aws configure".'
353
            )
354
355 1
    def get_service(self):
356 1
        services_definition = self._client.describe_services(
357
            cluster_name=self._cluster_name,
358
            service_name=self._service_name
359
        )
360 1
        return EcsService(
361
            cluster=self._cluster_name,
362
            service_definition=services_definition[u'services'][0]
363
        )
364
365 1
    def get_current_task_definition(self, service):
366 1
        task_definition_payload = self._client.describe_task_definition(
367
            task_definition_arn=service.task_definition
368
        )
369 1
        task_definition = EcsTaskDefinition(
370
            task_definition=task_definition_payload[u'taskDefinition']
371
        )
372 1
        return task_definition
373
374 1
    def get_task_definition(self, task_definition):
375 1
        task_definition_payload = self._client.describe_task_definition(
376
            task_definition_arn=task_definition
377
        )
378 1
        task_definition = EcsTaskDefinition(
379
            task_definition=task_definition_payload[u'taskDefinition']
380
        )
381 1
        return task_definition
382
383 1
    def update_task_definition(self, task_definition):
384 1
        response = self._client.register_task_definition(
385
            family=task_definition.family,
386
            containers=task_definition.containers,
387
            volumes=task_definition.volumes,
388
            role_arn=task_definition.role_arn
389
        )
390 1
        new_task_definition = EcsTaskDefinition(response[u'taskDefinition'])
391 1
        self._client.deregister_task_definition(task_definition.arn)
392 1
        return new_task_definition
393
394 1
    def update_service(self, service):
395 1
        response = self._client.update_service(
396
            cluster=service.cluster,
397
            service=service.name,
398
            desired_count=service.desired_count,
399
            task_definition=service.task_definition
400
        )
401 1
        return EcsService(self._cluster_name, response[u'service'])
402
403 1
    def is_deployed(self, service):
404 1
        if len(service[u'deployments']) != 1:
405 1
            return False
406 1
        running_tasks = self._client.list_tasks(
407
            cluster_name=service.cluster,
408
            service_name=service.name
409
        )
410 1
        if not running_tasks[u'taskArns']:
411 1
            return service.desired_count == 0
412 1
        running_count = self.get_running_tasks_count(
413
            service=service,
414
            task_arns=running_tasks[u'taskArns']
415
        )
416 1
        return service.desired_count == running_count
417
418 1
    def get_running_tasks_count(self, service, task_arns):
419 1
        running_count = 0
420 1
        tasks_details = self._client.describe_tasks(
421
            cluster_name=self._cluster_name,
422
            task_arns=task_arns
423
        )
424 1
        for task in tasks_details[u'tasks']:
425 1
            arn = task[u'taskDefinitionArn']
426 1
            status = task[u'lastStatus']
427 1
            if arn == service.task_definition and status == u'RUNNING':
428 1
                running_count += 1
429 1
        return running_count
430
431 1
    @property
432
    def client(self):
433 1
        return self._client
434
435 1
    @property
436
    def service(self):
437 1
        return self._service
438
439 1
    @property
440
    def cluster_name(self):
441 1
        return self._cluster_name
442
443 1
    @property
444
    def service_name(self):
445 1
        return self._service_name
446
447
448 1
class DeployAction(EcsAction):
449 1
    def deploy(self, task_definition):
450 1
        self._service.set_task_definition(task_definition)
451 1
        return self.update_service(self._service)
452
453
454 1
class ScaleAction(EcsAction):
455 1
    def scale(self, desired_count):
456 1
        self._service.set_desired_count(desired_count)
457 1
        return self.update_service(self._service)
458
459
460 1
class RunAction(EcsAction):
461 1
    def __init__(self, client, cluster_name):
462 1
        super(RunAction, self).__init__(client, cluster_name, None)
463 1
        self._client = client
464 1
        self._cluster_name = cluster_name
465 1
        self.started_tasks = []
466
467 1
    def run(self, task_definition, count, started_by):
468 1
        result = self._client.run_task(
469
            cluster=self._cluster_name,
470
            task_definition=task_definition.family_revision,
471
            count=count,
472
            started_by=started_by,
473
            overrides=dict(containerOverrides=task_definition.get_overrides())
474
        )
475 1
        self.started_tasks = result['tasks']
476 1
        return True
477
478
479 1
class EcsError(Exception):
480 1
    pass
481
482
483 1
class EcsConnectionError(EcsError):
484 1
    pass
485
486
487 1
class UnknownContainerError(EcsError):
488 1
    pass
489
490
491 1
class TaskPlacementError(EcsError):
492 1
    pass
493
494
495 1
class UnknownTaskDefinitionError(EcsError):
496
    pass
497