Passed
Push — master ( 041d2f...63c4fa )
by Jan
02:27 queued 11s
created

ssg_test_suite.common.run_with_stdout_logging()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 1.2963

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 3
dl 0
loc 4
ccs 1
cts 3
cp 0.3333
crap 1.2963
rs 10
c 0
b 0
f 0
1 1
from __future__ import print_function
2
3 1
import os
4 1
import logging
5 1
import subprocess
6 1
from collections import namedtuple
7 1
import functools
8 1
import tarfile
9 1
import tempfile
10 1
import re
11
12 1
from ssg.constants import MULTI_PLATFORM_MAPPING
13 1
from ssg.constants import PRODUCT_TO_CPE_MAPPING
14 1
from ssg.constants import FULL_NAME_TO_PRODUCT_MAPPING
15 1
from ssg.constants import OSCAP_RULE
16 1
from ssg_test_suite.log import LogHelper
17
18 1
Scenario_run = namedtuple(
19
    "Scenario_run",
20
    ("rule_id", "script"))
21 1
Scenario_conditions = namedtuple(
22
    "Scenario_conditions",
23
    ("backend", "scanning_mode", "remediated_by", "datastream"))
24 1
Rule = namedtuple(
25
    "Rule", ["directory", "id", "short_id", "files"])
26
27 1
_BENCHMARK_DIRS = [
28
        os.path.abspath(os.path.join(os.path.dirname(__file__), '../../linux_os/guide')),
29
        os.path.abspath(os.path.join(os.path.dirname(__file__), '../../applications')),
30
        ]
31
32 1
_SHARED_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../shared'))
33
34 1
REMOTE_USER = "root"
35 1
REMOTE_USER_HOME_DIRECTORY = "/root"
36 1
REMOTE_TEST_SCENARIOS_DIRECTORY = os.path.join(REMOTE_USER_HOME_DIRECTORY, "ssgts")
37
38 1
try:
39 1
    SSH_ADDITIONAL_OPTS = tuple(os.environ.get('SSH_ADDITIONAL_OPTIONS').split())
40 1
except AttributeError:
41
    # If SSH_ADDITIONAL_OPTIONS is not defined set it to empty tuple.
42 1
    SSH_ADDITIONAL_OPTS = tuple()
43
44 1
SSH_ADDITIONAL_OPTS = (
45
    "-o", "StrictHostKeyChecking=no",
46
    "-o", "UserKnownHostsFile=/dev/null",
47
) + SSH_ADDITIONAL_OPTS
48
49
50 1
def walk_through_benchmark_dirs():
51
    for dirname in _BENCHMARK_DIRS:
52
        for dirpath, dirnames, filenames in os.walk(dirname):
53
            yield dirpath, dirnames, filenames
54
55
56 1
class Stage(object):
57 1
    NONE = 0
58 1
    PREPARATION = 1
59 1
    INITIAL_SCAN = 2
60 1
    REMEDIATION = 3
61 1
    FINAL_SCAN = 4
62
63
64 1
@functools.total_ordering
65 1
class RuleResult(object):
66 1
    STAGE_STRINGS = {
67
        "preparation",
68
        "initial_scan",
69
        "remediation",
70
        "final_scan",
71
    }
72
73
    """
74
    Result of a test suite testing rule under a scenario.
75
76
    Supports ordering by success - the most successful run orders first.
77
    """
78 1
    def __init__(self, result_dict=None):
79 1
        self.scenario = Scenario_run("", "")
80 1
        self.conditions = Scenario_conditions("", "", "", "")
81 1
        self.when = ""
82 1
        self.passed_stages = dict()
83 1
        self.passed_stages_count = 0
84 1
        self.success = False
85
86 1
        if result_dict:
87 1
            self.load_from_dict(result_dict)
88
89 1
    def load_from_dict(self, data):
90 1
        self.scenario = Scenario_run(data["rule_id"], data["scenario_script"])
91 1
        self.conditions = Scenario_conditions(
92
            data["backend"], data["scanning_mode"],
93
            data["remediated_by"], data["datastream"])
94 1
        self.when = data["run_timestamp"]
95
96 1
        self.passed_stages = {key: data[key] for key in self.STAGE_STRINGS if key in data}
