Completed
Push — master ( f571c7...585a62 )
by
unknown
17s queued 12s
created

StartScanTestCase.test_is_new_scan_allowed_false()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
import time
19
20
from unittest import TestCase
21
from unittest.mock import patch
22
23
from xml.etree import ElementTree as et
24
25
from ospd.command.command import (
26
    GetPerformance,
27
    StartScan,
28
    StopScan,
29
    GetMemoryUsage,
30
)
31
from ospd.errors import OspdCommandError, OspdError
32
from ospd.misc import create_process
33
34
from ..helper import DummyWrapper, assert_called, FakeStream
35
36
37
class GetPerformanceTestCase(TestCase):
38
    @patch('ospd.command.command.subprocess')
39
    def test_get_performance(self, mock_subproc):
40
        cmd = GetPerformance(None)
41
        mock_subproc.check_output.return_value = b'foo'
42
        response = et.fromstring(
43
            cmd.handle_xml(
44
                et.fromstring(
45
                    '<get_performance start="0" end="0" titles="mem"/>'
46
                )
47
            )
48
        )
49
50
        self.assertEqual(response.get('status'), '200')
51
        self.assertEqual(response.tag, 'get_performance_response')
52
53
    def test_get_performance_fail_int(self):
54
        cmd = GetPerformance(None)
55
        request = et.fromstring(
56
            '<get_performance start="a" end="0" titles="mem"/>'
57
        )
58
59
        with self.assertRaises(OspdCommandError):
60
            cmd.handle_xml(request)
61
62
    def test_get_performance_fail_regex(self):
63
        cmd = GetPerformance(None)
64
        request = et.fromstring(
65
            '<get_performance start="0" end="0" titles="mem|bar"/>'
66
        )
67
68
        with self.assertRaises(OspdCommandError):
69
            cmd.handle_xml(request)
70
71
    def test_get_performance_fail_cmd(self):
72
        cmd = GetPerformance(None)
73
        request = et.fromstring(
74
            '<get_performance start="0" end="0" titles="mem1"/>'
75
        )
76
77
        with self.assertRaises(OspdCommandError):
78
            cmd.handle_xml(request)
79
80
81
class StartScanTestCase(TestCase):
82
    def test_scan_with_vts_empty_vt_list(self):
83
        daemon = DummyWrapper([])
84
        cmd = StartScan(daemon)
85
        request = et.fromstring(
86
            '<start_scan>'
87
            '<targets>'
88
            '<target>'
89
            '<hosts>localhost</hosts>'
90
            '<ports>80, 443</ports>'
91
            '</target>'
92
            '</targets>'
93
            '<scanner_params /><vt_selection />'
94
            '</start_scan>'
95
        )
96
97
        with self.assertRaises(OspdCommandError):
98
            cmd.handle_xml(request)
99
100
    @patch("ospd.command.command.create_process")
101
    def test_scan_with_vts(self, mock_create_process):
102
        daemon = DummyWrapper([])
103
        cmd = StartScan(daemon)
104
105
        request = et.fromstring(
106
            '<start_scan>'
107
            '<targets>'
108
            '<target>'
109
            '<hosts>localhost</hosts>'
110
            '<ports>80, 443</ports>'
111
            '</target>'
112
            '</targets>'
113
            '<scanner_params />'
114
            '<vt_selection>'
115
            '<vt_single id="1.2.3.4" />'
116
            '</vt_selection>'
117
            '</start_scan>'
118
        )
119
120
        # With one vt, without params
121
        response = et.fromstring(cmd.handle_xml(request))
122
        scan_id = response.findtext('id')
123
124
        self.assertEqual(
125
            daemon.get_scan_vts(scan_id), {'1.2.3.4': {}, 'vt_groups': []}
126
        )
127
        self.assertNotEqual(daemon.get_scan_vts(scan_id), {'1.2.3.6': {}})
128
129
        assert_called(mock_create_process)
