Passed
Pull Request — master (#311)
by
unknown
01:17
created

CreateConnectionTestCase.test_create_ssh_connection()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# Copyright (C) 2019 Greenbone Networks GmbH
3
#
4
# SPDX-License-Identifier: GPL-3.0-or-later
5
#
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19
import os
20
import sys
21
import unittest
22
23
from unittest.mock import patch
24
from pathlib import Path
25
26
from gvm.connections import (
27
    DEFAULT_UNIX_SOCKET_PATH,
28
    DEFAULT_TIMEOUT,
29
    UnixSocketConnection,
30
    TLSConnection,
31
    SSHConnection,
32
)
33
34
from gvmtools.parser import CliParser, create_parser, create_connection
35
from gvmtools.config import Config
36
37
from . import SuppressOutput
38
39
__here__ = Path(__file__).parent.resolve()
40
41
42
class ConfigParserTestCase(unittest.TestCase):
43
    def setUp(self):
44
        self.test_config_path = __here__ / 'test.cfg'
45
46
        self.assertTrue(self.test_config_path.is_file())
47
48
        self.parser = CliParser('TestParser', 'test.log')
49
50
    def test_socket_defaults_from_config(self):
51
        args = self.parser.parse_args(
52
            ['--config', str(self.test_config_path), 'socket']
53
        )
54
55
        self.assertEqual(args.foo, 'bar')
56
        self.assertEqual(args.timeout, 1000)
57
        self.assertEqual(args.gmp_password, 'bar')
58
        self.assertEqual(args.gmp_username, 'bar')
59
        self.assertEqual(args.socketpath, '/foo/bar.sock')
60
61
    def test_ssh_defaults_from_config(self):
62
        args = self.parser.parse_args(
63
            ['--config', str(self.test_config_path), 'ssh', '--hostname', 'foo']
64
        )
65
66
        self.assertEqual(args.foo, 'bar')
67
        self.assertEqual(args.timeout, 1000)
68
        self.assertEqual(args.gmp_password, 'bar')
69
        self.assertEqual(args.gmp_username, 'bar')
70
        self.assertEqual(args.ssh_password, 'lorem')
71
        self.assertEqual(args.ssh_username, 'ipsum')
72
        self.assertEqual(args.port, 123)
73
74
    def test_tls_defaults_from_config(self):
75
        args = self.parser.parse_args(
76
            ['--config', str(self.test_config_path), 'tls', '--hostname', 'foo']
77
        )
78
79
        self.assertEqual(args.foo, 'bar')
80
        self.assertEqual(args.timeout, 1000)
81
        self.assertEqual(args.gmp_password, 'bar')
82
        self.assertEqual(args.gmp_username, 'bar')
83
        self.assertEqual(args.certfile, 'foo.cert')
84
        self.assertEqual(args.keyfile, 'foo.key')
85
        self.assertEqual(args.cafile, 'foo.ca')
86
        self.assertEqual(args.port, 123)
87
88
    @patch('gvmtools.parser.Path')
89
    def test_resolve_file_not_found_error(self, path_mock):
90
        def resolve_raises_error():
91
            raise FileNotFoundError
92
93
        configpath = unittest.mock.MagicMock()
94
        configpath.expanduser().resolve = unittest.mock.MagicMock(
95
            side_effect=resolve_raises_error
96
        )
97
        path_mock.return_value = configpath
98
99
        # pylint: disable=protected-access
100
        return_value = self.parser._load_config('foobar')
101
102
        self.assertIsInstance(return_value, Config)
103
104
    @patch('gvmtools.parser.Path')
105
    @patch('gvmtools.parser.Config')
106
    def test_config_load_raises_error(self, config_mock, path_mock):
107
        def config_load_error():
108
            raise Exception
109
110
        config = unittest.mock.MagicMock()
111
        config.load = unittest.mock.MagicMock(side_effect=config_load_error)
112
        config_mock.return_value = config
113
114
        configpath = unittest.mock.Mock()
115
        configpath.expanduser().resolve().exists = unittest.mock.MagicMock(
116
            return_value=True
117
        )
118
        path_mock.return_value = configpath
119
120
        configfile = 'configfile'
121
122
        # pylint: disable=protected-access
123
        self.assertRaises(RuntimeError, self.parser._load_config, configfile)
124
125
126
class IgnoreConfigParserTestCase(unittest.TestCase):
127
    def test_unkown_config_file(self):
128
        test_config_path = __here__ / 'foo.cfg'
129
130
        self.assertFalse(test_config_path.is_file())
131
132
        self.parser = CliParser('TestParser', 'test.log')
133
134
        args = self.parser.parse_args(
135
            ['--config', str(test_config_path), 'socket']
136
        )
137
138
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
139
        self.assertEqual(args.gmp_password, '')
140
        self.assertEqual(args.gmp_username, '')
141
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
142
143
    def test_unkown_config_file_in_unkown_dir(self):
144
        test_config_path = __here__ / 'foo' / 'foo.cfg'
145
146
        self.assertFalse(test_config_path.is_file())
147
148
        self.parser = CliParser('TestParser', 'test.log')
149
150
        args = self.parser.parse_args(
151
            ['--config', str(test_config_path), 'socket']
152
        )
153
154
        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
155
        self.assertEqual(args.gmp_password, '')
156
        self.assertEqual(args.gmp_username, '')
157
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
158
159
160
class ParserTestCase(unittest.TestCase):
161
    def setUp(self):
162
        self.parser = CliParser(
163
            'TestParser', 'test.log', ignore_config=True, prog='gvm-test-cli'
164
        )
165
166
167
class RootArgumentsParserTest(ParserTestCase):
168
    def test_config(self):
169
        args = self.parser.parse_args(['--config', 'foo.cfg', 'socket'])
170
        self.assertEqual(args.config, 'foo.cfg')
171
172
    def test_defaults(self):
173
        args = self.parser.parse_args(['socket'])
174
        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
175
        self.assertEqual(args.gmp_password, '')
176
        self.assertEqual(args.gmp_username, '')
177
        self.assertEqual(args.timeout, 60)
178
        self.assertIsNone(args.loglevel)
179
180
    def test_loglevel(self):
181
        args = self.parser.parse_args(['--log', 'ERROR', 'socket'])
182
        self.assertEqual(args.loglevel, 'ERROR')
183
184
    def test_loglevel_after_subparser(self):
185
        with SuppressOutput(suppress_stderr=True):
186
            with self.assertRaises(SystemExit):
187
                self.parser.parse_args(['socket', '--log', 'ERROR'])
188
189
    def test_timeout(self):
190
        args = self.parser.parse_args(['--timeout', '1000', 'socket'])
191
        self.assertEqual(args.timeout, 1000)
192
193
    def test_timeout_after_subparser(self):
194
        with SuppressOutput(suppress_stderr=True):
195
            with self.assertRaises(SystemExit):
196
                self.parser.parse_args(['socket', '--timeout', '1000'])
197
198
    def test_gmp_username(self):
199
        args = self.parser.parse_args(['--gmp-username', 'foo', 'socket'])
200
        self.assertEqual(args.gmp_username, 'foo')
201
202
    def test_gmp_username_after_subparser(self):
203
        with SuppressOutput(suppress_stderr=True):
204
            with self.assertRaises(SystemExit):
205
                self.parser.parse_args(['socket', '--gmp-username', 'foo'])
206
207
    def test_gmp_password(self):
208
        args = self.parser.parse_args(['--gmp-password', 'foo', 'socket'])
209
        self.assertEqual(args.gmp_password, 'foo')
210
211
    def test_gmp_password_after_subparser(self):
212
        with SuppressOutput(suppress_stderr=True):
213
            with self.assertRaises(SystemExit):
214
                self.parser.parse_args(['socket', '--gmp-password', 'foo'])
215
216
    def test_with_unknown_args(self):
217
        args, script_args = self.parser.parse_known_args(
218
            ['--gmp-password', 'foo', 'socket', '--bar', '--bar2']
219
        )
220
        self.assertEqual(args.gmp_password, 'foo')
221
        self.assertEqual(script_args, ['--bar', '--bar2'])
222
223
    @patch('gvmtools.parser.logging')
224
    def test_socket_has_no_timeout(self, _logging_mock):
225
        # pylint: disable=protected-access
226
        self.parser._parser = unittest.mock.MagicMock()
227
        args_mock = unittest.mock.MagicMock()
228
        args_mock.timeout = -1
229
        self.parser._parser.parse_known_args = unittest.mock.MagicMock(
230
            return_value=(args_mock, unittest.mock.MagicMock())
231
        )
232
233
        args, _ = self.parser.parse_known_args(
234
            ['socket', '--timeout', '--', '-1']
235
        )
236
237
        self.assertIsNone(args.timeout)
238
239
    @patch('gvmtools.parser.logging')
240
    @patch('gvmtools.parser.argparse.ArgumentParser.print_usage')
241
    @patch('gvmtools.parser.argparse.ArgumentParser._print_message')
242
    def test_no_args_provided(
243
        self, _logging_mock, _print_usage_mock, _print_message
244
    ):
245
        # pylint: disable=protected-access
246
        self.parser._set_defaults = unittest.mock.MagicMock()
247
248
        self.assertRaises(SystemExit, self.parser.parse_known_args, None)
249
250
251
class SocketParserTestCase(ParserTestCase):
252
    def test_defaults(self):
253
        args = self.parser.parse_args(['socket'])
254
        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
255
256
    def test_connection_type(self):
257
        args = self.parser.parse_args(['socket'])
258
        self.assertEqual(args.connection_type, 'socket')
259
260
    def test_sockpath(self):
261
        args = self.parser.parse_args(['socket', '--sockpath', 'foo.sock'])
262
        self.assertEqual(args.socketpath, 'foo.sock')
263
264
    def test_socketpath(self):
265
        args = self.parser.parse_args(['socket', '--socketpath', 'foo.sock'])
266
        self.assertEqual(args.socketpath, 'foo.sock')
267
268
269
class SshParserTestCase(ParserTestCase):
270
    def test_defaults(self):
271
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
272
        self.assertEqual(args.port, 22)
273
        self.assertEqual(args.ssh_username, 'gmp')
274
        self.assertEqual(args.ssh_password, 'gmp')
275
276
    def test_connection_type(self):
277
        args = self.parser.parse_args(['ssh', '--hostname=foo'])
278
        self.assertEqual(args.connection_type, 'ssh')
279
280
    def test_hostname(self):
281
        args = self.parser.parse_args(['ssh', '--hostname', 'foo'])
282
        self.assertEqual(args.hostname, 'foo')
283
284
    def test_port(self):
285
        args = self.parser.parse_args(
286
            ['ssh', '--hostname', 'foo', '--port', '123']
287
        )
288
        self.assertEqual(args.port, 123)
289
290
    def test_ssh_username(self):
291
        args = self.parser.parse_args(
292
            ['ssh', '--hostname', 'foo', '--ssh-username', 'foo']
293
        )
294
        self.assertEqual(args.ssh_username, 'foo')
295
296
    def test_ssh_password(self):
297
        args = self.parser.parse_args(
298
            ['ssh', '--hostname', 'foo', '--ssh-password', 'foo']
299
        )
300
        self.assertEqual(args.ssh_password, 'foo')
301
302
303
class TlsParserTestCase(ParserTestCase):
304
    def test_defaults(self):
305
        args = self.parser.parse_args(['tls', '--hostname=foo'])
306
        self.assertIsNone(args.certfile)
307
        self.assertIsNone(args.keyfile)
308
        self.assertIsNone(args.cafile)
309
        self.assertEqual(args.port, 9390)
310
311
    def test_connection_type(self):
312
        args = self.parser.parse_args(['tls', '--hostname=foo'])
313
        self.assertEqual(args.connection_type, 'tls')
314
315
    def test_hostname(self):
316
        args = self.parser.parse_args(['tls', '--hostname', 'foo'])
317
        self.assertEqual(args.hostname, 'foo')
318
319
    def test_port(self):
320
        args = self.parser.parse_args(
321
            ['tls', '--hostname', 'foo', '--port', '123']
322
        )
323
        self.assertEqual(args.port, 123)
324
325
    def test_certfile(self):
326
        args = self.parser.parse_args(
327
            ['tls', '--hostname', 'foo', '--certfile', 'foo.cert']
328
        )
329
        self.assertEqual(args.certfile, 'foo.cert')
330
331
    def test_keyfile(self):
332
        args = self.parser.parse_args(
333
            ['tls', '--hostname', 'foo', '--keyfile', 'foo.key']
334
        )
335
        self.assertEqual(args.keyfile, 'foo.key')
336
337
    def test_cafile(self):
338
        args = self.parser.parse_args(
339
            ['tls', '--hostname', 'foo', '--cafile', 'foo.ca']
340
        )
341
        self.assertEqual(args.cafile, 'foo.ca')
342
343
    def test_no_credentials(self):
344
        args = self.parser.parse_args(
345
            ['tls', '--hostname', 'foo', '--no-credentials']
346
        )
347
        self.assertTrue(args.no_credentials)
348
349
350
class CustomizeParserTestCase(ParserTestCase):
351
    def test_add_optional_argument(self):
352
        self.parser.add_argument('--foo', type=int)
353
354
        args = self.parser.parse_args(['socket', '--foo', '123'])
355
        self.assertEqual(args.foo, 123)
356
357
        args = self.parser.parse_args(
358
            ['ssh', '--hostname', 'bar', '--foo', '123']
359
        )
360
        self.assertEqual(args.foo, 123)
361
362
        args = self.parser.parse_args(
363
            ['tls', '--hostname', 'bar', '--foo', '123']
364
        )
365
        self.assertEqual(args.foo, 123)
366
367
    def test_add_positional_argument(self):
368
        self.parser.add_argument('foo', type=int)
369
        args = self.parser.parse_args(['socket', '123'])
370
371
        self.assertEqual(args.foo, 123)
372
373
    def test_add_protocol_argument(self):
374
        self.parser.add_protocol_argument()
375
376
        args = self.parser.parse_args(['socket'])
377
        self.assertEqual(args.protocol, 'GMP')
378
379
        args = self.parser.parse_args(['--protocol', 'OSP', 'socket'])
380
381
        self.assertEqual(args.protocol, 'OSP')
382
383
384
class HelpFormattingParserTestCase(ParserTestCase):
385
    # pylint: disable=protected-access
386
    maxDiff = None
387
    python_version = '.'.join([str(i) for i in sys.version_info[:2]])
388
389
    def setUp(self):
390
        super().setUp()
391
392
        # ensure all tests are using the same terminal width
393
        self.columns = os.environ.get('COLUMNS')
394
        os.environ['COLUMNS'] = '80'
395
396
    def tearDown(self):
397
        super().tearDown()
398
399
        if not self.columns:
400
            del os.environ['COLUMNS']
401
        else:
402
            os.environ['COLUMNS'] = self.columns
403
404
    def _snapshot_specific_path(self, name):
405
        return __here__ / '{}.{}.snap'.format(name, self.python_version)
406
407
    def _snapshot_generic_path(self, name):
408
        return __here__ / '{}.snap'.format(name)
409
410
    def _snapshot_failed_path(self, name):
411
        return __here__ / '{}.{}-failed.snap'.format(name, self.python_version)
412
413
    def _snapshot_path(self, name):
414
        snapshot_specific_path = self._snapshot_specific_path(name)
415
416
        if snapshot_specific_path.exists():
417
            return snapshot_specific_path
418
419
        return self._snapshot_generic_path(name)
420
421
    def assert_snapshot(self, name, output):
422
        path = self._snapshot_path(name)
423
424
        if not path.exists():
425
            path.write_text(output)
426
427
        content = path.read_text()
428
429
        try:
430
            self.assertEqual(content, output, 'Snapshot differs from output')
431
        except AssertionError:
432
            # write new output to snapshot file
433
            # reraise error afterwards
434
            path = self._snapshot_failed_path(name)
435
            path.write_text(output)
436
            raise
437
438
    def test_root_help(self):
439
        help_output = self.parser._parser.format_help()
440
        self.assert_snapshot('root_help', help_output)
441
442
    def test_socket_help(self):
443
        help_output = self.parser._parser_socket.format_help()
444
        self.assert_snapshot('socket_help', help_output)
445
446
    def test_ssh_help(self):
447
        self.parser._set_defaults(None)
448
        help_output = self.parser._parser_ssh.format_help()
449
        self.assert_snapshot('ssh_help', help_output)
450
451
    def test_tls_help(self):
452
        self.parser._set_defaults(None)
453
        help_output = self.parser._parser_tls.format_help()
454
        self.assert_snapshot('tls_help', help_output)
455
456
457
class CreateParserFunctionTestCase(unittest.TestCase):
458
    # pylint: disable=protected-access
459
    def test_create_parser(self):
460
        description = 'parser description'
461
        logfilename = 'logfilename'
462
463
        parser = create_parser(description, logfilename)
464
465
        self.assertIsInstance(parser, CliParser)
466
        self.assertEqual(parser._logfilename, logfilename)
467
        self.assertEqual(parser._bootstrap_parser.description, description)
468
469
470
class CreateConnectionTestCase(unittest.TestCase):
471
    def test_create_unix_socket_connection(self):
472
        self.perform_create_connection_test()
473
474
    def test_create_tls_connection(self):
475
        self.perform_create_connection_test('tls', TLSConnection)
476
477
    def test_create_ssh_connection(self):
478
        self.perform_create_connection_test('ssh', SSHConnection)
479
480
    def perform_create_connection_test(
481
        self, connection_type='socket', connection_class=UnixSocketConnection
482
    ):
483
        connection = create_connection(connection_type)
484
        self.assertIsInstance(connection, connection_class)
485