97 1
        self.passed_stages_count = sum(self.passed_stages.values())
98
99 1
        self.success = data.get("final_scan", False)
100 1
        if not self.success:
101 1
            self.success = (
102
                "remediation" not in data
103
                and data.get("initial_scan", False))
104
105 1
    def save_to_dict(self):
106 1
        data = dict()
107 1
        data["rule_id"] = self.scenario.rule_id
108 1
        data["scenario_script"] = self.scenario.script
109
110 1
        data["backend"] = self.conditions.backend
111 1
        data["scanning_mode"] = self.conditions.scanning_mode
112 1
        data["remediated_by"] = self.conditions.remediated_by
113 1
        data["datastream"] = self.conditions.datastream
114
115 1
        data["run_timestamp"] = self.when
116
117 1
        for stage_str, result in self.passed_stages.items():
118 1
            data[stage_str] = result
119
120 1
        return data
121
122 1
    def record_stage_result(self, stage, successful):
123
        assert stage in self.STAGE_STRINGS, (
124
            "Stage name {name} is invalid, choose one from {choices}"
125
            .format(name=stage, choices=", ".join(self.STAGE_STRINGS))
126
        )
127
        self.passed_stages[stage] = successful
128
129 1
    def relative_conditions_to(self, other):
130 1
        if self.conditions == other.conditions:
131
            return self.when, other.when
132
        else:
133 1
            return tuple(self.conditions), tuple(other.conditions)
134
135 1
    def __eq__(self, other):
136 1
        return (self.success == other.success
137
                and tuple(self.passed_stages) == tuple(self.passed_stages))
138
139 1
    def __lt__(self, other):
140 1
        return self.passed_stages_count > other.passed_stages_count
141
142
143 1
def run_cmd_local(command, verbose_path, env=None):
144
    command_string = ' '.join(command)
145
    logging.debug('Running {}'.format(command_string))
146
    returncode, output = _run_cmd(command, verbose_path, env)
147
    return returncode, output
148
149
150 1
def run_cmd_remote(command_string, domain_ip, verbose_path, env=None):
151
    machine = '{0}@{1}'.format(REMOTE_USER, domain_ip)
152
    remote_cmd = ['ssh'] + list(SSH_ADDITIONAL_OPTS) + [machine, command_string]
153
    logging.debug('Running {}'.format(command_string))
154
    returncode, output = _run_cmd(remote_cmd, verbose_path, env)
155
    return returncode, output
156
157
158 1
def _run_cmd(command_list, verbose_path, env=None):
159
    returncode = 0
160
    output = b""
161
    try:
162
        with open(verbose_path, 'w') as verbose_file:
163
            output = subprocess.check_output(
164
                command_list, stderr=verbose_file, env=env)
165
    except subprocess.CalledProcessError as e:
166
        returncode = e.returncode
167
        output = e.output
168
    return returncode, output.decode('utf-8')
169
170
171 1
def _get_platform_cpes(platform):
172 1
    if platform.startswith("multi_platform_"):
173 1
        try:
174 1
            products = MULTI_PLATFORM_MAPPING[platform]
175 1
        except KeyError:
176 1
            logging.error(
177
                "Unknown multi_platform specifier: %s is not from %s"
178
                % (platform, ", ".join(MULTI_PLATFORM_MAPPING.keys())))
179 1
            raise ValueError
180 1
        platform_cpes = set()
181 1
        for p in products:
182 1
            platform_cpes |= set(PRODUCT_TO_CPE_MAPPING[p])
183 1
        return platform_cpes
184
    else:
185
        # scenario platform is specified by a full product name
186 1
        try:
187 1
            product = FULL_NAME_TO_PRODUCT_MAPPING[platform]
188 1
        except KeyError:
189 1
            logging.error(
190
                "Unknown product name: %s is not from %s"
191
                % (platform, ", ".join(FULL_NAME_TO_PRODUCT_MAPPING.keys())))
192 1
            raise ValueError
193 1
        platform_cpes = set(PRODUCT_TO_CPE_MAPPING[product])
194 1
        return platform_cpes