130
131
    def test_is_new_scan_allowed_false(self):
132
        daemon = DummyWrapper([])
133
        cmd = StartScan(daemon)
134
135
        cmd._daemon.scan_processes = {'a': 1, 'b': 2}
136
        daemon.max_scans = 1
137
138
        self.assertFalse(cmd.is_new_scan_allowed())
139
140
    def test_is_new_scan_allowed_true(self):
141
        daemon = DummyWrapper([])
142
        cmd = StartScan(daemon)
143
144
        cmd._daemon.scan_processes = {'a': 1, 'b': 2}
145
        daemon.max_scans = 3
146
147
        self.assertTrue(cmd.is_new_scan_allowed())
148
149
    @patch("ospd.command.command.create_process")
150
    def test_scan_without_vts(self, mock_create_process):
151
        daemon = DummyWrapper([])
152
        cmd = StartScan(daemon)
153
154
        # With out vts
155
        request = et.fromstring(
156
            '<start_scan>'
157
            '<targets>'
158
            '<target>'
159
            '<hosts>localhost</hosts>'
160
            '<ports>80, 443</ports>'
161
            '</target>'
162
            '</targets>'
163
            '<scanner_params />'
164
            '</start_scan>'
165
        )
166
        response = et.fromstring(cmd.handle_xml(request))
167
168
        scan_id = response.findtext('id')
169
170
        self.assertEqual(daemon.get_scan_vts(scan_id), {})
171
172
        assert_called(mock_create_process)
173
174
    def test_scan_with_vts_and_param_missing_vt_param_id(self):
175
        daemon = DummyWrapper([])
176
        cmd = StartScan(daemon)
177
178
        # Raise because no vt_param id attribute
179
        request = et.fromstring(
180
            '<start_scan>'
181
            '<targets>'
182
            '<target>'
183
            '<hosts>localhost</hosts>'
184
            '<ports>80, 443</ports>'
185
            '</target>'
186
            '</targets>'
187
            '<scanner_params />'
188
            '<vt_selection>'
189
            '<vt_single id="1234"><vt_value>200</vt_value></vt_single>'
190
            '</vt_selection>'
191
            '</start_scan>'
192
        )
193
194
        with self.assertRaises(OspdError):
195
            cmd.handle_xml(request)
196
197 View Code Duplication
    @patch("ospd.command.command.create_process")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
198
    def test_scan_with_vts_and_param(self, mock_create_process):
199
        daemon = DummyWrapper([])
200
        cmd = StartScan(daemon)
201
202
        # No error
203
        request = et.fromstring(
204
            '<start_scan>'
205
            '<targets>'
206
            '<target>'
207
            '<hosts>localhost</hosts>'
208
            '<ports>80, 443</ports>'
209
            '</target>'
210
            '</targets>'
211
            '<scanner_params />'
212
            '<vt_selection>'
213
            '<vt_single id="1234">'
214
            '<vt_value id="ABC">200</vt_value>'
215
            '</vt_single>'
216
            '</vt_selection>'
217
            '</start_scan>'
218
        )
219
        response = et.fromstring(cmd.handle_xml(request))
220
        scan_id = response.findtext('id')
221
222
        self.assertEqual(
223
            daemon.get_scan_vts(scan_id),
224
            {'1234': {'ABC': '200'}, 'vt_groups': []},
225
        )
226
227
        assert_called(mock_create_process)
228
229
    def test_scan_with_vts_and_param_missing_vt_group_filter(self):
230
        daemon = DummyWrapper([])
231
        cmd = StartScan(daemon)
232
233
        # Raise because no vtgroup filter attribute
234
        request = et.fromstring(
235
            '<start_scan>'
236
            '<targets>'
237
            '<target>'
238
            '<hosts>localhost</hosts>'
239
            '<ports>80, 443</ports>'
240
            '</target>'
241
            '</targets>'
242
            '<scanner_params />'
243
            '<vt_selection><vt_group/></vt_selection>'
244
            '</start_scan>'
245
        )
