Completed
Push — master ( 745830...54fe4d )
by Juan José
17s queued 13s
created

StartScanTestCase.test_scan_pop_ports()   A

Complexity

Conditions 1

Size

Total Lines 26
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 10
nop 1
dl 0
loc 26
rs 9.9
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
import time
19
20
from unittest import TestCase
21
from unittest.mock import patch
22
23
from xml.etree import ElementTree as et
24
25
from ospd.command.command import (
26
    GetPerformance,
27
    StartScan,
28
    StopScan,
29
    GetMemoryUsage,
30
)
31
from ospd.errors import OspdCommandError, OspdError
32
from ospd.misc import create_process
33
34
from ..helper import DummyWrapper, assert_called, FakeStream, FakeDataManager
35
36
37
class GetPerformanceTestCase(TestCase):
38
    @patch('ospd.command.command.subprocess')
39
    def test_get_performance(self, mock_subproc):
40
        cmd = GetPerformance(None)
41
        mock_subproc.check_output.return_value = b'foo'
42
        response = et.fromstring(
43
            cmd.handle_xml(
44
                et.fromstring(
45
                    '<get_performance start="0" end="0" titles="mem"/>'
46
                )
47
            )
48
        )
49
50
        self.assertEqual(response.get('status'), '200')
51
        self.assertEqual(response.tag, 'get_performance_response')
52
53
    def test_get_performance_fail_int(self):
54
        cmd = GetPerformance(None)
55
        request = et.fromstring(
56
            '<get_performance start="a" end="0" titles="mem"/>'
57
        )
58
59
        with self.assertRaises(OspdCommandError):
60
            cmd.handle_xml(request)
61
62
    def test_get_performance_fail_regex(self):
63
        cmd = GetPerformance(None)
64
        request = et.fromstring(
65
            '<get_performance start="0" end="0" titles="mem|bar"/>'
66
        )
67
68
        with self.assertRaises(OspdCommandError):
69
            cmd.handle_xml(request)
70
71
    def test_get_performance_fail_cmd(self):
72
        cmd = GetPerformance(None)
73
        request = et.fromstring(
74
            '<get_performance start="0" end="0" titles="mem1"/>'
75
        )
76
77
        with self.assertRaises(OspdCommandError):
78
            cmd.handle_xml(request)
79
80
81
class StartScanTestCase(TestCase):
82
    def test_scan_with_vts_empty_vt_list(self):
83
        daemon = DummyWrapper([])
84
        cmd = StartScan(daemon)
85
        request = et.fromstring(
86
            '<start_scan>'
87
            '<targets>'
88
            '<target>'
89
            '<hosts>localhost</hosts>'
90
            '<ports>80, 443</ports>'
91
            '</target>'
92
            '</targets>'
93
            '<scanner_params /><vt_selection />'
94
            '</start_scan>'
95
        )
96
97
        with self.assertRaises(OspdCommandError):
98
            cmd.handle_xml(request)
99
100
    @patch("ospd.ospd.create_process")
101
    def test_scan_with_vts(self, mock_create_process):
102
        daemon = DummyWrapper([])
103
        cmd = StartScan(daemon)
104
105
        request = et.fromstring(
106
            '<start_scan>'
107
            '<targets>'
108
            '<target>'
109
            '<hosts>localhost</hosts>'
110
            '<ports>80, 443</ports>'
111
            '</target>'
112
            '</targets>'
113
            '<scanner_params />'
114
            '<vt_selection>'
115
            '<vt_single id="1.2.3.4" />'
116
            '</vt_selection>'
117
            '</start_scan>'
118
        )
119
120
        # With one vt, without params
121
        response = et.fromstring(cmd.handle_xml(request))
122
        scan_id = response.findtext('id')
123
124
        vts_collection = daemon.get_scan_vts(scan_id)
125
        self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
126
        self.assertNotEqual(vts_collection, {'1.2.3.6': {}})
127
128
        daemon.start_pending_scans()
129
        assert_called(mock_create_process)
130
131
    def test_scan_pop_vts(self):
132
        daemon = DummyWrapper([])
133
        cmd = StartScan(daemon)
134
135
        request = et.fromstring(
136
            '<start_scan>'
137
            '<targets>'
138
            '<target>'
139
            '<hosts>localhost</hosts>'
140
            '<ports>80, 443</ports>'
141
            '</target>'
142
            '</targets>'
143
            '<scanner_params />'
144
            '<vt_selection>'
145
            '<vt_single id="1.2.3.4" />'
146
            '</vt_selection>'
147
            '</start_scan>'
148
        )