195
196
197 1
def matches_platform(scenario_platforms, benchmark_cpes):
198 1
    if "multi_platform_all" in scenario_platforms:
199 1
        return True
200 1
    scenario_cpes = set()
201 1
    for p in scenario_platforms:
202 1
        scenario_cpes |= _get_platform_cpes(p)
203 1
    return len(scenario_cpes & benchmark_cpes) > 0
204
205
206 1
def run_with_stdout_logging(command, args, log_file):
207
    log_file.write("{0} {1}\n".format(command, " ".join(args)))
208
    subprocess.check_call(
209
        (command,) + args, stdout=log_file, stderr=subprocess.STDOUT)
210
211
212 1
def _exclude_garbage(tarinfo):
213
    file_name = tarinfo.name
214
    if file_name.endswith('pyc'):
215
        return None
216
    if file_name.endswith('swp'):
217
        return None
218
    return tarinfo
219
220
221 1
def _make_file_root_owned(tarinfo):
222
    if tarinfo:
223
        tarinfo.uid = 0
224
        tarinfo.gid = 0
225
    return tarinfo
226
227
228 1
def create_tarball():
229
    """Create a tarball which contains all test scenarios for every rule.
230
    Tarball contains directories with the test scenarios. The name of the
231
    directories is the same as short rule ID. There is no tree structure.
232
    """
233
    with tempfile.NamedTemporaryFile(
234
            "wb", suffix=".tar.gz", delete=False) as fp:
235
        with tarfile.TarFile.open(fileobj=fp, mode="w") as tarball:
236
            tarball.add(_SHARED_DIR, arcname="shared", filter=_make_file_root_owned)
237
            for dirpath, dirnames, _ in walk_through_benchmark_dirs():
238
                rule_id = os.path.basename(dirpath)
239
                if "tests" in dirnames:
240
                    tests_dir_path = os.path.join(dirpath, "tests")
241
                    tarball.add(
242
                        tests_dir_path, arcname=rule_id,
243
                        filter=lambda tinfo: _exclude_garbage(_make_file_root_owned(tinfo))
244
                    )
245
        return fp.name
246
247
248 1
def execute_remote_command(machine, args, log_file, error_msg=""):
249
    if not error_msg:
250
        error_msg = (
251
            "Failed to execute '{cmd}' on {machine}"
252
            .format(cmd=" ".join(args), machine=machine))
253
    try:
254
        run_with_stdout_logging("ssh", SSH_ADDITIONAL_OPTS + (machine,) + args, log_file)
255
    except Exception as exc:
256
        logging.error(error_msg + ": " + str(exc))
257
        raise RuntimeError(error_msg)
258
259
260 1
def copy_file_to(machine, what, dest, log_file, error_msg=""):
261
    scp_dest = "{machine}:{dest}".format(machine=machine, dest=dest)
262
    if not error_msg:
263
        error_msg = (
264
            "Failed to copy {what} to {scp_dest}"
265
            .format(what=what, scp_dest=scp_dest))
266
    try:
267
        run_with_stdout_logging("scp", SSH_ADDITIONAL_OPTS + (what, scp_dest), log_file)
268
    except Exception as exc:
269
        error_msg = error_msg + ": " + str(exc)
270
        logging.error(error_msg)
271
        raise RuntimeError(error_msg)
272
273
274 1
def send_scripts(domain_ip):
275
    remote_dir = REMOTE_TEST_SCENARIOS_DIRECTORY
276
    archive_file = create_tarball()
277
    archive_file_basename = os.path.basename(archive_file)
278
    remote_archive_file = os.path.join(remote_dir, archive_file_basename)
279
    machine = "{0}@{1}".format(REMOTE_USER, domain_ip)
280
    logging.debug("Uploading scripts.")
281
    log_file_name = os.path.join(LogHelper.LOG_DIR, "env-preparation.log")
282
283
    with open(log_file_name, 'a') as log_file:
284
        print("Setting up test setup scripts", file=log_file)
285
286
        execute_remote_command(
287
            machine, ("mkdir", "-p", remote_dir),
288
            log_file, "Cannot create directory {0}".format(remote_dir))
