Passed
Pull Request — master (#311)
by
unknown
01:15
created

RootArgumentsParserTest.test_gmp_password_after_subparser()   A

Complexity

Conditions 3

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 3
nop 1
1
# -*- coding: utf-8 -*-
2
# Copyright (C) 2019 Greenbone Networks GmbH
3
#
4
# SPDX-License-Identifier: GPL-3.0-or-later
5
#
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19
import os
20
import sys
21
import unittest
22
23
from unittest.mock import patch
24
from pathlib import Path
25
26
from argparse import Namespace
27
from gvm.connections import (
28
    DEFAULT_UNIX_SOCKET_PATH,
29
    DEFAULT_TIMEOUT,
30
    UnixSocketConnection,
31
    TLSConnection,
32
    SSHConnection,
33
)
34
35
from gvmtools.parser import CliParser, create_parser, create_connection
36
37
from . import SuppressOutput
38
39
__here__ = Path(__file__).parent.resolve()
40
41
42
class ConfigParserTestCase(unittest.TestCase):
43
    def setUp(self):
44
        self.test_config_path = __here__ / 'test.cfg'
45
46
        self.assertTrue(self.test_config_path.is_file())
47
48
        self.parser = CliParser('TestParser', 'test.log')
49
50
    def test_socket_defaults_from_config(self):
51
        args = self.parser.parse_args(
52
            ['--config', str(self.test_config_path), 'socket']
53
        )
54
55
        self.assertEqual(args.foo, 'bar')
56
        self.assertEqual(args.timeout, 1000)
57
        self.assertEqual(args.gmp_password, 'bar')
58
        self.assertEqual(args.gmp_username, 'bar')
59
        self.assertEqual(args.socketpath, '/foo/bar.sock')
60
61
    def test_ssh_defaults_from_config(self):
62
        args = self.parser.parse_args(
63
            ['--config', str(self.test_config_path), 'ssh', '--hostname', 'foo']
64
        )
65
66
        self.assertEqual(args.foo, 'bar')
67
        self.assertEqual(args.timeout, 1000)
68
        self.assertEqual(args.gmp_password, 'bar')
69
        self.assertEqual(args.gmp_username, 'bar')
70
        self.assertEqual(args.ssh_password, 'lorem')
71
        self.assertEqual(args.ssh_username, 'ipsum')
72
        self.assertEqual(args.port, 123)
73
74
    def test_tls_defaults_from_config(self):
75
        args = self.parser.parse_args(
76
            ['--config', str(self.test_config_path), 'tls', '--hostname', 'foo']
77
        )
78
79
        self.assertEqual(args.foo, 'bar')
80
        self.assertEqual(args.timeout, 1000)
81
        self.assertEqual(args.gmp_password, 'bar')
82
        self.assertEqual(args.gmp_username, 'bar')
83
        self.assertEqual(args.certfile, 'foo.cert')
84
        self.assertEqual(args.keyfile, 'foo.key')
85
        self.assertEqual(args.cafile, 'foo.ca')
86
        self.assertEqual(args.port, 123)
87
88
    @patch('gvmtools.parser.logger')
89
    @patch('gvmtools.parser.Path')
90
    def test_resolve_file_not_found_error(self, path_mock, logger_mock):
91
        # Making sure that resolve raises an error
92
        def resolve_raises_error():
93
            raise FileNotFoundError()
94
95
        configpath = unittest.mock.MagicMock()
96
        configpath.expanduser().resolve = unittest.mock.MagicMock(
97
            side_effect=resolve_raises_error
98
        )
99
        path_mock.return_value = configpath
100
101
        logger_mock.debug = unittest.mock.MagicMock()
102
103
        args = self.parser.parse_args(['socket'])
104
105
        self.assertIsInstance(args, Namespace)
106
        self.assertEqual(args.connection_type, 'socket')
107
        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
108
        logger_mock.debug.assert_any_call(
109
            'Ignoring non existing config file %s', '~/.config/gvm-tools.conf'
110
        )
111
112
    @patch('gvmtools.parser.Path')
113
    @patch('gvmtools.parser.Config')
114
    def test_config_load_raises_error(self, config_mock, path_mock):
115
        def config_load_error():
116
            raise Exception
117
118
        config = unittest.mock.MagicMock()
119
        config.load = unittest.mock.MagicMock(side_effect=config_load_error)
120
        config_mock.return_value = config
121
122
        # Making sure that the function thinks the config file exists
123
        configpath_exists = unittest.mock.Mock()
124
        configpath_exists.expanduser().resolve().exists = (
125
            unittest.mock.MagicMock(return_value=True)
126
        )
127
        path_mock.return_value = configpath_exists
128
129
        self.assertRaises(RuntimeError, self.parser.parse_args, ['socket'])
130
131
132
class IgnoreConfigParserTestCase(unittest.TestCase):
133
    def test_unkown_config_file(self):
134
        test_config_path = __here__ / 'foo.cfg'
135
136
        self.assertFalse(test_config_path.is_file())
137
138
        self.parser = CliParser('TestParser', 'test.log')
139
140
        args = self.parser.parse_args(
141
            ['--config', str(test_config_path), 'socket']
142
        )
143
144
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
145
        self.assertEqual(args.gmp_password, '')
146
        self.assertEqual(args.gmp_username, '')
147
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
148
149
    def test_unkown_config_file_in_unkown_dir(self):
150
        test_config_path = __here__ / 'foo' / 'foo.cfg'
151
152
        self.assertFalse(test_config_path.is_file())
153
154
        self.parser = CliParser('TestParser', 'test.log')
155
156
        args = self.parser.parse_args(
157
            ['--config', str(test_config_path), 'socket']
158
        )
159
160
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
161
        self.assertEqual(args.gmp_password, '')
162
        self.assertEqual(args.gmp_username, '')
163
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
164
165
166
class ParserTestCase(unittest.TestCase):
167
    def setUp(self):
168
        self.parser = CliParser(
169
            'TestParser', 'test.log', ignore_config=True, prog='gvm-test-cli'
170
        )
171
172
173
class RootArgumentsParserTest(ParserTestCase):
174
    def test_config(self):
175
        args = self.parser.parse_args(['--config', 'foo.cfg', 'socket'])
176
        self.assertEqual(args.config, 'foo.cfg')
177
178
    def test_defaults(self):
179
        args = self.parser.parse_args(['socket'])
180
        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
181
        self.assertEqual(args.gmp_password, '')
182
        self.assertEqual(args.gmp_username, '')
183
        self.assertEqual(args.timeout, 60)
184
        self.assertIsNone(args.loglevel)
185
186
    def test_loglevel(self):
187
        args = self.parser.parse_args(['--log', 'ERROR', 'socket'])
188
        self.assertEqual(args.loglevel, 'ERROR')
189
190
    def test_loglevel_after_subparser(self):
191
        with SuppressOutput(suppress_stderr=True):
192
            with self.assertRaises(SystemExit):
193
                self.parser.parse_args(['socket', '--log', 'ERROR'])
194
195
    def test_timeout(self):
196
        args = self.parser.parse_args(['--timeout', '1000', 'socket'])
197
        self.assertEqual(args.timeout, 1000)
198
199
    def test_timeout_after_subparser(self):
200
        with SuppressOutput(suppress_stderr=True):
201
            with self.assertRaises(SystemExit):
202
                self.parser.parse_args(['socket', '--timeout', '1000'])
203
204
    def test_gmp_username(self):
205
        args = self.parser.parse_args(['--gmp-username', 'foo', 'socket'])
206
        self.assertEqual(args.gmp_username, 'foo')
207
208
    def test_gmp_username_after_subparser(self):
209
        with SuppressOutput(suppress_stderr=True):
210
            with self.assertRaises(SystemExit):
211
                self.parser.parse_args(['socket', '--gmp-username', 'foo'])
212
213
    def test_gmp_password(self):
214
        args = self.parser.parse_args(['--gmp-password', 'foo', 'socket'])
215
        self.assertEqual(args.gmp_password, 'foo')
216
217
    def test_gmp_password_after_subparser(self):
218
        with SuppressOutput(suppress_stderr=True):
219
            with self.assertRaises(SystemExit):
220
                self.parser.parse_args(['socket', '--gmp-password', 'foo'])
221
222
    def test_with_unknown_args(self):
223
        args, script_args = self.parser.parse_known_args(
224
            ['--gmp-password', 'foo', 'socket', '--bar', '--bar2']
225
        )
226
        self.assertEqual(args.gmp_password, 'foo')
227
        self.assertEqual(script_args, ['--bar', '--bar2'])
228
229
    @patch('gvmtools.parser.logging')
230
    def test_socket_has_no_timeout(self, _logging_mock):
231
        # pylint: disable=protected-access
232
        self.parser._parser = unittest.mock.MagicMock()
233
        args_mock = unittest.mock.MagicMock()
234
        args_mock.timeout = -1
235
        self.parser._parser.parse_known_args = unittest.mock.MagicMock(
236
            return_value=(args_mock, unittest.mock.MagicMock())
237
        )
238
239
        args, _ = self.parser.parse_known_args(
240
            ['socket', '--timeout', '--', '-1']
241
        )
242
243
        self.assertIsNone(args.timeout)
244
245
    @patch('gvmtools.parser.logging')
246
    @patch('gvmtools.parser.argparse.ArgumentParser.print_usage')
247
    @patch('gvmtools.parser.argparse.ArgumentParser._print_message')
248
    def test_no_args_provided(
249
        self, _logging_mock, _print_usage_mock, _print_message
250
    ):
251
        # pylint: disable=protected-access
252
        self.parser._set_defaults = unittest.mock.MagicMock()
253
254
        self.assertRaises(SystemExit, self.parser.parse_known_args, None)
255
256
257
class SocketParserTestCase(ParserTestCase):
258
    def test_defaults(self):
259
        args = self.parser.parse_args(['socket'])
260
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
261
262
    def test_connection_type(self):
263
        args = self.parser.parse_args(['socket'])
264
        self.assertEqual(args.connection_type, 'socket')
265
266
    def test_sockpath(self):
267
        args = self.parser.parse_args(['socket', '--sockpath', 'foo.sock'])
268
        self.assertEqual(args.socketpath, 'foo.sock')
269
270
    def test_socketpath(self):
271
        args = self.parser.parse_args(['socket', '--socketpath', 'foo.sock'])
272
        self.assertEqual(args.socketpath, 'foo.sock')
273
274
275
class SshParserTestCase(ParserTestCase):
276
    def test_defaults(self):
277
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
278
        self.assertEqual(args.port, 22)
279
        self.assertEqual(args.ssh_username, 'gmp')
280
        self.assertEqual(args.ssh_password, 'gmp')
281
282
    def test_connection_type(self):
283
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
284
        self.assertEqual(args.connection_type, 'ssh')
285
286
    def test_hostname(self):
287
        args = self.parser.parse_args(['ssh', '--hostname', 'foo'])
288
        self.assertEqual(args.hostname, 'foo')
289
290
    def test_port(self):
291
        args = self.parser.parse_args(
292
            ['ssh', '--hostname', 'foo', '--port', '123']
293
        )
294
        self.assertEqual(args.port, 123)
295
296
    def test_ssh_username(self):
297
        args = self.parser.parse_args(
298
            ['ssh', '--hostname', 'foo', '--ssh-username', 'foo']
299
        )
300
        self.assertEqual(args.ssh_username, 'foo')
301
302
    def test_ssh_password(self):
303
        args = self.parser.parse_args(
304
            ['ssh', '--hostname', 'foo', '--ssh-password', 'foo']
305
        )
306
        self.assertEqual(args.ssh_password, 'foo')
307
308
309
class TlsParserTestCase(ParserTestCase):
310
    def test_defaults(self):
311
        args = self.parser.parse_args(['tls', '--hostname=foo'])
312
        self.assertIsNone(args.certfile)
313
        self.assertIsNone(args.keyfile)
314
        self.assertIsNone(args.cafile)
315
        self.assertEqual(args.port, 9390)
316
317
    def test_connection_type(self):
318
        args = self.parser.parse_args(['tls', '--hostname=foo'])
319
        self.assertEqual(args.connection_type, 'tls')
320
321
    def test_hostname(self):
322
        args = self.parser.parse_args(['tls', '--hostname', 'foo'])
323
        self.assertEqual(args.hostname, 'foo')
324
325
    def test_port(self):
326
        args = self.parser.parse_args(
327
            ['tls', '--hostname', 'foo', '--port', '123']
328
        )
329
        self.assertEqual(args.port, 123)
330
331
    def test_certfile(self):
332
        args = self.parser.parse_args(
333
            ['tls', '--hostname', 'foo', '--certfile', 'foo.cert']
334
        )
335
        self.assertEqual(args.certfile, 'foo.cert')
336
337
    def test_keyfile(self):
338
        args = self.parser.parse_args(
339
            ['tls', '--hostname', 'foo', '--keyfile', 'foo.key']
340
        )
341
        self.assertEqual(args.keyfile, 'foo.key')
342
343
    def test_cafile(self):
344
        args = self.parser.parse_args(
345
            ['tls', '--hostname', 'foo', '--cafile', 'foo.ca']
346
        )
347
        self.assertEqual(args.cafile, 'foo.ca')
348
349
    def test_no_credentials(self):
350
        args = self.parser.parse_args(
351
            ['tls', '--hostname', 'foo', '--no-credentials']
352
        )
353
        self.assertTrue(args.no_credentials)
354
355
356
class CustomizeParserTestCase(ParserTestCase):
357
    def test_add_optional_argument(self):
358
        self.parser.add_argument('--foo', type=int)
359
360
        args = self.parser.parse_args(['socket', '--foo', '123'])
361
        self.assertEqual(args.foo, 123)
362
363
        args = self.parser.parse_args(
364
            ['ssh', '--hostname', 'bar', '--foo', '123']
365
        )
366
        self.assertEqual(args.foo, 123)
367
368
        args = self.parser.parse_args(
369
            ['tls', '--hostname', 'bar', '--foo', '123']
370
        )
371
        self.assertEqual(args.foo, 123)
372
373
    def test_add_positional_argument(self):
374
        self.parser.add_argument('foo', type=int)
375
        args = self.parser.parse_args(['socket', '123'])
376
377
        self.assertEqual(args.foo, 123)
378
379
    def test_add_protocol_argument(self):
380
        self.parser.add_protocol_argument()
381
382
        args = self.parser.parse_args(['socket'])
383
        self.assertEqual(args.protocol, 'GMP')
384
385
        args = self.parser.parse_args(['--protocol', 'OSP', 'socket'])
386
387
        self.assertEqual(args.protocol, 'OSP')
388
389
390
class HelpFormattingParserTestCase(ParserTestCase):
391
    # pylint: disable=protected-access
392
    maxDiff = None
393
    python_version = '.'.join([str(i) for i in sys.version_info[:2]])
394
395
    def setUp(self):
396
        super().setUp()
397
398
        # ensure all tests are using the same terminal width
399
        self.columns = os.environ.get('COLUMNS')
400
        os.environ['COLUMNS'] = '80'
401
402
    def tearDown(self):
403
        super().tearDown()
404
405
        if not self.columns:
406
            del os.environ['COLUMNS']
407
        else:
408
            os.environ['COLUMNS'] = self.columns
409
410
    def _snapshot_specific_path(self, name):
411
        return __here__ / '{}.{}.snap'.format(name, self.python_version)
412
413
    def _snapshot_generic_path(self, name):
414
        return __here__ / '{}.snap'.format(name)
415
416
    def _snapshot_failed_path(self, name):
417
        return __here__ / '{}.{}-failed.snap'.format(name, self.python_version)
418
419
    def _snapshot_path(self, name):
420
        snapshot_specific_path = self._snapshot_specific_path(name)
421
422
        if snapshot_specific_path.exists():
423
            return snapshot_specific_path
424
425
        return self._snapshot_generic_path(name)
426
427
    def assert_snapshot(self, name, output):
428
        path = self._snapshot_path(name)
429
430
        if not path.exists():
431
            path.write_text(output)
432
433
        content = path.read_text()
434
435
        try:
436
            self.assertEqual(content, output, 'Snapshot differs from output')
437
        except AssertionError:
438
            # write new output to snapshot file
439
            # reraise error afterwards
440
            path = self._snapshot_failed_path(name)
441
            path.write_text(output)
442
            raise
443
444
    def test_root_help(self):
445
        help_output = self.parser._parser.format_help()
446
        self.assert_snapshot('root_help', help_output)
447
448
    def test_socket_help(self):
449
        help_output = self.parser._parser_socket.format_help()
450
        self.assert_snapshot('socket_help', help_output)
451
452
    def test_ssh_help(self):
453
        self.parser._set_defaults(None)
454
        help_output = self.parser._parser_ssh.format_help()
455
        self.assert_snapshot('ssh_help', help_output)
456
457
    def test_tls_help(self):
458
        self.parser._set_defaults(None)
459
        help_output = self.parser._parser_tls.format_help()
460
        self.assert_snapshot('tls_help', help_output)
461
462
463
class CreateParserFunctionTestCase(unittest.TestCase):
464
    # pylint: disable=protected-access
465
    def test_create_parser(self):
466
        description = 'parser description'
467
        logfilename = 'logfilename'
468
469
        parser = create_parser(description, logfilename)
470
471
        self.assertIsInstance(parser, CliParser)
472
        self.assertEqual(parser._logfilename, logfilename)
473
        self.assertEqual(parser._bootstrap_parser.description, description)
474
475
476
class CreateConnectionTestCase(unittest.TestCase):
477
    def test_create_unix_socket_connection(self):
478
        self.perform_create_connection_test()
479
480
    def test_create_tls_connection(self):
481
        self.perform_create_connection_test('tls', TLSConnection)
482
483
    def test_create_ssh_connection(self):
484
        self.perform_create_connection_test('ssh', SSHConnection)
485
486
    def perform_create_connection_test(
487
        self, connection_type='socket', connection_class=UnixSocketConnection
488
    ):
489
        connection = create_connection(connection_type)
490
        self.assertIsInstance(connection, connection_class)
491