Issues (17)

tests/command/test_commands.py (4 issues)

1
# Copyright (C) 2014-2021 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
import time
19
20
from unittest import TestCase
21
from unittest.mock import patch
22
23
from xml.etree import ElementTree as et
24
25
from ospd.command.command import (
26
    GetPerformance,
27
    StartScan,
28
    StopScan,
29
    GetMemoryUsage,
30
)
31
from ospd.errors import OspdCommandError, OspdError
32
from ospd.misc import create_process
33
34
from ..helper import DummyWrapper, assert_called, FakeStream, FakeDataManager
35
36
37
class GetPerformanceTestCase(TestCase):
38
    @patch('ospd.command.command.subprocess')
39
    def test_get_performance(self, mock_subproc):
40
        cmd = GetPerformance(None)
41
        mock_subproc.check_output.return_value = b'foo'
42
        response = et.fromstring(
43
            cmd.handle_xml(
44
                et.fromstring(
45
                    '<get_performance start="0" end="0" titles="mem"/>'
46
                )
47
            )
48
        )
49
50
        self.assertEqual(response.get('status'), '200')
51
        self.assertEqual(response.tag, 'get_performance_response')
52
53
    def test_get_performance_fail_int(self):
54
        cmd = GetPerformance(None)
55
        request = et.fromstring(
56
            '<get_performance start="a" end="0" titles="mem"/>'
57
        )
58
59
        with self.assertRaises(OspdCommandError):
60
            cmd.handle_xml(request)
61
62
    def test_get_performance_fail_regex(self):
63
        cmd = GetPerformance(None)
64
        request = et.fromstring(
65
            '<get_performance start="0" end="0" titles="mem|bar"/>'
66
        )
67
68
        with self.assertRaises(OspdCommandError):
69
            cmd.handle_xml(request)
70
71
    def test_get_performance_fail_cmd(self):
72
        cmd = GetPerformance(None)
73
        request = et.fromstring(
74
            '<get_performance start="0" end="0" titles="mem1"/>'
75
        )
76
77
        with self.assertRaises(OspdCommandError):
78
            cmd.handle_xml(request)
79
80
81
class StartScanTestCase(TestCase):
82
    def test_scan_with_vts_empty_vt_list(self):
83
        daemon = DummyWrapper([])
84
        cmd = StartScan(daemon)
85
        request = et.fromstring(
86
            '<start_scan>'
87
            '<targets>'
88
            '<target>'
89
            '<hosts>localhost</hosts>'
90
            '<ports>80, 443</ports>'
91
            '</target>'
92
            '</targets>'
93
            '<scanner_params /><vt_selection />'
94
            '</start_scan>'
95
        )
96
97
        with self.assertRaises(OspdCommandError):
98
            cmd.handle_xml(request)
99
100
    @patch("ospd.ospd.create_process")
101
    def test_scan_with_vts(self, mock_create_process):
102
        daemon = DummyWrapper([])
103
        cmd = StartScan(daemon)
104
105
        request = et.fromstring(
106
            '<start_scan>'
107
            '<targets>'
108
            '<target>'
109
            '<hosts>localhost</hosts>'
110
            '<ports>80, 443</ports>'
111
            '</target>'
112
            '</targets>'
113
            '<scanner_params />'
114
            '<vt_selection>'
115
            '<vt_single id="1.2.3.4" />'
116
            '</vt_selection>'
117
            '</start_scan>'
118
        )
119
120
        # With one vt, without params
121
        response = et.fromstring(cmd.handle_xml(request))
122
        daemon.start_queued_scans()
123
        scan_id = response.findtext('id')
124
125
        vts_collection = daemon.get_scan_vts(scan_id)
126
        self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
127
        self.assertNotEqual(vts_collection, {'1.2.3.6': {}})
128
129
        daemon.start_queued_scans()
130
        assert_called(mock_create_process)
131
132
    def test_scan_pop_vts(self):
133
        daemon = DummyWrapper([])
134
        cmd = StartScan(daemon)
135
136
        request = et.fromstring(
137
            '<start_scan>'
138
            '<targets>'
139
            '<target>'
140
            '<hosts>localhost</hosts>'
141
            '<ports>80, 443</ports>'
142
            '</target>'
143
            '</targets>'
144
            '<scanner_params />'
145
            '<vt_selection>'
146
            '<vt_single id="1.2.3.4" />'
147
            '</vt_selection>'
148
            '</start_scan>'
149
        )
150
151
        # With one vt, without params