289
290
        copy_file_to(
291
            machine, archive_file, remote_dir,
292
            log_file, "Cannot copy archive {0} to the target machine's directory {1}"
293
            .format(archive_file, remote_dir))
294
295
        execute_remote_command(
296
            machine, ("tar", "xf", remote_archive_file, "-C", remote_dir),
297
            log_file, "Cannot extract data tarball {0}".format(remote_archive_file))
298
    os.unlink(archive_file)
299
    return remote_dir
300
301
302 1
def iterate_over_rules():
303
    """Iterate over rule directories which have test scenarios".
304
305
    Returns:
306
        Named tuple Rule having these fields:
307
            directory -- absolute path to the rule "tests" subdirectory
308
                         containing the test scenarios in Bash
309
            id -- full rule id as it is present in datastream
310
            short_id -- short rule ID, the same as basename of the directory
311
                        containing the test scenarios in Bash
312
            files -- list of executable .sh files in the "tests" directory
313
    """
314
    for dirpath, dirnames, filenames in walk_through_benchmark_dirs():
315
        if "rule.yml" in filenames and "tests" in dirnames:
316
            short_rule_id = os.path.basename(dirpath)
317
            tests_dir = os.path.join(dirpath, "tests")
318
            tests_dir_files = os.listdir(tests_dir)
319
            # Filter out everything except the shell test scenarios.
320
            # Other files in rule directories are editor swap files
321
            # or other content than a test case.
322
            scripts = filter(lambda x: x.endswith(".sh"), tests_dir_files)
323
            full_rule_id = OSCAP_RULE + short_rule_id
324
            result = Rule(
325
                directory=tests_dir, id=full_rule_id, short_id=short_rule_id,
326
                files=scripts)
327
            yield result
328
329
330 1
def get_cpe_of_tested_os(domain_ip, logfile_name):
331
    os_release_file = "/etc/os-release"
332
    ret, cpe_line = run_cmd_remote(
333
        "grep CPE_NAME {os_release_file}".format(os_release_file=os_release_file),
334
        domain_ip, logfile_name)
335
    # We are parsing an assignment that is possibly quoted
336
    cpe = re.match(r'''CPE_NAME=(["']?)(.*)\1''', cpe_line)
337
    if cpe and cpe.groups()[1]:
338
        return cpe.groups()[1]
339
    msg = ["Unable to get a CPE of the system running tests"]
340
    if cpe_line:
341
        msg.append(
342
            "Retreived a CPE line that we couldn't parse: {cpe_line}"
343
            .format(cpe_line=cpe_line))
344
    else:
345
        msg.append(
346
            "Couldn't get CPE entry from '{os_release_file}'"
347
            .format(os_release_file=os_release_file))
348
    raise RuntimeError("\n".join(msg))
349
350
351 1
INSTALL_COMMANDS = dict(
352
    fedora=("dnf", "install", "-y"),
353
    rhel7=("yum", "install", "-y"),
354
    rhel8=("yum", "install", "-y"),
355
)
356
357
358 1
def install_packages(domain_ip, packages):
359
    machine = "{0}@{1}".format(REMOTE_USER, domain_ip)
360
    log_file_name = os.path.join(LogHelper.LOG_DIR, "env-preparation.log")
361
362
    platform_cpe = get_cpe_of_tested_os(domain_ip, log_file_name)
363
    platform = cpes_to_platform([platform_cpe])
364
365
    with open(log_file_name, 'a') as log_file:
366
        print("Installing packages", file=log_file)
367
        log_file.flush()
368
        execute_remote_command(
369
            machine, INSTALL_COMMANDS[platform] + tuple(packages), log_file,
370
            "Couldn't install required packages {packages}".format(packages=packages))
371
372
373 1
def cpes_to_platform(cpes):
374
    for cpe in cpes:
375
        if "fedora" in cpe:
376
            return "fedora"
377
        if "redhat:enterprise_linux" in cpe:
378
            match = re.search(r":enterprise_linux:([^:]+):", cpe)
379
            if match:
380
                major_version = match.groups()[0].split(".")[0]
381
                return "rhel" + major_version
382
    msg = "Unable to deduce a platform from these CPEs: {cpes}".format(cpes=cpes)
383
    raise ValueError(msg)
384