Passed
Push — master ( 42467a...9d011f )
by Matěj
03:19 queued 11s
created

RuleResult.relative_conditions_to()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 3
CRAP Score 2.0625

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 2
dl 0
loc 5
ccs 3
cts 4
cp 0.75
crap 2.0625
rs 10
c 0
b 0
f 0
1 1
from __future__ import print_function
2
3 1
import os
4 1
import logging
5 1
import subprocess
6 1
from collections import namedtuple
7 1
import functools
8 1
import tarfile
9 1
import tempfile
10 1
import re
11
12 1
from ssg.constants import MULTI_PLATFORM_MAPPING
13 1
from ssg.constants import PRODUCT_TO_CPE_MAPPING
14 1
from ssg.constants import FULL_NAME_TO_PRODUCT_MAPPING
15 1
from ssg.constants import OSCAP_RULE
16 1
from ssg_test_suite.log import LogHelper
17
18 1
Scenario_run = namedtuple(
19
    "Scenario_run",
20
    ("rule_id", "script"))
21 1
Scenario_conditions = namedtuple(
22
    "Scenario_conditions",
23
    ("backend", "scanning_mode", "remediated_by", "datastream"))
24 1
Rule = namedtuple(
25
    "Rule", ["directory", "id", "short_id", "files"])
26
27 1
_BENCHMARK_DIRS = [
28
        os.path.abspath(os.path.join(os.path.dirname(__file__), '../../linux_os/guide')),
29
        os.path.abspath(os.path.join(os.path.dirname(__file__), '../../applications')),
30
        ]
31
32 1
_SHARED_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../shared'))
33
34 1
REMOTE_USER = "root"
35 1
REMOTE_USER_HOME_DIRECTORY = "/root"
36 1
REMOTE_TEST_SCENARIOS_DIRECTORY = os.path.join(REMOTE_USER_HOME_DIRECTORY, "ssgts")
37
38 1
try:
39 1
    SSH_ADDITIONAL_OPTS = tuple(os.environ.get('SSH_ADDITIONAL_OPTIONS').split())
40 1
except AttributeError:
41
    # If SSH_ADDITIONAL_OPTIONS is not defined set it to empty tuple.
42 1
    SSH_ADDITIONAL_OPTS = tuple()
43
44 1
SSH_ADDITIONAL_OPTS = (
45
    "-o", "StrictHostKeyChecking=no",
46
    "-o", "UserKnownHostsFile=/dev/null",
47
) + SSH_ADDITIONAL_OPTS
48
49
50 1
def walk_through_benchmark_dirs():
51
    for dirname in _BENCHMARK_DIRS:
52
        for dirpath, dirnames, filenames in os.walk(dirname):
53
            yield dirpath, dirnames, filenames
54
55
56 1
class Stage(object):
57 1
    NONE = 0
58 1
    PREPARATION = 1
59 1
    INITIAL_SCAN = 2
60 1
    REMEDIATION = 3
61 1
    FINAL_SCAN = 4
62
63
64 1
@functools.total_ordering
65 1
class RuleResult(object):
66 1
    STAGE_STRINGS = {
67
        "preparation",
68
        "initial_scan",
69
        "remediation",
70
        "final_scan",
71
    }
72
73
    """
74
    Result of a test suite testing rule under a scenario.
75
76
    Supports ordering by success - the most successful run orders first.
77
    """
78 1
    def __init__(self, result_dict=None):
79 1
        self.scenario = Scenario_run("", "")
80 1
        self.conditions = Scenario_conditions("", "", "", "")
81 1
        self.when = ""
82 1
        self.passed_stages = dict()
83 1
        self.passed_stages_count = 0
84 1
        self.success = False
85
86 1
        if result_dict:
87 1
            self.load_from_dict(result_dict)
88
89 1
    def load_from_dict(self, data):
90 1
        self.scenario = Scenario_run(data["rule_id"], data["scenario_script"])
91 1
        self.conditions = Scenario_conditions(
92
            data["backend"], data["scanning_mode"],
93
            data["remediated_by"], data["datastream"])
94 1
        self.when = data["run_timestamp"]
95
96 1
        self.passed_stages = {key: data[key] for key in self.STAGE_STRINGS if key in data}
97 1
        self.passed_stages_count = sum(self.passed_stages.values())
98
99 1
        self.success = data.get("final_scan", False)
100 1
        if not self.success:
101 1
            self.success = (
102
                "remediation" not in data
103
                and data.get("initial_scan", False))
104
105 1
    def save_to_dict(self):
106 1
        data = dict()
107 1
        data["rule_id"] = self.scenario.rule_id
108 1
        data["scenario_script"] = self.scenario.script
109
110 1
        data["backend"] = self.conditions.backend
111 1
        data["scanning_mode"] = self.conditions.scanning_mode
112 1
        data["remediated_by"] = self.conditions.remediated_by
113 1
        data["datastream"] = self.conditions.datastream
114
115 1
        data["run_timestamp"] = self.when
116
117 1
        for stage_str, result in self.passed_stages.items():
118 1
            data[stage_str] = result
119
120 1
        return data
