Passed
Pull Request — master (#311)
by
unknown
01:13
created

HelpFormattingParserTestCase.assert_snapshot()   A

Complexity

Conditions 3

Size

Total Lines 16
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 11
nop 3
dl 0
loc 16
rs 9.85
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright (C) 2019 Greenbone Networks GmbH
3
#
4
# SPDX-License-Identifier: GPL-3.0-or-later
5
#
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19
import os
20
import sys
21
import unittest
22
23
from unittest import mock
24
from pathlib import Path
25
26
from gvm.connections import (
27
    DEFAULT_UNIX_SOCKET_PATH,
28
    DEFAULT_TIMEOUT,
29
    UnixSocketConnection,
30
    TLSConnection,
31
    SSHConnection,
32
)
33
34
from gvmtools.parser import CliParser, create_parser, create_connection
35
from gvmtools.config import Config
36
37
from . import SuppressOutput
38
39
__here__ = Path(__file__).parent.resolve()
40
41
42
class ConfigParserTestCase(unittest.TestCase):
43
    def setUp(self):
44
        self.test_config_path = __here__ / 'test.cfg'
45
46
        self.assertTrue(self.test_config_path.is_file())
47
48
        self.parser = CliParser('TestParser', 'test.log')
49
50
    def test_socket_defaults_from_config(self):
51
        args = self.parser.parse_args(
52
            ['--config', str(self.test_config_path), 'socket']
53
        )
54
55
        self.assertEqual(args.foo, 'bar')
56
        self.assertEqual(args.timeout, 1000)
57
        self.assertEqual(args.gmp_password, 'bar')
58
        self.assertEqual(args.gmp_username, 'bar')
59
        self.assertEqual(args.socketpath, '/foo/bar.sock')
60
61
    def test_ssh_defaults_from_config(self):
62
        args = self.parser.parse_args(
63
            ['--config', str(self.test_config_path), 'ssh', '--hostname', 'foo']
64
        )
65
66
        self.assertEqual(args.foo, 'bar')
67
        self.assertEqual(args.timeout, 1000)
68
        self.assertEqual(args.gmp_password, 'bar')
69
        self.assertEqual(args.gmp_username, 'bar')
70
        self.assertEqual(args.ssh_password, 'lorem')
71
        self.assertEqual(args.ssh_username, 'ipsum')
72
        self.assertEqual(args.port, 123)
73
74
    def test_tls_defaults_from_config(self):
75
        args = self.parser.parse_args(
76
            ['--config', str(self.test_config_path), 'tls', '--hostname', 'foo']
77
        )
78
79
        self.assertEqual(args.foo, 'bar')
80
        self.assertEqual(args.timeout, 1000)
81
        self.assertEqual(args.gmp_password, 'bar')
82
        self.assertEqual(args.gmp_username, 'bar')
83
        self.assertEqual(args.certfile, 'foo.cert')
84
        self.assertEqual(args.keyfile, 'foo.key')
85
        self.assertEqual(args.cafile, 'foo.ca')
86
        self.assertEqual(args.port, 123)
87
88
    @mock.patch('gvmtools.parser.Path')
89
    def test_resolve_file_not_found_error(self, path_mock):
90
        def resolve_raises_error():
91
            raise FileNotFoundError
92
93
        configpath = mock.MagicMock()
94
        configpath.expanduser().resolve = mock.MagicMock(
95
            side_effect=resolve_raises_error
96
        )
97
        path_mock.return_value = configpath
98
99
        # pylint: disable=protected-access
100
        return_value = self.parser._load_config('foobar')
101
102
        self.assertTrue(isinstance(return_value, Config))
103
104
    @mock.patch('gvmtools.parser.Path')
105
    @mock.patch('gvmtools.parser.Config')
106
    def test_config_load_raises_error(self, config_mock, path_mock):
107
        def config_load_error():
108
            raise Exception
109
110
        config = mock.MagicMock()
111
        config.load = mock.MagicMock(side_effect=config_load_error)
112
        config_mock.return_value = config
113
114
        configpath = mock.Mock()
115
        configpath.expanduser().resolve().exists = mock.MagicMock(
116
            return_value=True
117
        )
118
        path_mock.return_value = configpath
119
120
        configfile = 'configfile'
121
122
        # pylint: disable=protected-access
123
        self.assertRaises(RuntimeError, self.parser._load_config, configfile)
124
125
126
class IgnoreConfigParserTestCase(unittest.TestCase):
127
    def test_unkown_config_file(self):
128
        test_config_path = __here__ / 'foo.cfg'
129
130
        self.assertFalse(test_config_path.is_file())
131
132
        self.parser = CliParser('TestParser', 'test.log')
133
134
        args = self.parser.parse_args(
135
            ['--config', str(test_config_path), 'socket']
136
        )
137
138
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
139
        self.assertEqual(args.gmp_password, '')
140
        self.assertEqual(args.gmp_username, '')
141
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
142
143
    def test_unkown_config_file_in_unkown_dir(self):
144
        test_config_path = __here__ / 'foo' / 'foo.cfg'
145
146
        self.assertFalse(test_config_path.is_file())
147
148
        self.parser = CliParser('TestParser', 'test.log')
149
150
        args = self.parser.parse_args(
151
            ['--config', str(test_config_path), 'socket']
152
        )
153
154
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
155
        self.assertEqual(args.gmp_password, '')
156
        self.assertEqual(args.gmp_username, '')
157
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
158
159
160
class ParserTestCase(unittest.TestCase):
161
    def setUp(self):
162
        self.parser = CliParser(
163
            'TestParser', 'test.log', ignore_config=True, prog='gvm-test-cli'
164
        )
165
166
167
class RootArgumentsParserTest(ParserTestCase):
168
    def test_config(self):
169
        args = self.parser.parse_args(['--config', 'foo.cfg', 'socket'])
170
        self.assertEqual(args.config, 'foo.cfg')
171
172
    def test_defaults(self):
173
        args = self.parser.parse_args(['socket'])
174
        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
175
        self.assertEqual(args.gmp_password, '')
176
        self.assertEqual(args.gmp_username, '')
177
        self.assertEqual(args.timeout, 60)
178
        self.assertIsNone(args.loglevel)
179
180
    def test_loglevel(self):
181
        args = self.parser.parse_args(['--log', 'ERROR', 'socket'])
182
        self.assertEqual(args.loglevel, 'ERROR')
183
184
    def test_loglevel_after_subparser(self):
185
        with SuppressOutput(suppress_stderr=True):
186
            with self.assertRaises(SystemExit):
187
                self.parser.parse_args(['socket', '--log', 'ERROR'])
188
189
    def test_timeout(self):
190
        args = self.parser.parse_args(['--timeout', '1000', 'socket'])
191
        self.assertEqual(args.timeout, 1000)
192
193
    def test_timeout_after_subparser(self):
194
        with SuppressOutput(suppress_stderr=True):
195
            with self.assertRaises(SystemExit):
196
                self.parser.parse_args(['socket', '--timeout', '1000'])
197
198
    def test_gmp_username(self):
199
        args = self.parser.parse_args(['--gmp-username', 'foo', 'socket'])
200
        self.assertEqual(args.gmp_username, 'foo')
201
202
    def test_gmp_username_after_subparser(self):
203
        with SuppressOutput(suppress_stderr=True):
204
            with self.assertRaises(SystemExit):
205
                self.parser.parse_args(['socket', '--gmp-username', 'foo'])
206
207
    def test_gmp_password(self):
208
        args = self.parser.parse_args(['--gmp-password', 'foo', 'socket'])
209
        self.assertEqual(args.gmp_password, 'foo')
210
211
    def test_gmp_password_after_subparser(self):
212
        with SuppressOutput(suppress_stderr=True):
213
            with self.assertRaises(SystemExit):
214
                self.parser.parse_args(['socket', '--gmp-password', 'foo'])
215
216
    def test_with_unknown_args(self):
217
        args, script_args = self.parser.parse_known_args(
218
            ['--gmp-password', 'foo', 'socket', '--bar', '--bar2']
219
        )
220
        self.assertEqual(args.gmp_password, 'foo')
221
        self.assertEqual(script_args, ['--bar', '--bar2'])
222
223
    @mock.patch('gvmtools.parser.logging')
224
    def test_socket_has_no_timeout(
225
        self, logging_mock
226
    ):  # pylint: disable=unused-argument
227
        # pylint: disable=protected-access
228
        self.parser._parser = mock.MagicMock()
229
        args_mock = mock.MagicMock()
230
        args_mock.timeout = -1
231
        self.parser._parser.parse_known_args = mock.MagicMock(
232
            return_value=(args_mock, mock.MagicMock())
233
        )
234
235
        args, _ = self.parser.parse_known_args(
236
            ['socket', '--timeout', '--', '-1']
237
        )
238
239
        self.assertIsNone(args.timeout)
240
241
    @mock.patch('gvmtools.parser.logging')
242
    @mock.patch('gvmtools.parser.argparse.ArgumentParser.print_usage')
243
    @mock.patch('gvmtools.parser.argparse.ArgumentParser._print_message')
244
    def test_no_args_provided(
245
        self, logging_mock, print_usage_mock, print_message
246
    ):  # pylint: disable=unused-argument
247
        # pylint: disable=protected-access
248
        self.parser._set_defaults = mock.MagicMock()
249
250
        self.assertRaises(SystemExit, self.parser.parse_known_args, None)
251
252
253
class SocketParserTestCase(ParserTestCase):
254
    def test_defaults(self):
255
        args = self.parser.parse_args(['socket'])
256
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
257
258
    def test_connection_type(self):
259
        args = self.parser.parse_args(['socket'])
260
        self.assertEqual(args.connection_type, 'socket')
261
262
    def test_sockpath(self):
263
        args = self.parser.parse_args(['socket', '--sockpath', 'foo.sock'])
264
        self.assertEqual(args.socketpath, 'foo.sock')
265
266
    def test_socketpath(self):
267
        args = self.parser.parse_args(['socket', '--socketpath', 'foo.sock'])
268
        self.assertEqual(args.socketpath, 'foo.sock')
269
270
271
class SshParserTestCase(ParserTestCase):
272
    def test_defaults(self):
273
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
274
        self.assertEqual(args.port, 22)
275
        self.assertEqual(args.ssh_username, 'gmp')
276
        self.assertEqual(args.ssh_password, 'gmp')
277
278
    def test_connection_type(self):
279
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
280
        self.assertEqual(args.connection_type, 'ssh')
281
282
    def test_hostname(self):
283
        args = self.parser.parse_args(['ssh', '--hostname', 'foo'])
284
        self.assertEqual(args.hostname, 'foo')
285
286
    def test_port(self):
287
        args = self.parser.parse_args(
288
            ['ssh', '--hostname', 'foo', '--port', '123']
289
        )
290
        self.assertEqual(args.port, 123)
291
292
    def test_ssh_username(self):
293
        args = self.parser.parse_args(
294
            ['ssh', '--hostname', 'foo', '--ssh-username', 'foo']
295
        )
296
        self.assertEqual(args.ssh_username, 'foo')
297
298
    def test_ssh_password(self):
299
        args = self.parser.parse_args(
300
            ['ssh', '--hostname', 'foo', '--ssh-password', 'foo']
301
        )
302
        self.assertEqual(args.ssh_password, 'foo')
303
304
305
class TlsParserTestCase(ParserTestCase):
306
    def test_defaults(self):
307
        args = self.parser.parse_args(['tls', '--hostname=foo'])
308
        self.assertIsNone(args.certfile)
309
        self.assertIsNone(args.keyfile)
310
        self.assertIsNone(args.cafile)
311
        self.assertEqual(args.port, 9390)
312
313
    def test_connection_type(self):
314
        args = self.parser.parse_args(['tls', '--hostname=foo'])
315
        self.assertEqual(args.connection_type, 'tls')
316
317
    def test_hostname(self):
318
        args = self.parser.parse_args(['tls', '--hostname', 'foo'])
319
        self.assertEqual(args.hostname, 'foo')
320
321
    def test_port(self):
322
        args = self.parser.parse_args(
323
            ['tls', '--hostname', 'foo', '--port', '123']
324
        )
325
        self.assertEqual(args.port, 123)
326
327
    def test_certfile(self):
328
        args = self.parser.parse_args(
329
            ['tls', '--hostname', 'foo', '--certfile', 'foo.cert']
330
        )
331
        self.assertEqual(args.certfile, 'foo.cert')
332
333
    def test_keyfile(self):
334
        args = self.parser.parse_args(
335
            ['tls', '--hostname', 'foo', '--keyfile', 'foo.key']
336
        )
337
        self.assertEqual(args.keyfile, 'foo.key')
338
339
    def test_cafile(self):
340
        args = self.parser.parse_args(
341
            ['tls', '--hostname', 'foo', '--cafile', 'foo.ca']
342
        )
343
        self.assertEqual(args.cafile, 'foo.ca')
344
345
    def test_no_credentials(self):
346
        args = self.parser.parse_args(
347
            ['tls', '--hostname', 'foo', '--no-credentials']
348
        )
349
        self.assertTrue(args.no_credentials)
350
351
352
class CustomizeParserTestCase(ParserTestCase):
353
    def test_add_optional_argument(self):
354
        self.parser.add_argument('--foo', type=int)
355
356
        args = self.parser.parse_args(['socket', '--foo', '123'])
357
        self.assertEqual(args.foo, 123)
358
359
        args = self.parser.parse_args(
360
            ['ssh', '--hostname', 'bar', '--foo', '123']
361
        )
362
        self.assertEqual(args.foo, 123)
363
364
        args = self.parser.parse_args(
365
            ['tls', '--hostname', 'bar', '--foo', '123']
366
        )
367
        self.assertEqual(args.foo, 123)
368
369
    def test_add_positional_argument(self):
370
        self.parser.add_argument('foo', type=int)
371
        args = self.parser.parse_args(['socket', '123'])
372
373
        self.assertEqual(args.foo, 123)
374
375
    def test_add_protocol_argument(self):
376
        self.parser.add_protocol_argument()
377
378
        args = self.parser.parse_args(['socket'])
379
        self.assertEqual(args.protocol, 'GMP')
380
381
        args = self.parser.parse_args(['--protocol', 'OSP', 'socket'])
382
383
        self.assertEqual(args.protocol, 'OSP')
384
385
386
class HelpFormattingParserTestCase(ParserTestCase):
387
    # pylint: disable=protected-access
388
    maxDiff = None
389
    python_version = '.'.join([str(i) for i in sys.version_info[:2]])
390
391
    def setUp(self):
392
        super().setUp()
393
394
        # ensure all tests are using the same terminal width
395
        self.columns = os.environ.get('COLUMNS')
396
        os.environ['COLUMNS'] = '80'
397
398
    def tearDown(self):
399
        super().tearDown()
400
401
        if not self.columns:
402
            del os.environ['COLUMNS']
403
        else:
404
            os.environ['COLUMNS'] = self.columns
405
406
    def _snapshot_specific_path(self, name):
407
        return __here__ / '{}.{}.snap'.format(name, self.python_version)
408
409
    def _snapshot_generic_path(self, name):
410
        return __here__ / '{}.snap'.format(name)
411
412
    def _snapshot_failed_path(self, name):
413
        return __here__ / '{}.{}-failed.snap'.format(name, self.python_version)
414
415
    def _snapshot_path(self, name):
416
        snapshot_specific_path = self._snapshot_specific_path(name)
417
418
        if snapshot_specific_path.exists():
419
            return snapshot_specific_path
420
421
        return self._snapshot_generic_path(name)
422
423
    def assert_snapshot(self, name, output):
424
        path = self._snapshot_path(name)
425
426
        if not path.exists():
427
            path.write_text(output)
428
429
        content = path.read_text()
430
431
        try:
432
            self.assertEqual(content, output, 'Snapshot differs from output')
433
        except AssertionError:
434
            # write new output to snapshot file
435
            # reraise error afterwards
436
            path = self._snapshot_failed_path(name)
437
            path.write_text(output)
438
            raise
439
440
    def test_root_help(self):
441
        help_output = self.parser._parser.format_help()
442
        self.assert_snapshot('root_help', help_output)
443
444
    def test_socket_help(self):
445
        help_output = self.parser._parser_socket.format_help()
446
        self.assert_snapshot('socket_help', help_output)
447
448
    def test_ssh_help(self):
449
        self.parser._set_defaults(None)
450
        help_output = self.parser._parser_ssh.format_help()
451
        self.assert_snapshot('ssh_help', help_output)
452
453
    def test_tls_help(self):
454
        self.parser._set_defaults(None)
455
        help_output = self.parser._parser_tls.format_help()
456
        self.assert_snapshot('tls_help', help_output)
457
458
459
class ParserModuleFunctionTestCase(unittest.TestCase):
460
    # pylint: disable=protected-access
461
    def test_create_parser(self):
462
        description = 'parser description'
463
        logfilename = 'logfilename'
464
465
        parser = create_parser(description, logfilename)
466
467
        self.assertTrue(isinstance(parser, CliParser))
468
        self.assertEqual(parser._logfilename, logfilename)
469
        self.assertEqual(parser._bootstrap_parser.description, description)
470
471
    @mock.patch('gvmtools.parser.TLSConnection')
472
    @mock.patch('gvmtools.parser.SSHConnection')
473
    def test_create_unix_socket_connection(
474
        self, *args
475
    ):  # pylint: disable=unused-argument
476
        self.perform_create_connection_test()
477
478
    @mock.patch('gvmtools.parser.UnixSocketConnection')
479
    @mock.patch('gvmtools.parser.SSHConnection')
480
    def test_create_tls_connection(
481
        self, *args
482
    ):  # pylint: disable=unused-argument
483
        self.perform_create_connection_test('tls', TLSConnection)
484
485
    @mock.patch('gvmtools.parser.UnixSocketConnection')
486
    @mock.patch('gvmtools.parser.TLSConnection')
487
    def test_create_ssh_connection(
488
        self, *args
489
    ):  # pylint: disable=unused-argument
490
        self.perform_create_connection_test('ssh', SSHConnection)
491
492
        connection = create_connection('ssh', port=123)
493
        self.assertTrue(isinstance(connection, SSHConnection))
494
495
    def perform_create_connection_test(
496
        self, connection_type='socket', connection_class=UnixSocketConnection
497
    ):
498
        connection = create_connection(connection_type)
499
        self.assertTrue(isinstance(connection, connection_class))
500