149
150
        # With one vt, without params
151
        response = et.fromstring(cmd.handle_xml(request))
152
        scan_id = response.findtext('id')
153
154
        vts_collection = daemon.get_scan_vts(scan_id)
155
        self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []})
156
        self.assertRaises(KeyError, daemon.get_scan_vts, scan_id)
157
158
    def test_scan_pop_ports(self):
159
        daemon = DummyWrapper([])
160
        cmd = StartScan(daemon)
161
162
        request = et.fromstring(
163
            '<start_scan>'
164
            '<targets>'
165
            '<target>'
166
            '<hosts>localhost</hosts>'
167
            '<ports>80, 443</ports>'
168
            '</target>'
169
            '</targets>'
170
            '<scanner_params />'
171
            '<vt_selection>'
172
            '<vt_single id="1.2.3.4" />'
173
            '</vt_selection>'
174
            '</start_scan>'
175
        )
176
177
        # With one vt, without params
178
        response = et.fromstring(cmd.handle_xml(request))
179
        scan_id = response.findtext('id')
180
181
        ports = daemon.scan_collection.get_ports(scan_id)
182
        self.assertEqual(ports, '80, 443')
183
        self.assertRaises(KeyError, daemon.scan_collection.get_ports, scan_id)
184
185
    def test_is_new_scan_allowed_false(self):
186
        daemon = DummyWrapper([])
187
        cmd = StartScan(daemon)
188
189
        cmd._daemon.scan_processes = {  # pylint: disable=protected-access
190
            'a': 1,
191
            'b': 2,
192
        }
193
        daemon.max_scans = 1
194
195
        self.assertFalse(cmd.is_new_scan_allowed())
196
197
    def test_is_new_scan_allowed_true(self):
198
        daemon = DummyWrapper([])
199
        cmd = StartScan(daemon)
200
201
        cmd._daemon.scan_processes = {  # pylint: disable=protected-access
202
            'a': 1,
203
            'b': 2,
204
        }
205
        daemon.max_scans = 3
206
207
        self.assertTrue(cmd.is_new_scan_allowed())
208
209 View Code Duplication
    @patch("ospd.ospd.create_process")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
210
    def test_scan_without_vts(self, mock_create_process):
211
        daemon = DummyWrapper([])
212
        cmd = StartScan(daemon)
213
214
        # With out vts
215
        request = et.fromstring(
216
            '<start_scan>'
217
            '<targets>'
218
            '<target>'
219
            '<hosts>localhost</hosts>'
220
            '<ports>80, 443</ports>'
221
            '</target>'
222
            '</targets>'
223
            '<scanner_params />'
224
            '</start_scan>'
225
        )
226
        response = et.fromstring(cmd.handle_xml(request))
227
228
        scan_id = response.findtext('id')
229
        self.assertEqual(daemon.get_scan_vts(scan_id), {})
230
231
        daemon.start_pending_scans()
232
        assert_called(mock_create_process)
233
234
    def test_scan_with_vts_and_param_missing_vt_param_id(self):
235
        daemon = DummyWrapper([])
236
        cmd = StartScan(daemon)
237
238
        # Raise because no vt_param id attribute
239
        request = et.fromstring(
240
            '<start_scan>'
241
            '<targets>'
242
            '<target>'
243
            '<hosts>localhost</hosts>'
244
            '<ports>80, 443</ports>'
245
            '</target>'
246
            '</targets>'
247
            '<scanner_params />'
248
            '<vt_selection>'
249
            '<vt_single id="1234"><vt_value>200</vt_value></vt_single>'
250
            '</vt_selection>'
251
            '</start_scan>'
252
        )
253
254
        with self.assertRaises(OspdError):
255
            cmd.handle_xml(request)
256
257
    @patch("ospd.ospd.create_process")
258
    def test_scan_with_vts_and_param(self, mock_create_process):
259
        daemon = DummyWrapper([])
260
        cmd = StartScan(daemon)
261
262
        # No error
263
        request = et.fromstring(
264
            '<start_scan>'
265
            '<targets>'
266
            '<target>'
267
            '<hosts>localhost</hosts>'
268
            '<ports>80, 443</ports>'
269
            '</target>'
270
            '</targets>'
271
            '<scanner_params />'
272
            '<vt_selection>'
273
            '<vt_single id="1234">'
274
            '<vt_value id="ABC">200</vt_value>'
275
            '</vt_single>'
276
            '</vt_selection>'
277
            '</start_scan>'
278
        )