152
        response = et.fromstring(cmd.handle_xml(request))
153
        scan_id = response.findtext('id')
154
        daemon.start_queued_scans()
155
        vts_collection = daemon.get_scan_vts(scan_id)
156
        self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
157
        self.assertRaises(KeyError, daemon.get_scan_vts, scan_id)
158
159
    def test_scan_pop_ports(self):
160
        daemon = DummyWrapper([])
161
        cmd = StartScan(daemon)
162
163
        request = et.fromstring(
164
            '<start_scan>'
165
            '<targets>'
166
            '<target>'
167
            '<hosts>localhost</hosts>'
168
            '<ports>80, 443</ports>'
169
            '</target>'
170
            '</targets>'
171
            '<scanner_params />'
172
            '<vt_selection>'
173
            '<vt_single id="1.2.3.4" />'
174
            '</vt_selection>'
175
            '</start_scan>'
176
        )
177
178
        # With one vt, without params
179
        response = et.fromstring(cmd.handle_xml(request))
180
        daemon.start_queued_scans()
181
        scan_id = response.findtext('id')
182
183
        ports = daemon.scan_collection.get_ports(scan_id)
184
        self.assertEqual(ports, '80, 443')
185
        self.assertRaises(KeyError, daemon.scan_collection.get_ports, scan_id)
186
187
    @patch("ospd.ospd.create_process")
188
    def test_scan_without_vts(self, mock_create_process):
189
        daemon = DummyWrapper([])
190
        cmd = StartScan(daemon)
191
192
        # With out vts
193
        request = et.fromstring(
194
            '<start_scan>'
195
            '<targets>'
196
            '<target>'
197
            '<hosts>localhost</hosts>'
198
            '<ports>80, 443</ports>'
199
            '</target>'
200
            '</targets>'
201
            '<scanner_params />'
202
            '</start_scan>'
203
        )
204
205
        response = et.fromstring(cmd.handle_xml(request))
206
        daemon.start_queued_scans()
207
208
        scan_id = response.findtext('id')
209
        self.assertEqual(daemon.get_scan_vts(scan_id), {})
210
211
        assert_called(mock_create_process)
212
213
    def test_scan_with_vts_and_param_missing_vt_param_id(self):
214
        daemon = DummyWrapper([])
215
        cmd = StartScan(daemon)
216
217
        # Raise because no vt_param id attribute
218
        request = et.fromstring(
219
            '<start_scan>'
220
            '<targets>'
221
            '<target>'
222
            '<hosts>localhost</hosts>'
223
            '<ports>80, 443</ports>'
224
            '</target>'
225
            '</targets>'
226
            '<scanner_params />'
227
            '<vt_selection>'
228
            '<vt_single id="1234"><vt_value>200</vt_value></vt_single>'
229
            '</vt_selection>'
230
            '</start_scan>'
231
        )
232
233
        with self.assertRaises(OspdError):
234
            cmd.handle_xml(request)
235
236 View Code Duplication
    @patch("ospd.ospd.create_process")
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
237
    def test_scan_with_vts_and_param(self, mock_create_process):
238
        daemon = DummyWrapper([])
239
        cmd = StartScan(daemon)
240
241
        # No error
242
        request = et.fromstring(
243
            '<start_scan>'
244
            '<targets>'
245
            '<target>'
246
            '<hosts>localhost</hosts>'
247
            '<ports>80, 443</ports>'
248
            '</target>'
249
            '</targets>'
250
            '<scanner_params />'
251
            '<vt_selection>'
252
            '<vt_single id="1234">'
253
            '<vt_value id="ABC">200</vt_value>'
254
            '</vt_single>'
255
            '</vt_selection>'
256
            '</start_scan>'
257
        )
258
        response = et.fromstring(cmd.handle_xml(request))
259
        daemon.start_queued_scans()
260
261
        scan_id = response.findtext('id')
262
263
        self.assertEqual(
264
            daemon.get_scan_vts(scan_id),
265
            {'1234': {'ABC': '200'}, 'vt_groups': []},
266
        )
267
        daemon.start_queued_scans()
268
        assert_called(mock_create_process)
269
270
    def test_scan_with_vts_and_param_missing_vt_group_filter(self):
271
        daemon = DummyWrapper([])
272
        cmd = StartScan(daemon)
273
274
        # Raise because no vtgroup filter attribute
