Passed
Pull Request — master (#207)
by
unknown
07:09
created

GetMemoryUsageTestCase.test_with_main_process_only()   A

Complexity

Conditions 1

Size

Total Lines 26
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 17
nop 1
dl 0
loc 26
rs 9.55
c 0
b 0
f 0
1
# Copyright (C) 2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
import time
20
21
from unittest import TestCase
22
from unittest.mock import patch
23
24
from xml.etree import ElementTree as et
25
26
from ospd.command.command import (
27
    GetPerformance,
28
    StartScan,
29
    StopScan,
30
    GetMemoryUsage,
31
)
32
from ospd.errors import OspdCommandError, OspdError
33
from ospd.misc import create_process
34
35
from ..helper import DummyWrapper, assert_called, FakeStream
36
37
38
class GetPerformanceTestCase(TestCase):
39
    @patch('ospd.command.command.subprocess')
40
    def test_get_performance(self, mock_subproc):
41
        cmd = GetPerformance(None)
42
        mock_subproc.check_output.return_value = b'foo'
43
        response = et.fromstring(
44
            cmd.handle_xml(
45
                et.fromstring(
46
                    '<get_performance start="0" end="0" titles="mem"/>'
47
                )
48
            )
49
        )
50
51
        self.assertEqual(response.get('status'), '200')
52
        self.assertEqual(response.tag, 'get_performance_response')
53
54
    def test_get_performance_fail_int(self):
55
        cmd = GetPerformance(None)
56
        request = et.fromstring(
57
            '<get_performance start="a" end="0" titles="mem"/>'
58
        )
59
60
        with self.assertRaises(OspdCommandError):
61
            cmd.handle_xml(request)
62
63
    def test_get_performance_fail_regex(self):
64
        cmd = GetPerformance(None)
65
        request = et.fromstring(
66
            '<get_performance start="0" end="0" titles="mem|bar"/>'
67
        )
68
69
        with self.assertRaises(OspdCommandError):
70
            cmd.handle_xml(request)
71
72
    def test_get_performance_fail_cmd(self):
73
        cmd = GetPerformance(None)
74
        request = et.fromstring(
75
            '<get_performance start="0" end="0" titles="mem1"/>'
76
        )
77
78
        with self.assertRaises(OspdCommandError):
79
            cmd.handle_xml(request)
80
81
82
class StartScanTestCase(TestCase):
83
    def test_scan_with_vts_empty_vt_list(self):
84
        daemon = DummyWrapper([])
85
        cmd = StartScan(daemon)
86
        request = et.fromstring(
87
            '<start_scan target="localhost" ports="80, 443">'
88
            '<scanner_params /><vt_selection />'
89
            '</start_scan>'
90
        )
91
92
        with self.assertRaises(OspdCommandError):
93
            cmd.handle_xml(request)
94
95
    @patch("ospd.command.command.create_process")
96
    def test_scan_with_vts(self, mock_create_process):
97
        daemon = DummyWrapper([])
98
        cmd = StartScan(daemon)
99
100
        request = et.fromstring(
101
            '<start_scan target="localhost" ports="80, 443">'
102
            '<scanner_params />'
103
            '<vt_selection>'
104
            '<vt_single id="1.2.3.4" />'
105
            '</vt_selection>'
106
            '</start_scan>'
107
        )
108
109
        # With one vt, without params
110
        response = et.fromstring(cmd.handle_xml(request))
111
        scan_id = response.findtext('id')
112
113
        self.assertEqual(
114
            daemon.get_scan_vts(scan_id), {'1.2.3.4': {}, 'vt_groups': []}
115
        )
116
        self.assertNotEqual(daemon.get_scan_vts(scan_id), {'1.2.3.6': {}})
117
118
        assert_called(mock_create_process)
119
120
    @patch("ospd.command.command.create_process")
121
    def test_scan_without_vts(self, mock_create_process):
122
        daemon = DummyWrapper([])
123
        cmd = StartScan(daemon)
124
125
        # With out vtS
126
        request = et.fromstring(
127
            '<start_scan target="localhost" ports="80, 443">'
128
            '<scanner_params />'
129
            '</start_scan>'
130
        )
131
        response = et.fromstring(cmd.handle_xml(request))
132
133
        scan_id = response.findtext('id')
134
135
        self.assertEqual(daemon.get_scan_vts(scan_id), {})
136
137
        assert_called(mock_create_process)
138
139
    def test_scan_with_vts_and_param_missing_vt_param_id(self):
140
        daemon = DummyWrapper([])
141
        cmd = StartScan(daemon)
142
143
        # Raise because no vt_param id attribute
144
        request = et.fromstring(
145
            '<start_scan target="localhost" ports="80, 443">'
146
            '<scanner_params />'
147
            '<vt_selection>'
148
            '<vt_single id="1234"><vt_value>200</vt_value></vt_single>'
149
            '</vt_selection>'
150
            '</start_scan>'
151
        )
152
153
        with self.assertRaises(OspdError):
154
            cmd.handle_xml(request)
155
156 View Code Duplication
    @patch("ospd.command.command.create_process")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
157
    def test_scan_with_vts_and_param(self, mock_create_process):
158
        daemon = DummyWrapper([])
159
        cmd = StartScan(daemon)
160
161
        # No error
162
        request = et.fromstring(
163
            '<start_scan target="localhost" ports="80, 443">'
164
            '<scanner_params />'
165
            '<vt_selection>'
166
            '<vt_single id="1234">'
167
            '<vt_value id="ABC">200</vt_value>'
168
            '</vt_single>'
169
            '</vt_selection>'
170
            '</start_scan>'
171
        )
172
        response = et.fromstring(cmd.handle_xml(request))
173
        scan_id = response.findtext('id')
174
175
        self.assertEqual(
176
            daemon.get_scan_vts(scan_id),
177
            {'1234': {'ABC': '200'}, 'vt_groups': []},
178
        )
179
180
        assert_called(mock_create_process)
181
182
    def test_scan_with_vts_and_param_missing_vt_group_filter(self):
183
        daemon = DummyWrapper([])
184
        cmd = StartScan(daemon)
185
186
        # Raise because no vtgroup filter attribute
187
        request = et.fromstring(
188
            '<start_scan target="localhost" ports="80, 443">'
189
            '<scanner_params />'
190
            '<vt_selection><vt_group/></vt_selection>'
191
            '</start_scan>'
192
        )
193
194
        with self.assertRaises(OspdError):
195
            cmd.handle_xml(request)
196
197 View Code Duplication
    @patch("ospd.command.command.create_process")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
198
    def test_scan_with_vts_and_param_with_vt_group_filter(
199
        self, mock_create_process
200
    ):
201
        daemon = DummyWrapper([])
202
        cmd = StartScan(daemon)
203
204
        # No error
205
        request = et.fromstring(
206
            '<start_scan target="localhost" ports="80, 443">'
207
            '<scanner_params />'
208
            '<vt_selection>'
209
            '<vt_group filter="a"/>'
210
            '</vt_selection>'
211
            '</start_scan>'
212
        )
213
        response = et.fromstring(cmd.handle_xml(request))
214
        scan_id = response.findtext('id')
215
216
        self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})
