Completed
Push — master ( 8e798a...2462d8 )
by Björn
28s queued 12s
created

CreateParserFunctionTestCase.test_create_parser()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright (C) 2019 Greenbone Networks GmbH
3
#
4
# SPDX-License-Identifier: GPL-3.0-or-later
5
#
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19
import os
20
import sys
21
import unittest
22
23
from unittest.mock import patch
24
from pathlib import Path
25
26
from argparse import Namespace
27
from gvm.connections import (
28
    DEFAULT_UNIX_SOCKET_PATH,
29
    DEFAULT_TIMEOUT,
30
    UnixSocketConnection,
31
    TLSConnection,
32
    SSHConnection,
33
)
34
35
from gvmtools.parser import CliParser, create_parser, create_connection
36
37
from . import SuppressOutput
38
39
__here__ = Path(__file__).parent.resolve()
40
41
42
class ConfigParserTestCase(unittest.TestCase):
43
    def setUp(self):
44
        self.test_config_path = __here__ / 'test.cfg'
45
46
        self.assertTrue(self.test_config_path.is_file())
47
48
        self.parser = CliParser('TestParser', 'test.log')
49
50
    def test_socket_defaults_from_config(self):
51
        args = self.parser.parse_args(
52
            ['--config', str(self.test_config_path), 'socket']
53
        )
54
55
        self.assertEqual(args.foo, 'bar')
56
        self.assertEqual(args.timeout, 1000)
57
        self.assertEqual(args.gmp_password, 'bar')
58
        self.assertEqual(args.gmp_username, 'bar')
59
        self.assertEqual(args.socketpath, '/foo/bar.sock')
60
61
    def test_ssh_defaults_from_config(self):
62
        args = self.parser.parse_args(
63
            ['--config', str(self.test_config_path), 'ssh', '--hostname', 'foo']
64
        )
65
66
        self.assertEqual(args.foo, 'bar')
67
        self.assertEqual(args.timeout, 1000)
68
        self.assertEqual(args.gmp_password, 'bar')
69
        self.assertEqual(args.gmp_username, 'bar')
70
        self.assertEqual(args.ssh_password, 'lorem')
71
        self.assertEqual(args.ssh_username, 'ipsum')
72
        self.assertEqual(args.port, 123)
73
74
    def test_tls_defaults_from_config(self):
75
        args = self.parser.parse_args(
76
            ['--config', str(self.test_config_path), 'tls', '--hostname', 'foo']
77
        )
78
79
        self.assertEqual(args.foo, 'bar')
80
        self.assertEqual(args.timeout, 1000)
81
        self.assertEqual(args.gmp_password, 'bar')
82
        self.assertEqual(args.gmp_username, 'bar')
83
        self.assertEqual(args.certfile, 'foo.cert')
84
        self.assertEqual(args.keyfile, 'foo.key')
85
        self.assertEqual(args.cafile, 'foo.ca')
86
        self.assertEqual(args.port, 123)
87
88
    @patch('gvmtools.parser.logger')
89
    @patch('gvmtools.parser.Path')
90
    def test_resolve_file_not_found_error(self, path_mock, logger_mock):
91
        # Making sure that resolve raises an error
92
        def resolve_raises_error():
93
            raise FileNotFoundError()
94
95
        configpath = unittest.mock.MagicMock()
96
        configpath.expanduser().resolve = unittest.mock.MagicMock(
97
            side_effect=resolve_raises_error
98
        )
99
        path_mock.return_value = configpath
100
101
        logger_mock.debug = unittest.mock.MagicMock()
102
103
        args = self.parser.parse_args(['socket'])
104
105
        self.assertIsInstance(args, Namespace)
106
        self.assertEqual(args.connection_type, 'socket')
107
        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
108
        logger_mock.debug.assert_any_call(
109
            'Ignoring non existing config file %s', '~/.config/gvm-tools.conf'
110
        )
111
112
    @patch('gvmtools.parser.Path')
113
    @patch('gvmtools.parser.Config')
114
    def test_config_load_raises_error(self, config_mock, path_mock):
115
        def config_load_error():
116
            raise Exception
117
118
        config = unittest.mock.MagicMock()
119
        config.load = unittest.mock.MagicMock(side_effect=config_load_error)
120
        config_mock.return_value = config
121
122
        # Making sure that the function thinks the config file exists
123
        configpath_exists = unittest.mock.Mock()
124
        configpath_exists.expanduser().resolve().exists = (
125
            unittest.mock.MagicMock(return_value=True)
126
        )
127
        path_mock.return_value = configpath_exists