275
        request = et.fromstring(
276
            '<start_scan>'
277
            '<targets>'
278
            '<target>'
279
            '<hosts>localhost</hosts>'
280
            '<ports>80, 443</ports>'
281
            '</target>'
282
            '</targets>'
283
            '<scanner_params />'
284
            '<vt_selection><vt_group/></vt_selection>'
285
            '</start_scan>'
286
        )
287
        daemon.start_queued_scans()
288
289
        with self.assertRaises(OspdError):
290
            cmd.handle_xml(request)
291
292 View Code Duplication
    @patch("ospd.ospd.create_process")
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
293
    def test_scan_with_vts_and_param_with_vt_group_filter(
294
        self, mock_create_process
295
    ):
296
        daemon = DummyWrapper([])
297
        cmd = StartScan(daemon)
298
299
        # No error
300
        request = et.fromstring(
301
            '<start_scan>'
302
            '<targets>'
303
            '<target>'
304
            '<hosts>localhost</hosts>'
305
            '<ports>80, 443</ports>'
306
            '</target>'
307
            '</targets>'
308
            '<scanner_params />'
309
            '<vt_selection>'
310
            '<vt_group filter="a"/>'
311
            '</vt_selection>'
312
            '</start_scan>'
313
        )
314
        response = et.fromstring(cmd.handle_xml(request))
315
        daemon.start_queued_scans()
316
        scan_id = response.findtext('id')
317
318
        self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})
319
320
        assert_called(mock_create_process)
321
322
    @patch("ospd.ospd.create_process")
323
    @patch("ospd.command.command.logger")
324
    def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
325
        daemon = DummyWrapper([])
326
        cmd = StartScan(daemon)
327
        request = et.fromstring(
328
            '<start_scan parallel="100a">'
329
            '<targets>'
330
            '<target>'
331
            '<hosts>localhosts</hosts>'
332
            '<ports>22</ports>'
333
            '</target>'
334
            '</targets>'
335
            '<scanner_params />'
336
            '</start_scan>'
337
        )
338
339
        cmd.handle_xml(request)
340
        daemon.start_queued_scans()
341
        assert_called(mock_logger.warning)
342
        assert_called(mock_create_process)
343
344
    def test_max_queued_scans_reached(self):
345
        daemon = DummyWrapper([])
346
        daemon.max_queued_scans = 1
347
        cmd = StartScan(daemon)
348
        request = et.fromstring(
349
            '<start_scan parallel="100a">'
350
            '<targets>'
351
            '<target>'
352
            '<hosts>localhosts</hosts>'
353
            '<ports>22</ports>'
354
            '</target>'
355
            '</targets>'
356
            '<scanner_params />'
357
            '</start_scan>'
358
        )
359
360
        # create first scan
361
        response = et.fromstring(cmd.handle_xml(request))
362
        scan_id_1 = response.findtext('id')
363
364
        with self.assertRaises(OspdCommandError):
365
            cmd.handle_xml(request)
366
367
        daemon.scan_collection.remove_file_pickled_scan_info(scan_id_1)
368
369
    @patch("ospd.ospd.create_process")
370
    @patch("ospd.command.command.logger")
371
    def test_scan_use_legacy_target_and_port(
372
        self, mock_logger, mock_create_process
373
    ):
374
        daemon = DummyWrapper([])
375
        daemon.scan_collection.datamanager = FakeDataManager()
376
377
        cmd = StartScan(daemon)
378
        request = et.fromstring(
379
            '<start_scan target="localhost" ports="22">'
380
            '<scanner_params />'
381
            '</start_scan>'
382
        )
383
384
        response = et.fromstring(cmd.handle_xml(request))
385
        daemon.start_queued_scans()
386
        scan_id = response.findtext('id')
387
388
        self.assertIsNotNone(scan_id)
389
390
        self.assertEqual(daemon.get_scan_host(scan_id), 'localhost')
391
        self.assertEqual(daemon.get_scan_ports(scan_id), '22')
392
393
        assert_called(mock_logger.warning)
394
        assert_called(mock_create_process)
395
396
397
class StopCommandTestCase(TestCase):
398
    @patch("ospd.ospd.os")
399
    @patch("ospd.ospd.create_process")
400
    def test_stop_scan(self, mock_create_process, mock_os):
401
        mock_process = mock_create_process.return_value
402
        mock_process.is_alive.return_value = True
403
        mock_process.pid = "foo"
404
        fs = FakeStream()
405
        daemon = DummyWrapper([])
406
        daemon.scan_collection.datamanager = FakeDataManager()