217
218
        assert_called(mock_create_process)
219
220
    def test_scan_multi_target_parallel_with_error(self):
221
        daemon = DummyWrapper([])
222
        cmd = StartScan(daemon)
223
        request = et.fromstring(
224
            '<start_scan parallel="100a">'
225
            '<scanner_params />'
226
            '<targets>'
227
            '<target>'
228
            '<hosts>localhosts</hosts>'
229
            '<ports>22</ports>'
230
            '</target>'
231
            '</targets>'
232
            '</start_scan>'
233
        )
234
235
        with self.assertRaises(OspdCommandError):
236
            cmd.handle_xml(request)
237
238
    @patch("ospd.ospd.OSPDaemon")
239
    @patch("ospd.command.command.create_process")
240
    def test_scan_multi_target_parallel_100(
241
        self, mock_create_process, mock_daemon
242
    ):
243
        daemon = mock_daemon()
244
        daemon.create_scan.return_value = '1'
245
        cmd = StartScan(daemon)
246
        request = et.fromstring(
247
            '<start_scan parallel="100">'
248
            '<scanner_params />'
249
            '<targets>'
250
            '<target>'
251
            '<hosts>localhosts</hosts>'
252
            '<ports>22</ports>'
253
            '</target>'
254
            '</targets>'
255
            '</start_scan>'
256
        )
257
        response = et.fromstring(cmd.handle_xml(request))
258
259
        self.assertEqual(response.get('status'), '200')
260
261
        assert_called(mock_create_process)
262
263
264
class StopCommandTestCase(TestCase):
265
    @patch("ospd.ospd.os")
266
    @patch("ospd.command.command.create_process")
267
    def test_stop_scan(self, mock_create_process, mock_os):
268
        mock_process = mock_create_process.return_value
269
        mock_process.is_alive.return_value = True