128
129
        self.assertRaises(RuntimeError, self.parser.parse_args, ['socket'])
130
131
132
class IgnoreConfigParserTestCase(unittest.TestCase):
133
    def test_unkown_config_file(self):
134
        test_config_path = __here__ / 'foo.cfg'
135
136
        self.assertFalse(test_config_path.is_file())
137
138
        self.parser = CliParser('TestParser', 'test.log')
139
140
        args = self.parser.parse_args(
141
            ['--config', str(test_config_path), 'socket']
142
        )
143
144
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
145
        self.assertEqual(args.gmp_password, '')
146
        self.assertEqual(args.gmp_username, '')
147
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
148
149
    def test_unkown_config_file_in_unkown_dir(self):
150
        test_config_path = __here__ / 'foo' / 'foo.cfg'
151
152
        self.assertFalse(test_config_path.is_file())
153
154
        self.parser = CliParser('TestParser', 'test.log')
155
156
        args = self.parser.parse_args(
157
            ['--config', str(test_config_path), 'socket']
158
        )
159
160
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
161
        self.assertEqual(args.gmp_password, '')
162
        self.assertEqual(args.gmp_username, '')
163
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
164
165
166
class ParserTestCase(unittest.TestCase):
167
    def setUp(self):
168
        self.parser = CliParser(
169
            'TestParser', 'test.log', ignore_config=True, prog='gvm-test-cli'
170
        )
171
172
173
class RootArgumentsParserTest(ParserTestCase):
174
    def test_config(self):
175
        args = self.parser.parse_args(['--config', 'foo.cfg', 'socket'])
176
        self.assertEqual(args.config, 'foo.cfg')
177
178
    def test_defaults(self):
179
        args = self.parser.parse_args(['socket'])
180
        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
181
        self.assertEqual(args.gmp_password, '')
182
        self.assertEqual(args.gmp_username, '')
183
        self.assertEqual(args.timeout, 60)
184
        self.assertIsNone(args.loglevel)
185
186
    def test_loglevel(self):
187
        args = self.parser.parse_args(['--log', 'ERROR', 'socket'])
188
        self.assertEqual(args.loglevel, 'ERROR')
189
190
    def test_loglevel_after_subparser(self):
191
        with SuppressOutput(suppress_stderr=True):
192
            with self.assertRaises(SystemExit):
193
                self.parser.parse_args(['socket', '--log', 'ERROR'])
194
195
    def test_timeout(self):
196
        args = self.parser.parse_args(['--timeout', '1000', 'socket'])
197
        self.assertEqual(args.timeout, 1000)
198
199
    def test_timeout_after_subparser(self):
200
        with SuppressOutput(suppress_stderr=True):
201
            with self.assertRaises(SystemExit):
202
                self.parser.parse_args(['socket', '--timeout', '1000'])
203
204
    def test_gmp_username(self):
205
        args = self.parser.parse_args(['--gmp-username', 'foo', 'socket'])
206
        self.assertEqual(args.gmp_username, 'foo')
207
208
    def test_gmp_username_after_subparser(self):
209
        with SuppressOutput(suppress_stderr=True):
210
            with self.assertRaises(SystemExit):
211
                self.parser.parse_args(['socket', '--gmp-username', 'foo'])
212
213
    def test_gmp_password(self):
214
        args = self.parser.parse_args(['--gmp-password', 'foo', 'socket'])
215
        self.assertEqual(args.gmp_password, 'foo')
216
217
    def test_gmp_password_after_subparser(self):
218
        with SuppressOutput(suppress_stderr=True):
219
            with self.assertRaises(SystemExit):
220
                self.parser.parse_args(['socket', '--gmp-password', 'foo'])
221
222
    def test_with_unknown_args(self):
223
        args, script_args = self.parser.parse_known_args(
224
            ['--gmp-password', 'foo', 'socket', '--bar', '--bar2']
225
        )
226
        self.assertEqual(args.gmp_password, 'foo')
227
        self.assertEqual(script_args, ['--bar', '--bar2'])
228
229
    @patch('gvmtools.parser.logging')
230
    def test_socket_has_no_timeout(self, _logging_mock):
231
        # pylint: disable=protected-access
232
        args_mock = unittest.mock.MagicMock()
233
        args_mock.timeout = -1
234
        self.parser._parser.parse_known_args = unittest.mock.MagicMock(
235
            return_value=(args_mock, unittest.mock.MagicMock())
236
        )
237
238
        args, _ = self.parser.parse_known_args(
239
            ['socket', '--timeout', '--', '-1']
240
        )
241
242
        self.assertIsNone(args.timeout)
243
244
    @patch('gvmtools.parser.logging')