246
247
        with self.assertRaises(OspdError):
248
            cmd.handle_xml(request)
249
250 View Code Duplication
    @patch("ospd.command.command.create_process")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
251
    def test_scan_with_vts_and_param_with_vt_group_filter(
252
        self, mock_create_process
253
    ):
254
        daemon = DummyWrapper([])
255
        cmd = StartScan(daemon)
256
257
        # No error
258
        request = et.fromstring(
259
            '<start_scan>'
260
            '<targets>'
261
            '<target>'
262
            '<hosts>localhost</hosts>'
263
            '<ports>80, 443</ports>'
264
            '</target>'
265
            '</targets>'
266
            '<scanner_params />'
267
            '<vt_selection>'
268
            '<vt_group filter="a"/>'
269
            '</vt_selection>'
270
            '</start_scan>'
271
        )
272
        response = et.fromstring(cmd.handle_xml(request))
273
        scan_id = response.findtext('id')
274
275
        self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})
276
277
        assert_called(mock_create_process)
278
279
    @patch("ospd.command.command.create_process")
280
    @patch("ospd.command.command.logger")
281
    def test_scan_ignore_multi_target(self, mock_logger, mock_create_process):
282
        daemon = DummyWrapper([])
283
        cmd = StartScan(daemon)
284
        request = et.fromstring(
285
            '<start_scan parallel="100a">'
286
            '<targets>'
287
            '<target>'
288
            '<hosts>localhosts</hosts>'
289
            '<ports>22</ports>'
290
            '</target>'
291
            '</targets>'
292
            '<scanner_params />'
293
            '</start_scan>'
294
        )
295
296
        cmd.handle_xml(request)
297
298
        assert_called(mock_logger.warning)
299
        assert_called(mock_create_process)
300
301
    @patch("ospd.command.command.create_process")
302
    @patch("ospd.command.command.logger")
303
    def test_scan_use_legacy_target_and_port(
304
        self, mock_logger, mock_create_process
305
    ):
306
        daemon = DummyWrapper([])
307
        cmd = StartScan(daemon)
308
        request = et.fromstring(
309
            '<start_scan target="localhost" ports="22">'
310
            '<scanner_params />'
311
            '</start_scan>'
312
        )
313
314
        response = et.fromstring(cmd.handle_xml(request))
315
        scan_id = response.findtext('id')
316
317
        self.assertIsNotNone(scan_id)
318
319
        self.assertEqual(daemon.get_scan_host(scan_id), 'localhost')
320
        self.assertEqual(daemon.get_scan_ports(scan_id), '22')
321
322
        assert_called(mock_logger.warning)
323
        assert_called(mock_create_process)
324
325
326
class StopCommandTestCase(TestCase):
327
    @patch("ospd.ospd.os")
328
    @patch("ospd.command.command.create_process")
329
    def test_stop_scan(self, mock_create_process, mock_os):
330
        mock_process = mock_create_process.return_value
331
        mock_process.is_alive.return_value = True
332
        mock_process.pid = "foo"
333
334
        fs = FakeStream()
335
        daemon = DummyWrapper([])
336
        request = (
337
            '<start_scan>'
338
            '<targets>'
339
            '<target>'
340
            '<hosts>localhosts</hosts>'
341
            '<ports>22</ports>'
342
            '</target>'
343
            '</targets>'
344
            '<scanner_params />'
345
            '</start_scan>'
346
        )
347
        daemon.handle_command(request, fs)
348
        response = fs.get_response()
349
350
        assert_called(mock_create_process)
351
        assert_called(mock_process.start)
352
353
        scan_id = response.findtext('id')
354
355
        request = et.fromstring('<stop_scan scan_id="%s" />' % scan_id)
356
        cmd = StopScan(daemon)
357
        cmd.handle_xml(request)