279
        response = et.fromstring(cmd.handle_xml(request))
280
        scan_id = response.findtext('id')
281
282
        self.assertEqual(
283
            daemon.get_scan_vts(scan_id),
284
            {'1234': {'ABC': '200'}, 'vt_groups': []},
285
        )
286
        daemon.start_pending_scans()
287
        assert_called(mock_create_process)
288
289
    def test_scan_with_vts_and_param_missing_vt_group_filter(self):
290
        daemon = DummyWrapper([])
291
        cmd = StartScan(daemon)
292
293
        # Raise because no vtgroup filter attribute
294
        request = et.fromstring(
295
            '<start_scan>'
296
            '<targets>'
297
            '<target>'
298
            '<hosts>localhost</hosts>'
299
            '<ports>80, 443</ports>'
300
            '</target>'
301
            '</targets>'
302
            '<scanner_params />'
303
            '<vt_selection><vt_group/></vt_selection>'
304
            '</start_scan>'
305
        )
306
307
        with self.assertRaises(OspdError):
308
            cmd.handle_xml(request)
309
310 View Code Duplication
    @patch("ospd.ospd.create_process")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
311
    def test_scan_with_vts_and_param_with_vt_group_filter(
312
        self, mock_create_process
313
    ):
314
        daemon = DummyWrapper([])
315
        cmd = StartScan(daemon)
316
317
        # No error
318
        request = et.fromstring(
319
            '<start_scan>'
320
            '<targets>'
321
            '<target>'
322
            '<hosts>localhost</hosts>'
323
            '<ports>80, 443</ports>'
324
            '</target>'
325
            '</targets>'
326
            '<scanner_params />'
327
            '<vt_selection>'
328
            '<vt_group filter="a"/>'
329
            '</vt_selection>'
330
            '</start_scan>'
331
        )
332
        response = et.fromstring(cmd.handle_xml(request))
333
        scan_id = response.findtext('id')
334
335
        self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})
336
337
        daemon.start_pending_scans()
338
        assert_called(mock_create_process)
339
340
    @patch("ospd.ospd.create_process")
341
    @patch("ospd.command.command.logger")
342
    def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
343
        daemon = DummyWrapper([])
344
        cmd = StartScan(daemon)
345
        request = et.fromstring(
346
            '<start_scan parallel="100a">'
347
            '<targets>'
348
            '<target>'
349
            '<hosts>localhosts</hosts>'
350
            '<ports>22</ports>'
351
            '</target>'
352
            '</targets>'
353
            '<scanner_params />'
354
            '</start_scan>'
355
        )
356
357
        cmd.handle_xml(request)
358
        daemon.start_pending_scans()
359
        assert_called(mock_logger.warning)
360
        assert_called(mock_create_process)
361
362
    @patch("ospd.ospd.create_process")
363
    @patch("ospd.command.command.logger")
364
    def test_scan_use_legacy_target_and_port(
365
        self, mock_logger, mock_create_process
366
    ):
367
        daemon = DummyWrapper([])
368
        daemon.scan_collection.datamanager = FakeDataManager()
369
370
        cmd = StartScan(daemon)
371
        request = et.fromstring(
372
            '<start_scan target="localhost" ports="22">'
373
            '<scanner_params />'
374
            '</start_scan>'
375
        )
376
377
        response = et.fromstring(cmd.handle_xml(request))
378
        scan_id = response.findtext('id')
379
380
        self.assertIsNotNone(scan_id)
381
382
        self.assertEqual(daemon.get_scan_host(scan_id), 'localhost')
383
        self.assertEqual(daemon.get_scan_ports(scan_id), '22')
384
385
        daemon.start_pending_scans()
386
387
        assert_called(mock_logger.warning)
388
        assert_called(mock_create_process)
389
390
391
class StopCommandTestCase(TestCase):
392
    @patch("ospd.ospd.os")
393
    @patch("ospd.ospd.create_process")
394
    def test_stop_scan(self, mock_create_process, mock_os):
395
        mock_process = mock_create_process.return_value
396
        mock_process.is_alive.return_value = True
397
        mock_process.pid = "foo"
398
        fs = FakeStream()
399
        daemon = DummyWrapper([])
400
        daemon.scan_collection.datamanager = FakeDataManager()