245
    @patch('gvmtools.parser.argparse.ArgumentParser.print_usage')
246
    @patch('gvmtools.parser.argparse.ArgumentParser._print_message')
247
    def test_no_args_provided(
248
        self, _logging_mock, _print_usage_mock, _print_message
249
    ):
250
        # pylint: disable=protected-access
251
        self.parser._set_defaults = unittest.mock.MagicMock()
252
253
        self.assertRaises(SystemExit, self.parser.parse_known_args, None)
254
255
256
class SocketParserTestCase(ParserTestCase):
257
    def test_defaults(self):
258
        args = self.parser.parse_args(['socket'])
259
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
260
261
    def test_connection_type(self):
262
        args = self.parser.parse_args(['socket'])
263
        self.assertEqual(args.connection_type, 'socket')
264
265
    def test_sockpath(self):
266
        args = self.parser.parse_args(['socket', '--sockpath', 'foo.sock'])
267
        self.assertEqual(args.socketpath, 'foo.sock')
268
269
    def test_socketpath(self):
270
        args = self.parser.parse_args(['socket', '--socketpath', 'foo.sock'])
271
        self.assertEqual(args.socketpath, 'foo.sock')
272
273
274
class SshParserTestCase(ParserTestCase):
275
    def test_defaults(self):
276
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
277
        self.assertEqual(args.port, 22)
278
        self.assertEqual(args.ssh_username, 'gmp')
279
        self.assertEqual(args.ssh_password, 'gmp')
280
281
    def test_connection_type(self):
282
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
283
        self.assertEqual(args.connection_type, 'ssh')
284
285
    def test_hostname(self):
286
        args = self.parser.parse_args(['ssh', '--hostname', 'foo'])
287
        self.assertEqual(args.hostname, 'foo')
288
289
    def test_port(self):
290
        args = self.parser.parse_args(
291
            ['ssh', '--hostname', 'foo', '--port', '123']
292
        )
293
        self.assertEqual(args.port, 123)
294
295
    def test_ssh_username(self):
296
        args = self.parser.parse_args(
297
            ['ssh', '--hostname', 'foo', '--ssh-username', 'foo']
298
        )
299
        self.assertEqual(args.ssh_username, 'foo')
300
301
    def test_ssh_password(self):
302
        args = self.parser.parse_args(
303
            ['ssh', '--hostname', 'foo', '--ssh-password', 'foo']
304
        )
305
        self.assertEqual(args.ssh_password, 'foo')
306
307
308
class TlsParserTestCase(ParserTestCase):
309
    def test_defaults(self):
310
        args = self.parser.parse_args(['tls', '--hostname=foo'])
311
        self.assertIsNone(args.certfile)
312
        self.assertIsNone(args.keyfile)
313
        self.assertIsNone(args.cafile)
314
        self.assertEqual(args.port, 9390)
315
316
    def test_connection_type(self):
317
        args = self.parser.parse_args(['tls', '--hostname=foo'])
318
        self.assertEqual(args.connection_type, 'tls')
319
320
    def test_hostname(self):
321
        args = self.parser.parse_args(['tls', '--hostname', 'foo'])
322
        self.assertEqual(args.hostname, 'foo')
323
324
    def test_port(self):
325
        args = self.parser.parse_args(
326
            ['tls', '--hostname', 'foo', '--port', '123']
327
        )
328
        self.assertEqual(args.port, 123)
329
330
    def test_certfile(self):
331
        args = self.parser.parse_args(
332
            ['tls', '--hostname', 'foo', '--certfile', 'foo.cert']
333
        )
334
        self.assertEqual(args.certfile, 'foo.cert')
335
336
    def test_keyfile(self):
337
        args = self.parser.parse_args(
338
            ['tls', '--hostname', 'foo', '--keyfile', 'foo.key']
339
        )
340
        self.assertEqual(args.keyfile, 'foo.key')
341
342
    def test_cafile(self):
343
        args = self.parser.parse_args(
344
            ['tls', '--hostname', 'foo', '--cafile', 'foo.ca']
345
        )
346
        self.assertEqual(args.cafile, 'foo.ca')
347
348
    def test_no_credentials(self):
349
        args = self.parser.parse_args(
350
            ['tls', '--hostname', 'foo', '--no-credentials']
351
        )
352
        self.assertTrue(args.no_credentials)
353
354
355
class CustomizeParserTestCase(ParserTestCase):
356
    def test_add_optional_argument(self):
357
        self.parser.add_argument('--foo', type=int)
358
359
        args = self.parser.parse_args(['socket', '--foo', '123'])