358
359
        assert_called(mock_process.terminate)
360
361
        mock_os.getpgid.assert_called_with('foo')
362
363
    def test_unknown_scan_id(self):
364
        daemon = DummyWrapper([])
365
        cmd = StopScan(daemon)
366
        request = et.fromstring('<stop_scan scan_id="foo" />')
367
368
        with self.assertRaises(OspdCommandError):
369
            cmd.handle_xml(request)
370
371
    def test_missing_scan_id(self):
372
        request = et.fromstring('<stop_scan />')
373
        cmd = StopScan(None)
374
375
        with self.assertRaises(OspdCommandError):
376
            cmd.handle_xml(request)
377
378
379
class GetMemoryUsageTestCase(TestCase):
380
    def test_with_main_process_only(self):
381
        cmd = GetMemoryUsage(None)
382
383
        request = et.fromstring('<get_memory_usage />')
384
385
        response = et.fromstring(cmd.handle_xml(request))
386
        processes_element = response.find('processes')
387
388
        process_elements = processes_element.findall('process')
389
390
        self.assertTrue(len(process_elements), 1)
391
392
        main_process_element = process_elements[0]
393
394
        rss_element = main_process_element.find('rss')
395
        vms_element = main_process_element.find('vms')
396
        shared_element = main_process_element.find('shared')
397
398
        self.assertIsNotNone(rss_element)
399
        self.assertIsNotNone(rss_element.text)
400
401
        self.assertIsNotNone(vms_element)
402
        self.assertIsNotNone(vms_element.text)
403
404
        self.assertIsNotNone(shared_element)
405
        self.assertIsNotNone(shared_element.text)
406
407 View Code Duplication
    def test_with_subprocess(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
408
        cmd = GetMemoryUsage(None)
409
410
        def foo():  # pylint: disable=blacklisted-name
411
            time.sleep(60)
412
413
        create_process(foo, args=[])
414
415
        request = et.fromstring('<get_memory_usage />')
416
417
        response = et.fromstring(cmd.handle_xml(request))
418
        processes_element = response.find('processes')
419
420
        process_elements = processes_element.findall('process')
421
422
        self.assertTrue(len(process_elements), 2)
423
424
        for process_element in process_elements:
425
            rss_element = process_element.find('rss')
426
            vms_element = process_element.find('vms')
427
            shared_element = process_element.find('shared')
428
429
            self.assertIsNotNone(rss_element)
430
            self.assertIsNotNone(rss_element.text)
431
432
            self.assertIsNotNone(vms_element)
433
            self.assertIsNotNone(vms_element.text)
434
435
            self.assertIsNotNone(shared_element)
436
            self.assertIsNotNone(shared_element.text)
437
438 View Code Duplication
    def test_with_subsubprocess(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
439
        cmd = GetMemoryUsage(None)
440
441
        def bar():  # pylint: disable=blacklisted-name
442
            create_process(foo, args=[])
443
444
        def foo():  # pylint: disable=blacklisted-name
445
            time.sleep(60)
446
447
        create_process(bar, args=[])
448
449
        request = et.fromstring('<get_memory_usage />')
450
451
        response = et.fromstring(cmd.handle_xml(request))
452
        processes_element = response.find('processes')
453
454
        process_elements = processes_element.findall('process')
455
456
        # sub-sub-processes aren't listed
457
        self.assertTrue(len(process_elements), 2)
458
459
        for process_element in process_elements:
460
            rss_element = process_element.find('rss')
461
            vms_element = process_element.find('vms')
462
            shared_element = process_element.find('shared')
463
464
            self.assertIsNotNone(rss_element)
465
            self.assertIsNotNone(rss_element.text)
466
467
            self.assertIsNotNone(vms_element)
468
            self.assertIsNotNone(vms_element.text)
469
470
            self.assertIsNotNone(shared_element)
471
            self.assertIsNotNone(shared_element.text)
472