407
        request = (
408
            '<start_scan>'
409
            '<targets>'
410
            '<target>'
411
            '<hosts>localhosts</hosts>'
412
            '<ports>22</ports>'
413
            '</target>'
414
            '</targets>'
415
            '<scanner_params />'
416
            '</start_scan>'
417
        )
418
        daemon.handle_command(request, fs)
419
        response = fs.get_response()
420
421
        daemon.start_queued_scans()
422
423
        assert_called(mock_create_process)
424
        assert_called(mock_process.start)
425
426
        scan_id = response.findtext('id')
427
428
        request = et.fromstring('<stop_scan scan_id="%s" />' % scan_id)
429
        cmd = StopScan(daemon)
430
        cmd.handle_xml(request)
431
432
        assert_called(mock_process.terminate)
433
434
        mock_os.getpgid.assert_called_with('foo')
435
436
    def test_unknown_scan_id(self):
437
        daemon = DummyWrapper([])
438
        cmd = StopScan(daemon)
439
        request = et.fromstring('<stop_scan scan_id="foo" />')
440
441
        with self.assertRaises(OspdCommandError):
442
            cmd.handle_xml(request)
443
444
    def test_missing_scan_id(self):
445
        request = et.fromstring('<stop_scan />')
446
        cmd = StopScan(None)
447
448
        with self.assertRaises(OspdCommandError):
449
            cmd.handle_xml(request)
450
451
452
class GetMemoryUsageTestCase(TestCase):
453
    def test_with_main_process_only(self):
454
        cmd = GetMemoryUsage(None)
455
456
        request = et.fromstring('<get_memory_usage />')
457
458
        response = et.fromstring(cmd.handle_xml(request))
459
        processes_element = response.find('processes')
460
461
        process_elements = processes_element.findall('process')
462
463
        self.assertTrue(len(process_elements), 1)
464
465
        main_process_element = process_elements[0]
466
467
        rss_element = main_process_element.find('rss')
468
        vms_element = main_process_element.find('vms')
469
        shared_element = main_process_element.find('shared')
470
471
        self.assertIsNotNone(rss_element)
472
        self.assertIsNotNone(rss_element.text)
473
474
        self.assertIsNotNone(vms_element)
475
        self.assertIsNotNone(vms_element.text)
476
477
        self.assertIsNotNone(shared_element)
478
        self.assertIsNotNone(shared_element.text)
479
480 View Code Duplication
    def test_with_subprocess(self):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
481
        cmd = GetMemoryUsage(None)
482
483
        def foo():  # pylint: disable=blacklisted-name
484
            time.sleep(60)
485
486
        create_process(foo, args=[])
487
488
        request = et.fromstring('<get_memory_usage />')
489
490
        response = et.fromstring(cmd.handle_xml(request))
491
        processes_element = response.find('processes')
492
493
        process_elements = processes_element.findall('process')
494
495
        self.assertTrue(len(process_elements), 2)
496
497
        for process_element in process_elements:
498
            rss_element = process_element.find('rss')
499
            vms_element = process_element.find('vms')
500
            shared_element = process_element.find('shared')
501
502
            self.assertIsNotNone(rss_element)
503
            self.assertIsNotNone(rss_element.text)
504
505
            self.assertIsNotNone(vms_element)
506
            self.assertIsNotNone(vms_element.text)
507
508
            self.assertIsNotNone(shared_element)
509
            self.assertIsNotNone(shared_element.text)
510
511 View Code Duplication
    def test_with_subsubprocess(self):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
512
        cmd = GetMemoryUsage(None)
513
514
        def bar():  # pylint: disable=blacklisted-name
515
            create_process(foo, args=[])
516
517
        def foo():  # pylint: disable=blacklisted-name
518
            time.sleep(60)
519
520
        create_process(bar, args=[])
521
522
        request = et.fromstring('<get_memory_usage />')
523
524
        response = et.fromstring(cmd.handle_xml(request))
525
        processes_element = response.find('processes')
526
527
        process_elements = processes_element.findall('process')
528
529
        # sub-sub-processes aren't listed
530
        self.assertTrue(len(process_elements), 2)
531
532
        for process_element in process_elements:
533
            rss_element = process_element.find('rss')
534
            vms_element = process_element.find('vms')
535
            shared_element = process_element.find('shared')
536
537
            self.assertIsNotNone(rss_element)
538
            self.assertIsNotNone(rss_element.text)
539
540
            self.assertIsNotNone(vms_element)
541
            self.assertIsNotNone(vms_element.text)
542
543
            self.assertIsNotNone(shared_element)
544
            self.assertIsNotNone(shared_element.text)
545