360
        self.assertEqual(args.foo, 123)
361
362
        args = self.parser.parse_args(
363
            ['ssh', '--hostname', 'bar', '--foo', '123']
364
        )
365
        self.assertEqual(args.foo, 123)
366
367
        args = self.parser.parse_args(
368
            ['tls', '--hostname', 'bar', '--foo', '123']
369
        )
370
        self.assertEqual(args.foo, 123)
371
372
    def test_add_positional_argument(self):
373
        self.parser.add_argument('foo', type=int)
374
        args = self.parser.parse_args(['socket', '123'])
375
376
        self.assertEqual(args.foo, 123)
377
378
    def test_add_protocol_argument(self):
379
        self.parser.add_protocol_argument()
380
381
        args = self.parser.parse_args(['socket'])
382
        self.assertEqual(args.protocol, 'GMP')
383
384
        args = self.parser.parse_args(['--protocol', 'OSP', 'socket'])
385
386
        self.assertEqual(args.protocol, 'OSP')
387
388
389
class HelpFormattingParserTestCase(ParserTestCase):
390
    # pylint: disable=protected-access
391
    maxDiff = None
392
    python_version = '.'.join([str(i) for i in sys.version_info[:2]])
393
394
    def setUp(self):
395
        super().setUp()
396
397
        # ensure all tests are using the same terminal width
398
        self.columns = os.environ.get('COLUMNS')
399
        os.environ['COLUMNS'] = '80'
400
401
    def tearDown(self):
402
        super().tearDown()
403
404
        if not self.columns:
405
            del os.environ['COLUMNS']
406
        else:
407
            os.environ['COLUMNS'] = self.columns
408
409
    def _snapshot_specific_path(self, name):
410
        return __here__ / '{}.{}.snap'.format(name, self.python_version)
411
412
    def _snapshot_generic_path(self, name):
413
        return __here__ / '{}.snap'.format(name)
414
415
    def _snapshot_failed_path(self, name):
416
        return __here__ / '{}.{}-failed.snap'.format(name, self.python_version)
417
418
    def _snapshot_path(self, name):
419
        snapshot_specific_path = self._snapshot_specific_path(name)
420
421
        if snapshot_specific_path.exists():
422
            return snapshot_specific_path
423
424
        return self._snapshot_generic_path(name)
425
426
    def assert_snapshot(self, name, output):
427
        path = self._snapshot_path(name)
428
429
        if not path.exists():
430
            path.write_text(output)
431
432
        content = path.read_text()
433
434
        try:
435
            self.assertEqual(content, output, 'Snapshot differs from output')
436
        except AssertionError:
437
            # write new output to snapshot file
438
            # reraise error afterwards
439
            path = self._snapshot_failed_path(name)
440
            path.write_text(output)
441
            raise
442
443
    def test_root_help(self):
444
        help_output = self.parser._parser.format_help()
445
        self.assert_snapshot('root_help', help_output)
446
447
    def test_socket_help(self):
448
        help_output = self.parser._parser_socket.format_help()
449
        self.assert_snapshot('socket_help', help_output)
450
451
    def test_ssh_help(self):
452
        self.parser._set_defaults(None)
453
        help_output = self.parser._parser_ssh.format_help()
454
        self.assert_snapshot('ssh_help', help_output)
455
456
    def test_tls_help(self):
457
        self.parser._set_defaults(None)
458
        help_output = self.parser._parser_tls.format_help()
459
        self.assert_snapshot('tls_help', help_output)
460
461
462
class CreateParserFunctionTestCase(unittest.TestCase):
463
    # pylint: disable=protected-access
464
    def test_create_parser(self):
465
        description = 'parser description'
466
        logfilename = 'logfilename'
467
468
        parser = create_parser(description, logfilename)
469
470
        self.assertIsInstance(parser, CliParser)
471
        self.assertEqual(parser._logfilename, logfilename)
472
        self.assertEqual(parser._bootstrap_parser.description, description)
473
474
475
class CreateConnectionTestCase(unittest.TestCase):
476
    def test_create_unix_socket_connection(self):
477
        self.perform_create_connection_test()
478
479
    def test_create_tls_connection(self):
480
        self.perform_create_connection_test('tls', TLSConnection)
481
482
    def test_create_ssh_connection(self):
483
        self.perform_create_connection_test('ssh', SSHConnection, 22)
484
485
    def perform_create_connection_test(
486
        self,
487
        connection_type='socket',
488
        connection_class=UnixSocketConnection,
489
        port=None,
490
    ):
491
        connection = create_connection(connection_type, port=port)
492
        self.assertIsInstance(connection, connection_class)
493