121
122 1
    def record_stage_result(self, stage, successful):
123
        assert stage in self.STAGE_STRINGS, (
124
            "Stage name {name} is invalid, choose one from {choices}"
125
            .format(name=stage, choices=", ".join(self.STAGE_STRINGS))
126
        )
127
        self.passed_stages[stage] = successful
128
129 1
    def relative_conditions_to(self, other):
130 1
        if self.conditions == other.conditions:
131
            return self.when, other.when
132
        else:
133 1
            return tuple(self.conditions), tuple(other.conditions)
134
135 1
    def __eq__(self, other):
136 1
        return (self.success == other.success
137
                and tuple(self.passed_stages) == tuple(self.passed_stages))
138
139 1
    def __lt__(self, other):
140 1
        return self.passed_stages_count > other.passed_stages_count
141
142
143 1
def run_cmd_local(command, verbose_path, env=None):
144
    command_string = ' '.join(command)
145
    logging.debug('Running {}'.format(command_string))
146
    returncode, output = _run_cmd(command, verbose_path, env)
147
    return returncode, output
148
149
150 1
def _run_cmd(command_list, verbose_path, env=None):
151
    returncode = 0
152
    output = b""
153
    try:
154
        with open(verbose_path, 'w') as verbose_file:
155
            output = subprocess.check_output(
156
                command_list, stderr=verbose_file, env=env)
157
    except subprocess.CalledProcessError as e:
158
        returncode = e.returncode
159
        output = e.output
160
    return returncode, output.decode('utf-8')
161
162
163 1
def _get_platform_cpes(platform):
164 1
    if platform.startswith("multi_platform_"):
165 1
        try:
166 1
            products = MULTI_PLATFORM_MAPPING[platform]
167 1
        except KeyError:
168 1
            logging.error(
169
                "Unknown multi_platform specifier: %s is not from %s"
170
                % (platform, ", ".join(MULTI_PLATFORM_MAPPING.keys())))
171 1
            raise ValueError
172 1
        platform_cpes = set()
173 1
        for p in products:
174 1
            platform_cpes |= set(PRODUCT_TO_CPE_MAPPING[p])
175 1
        return platform_cpes
176
    else:
177
        # scenario platform is specified by a full product name
178 1
        try:
179 1
            product = FULL_NAME_TO_PRODUCT_MAPPING[platform]
180 1
        except KeyError:
181 1
            logging.error(
182
                "Unknown product name: %s is not from %s"
183
                % (platform, ", ".join(FULL_NAME_TO_PRODUCT_MAPPING.keys())))
184 1
            raise ValueError
185 1
        platform_cpes = set(PRODUCT_TO_CPE_MAPPING[product])
186 1
        return platform_cpes
187
188
189 1
def matches_platform(scenario_platforms, benchmark_cpes):
190 1
    if "multi_platform_all" in scenario_platforms:
191 1
        return True
192 1
    scenario_cpes = set()
193 1
    for p in scenario_platforms:
194 1
        scenario_cpes |= _get_platform_cpes(p)
195 1
    return len(scenario_cpes & benchmark_cpes) > 0
196
197
198 1
def run_with_stdout_logging(command, args, log_file):
199
    log_file.write("{0} {1}\n".format(command, " ".join(args)))
200
    result = subprocess.run(
201
            (command,) + args, encoding="utf-8", stdout=subprocess.PIPE,
202
            stderr=subprocess.PIPE, check=True)
203
    if result.stdout:
204
        log_file.write("STDOUT: ")
205
        log_file.write(result.stdout)
206
    if result.stderr:
207
        log_file.write("STDERR: ")
208
        log_file.write(result.stderr)
209
    return result.stdout
210
211
212 1
def _exclude_garbage(tarinfo):
213
    file_name = tarinfo.name
214
    if file_name.endswith('pyc'):
215
        return None
216
    if file_name.endswith('swp'):
217
        return None
218
    return tarinfo
219
220
221 1
def _make_file_root_owned(tarinfo):
222
    if tarinfo:
223
        tarinfo.uid = 0
224
        tarinfo.gid = 0
225
    return tarinfo
226
227
228 1
def create_tarball():
229
    """Create a tarball which contains all test scenarios for every rule.
230
    Tarball contains directories with the test scenarios. The name of the
231
    directories is the same as short rule ID. There is no tree structure.
232
    """
233
    with tempfile.NamedTemporaryFile(
234
            "wb", suffix=".tar.gz", delete=False) as fp:
235
        with tarfile.TarFile.open(fileobj=fp, mode="w") as tarball:
236
            tarball.add(_SHARED_DIR, arcname="shared", filter=_make_file_root_owned)
237
            for dirpath, dirnames, _ in walk_through_benchmark_dirs():
238
                rule_id = os.path.basename(dirpath)
239
                if "tests" in dirnames:
240
                    tests_dir_path = os.path.join(dirpath, "tests")
241
                    tarball.add(
242
                        tests_dir_path, arcname=rule_id,
243
                        filter=lambda tinfo: _exclude_garbage(_make_file_root_owned(tinfo))
244
                    )
