Passed
Pull Request — master (#3179)
by Matěj
02:26
created

DockerTestEnv._create_new_image()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
from __future__ import print_function
2
3
import contextlib
4
import sys
5
6
import docker
7
8
import ssg_test_suite
9
from ssg_test_suite.virt import SnapshotStack
10
11
12
class TestEnv(object):
13
    def __init__(self):
14
        self.domain_ip = ""
15
16
    def start(self):
17
        pass
18
19
    def finalize(self):
20
        pass
21
22
    def _get_snapshot(self, snapshot_name):
23
        raise NotImplementedError()
24
25
    def _revert_snapshot(self, snapshot_name):
26
        raise NotImplementedError()
27
28
    @contextlib.contextmanager
29
    def in_layer(self, snapshot_name, delete=True):
30
        snapshot = self._get_snapshot(snapshot_name)
31
        exception_to_reraise = None
32
        try:
33
            yield snapshot
34
        except KeyboardInterrupt as exc:
35
            print("Hang on for a minute, cleaning up the snapshot '{0}'."
36
                  .format(snapshot_name), file=sys.stderr)
37
            exception_to_reraise = exc
38
        finally:
39
            try:
40
                self._revert_snapshot(snapshot)
41
            except KeyboardInterrupt as exc:
42
                print("Hang on for a minute, cleaning up the snapshot '{0}'."
43
                      .format(snapshot_name), file=sys.stderr)
44
                self._revert_snapshot(snapshot)
45
            finally:
46
                if exception_to_reraise:
47
                    raise exception_to_reraise
48
49
50
class VMTestEnv(TestEnv):
51
    name = "libvirt-based"
52
53
    def __init__(self, hypervisor, domain_name):
54
        super(VMTestEnv, self).__init__()
55
56
        self.hypervisor = hypervisor
57
        self.domain_name = domain_name
58
        self.snapshot_stack = None
59
60
    def start(self):
61
        dom = ssg_test_suite.virt.connect_domain(
62
            self.hypervisor, self.domain_name)
63
        self.snapshot_stack = SnapshotStack(dom)
64
65
        ssg_test_suite.virt.start_domain(dom)
66
        self.domain_ip = ssg_test_suite.virt.determine_ip(dom)
67
68
    def _get_snapshot(self, snapshot_name):
69
        return self.snapshot_stack.create(snapshot_name)
70
71
    def _revert_snapshot(self, snapshot):
72
        self.snapshot_stack.revert()
73
74
75
class DockerTestEnv(TestEnv):
76
    name = "container-based"
77
78
    def __init__(self, image_name):
79
        super(DockerTestEnv, self).__init__()
80
81
        self._name_stem = "ssg_test"
82
83
        try:
84
            self.client = docker.from_env(version="auto")
85
            self.client.ping()
86
        except Exception as exc:
87
            msg = (
88
                "Unable to start the Docker test environment, "
89
                "is the Docker service started "
90
                "and do you have rights to access it?"
91
                .format(str(exc)))
92
            raise RuntimeError(msg)
93
        self.base_image = image_name
94
        self.created_images = []
95
        self.containers = []
96
97
    @property
98
    def current_container(self):
99
        if self.containers:
100
            return self.containers[-1]
101
        return None
102
103
    def _create_new_image(self, from_container, name):
104
        new_image_name = "{0}_{1}".format(self.base_image, name)
105
        from_container.commit(repository=new_image_name)
106
        self.created_images.append(new_image_name)
107
        return new_image_name
108
109
    def _new_container(self, name):
110
        if self.containers:
111
            img = self._create_new_image(self.containers[-1], name)
112
        else:
113
            img = self.base_image
114
        return self.client.containers.run(
115
            img, "/usr/sbin/sshd -D",
116
            name="{0}_{1}".format(self._name_stem, name), ports={"22": None},
117
            detach=True)
118
119
    def _get_snapshot(self, snapshot_name):
120
        new_container = self._new_container(snapshot_name)
121
        self.containers.append(new_container)
122
123
        new_container.reload()
124
        self.domain_ip = new_container.attrs["NetworkSettings"]["Networks"]["bridge"]["IPAddress"]
125
126
        return new_container
127
128
    def _revert_snapshot(self, snapshot):
129
        snapshot.stop()
130
        snapshot.remove()
131
132
        assert snapshot == self.containers.pop()
133
134
        if self.created_images:
135
            associated_image = self.created_images.pop()
136
            self.client.images.remove(associated_image)
137