270
        mock_process.pid = "foo"
271
272
        fs = FakeStream()
273
        daemon = DummyWrapper([])
274
        request = (
275
            '<start_scan target="localhost" ports="80, 443">'
276
            '<scanner_params />'
277
            '</start_scan>'
278
        )
279
        daemon.handle_command(request, fs)
280
        response = fs.get_response()
281
282
        assert_called(mock_create_process)
283
        assert_called(mock_process.start)
284
285
        scan_id = response.findtext('id')
286
287
        request = et.fromstring('<stop_scan scan_id="%s" />' % scan_id)
288
        cmd = StopScan(daemon)
289
        cmd.handle_xml(request)
290
291
        assert_called(mock_process.terminate)
292
293
        mock_os.getpgid.assert_called_with('foo')
294
295
    def test_unknown_scan_id(self):
296
        daemon = DummyWrapper([])
297
        cmd = StopScan(daemon)
298
        request = et.fromstring('<stop_scan scan_id="foo" />')
299
300
        with self.assertRaises(OspdCommandError):
301
            cmd.handle_xml(request)
302
303
    def test_missing_scan_id(self):
304
        request = et.fromstring('<stop_scan />')
305
        cmd = StopScan(None)
306
307
        with self.assertRaises(OspdCommandError):
308
            cmd.handle_xml(request)
309
310
311
class GetMemoryUsageTestCase(TestCase):
312
    def test_with_main_process_only(self):
313
        cmd = GetMemoryUsage(None)
314
315
        request = et.fromstring('<get_memory_usage />')
316
317
        response = et.fromstring(cmd.handle_xml(request))
318
        processes_element = response.find('processes')
319
320
        process_elements = processes_element.findall('process')
321
322
        self.assertTrue(len(process_elements), 1)
323
324
        main_process_element = process_elements[0]
325
326
        rss_element = main_process_element.find('rss')
327
        vms_element = main_process_element.find('vms')
328
        shared_element = main_process_element.find('shared')
329
330
        self.assertIsNotNone(rss_element)
331
        self.assertIsNotNone(rss_element.text)
332
333
        self.assertIsNotNone(vms_element)
334
        self.assertIsNotNone(vms_element.text)
335
336
        self.assertIsNotNone(shared_element)
337
        self.assertIsNotNone(shared_element.text)
338
339 View Code Duplication
    def test_with_subprocess(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
340
        cmd = GetMemoryUsage(None)
341
342
        def foo():  # pylint: disable=blacklisted-name
343
            time.sleep(60)
344
345
        create_process(foo, args=[])
346
347
        request = et.fromstring('<get_memory_usage />')
348
349
        response = et.fromstring(cmd.handle_xml(request))
350
        processes_element = response.find('processes')
351
352
        process_elements = processes_element.findall('process')
353
354
        self.assertTrue(len(process_elements), 2)
355
356
        for process_element in process_elements:
357
            rss_element = process_element.find('rss')
358
            vms_element = process_element.find('vms')
359
            shared_element = process_element.find('shared')
360
361
            self.assertIsNotNone(rss_element)
362
            self.assertIsNotNone(rss_element.text)
363
364
            self.assertIsNotNone(vms_element)
365
            self.assertIsNotNone(vms_element.text)
366
367
            self.assertIsNotNone(shared_element)
368
            self.assertIsNotNone(shared_element.text)
369
370 View Code Duplication
    def test_with_subsubprocess(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
371
        cmd = GetMemoryUsage(None)
372
373
        def bar():  # pylint: disable=blacklisted-name
374
            create_process(foo, args=[])
375
376
        def foo():  # pylint: disable=blacklisted-name
377
            time.sleep(60)
378
379
        create_process(bar, args=[])
380
381
        request = et.fromstring('<get_memory_usage />')
382
383
        response = et.fromstring(cmd.handle_xml(request))
384
        processes_element = response.find('processes')
385
386
        process_elements = processes_element.findall('process')
387
388
        # sub-sub-processes aren't listed
389
        self.assertTrue(len(process_elements), 2)
390
391
        for process_element in process_elements:
392
            rss_element = process_element.find('rss')
393
            vms_element = process_element.find('vms')
394
            shared_element = process_element.find('shared')
395
396
            self.assertIsNotNone(rss_element)
397
            self.assertIsNotNone(rss_element.text)
398
399
            self.assertIsNotNone(vms_element)
400
            self.assertIsNotNone(vms_element.text)
401
402
            self.assertIsNotNone(shared_element)
403
            self.assertIsNotNone(shared_element.text)
404