245
        return fp.name
246
247
248 1
def send_scripts(test_env):
249
    remote_dir = REMOTE_TEST_SCENARIOS_DIRECTORY
250
    archive_file = create_tarball()
251
    archive_file_basename = os.path.basename(archive_file)
252
    remote_archive_file = os.path.join(remote_dir, archive_file_basename)
253
    logging.debug("Uploading scripts.")
254
    log_file_name = os.path.join(LogHelper.LOG_DIR, "env-preparation.log")
255
256
    with open(log_file_name, 'a') as log_file:
257
        print("Setting up test setup scripts", file=log_file)
258
259
        test_env.execute_ssh_command(
260
            "mkdir -p {remote_dir}".format(remote_dir=remote_dir),
261
            log_file, "Cannot create directory {0}".format(remote_dir))
262
        test_env.scp_upload_file(
263
            archive_file, remote_dir,
264
            log_file, "Cannot copy archive {0} to the target machine's directory {1}"
265
            .format(archive_file, remote_dir))
266
        test_env.execute_ssh_command(
267
            "tar xf {remote_archive_file} -C {remote_dir}"
268
            .format(remote_dir=remote_dir, remote_archive_file=remote_archive_file),
269
            log_file, "Cannot extract data tarball {0}".format(remote_archive_file))
270
    os.unlink(archive_file)
271
    return remote_dir
272
273
274 1
def iterate_over_rules():
275
    """Iterate over rule directories which have test scenarios".
276
277
    Returns:
278
        Named tuple Rule having these fields:
279
            directory -- absolute path to the rule "tests" subdirectory
280
                         containing the test scenarios in Bash
281
            id -- full rule id as it is present in datastream
282
            short_id -- short rule ID, the same as basename of the directory
283
                        containing the test scenarios in Bash
284
            files -- list of executable .sh files in the "tests" directory
285
    """
286
    for dirpath, dirnames, filenames in walk_through_benchmark_dirs():
287
        if "rule.yml" in filenames and "tests" in dirnames:
288
            short_rule_id = os.path.basename(dirpath)
289
            tests_dir = os.path.join(dirpath, "tests")
290
            tests_dir_files = os.listdir(tests_dir)
291
            # Filter out everything except the shell test scenarios.
292
            # Other files in rule directories are editor swap files
293
            # or other content than a test case.
294
            scripts = filter(lambda x: x.endswith(".sh"), tests_dir_files)
295
            full_rule_id = OSCAP_RULE + short_rule_id
296
            result = Rule(
297
                directory=tests_dir, id=full_rule_id, short_id=short_rule_id,
298
                files=scripts)
299
            yield result
300
301
302 1
def get_cpe_of_tested_os(test_env, log_file):
303
    os_release_file = "/etc/os-release"
304
    cpe_line = test_env.execute_ssh_command(
305
        "grep CPE_NAME {os_release_file}".format(os_release_file=os_release_file),
306
        log_file)
307
    # We are parsing an assignment that is possibly quoted
308
    cpe = re.match(r'''CPE_NAME=(["']?)(.*)\1''', cpe_line)
309
    if cpe and cpe.groups()[1]:
310
        return cpe.groups()[1]
311
    msg = ["Unable to get a CPE of the system running tests"]
312
    if cpe_line:
313
        msg.append(
314
            "Retreived a CPE line that we couldn't parse: {cpe_line}"
315
            .format(cpe_line=cpe_line))
316
    else:
317
        msg.append(
318
            "Couldn't get CPE entry from '{os_release_file}'"
319
            .format(os_release_file=os_release_file))
320
    raise RuntimeError("\n".join(msg))
321
322
323 1
INSTALL_COMMANDS = dict(
324
    fedora=("dnf", "install", "-y"),
325
    rhel7=("yum", "install", "-y"),
326
    rhel8=("yum", "install", "-y"),
327
)
328
329
330 1
def install_packages(test_env, packages):
331
    log_file_name = os.path.join(LogHelper.LOG_DIR, "env-preparation.log")
332
333
    with open(log_file_name, "a") as log_file:
334
        platform_cpe = get_cpe_of_tested_os(test_env, log_file)
335
    platform = cpes_to_platform([platform_cpe])
336
337
    command_str = " ".join(INSTALL_COMMANDS[platform] + tuple(packages))
338
339
    with open(log_file_name, 'a') as log_file:
340
        print("Installing packages", file=log_file)
341
        log_file.flush()
342
        test_env.execute_ssh_command(
343
            command_str, log_file,
344
            "Couldn't install required packages {packages}".format(packages=packages))
345
346
347 1
def cpes_to_platform(cpes):
348
    for cpe in cpes:
349
        if "fedora" in cpe:
350
            return "fedora"
351
        if "redhat:enterprise_linux" in cpe:
352
            match = re.search(r":enterprise_linux:([^:]+):", cpe)
353
            if match:
354
                major_version = match.groups()[0].split(".")[0]
355
                return "rhel" + major_version
356
    msg = "Unable to deduce a platform from these CPEs: {cpes}".format(cpes=cpes)
357
    raise ValueError(msg)
358