401
        request = (
402
            '<start_scan>'
403
            '<targets>'
404
            '<target>'
405
            '<hosts>localhosts</hosts>'
406
            '<ports>22</ports>'
407
            '</target>'
408
            '</targets>'
409
            '<scanner_params />'
410
            '</start_scan>'
411
        )
412
        daemon.handle_command(request, fs)
413
        response = fs.get_response()
414
415
        daemon.start_pending_scans()
416
417
        assert_called(mock_create_process)
418
        assert_called(mock_process.start)
419
420
        scan_id = response.findtext('id')
421
422
        request = et.fromstring('<stop_scan scan_id="%s" />' % scan_id)
423
        cmd = StopScan(daemon)
424
        cmd.handle_xml(request)
425
426
        assert_called(mock_process.terminate)
427
428
        mock_os.getpgid.assert_called_with('foo')
429
430
    def test_unknown_scan_id(self):
431
        daemon = DummyWrapper([])
432
        cmd = StopScan(daemon)
433
        request = et.fromstring('<stop_scan scan_id="foo" />')
434
435
        with self.assertRaises(OspdCommandError):
436
            cmd.handle_xml(request)
437
438
    def test_missing_scan_id(self):
439
        request = et.fromstring('<stop_scan />')
440
        cmd = StopScan(None)
441
442
        with self.assertRaises(OspdCommandError):
443
            cmd.handle_xml(request)
444
445
446
class GetMemoryUsageTestCase(TestCase):
447
    def test_with_main_process_only(self):
448
        cmd = GetMemoryUsage(None)
449
450
        request = et.fromstring('<get_memory_usage />')
451
452
        response = et.fromstring(cmd.handle_xml(request))
453
        processes_element = response.find('processes')
454
455
        process_elements = processes_element.findall('process')
456
457
        self.assertTrue(len(process_elements), 1)
458
459
        main_process_element = process_elements[0]
460
461
        rss_element = main_process_element.find('rss')
462
        vms_element = main_process_element.find('vms')
463
        shared_element = main_process_element.find('shared')
464
465
        self.assertIsNotNone(rss_element)
466
        self.assertIsNotNone(rss_element.text)
467
468
        self.assertIsNotNone(vms_element)
469
        self.assertIsNotNone(vms_element.text)
470
471
        self.assertIsNotNone(shared_element)
472
        self.assertIsNotNone(shared_element.text)
473
474 View Code Duplication
    def test_with_subprocess(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
475
        cmd = GetMemoryUsage(None)
476
477
        def foo():  # pylint: disable=blacklisted-name
478
            time.sleep(60)
479
480
        create_process(foo, args=[])
481
482
        request = et.fromstring('<get_memory_usage />')
483
484
        response = et.fromstring(cmd.handle_xml(request))
485
        processes_element = response.find('processes')
486
487
        process_elements = processes_element.findall('process')
488
489
        self.assertTrue(len(process_elements), 2)
490
491
        for process_element in process_elements:
492
            rss_element = process_element.find('rss')
493
            vms_element = process_element.find('vms')
494
            shared_element = process_element.find('shared')
495
496
            self.assertIsNotNone(rss_element)
497
            self.assertIsNotNone(rss_element.text)
498
499
            self.assertIsNotNone(vms_element)
500
            self.assertIsNotNone(vms_element.text)
501
502
            self.assertIsNotNone(shared_element)
503
            self.assertIsNotNone(shared_element.text)
504
505 View Code Duplication
    def test_with_subsubprocess(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
506
        cmd = GetMemoryUsage(None)
507
508
        def bar():  # pylint: disable=blacklisted-name
509
            create_process(foo, args=[])
510
511
        def foo():  # pylint: disable=blacklisted-name
512
            time.sleep(60)
513
514
        create_process(bar, args=[])
515
516
        request = et.fromstring('<get_memory_usage />')
517
518
        response = et.fromstring(cmd.handle_xml(request))
519
        processes_element = response.find('processes')
520
521
        process_elements = processes_element.findall('process')
522
523
        # sub-sub-processes aren't listed
524
        self.assertTrue(len(process_elements), 2)
525
526
        for process_element in process_elements:
527
            rss_element = process_element.find('rss')
528
            vms_element = process_element.find('vms')
529
            shared_element = process_element.find('shared')
530
531
            self.assertIsNotNone(rss_element)
532
            self.assertIsNotNone(rss_element.text)
533
534
            self.assertIsNotNone(vms_element)
535
            self.assertIsNotNone(vms_element.text)
536
537
            self.assertIsNotNone(shared_element)
538
            self.assertIsNotNone(shared_element.text)
539