Completed
Push — master ( 9305a2...0d408b )
by Björn
19s queued 12s
created

ScanTestCase.test_set_status_interrupted()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 13
nop 1
dl 0
loc 22
rs 9.75
c 0
b 0
f 0
1
# Copyright (C) 2014-2021 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" Test module for scan runs
21
"""
22
23
import time
24
import unittest
25
26
from unittest.mock import patch, MagicMock, Mock
27
28
import logging
29
import xml.etree.ElementTree as ET
30
31
from defusedxml.common import EntitiesForbidden
32
33
from ospd.resultlist import ResultList
34
from ospd.errors import OspdCommandError
35
from ospd.scan import ScanStatus
36
37
from .helper import (
38
    DummyWrapper,
39
    assert_called,
40
    FakeStream,
41
    FakeDataManager,
42
    FakePsutil,
43
)
44
45
46
class FakeStartProcess:
47
    def __init__(self):
48
        self.run_mock = MagicMock()
49
        self.call_mock = MagicMock()
50
51
        self.func = None
52
        self.args = None
53
        self.kwargs = None
54
55
    def __call__(self, func, *, args=None, kwargs=None):
56
        self.func = func
57
        self.args = args or []
58
        self.kwargs = kwargs or {}
59
        return self.call_mock
60
61
    def run(self):
62
        self.func(*self.args, **self.kwargs)
63
        return self.run_mock
64
65
    def __repr__(self):
66
        return "<FakeProcess func={} args={} kwargs={}>".format(
67
            self.func, self.args, self.kwargs
68
        )
69
70
71
class Result(object):
72
    def __init__(self, type_, **kwargs):
73
        self.result_type = type_
74
        self.host = ''
75
        self.hostname = ''
76
        self.name = ''
77
        self.value = ''
78
        self.port = ''
79
        self.test_id = ''
80
        self.severity = ''
81
        self.qod = ''
82
        self.uri = ''
83
        for name, value in kwargs.items():
84
            setattr(self, name, value)
85
86
87
class ScanTestCase(unittest.TestCase):
88
    def setUp(self):
89
        self.daemon = DummyWrapper([])
90
        self.daemon.scan_collection.datamanager = FakeDataManager()
91
        self.daemon.scan_collection.file_storage_dir = '/tmp'
92
93
    def test_get_default_scanner_params(self):
94
        fs = FakeStream()
95
96
        self.daemon.handle_command('<get_scanner_details />', fs)
97
        response = fs.get_response()
98
99
        # The status of the response must be success (i.e. 200)
100
        self.assertEqual(response.get('status'), '200')
101
        # The response root element must have the correct name
102
        self.assertEqual(response.tag, 'get_scanner_details_response')
103
        # The response must contain a 'scanner_params' element
104
        self.assertIsNotNone(response.find('scanner_params'))
105
106
    def test_get_default_help(self):
107
        fs = FakeStream()
108
109
        self.daemon.handle_command('<help />', fs)
110
        response = fs.get_response()
111
        self.assertEqual(response.get('status'), '200')
112
113
        fs = FakeStream()
114
        self.daemon.handle_command('<help format="xml" />', fs)
115
        response = fs.get_response()
116
117
        self.assertEqual(response.get('status'), '200')
118
        self.assertEqual(response.tag, 'help_response')
119
120
    def test_get_default_scanner_version(self):
121
        fs = FakeStream()
122
        self.daemon.handle_command('<get_version />', fs)
123
        response = fs.get_response()
124
125
        self.assertEqual(response.get('status'), '200')
126
        self.assertIsNotNone(response.find('protocol'))
127
128
    def test_get_vts_no_vt(self):
129
        fs = FakeStream()
130
131
        self.daemon.handle_command('<get_vts />', fs)
132
        response = fs.get_response()
133
134
        self.assertEqual(response.get('status'), '200')
135
        self.assertIsNotNone(response.find('vts'))
136
137
    def test_get_vt_xml_no_dict(self):
138
        single_vt = ('1234', None)
139
        vt = self.daemon.get_vt_xml(single_vt)
140
        self.assertFalse(vt.get('id'))
141
142
    def test_get_vts_single_vt(self):
143
        fs = FakeStream()
144
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
145
        self.daemon.handle_command('<get_vts />', fs)
146
        response = fs.get_response()
147
148
        self.assertEqual(response.get('status'), '200')
149
150
        vts = response.find('vts')
151
        self.assertIsNotNone(vts.find('vt'))
152
153
        vt = vts.find('vt')
154
        self.assertEqual(vt.get('id'), '1.2.3.4')
155
156
    def test_get_vts_version(self):
157
        fs = FakeStream()
158
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
159
        self.daemon.set_vts_version('today')
160
        self.daemon.handle_command('<get_vts />', fs)
161
        response = fs.get_response()
162
163
        self.assertEqual(response.get('status'), '200')
164
165
        vts_version = response.find('vts').attrib['vts_version']
166
        self.assertEqual(vts_version, self.daemon.get_vts_version())
167
168
        vts = response.find('vts')
169
        self.assertIsNotNone(vts.find('vt'))
170
171
        vt = vts.find('vt')
172
        self.assertEqual(vt.get('id'), '1.2.3.4')
173
174
    def test_get_vts_version_only(self):
175
        fs = FakeStream()
176
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
177
        self.daemon.set_vts_version('today')
178
        self.daemon.handle_command('<get_vts version_only="1"/>', fs)
179
        response = fs.get_response()
180
181
        self.assertEqual(response.get('status'), '200')
182
183
        vts_version = response.find('vts').attrib['vts_version']
184
        self.assertEqual(vts_version, self.daemon.get_vts_version())
185
186
        vts = response.find('vts')
187
        self.assertIsNone(vts.find('vt'))
188
189
    def test_get_vts_still_not_init(self):
190
        fs = FakeStream()
191
        self.daemon.initialized = False
192
        self.daemon.handle_command('<get_vts />', fs)
193
        response = fs.get_response()
194
195
        self.assertEqual(response.get('status'), '400')
196
197
    def test_get_help_still_not_init(self):
198
        fs = FakeStream()
199
        self.daemon.initialized = False
200
        self.daemon.handle_command('<help/>', fs)
201
        response = fs.get_response()
202
203
        self.assertEqual(response.get('status'), '200')
204
205 View Code Duplication
    def test_get_vts_filter_positive(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
206
        self.daemon.add_vt(
207
            '1.2.3.4',
208
            'A vulnerability test',
209
            vt_params="a",
210
            vt_modification_time='19000202',
211
        )
212
        fs = FakeStream()
213
214
        self.daemon.handle_command(
215
            '<get_vts filter="modification_time&gt;19000201"></get_vts>', fs
216
        )
217
        response = fs.get_response()
218
219
        self.assertEqual(response.get('status'), '200')
220
        vts = response.find('vts')
221
222
        vt = vts.find('vt')
223
        self.assertIsNotNone(vt)
224
        self.assertEqual(vt.get('id'), '1.2.3.4')
225
226
        modification_time = response.findall('vts/vt/modification_time')
227
        self.assertEqual(
228
            '<modification_time>19000202</modification_time>',
229
            ET.tostring(modification_time[0]).decode('utf-8'),
230
        )
231
232 View Code Duplication
    def test_get_vts_filter_negative(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
233
        self.daemon.add_vt(
234
            '1.2.3.4',
235
            'A vulnerability test',
236
            vt_params="a",
237
            vt_modification_time='19000202',
238
        )
239
        fs = FakeStream()
240
        self.daemon.handle_command(
241
            '<get_vts filter="modification_time&lt;19000203"></get_vts>',
242
            fs,
243
        )
244
        response = fs.get_response()
245
246
        self.assertEqual(response.get('status'), '200')
247
248
        vts = response.find('vts')
249
250
        vt = vts.find('vt')
251
        self.assertIsNotNone(vt)
252
        self.assertEqual(vt.get('id'), '1.2.3.4')
253
254
        modification_time = response.findall('vts/vt/modification_time')
255
        self.assertEqual(
256
            '<modification_time>19000202</modification_time>',
257
            ET.tostring(modification_time[0]).decode('utf-8'),
258
        )
259
260
    def test_get_vts_bad_filter(self):
261
        fs = FakeStream()
262
        cmd = '<get_vts filter="modification_time"/>'
263
264
        self.assertRaises(OspdCommandError, self.daemon.handle_command, cmd, fs)
265
        self.assertTrue(self.daemon.vts.is_cache_available)
266
267
    def test_get_vtss_multiple_vts(self):
268
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
269
        self.daemon.add_vt('1.2.3.5', 'Another vulnerability test')
270
        self.daemon.add_vt('123456789', 'Yet another vulnerability test')
271
272
        fs = FakeStream()
273
274
        self.daemon.handle_command('<get_vts />', fs)
275
        response = fs.get_response()
276
        self.assertEqual(response.get('status'), '200')
277
278
        vts = response.find('vts')
279
        self.assertIsNotNone(vts.find('vt'))
280
281
    def test_get_vts_multiple_vts_with_custom(self):
282
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b')
283
        self.daemon.add_vt(
284
            '4.3.2.1', 'Another vulnerability test with custom info', custom='b'
285
        )
286
        self.daemon.add_vt(
287
            '123456789', 'Yet another vulnerability test', custom='b'
288
        )
289
        fs = FakeStream()
290
291
        self.daemon.handle_command('<get_vts />', fs)
292
        response = fs.get_response()
293
294
        custom = response.findall('vts/vt/custom')
295
296
        self.assertEqual(3, len(custom))
297
298 View Code Duplication
    def test_get_vts_vts_with_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
299
        self.daemon.add_vt(
300
            '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b"
301
        )
302
        fs = FakeStream()
303
304
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
305
        response = fs.get_response()
306
307
        # The status of the response must be success (i.e. 200)
308
        self.assertEqual(response.get('status'), '200')
309
310
        # The response root element must have the correct name
311
        self.assertEqual(response.tag, 'get_vts_response')
312
        # The response must contain a 'scanner_params' element
313
        self.assertIsNotNone(response.find('vts'))
314
315
        vt_params = response[0][0].findall('params')
316
        self.assertEqual(1, len(vt_params))
317
318
        custom = response[0][0].findall('custom')
319
        self.assertEqual(1, len(custom))
320
321
        params = response.findall('vts/vt/params/param')
322
        self.assertEqual(2, len(params))
323
324 View Code Duplication
    def test_get_vts_vts_with_refs(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
325
        self.daemon.add_vt(
326
            '1.2.3.4',
327
            'A vulnerability test',
328
            vt_params="a",
329
            custom="b",
330
            vt_refs="c",
331
        )
332
        fs = FakeStream()
333
334
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
335
        response = fs.get_response()
336
337
        # The status of the response must be success (i.e. 200)
338
        self.assertEqual(response.get('status'), '200')
339
340
        # The response root element must have the correct name
341
        self.assertEqual(response.tag, 'get_vts_response')
342
343
        # The response must contain a 'vts' element
344
        self.assertIsNotNone(response.find('vts'))
345
346
        vt_params = response[0][0].findall('params')
347
        self.assertEqual(1, len(vt_params))
348
349
        custom = response[0][0].findall('custom')
350
        self.assertEqual(1, len(custom))
351
352
        refs = response.findall('vts/vt/refs/ref')
353
        self.assertEqual(2, len(refs))
354
355
    def test_get_vts_vts_with_dependencies(self):
356
        self.daemon.add_vt(
357
            '1.2.3.4',
358
            'A vulnerability test',
359
            vt_params="a",
360
            custom="b",
361
            vt_dependencies="c",
362
        )
363
        fs = FakeStream()
364
365
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
366
367
        response = fs.get_response()
368
369
        deps = response.findall('vts/vt/dependencies/dependency')
370
        self.assertEqual(2, len(deps))
371
372
    def test_get_vts_vts_with_severities(self):
373
        self.daemon.add_vt(
374
            '1.2.3.4',
375
            'A vulnerability test',
376
            vt_params="a",
377
            custom="b",
378
            severities="c",
379
        )
380
        fs = FakeStream()
381
382
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
383
        response = fs.get_response()
384
385
        severity = response.findall('vts/vt/severities/severity')
386
        self.assertEqual(1, len(severity))
387
388
    def test_get_vts_vts_with_detection_qodt(self):
389
        self.daemon.add_vt(
390
            '1.2.3.4',
391
            'A vulnerability test',
392
            vt_params="a",
393
            custom="b",
394
            detection="c",
395
            qod_t="d",
396
        )
397
        fs = FakeStream()
398
399
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
400
        response = fs.get_response()
401
402
        detection = response.findall('vts/vt/detection')
403
        self.assertEqual(1, len(detection))
404
405
    def test_get_vts_vts_with_detection_qodv(self):
406
        self.daemon.add_vt(
407
            '1.2.3.4',
408
            'A vulnerability test',
409
            vt_params="a",
410
            custom="b",
411
            detection="c",
412
            qod_v="d",
413
        )
414
        fs = FakeStream()
415
416
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
417
        response = fs.get_response()
418
419
        detection = response.findall('vts/vt/detection')
420
        self.assertEqual(1, len(detection))
421
422
    def test_get_vts_vts_with_summary(self):
423
        self.daemon.add_vt(
424
            '1.2.3.4',
425
            'A vulnerability test',
426
            vt_params="a",
427
            custom="b",
428
            summary="c",
429
        )
430
        fs = FakeStream()
431
432
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
433
        response = fs.get_response()
434
435
        summary = response.findall('vts/vt/summary')
436
        self.assertEqual(1, len(summary))
437
438
    def test_get_vts_vts_with_impact(self):
439
        self.daemon.add_vt(
440
            '1.2.3.4',
441
            'A vulnerability test',
442
            vt_params="a",
443
            custom="b",
444
            impact="c",
445
        )
446
        fs = FakeStream()
447
448
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
449
        response = fs.get_response()
450
451
        impact = response.findall('vts/vt/impact')
452
        self.assertEqual(1, len(impact))
453
454
    def test_get_vts_vts_with_affected(self):
455
        self.daemon.add_vt(
456
            '1.2.3.4',
457
            'A vulnerability test',
458
            vt_params="a",
459
            custom="b",
460
            affected="c",
461
        )
462
        fs = FakeStream()
463
464
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
465
        response = fs.get_response()
466
467
        affect = response.findall('vts/vt/affected')
468
        self.assertEqual(1, len(affect))
469
470
    def test_get_vts_vts_with_insight(self):
471
        self.daemon.add_vt(
472
            '1.2.3.4',
473
            'A vulnerability test',
474
            vt_params="a",
475
            custom="b",
476
            insight="c",
477
        )
478
        fs = FakeStream()
479
480
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
481
        response = fs.get_response()
482
483
        insight = response.findall('vts/vt/insight')
484
        self.assertEqual(1, len(insight))
485
486
    def test_get_vts_vts_with_solution(self):
487
        self.daemon.add_vt(
488
            '1.2.3.4',
489
            'A vulnerability test',
490
            vt_params="a",
491
            custom="b",
492
            solution="c",
493
            solution_t="d",
494
            solution_m="e",
495
        )
496
        fs = FakeStream()
497
498
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
499
        response = fs.get_response()
500
501
        solution = response.findall('vts/vt/solution')
502
        self.assertEqual(1, len(solution))
503
504
    def test_get_vts_vts_with_ctime(self):
505
        self.daemon.add_vt(
506
            '1.2.3.4',
507
            'A vulnerability test',
508
            vt_params="a",
509
            vt_creation_time='01-01-1900',
510
        )
511
        fs = FakeStream()
512
513
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
514
        response = fs.get_response()
515
516
        creation_time = response.findall('vts/vt/creation_time')
517
        self.assertEqual(
518
            '<creation_time>01-01-1900</creation_time>',
519
            ET.tostring(creation_time[0]).decode('utf-8'),
520
        )
521
522
    def test_get_vts_vts_with_mtime(self):
523
        self.daemon.add_vt(
524
            '1.2.3.4',
525
            'A vulnerability test',
526
            vt_params="a",
527
            vt_modification_time='02-01-1900',
528
        )
529
        fs = FakeStream()
530
531
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
532
        response = fs.get_response()
533
534
        modification_time = response.findall('vts/vt/modification_time')
535
        self.assertEqual(
536
            '<modification_time>02-01-1900</modification_time>',
537
            ET.tostring(modification_time[0]).decode('utf-8'),
538
        )
539
540
    def test_clean_forgotten_scans(self):
541
        fs = FakeStream()
542
543
        self.daemon.handle_command(
544
            '<start_scan target="localhost" ports="80, '
545
            '443"><scanner_params /></start_scan>',
546
            fs,
547
        )
548
        response = fs.get_response()
549
550
        scan_id = response.findtext('id')
551
552
        finished = False
553
554
        self.daemon.start_queued_scans()
555
        while not finished:
556
            fs = FakeStream()
557
            self.daemon.handle_command(
558
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
559
            )
560
            response = fs.get_response()
561
562
            scans = response.findall('scan')
563
            self.assertEqual(1, len(scans))
564
565
            scan = scans[0]
566
567
            if scan.get('end_time') != '0':
568
                finished = True
569
            else:
570
                time.sleep(0.01)
571
572
            fs = FakeStream()
573
            self.daemon.handle_command(
574
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
575
            )
576
            response = fs.get_response()
577
578
        self.assertEqual(
579
            len(list(self.daemon.scan_collection.ids_iterator())), 1
580
        )
581
582
        # Set an old end_time
583
        self.daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456
584
        # Run the check
585
        self.daemon.clean_forgotten_scans()
586
        # Not removed
587
        self.assertEqual(
588
            len(list(self.daemon.scan_collection.ids_iterator())), 1
589
        )
590
591
        # Set the max time and run again
592
        self.daemon.scaninfo_store_time = 1
593
        self.daemon.clean_forgotten_scans()
594
        # Now is removed
595
        self.assertEqual(
596
            len(list(self.daemon.scan_collection.ids_iterator())), 0
597
        )
598
599
    def test_scan_with_error(self):
600
        fs = FakeStream()
601
602
        self.daemon.handle_command(
603
            '<start_scan target="localhost" ports="80, '
604
            '443"><scanner_params /></start_scan>',
605
            fs,
606
        )
607
608
        response = fs.get_response()
609
        scan_id = response.findtext('id')
610
        finished = False
611
        self.daemon.start_queued_scans()
612
        self.daemon.add_scan_error(
613
            scan_id, host='a', value='something went wrong'
614
        )
615
616
        while not finished:
617
            fs = FakeStream()
618
            self.daemon.handle_command(
619
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
620
            )
621
            response = fs.get_response()
622
623
            scans = response.findall('scan')
624
            self.assertEqual(1, len(scans))
625
626
            scan = scans[0]
627
            status = scan.get('status')
628
629
            if status == "init" or status == "running":
630
                self.assertEqual('0', scan.get('end_time'))
631
                time.sleep(0.010)
632
            else:
633
                finished = True
634
635
            fs = FakeStream()
636
637
            self.daemon.handle_command(
638
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
639
            )
640
            response = fs.get_response()
641
642
        self.assertEqual(
643
            response.findtext('scan/results/result'), 'something went wrong'
644
        )
645
        fs = FakeStream()
646
        self.daemon.handle_command('<delete_scan scan_id="%s" />' % scan_id, fs)
647
        response = fs.get_response()
648
649
        self.assertEqual(response.get('status'), '200')
650
651
    def test_get_scan_pop(self):
652
        fs = FakeStream()
653
654
        self.daemon.handle_command(
655
            '<start_scan target="localhost" ports="80, 443">'
656
            '<scanner_params /></start_scan>',
657
            fs,
658
        )
659
        self.daemon.start_queued_scans()
660
        response = fs.get_response()
661
662
        scan_id = response.findtext('id')
663
        self.daemon.add_scan_host_detail(
664
            scan_id, host='a', value='Some Host Detail'
665
        )
666
667
        time.sleep(1)
668
669
        fs = FakeStream()
670
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
671
        response = fs.get_response()
672
673
        self.assertEqual(
674
            response.findtext('scan/results/result'), 'Some Host Detail'
675
        )
676
        fs = FakeStream()
677
        self.daemon.handle_command(
678
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
679
        )
680
        response = fs.get_response()
681
682
        self.assertEqual(
683
            response.findtext('scan/results/result'), 'Some Host Detail'
684
        )
685
686
        fs = FakeStream()
687
        self.daemon.handle_command(
688
            '<get_scans scan_id="%s" details="0" pop_results="1"/>' % scan_id,
689
            fs,
690
        )
691
        response = fs.get_response()
692
693
        self.assertEqual(response.findtext('scan/results/result'), None)
694
695
    def test_get_scan_pop_max_res(self):
696
        fs = FakeStream()
697
        self.daemon.handle_command(
698
            '<start_scan target="localhost" ports="80, 443">'
699
            '<scanner_params /></start_scan>',
700
            fs,
701
        )
702
        self.daemon.start_queued_scans()
703
        response = fs.get_response()
704
        scan_id = response.findtext('id')
705
706
        self.daemon.add_scan_log(scan_id, host='a', name='a')
707
        self.daemon.add_scan_log(scan_id, host='c', name='c')
708
        self.daemon.add_scan_log(scan_id, host='b', name='b')
709
710
        fs = FakeStream()
711
        self.daemon.handle_command(
712
            '<get_scans scan_id="%s" pop_results="1" max_results="1"/>'
713
            % scan_id,
714
            fs,
715
        )
716
717
        response = fs.get_response()
718
719
        self.assertEqual(len(response.findall('scan/results/result')), 1)
720
721
        fs = FakeStream()
722
        self.daemon.handle_command(
723
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
724
        )
725
        response = fs.get_response()
726
        self.assertEqual(len(response.findall('scan/results/result')), 2)
727
728 View Code Duplication
    def test_get_scan_results_clean(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
729
        fs = FakeStream()
730
        self.daemon.handle_command(
731
            '<start_scan target="localhost" ports="80, 443">'
732
            '<scanner_params /></start_scan>',
733
            fs,
734
        )
735
        self.daemon.start_queued_scans()
736
        response = fs.get_response()
737
        scan_id = response.findtext('id')
738
739
        self.daemon.add_scan_log(scan_id, host='a', name='a')
740
        self.daemon.add_scan_log(scan_id, host='c', name='c')
741
        self.daemon.add_scan_log(scan_id, host='b', name='b')
742
743
        fs = FakeStream()
744
        self.daemon.handle_command(
745
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
746
            fs,
747
        )
748
749
        res_len = len(
750
            self.daemon.scan_collection.scans_table[scan_id]['results']
751
        )
752
        self.assertEqual(res_len, 0)
753
754
        res_len = len(
755
            self.daemon.scan_collection.scans_table[scan_id]['temp_results']
756
        )
757
        self.assertEqual(res_len, 0)
758
759 View Code Duplication
    def test_get_scan_results_restore(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
760
        fs = FakeStream()
761
        self.daemon.handle_command(
762
            '<start_scan target="localhost" ports="80, 443">'
763
            '<scanner_params /></start_scan>',
764
            fs,
765
        )
766
        self.daemon.start_queued_scans()
767
        response = fs.get_response()
768
        scan_id = response.findtext('id')
769
770
        self.daemon.add_scan_log(scan_id, host='a', name='a')
771
        self.daemon.add_scan_log(scan_id, host='c', name='c')
772
        self.daemon.add_scan_log(scan_id, host='b', name='b')
773
774
        fs = FakeStream(return_value=False)
775
        self.daemon.handle_command(
776
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
777
            fs,
778
        )
779
780
        res_len = len(
781
            self.daemon.scan_collection.scans_table[scan_id]['results']
782
        )
783
        self.assertEqual(res_len, 3)
784
785
        res_len = len(
786
            self.daemon.scan_collection.scans_table[scan_id]['temp_results']
787
        )
788
        self.assertEqual(res_len, 0)
789
790
    def test_billon_laughs(self):
791
        # pylint: disable=line-too-long
792
793
        lol = (
794
            '<?xml version="1.0"?>'
795
            '<!DOCTYPE lolz ['
796
            ' <!ENTITY lol "lol">'
797
            ' <!ELEMENT lolz (#PCDATA)>'
798
            ' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">'
799
            ' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">'
800
            ' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">'
801
            ' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">'
802
            ' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">'
803
            ' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">'
804
            ' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">'
805
            ' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">'
806
            ' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">'
807
            ']>'
808
        )
809
        fs = FakeStream()
810
        self.assertRaises(
811
            EntitiesForbidden, self.daemon.handle_command, lol, fs
812
        )
813
814
    def test_target_with_credentials(self):
815
        fs = FakeStream()
816
        self.daemon.handle_command(
817
            '<start_scan>'
818
            '<scanner_params /><vts><vt id="1.2.3.4" />'
819
            '</vts>'
820
            '<targets><target>'
821
            '<hosts>192.168.0.0/24</hosts><ports>22'
822
            '</ports><credentials>'
823
            '<credential type="up" service="ssh" port="22">'
824
            '<username>scanuser</username>'
825
            '<password>mypass</password>'
826
            '</credential><credential type="up" service="smb">'
827
            '<username>smbuser</username>'
828
            '<password>mypass</password></credential>'
829
            '</credentials>'
830
            '</target></targets>'
831
            '</start_scan>',
832
            fs,
833
        )
834
        self.daemon.start_queued_scans()
835
        response = fs.get_response()
836
837
        self.assertEqual(response.get('status'), '200')
838
839
        cred_dict = {
840
            'ssh': {
841
                'type': 'up',
842
                'password': 'mypass',
843
                'port': '22',
844
                'username': 'scanuser',
845
            },
846
            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
847
        }
848
        scan_id = response.findtext('id')
849
        response = self.daemon.get_scan_credentials(scan_id)
850
        self.assertEqual(response, cred_dict)
851
852
    def test_target_with_credential_empty_community(self):
853
        fs = FakeStream()
854
        self.daemon.handle_command(
855
            '<start_scan>'
856
            '<scanner_params /><vts><vt id="1.2.3.4" />'
857
            '</vts>'
858
            '<targets><target>'
859
            '<hosts>192.168.0.0/24</hosts><ports>22'
860
            '</ports><credentials>'
861
            '<credential type="up" service="snmp">'
862
            '<community></community></credential>'
863
            '</credentials>'
864
            '</target></targets>'
865
            '</start_scan>',
866
            fs,
867
        )
868
        self.daemon.start_queued_scans()
869
        response = fs.get_response()
870
871
        self.assertEqual(response.get('status'), '200')
872
873
        cred_dict = {
874
            'snmp': {'type': 'up', 'community': ''},
875
        }
876
        scan_id = response.findtext('id')
877
        response = self.daemon.get_scan_credentials(scan_id)
878
        self.assertEqual(response, cred_dict)
879
880
    def test_scan_get_target(self):
881
        fs = FakeStream()
882
        self.daemon.handle_command(
883
            '<start_scan>'
884
            '<scanner_params /><vts><vt id="1.2.3.4" />'
885
            '</vts>'
886
            '<targets><target>'
887
            '<hosts>localhosts,192.168.0.0/24</hosts>'
888
            '<ports>80,443</ports>'
889
            '</target></targets>'
890
            '</start_scan>',
891
            fs,
892
        )
893
        self.daemon.start_queued_scans()
894
895
        response = fs.get_response()
896
        scan_id = response.findtext('id')
897
898
        fs = FakeStream()
899
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
900
        response = fs.get_response()
901
902
        scan_res = response.find('scan')
903
        self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
904
905
    def test_scan_get_target_options(self):
906
        fs = FakeStream()
907
        self.daemon.handle_command(
908
            '<start_scan>'
909
            '<scanner_params /><vts><vt id="1.2.3.4" />'
910
            '</vts>'
911
            '<targets>'
912
            '<target><hosts>192.168.0.1</hosts>'
913
            '<ports>22</ports><alive_test>0</alive_test></target>'
914
            '</targets>'
915
            '</start_scan>',
916
            fs,
917
        )
918
        self.daemon.start_queued_scans()
919
920
        response = fs.get_response()
921
922
        scan_id = response.findtext('id')
923
        time.sleep(1)
924
        target_options = self.daemon.get_scan_target_options(scan_id)
925
        self.assertEqual(target_options, {'alive_test': '0'})
926
927
    def test_scan_get_target_options_alive_test_methods(self):
928
        fs = FakeStream()
929
        self.daemon.handle_command(
930
            '<start_scan>'
931
            '<scanner_params /><vts><vt id="1.2.3.4" />'
932
            '</vts>'
933
            '<targets>'
934
            '<target><hosts>192.168.0.1</hosts>'
935
            '<ports>22</ports>'
936
            '<alive_test_methods>'
937
            '<icmp>1</icmp>'
938
            '<tcp_syn>1</tcp_syn>'
939
            '<tcp_ack>1</tcp_ack>'
940
            '<arp>1</arp>'
941
            '<consider_alive>1</consider_alive>'
942
            '</alive_test_methods>'
943
            '</target>'
944
            '</targets>'
945
            '</start_scan>',
946
            fs,
947
        )
948
        self.daemon.start_queued_scans()
949
950
        response = fs.get_response()
951
952
        scan_id = response.findtext('id')
953
        time.sleep(1)
954
        target_options = self.daemon.get_scan_target_options(scan_id)
955
        self.assertEqual(
956
            target_options,
957
            {
958
                'alive_test_methods': '1',
959
                'icmp': '1',
960
                'tcp_syn': '1',
961
                'tcp_ack': '1',
962
                'arp': '1',
963
                'consider_alive': '1',
964
            },
965
        )
966
967
    def test_scan_get_target_options_alive_test_methods_dont_add_empty_or_missing(  # pylint: disable=line-too-long
968
        self,
969
    ):
970
        fs = FakeStream()
971
        self.daemon.handle_command(
972
            '<start_scan>'
973
            '<scanner_params /><vts><vt id="1.2.3.4" />'
974
            '</vts>'
975
            '<targets>'
976
            '<target><hosts>192.168.0.1</hosts>'
977
            '<ports>22</ports>'
978
            '<alive_test_methods>'
979
            '<icmp>1</icmp>'
980
            '<arp></arp>'
981
            '<consider_alive></consider_alive>'
982
            '</alive_test_methods>'
983
            '</target>'
984
            '</targets>'
985
            '</start_scan>',
986
            fs,
987
        )
988
        self.daemon.start_queued_scans()
989
990
        response = fs.get_response()
991
992
        scan_id = response.findtext('id')
993
        time.sleep(1)
994
        target_options = self.daemon.get_scan_target_options(scan_id)
995
        self.assertEqual(
996
            target_options,
997
            {
998
                'alive_test_methods': '1',
999
                'icmp': '1',
1000
            },
1001
        )
1002
1003
    def test_progress(self):
1004
1005
        fs = FakeStream()
1006
        self.daemon.handle_command(
1007
            '<start_scan parallel="2">'
1008
            '<scanner_params />'
1009
            '<targets><target>'
1010
            '<hosts>localhost1, localhost2</hosts>'
1011
            '<ports>22</ports>'
1012
            '</target></targets>'
1013
            '</start_scan>',
1014
            fs,
1015
        )
1016
        self.daemon.start_queued_scans()
1017
        response = fs.get_response()
1018
1019
        scan_id = response.findtext('id')
1020
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1021
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1022
1023
        self.assertEqual(
1024
            self.daemon.scan_collection.calculate_target_progress(scan_id), 50
1025
        )
1026
1027
    def test_progress_all_host_dead(self):
1028
1029
        fs = FakeStream()
1030
        self.daemon.handle_command(
1031
            '<start_scan parallel="2">'
1032
            '<scanner_params />'
1033
            '<targets><target>'
1034
            '<hosts>localhost1, localhost2</hosts>'
1035
            '<ports>22</ports>'
1036
            '</target></targets>'
1037
            '</start_scan>',
1038
            fs,
1039
        )
1040
        self.daemon.start_queued_scans()
1041
        response = fs.get_response()
1042
1043
        scan_id = response.findtext('id')
1044
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', -1)
1045
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', -1)
1046
1047
        self.daemon.sort_host_finished(scan_id, ['localhost1', 'localhost2'])
1048
        self.assertEqual(
1049
            self.daemon.scan_collection.calculate_target_progress(scan_id), 100
1050
        )
1051
1052
    @patch('ospd.ospd.os')
1053
    def test_interrupted_scan(self, mock_os):
1054
        mock_os.setsid.return_value = None
1055
        fs = FakeStream()
1056
        self.daemon.handle_command(
1057
            '<start_scan parallel="2">'
1058
            '<scanner_params />'
1059
            '<targets><target>'
1060
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1061
            '<ports>22</ports>'
1062
            '</target></targets>'
1063
            '</start_scan>',
1064
            fs,
1065
        )
1066
        self.daemon.start_queued_scans()
1067
1068
        response = fs.get_response()
1069
        scan_id = response.findtext('id')
1070
1071
        self.daemon.exec_scan = Mock(return_value=None)
1072
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 5)
1073
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 14)
1074
        while self.daemon.get_scan_status(scan_id) == ScanStatus.INIT:
1075
            fs = FakeStream()
1076
            self.daemon.handle_command(
1077
                '<get_scans scan_id="%s" details="0" progress="0"/>' % scan_id,
1078
                fs,
1079
            )
1080
        response = fs.get_response()
1081
        status = response.find('scan').attrib['status']
1082
1083
        self.assertEqual(status, ScanStatus.INTERRUPTED.name.lower())
1084
1085
    def test_sort_host_finished(self):
1086
1087
        fs = FakeStream()
1088
        self.daemon.handle_command(
1089
            '<start_scan parallel="2">'
1090
            '<scanner_params />'
1091
            '<targets><target>'
1092
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1093
            '<ports>22</ports>'
1094
            '</target></targets>'
1095
            '</start_scan>',
1096
            fs,
1097
        )
1098
        self.daemon.start_queued_scans()
1099
1100
        response = fs.get_response()
1101
1102
        scan_id = response.findtext('id')
1103
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1104
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1105
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1106
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1107
1108
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1109
1110
        rounded_progress = self.daemon.scan_collection.calculate_target_progress(  # pylint: disable=line-too-long)
1111
            scan_id
1112
        )
1113
        self.assertEqual(rounded_progress, 66)
1114
1115
    def test_set_status_interrupted(self):
1116
        fs = FakeStream()
1117
        self.daemon.handle_command(
1118
            '<start_scan parallel="2">'
1119
            '<scanner_params />'
1120
            '<targets><target>'
1121
            '<hosts>localhost1</hosts>'
1122
            '<ports>22</ports>'
1123
            '</target></targets>'
1124
            '</start_scan>',
1125
            fs,
1126
        )
1127
        self.daemon.start_queued_scans()
1128
        response = fs.get_response()
1129
        scan_id = response.findtext('id')
1130
1131
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1132
        self.assertEqual(end_time, 0)
1133
1134
        self.daemon.interrupt_scan(scan_id)
1135
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1136
        self.assertNotEqual(end_time, 0)
1137
1138
    def test_set_status_stopped(self):
1139
        fs = FakeStream()
1140
        self.daemon.handle_command(
1141
            '<start_scan parallel="2">'
1142
            '<scanner_params />'
1143
            '<targets><target>'
1144
            '<hosts>localhost1</hosts>'
1145
            '<ports>22</ports>'
1146
            '</target></targets>'
1147
            '</start_scan>',
1148
            fs,
1149
        )
1150
        self.daemon.start_queued_scans()
1151
        response = fs.get_response()
1152
        scan_id = response.findtext('id')
1153
1154
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1155
        self.assertEqual(end_time, 0)
1156
1157
        self.daemon.set_scan_status(scan_id, ScanStatus.STOPPED)
1158
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1159
        self.assertNotEqual(end_time, 0)
1160
1161
    def test_calculate_progress_without_current_hosts(self):
1162
1163
        fs = FakeStream()
1164
        self.daemon.handle_command(
1165
            '<start_scan parallel="2">'
1166
            '<scanner_params />'
1167
            '<targets><target>'
1168
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1169
            '<ports>22</ports>'
1170
            '</target></targets>'
1171
            '</start_scan>',
1172
            fs,
1173
        )
1174
        self.daemon.start_queued_scans()
1175
        response = fs.get_response()
1176
1177
        scan_id = response.findtext('id')
1178
        self.daemon.set_scan_host_progress(scan_id)
1179
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1180
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1181
1182
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1183
1184
        float_progress = self.daemon.scan_collection.calculate_target_progress(
1185
            scan_id
1186
        )
1187
        self.assertEqual(int(float_progress), 33)
1188
1189
        self.daemon.scan_collection.set_progress(scan_id, float_progress)
1190
        progress = self.daemon.get_scan_progress(scan_id)
1191
        self.assertEqual(progress, 33)
1192
1193
    def test_get_scan_without_scanid(self):
1194
1195
        fs = FakeStream()
1196
        self.daemon.handle_command(
1197
            '<start_scan parallel="2">'
1198
            '<scanner_params />'
1199
            '<targets><target>'
1200
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1201
            '<ports>22</ports>'
1202
            '</target></targets>'
1203
            '</start_scan>',
1204
            fs,
1205
        )
1206
        self.daemon.start_queued_scans()
1207
1208
        fs = FakeStream()
1209
        self.assertRaises(
1210
            OspdCommandError,
1211
            self.daemon.handle_command,
1212
            '<get_scans details="0" progress="1"/>',
1213
            fs,
1214
        )
1215
1216
    def test_set_scan_total_hosts(self):
1217
1218
        fs = FakeStream()
1219
        self.daemon.handle_command(
1220
            '<start_scan parallel="2">'
1221
            '<scanner_params />'
1222
            '<targets><target>'
1223
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1224
            '<ports>22</ports>'
1225
            '</target></targets>'
1226
            '</start_scan>',
1227
            fs,
1228
        )
1229
        self.daemon.start_queued_scans()
1230
1231
        response = fs.get_response()
1232
        scan_id = response.findtext('id')
1233
1234
        count = self.daemon.scan_collection.get_count_total(scan_id)
1235
        self.assertEqual(count, 4)
1236
1237
        self.daemon.set_scan_total_hosts(scan_id, 3)
1238
        count = self.daemon.scan_collection.get_count_total(scan_id)
1239
        self.assertEqual(count, 3)
1240
1241
    def test_set_scan_total_hosts_invalid_target(self):
1242
1243
        fs = FakeStream()
1244
        self.daemon.handle_command(
1245
            '<start_scan parallel="2">'
1246
            '<scanner_params />'
1247
            '<targets><target>'
1248
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1249
            '<ports>22</ports>'
1250
            '</target></targets>'
1251
            '</start_scan>',
1252
            fs,
1253
        )
1254
        self.daemon.start_queued_scans()
1255
1256
        response = fs.get_response()
1257
        scan_id = response.findtext('id')
1258
1259
        count = self.daemon.scan_collection.get_count_total(scan_id)
1260
        self.assertEqual(count, 4)
1261
1262
        # The total host is set by the server as -1, because invalid target
1263
        self.daemon.set_scan_total_hosts(scan_id, -1)
1264
        count = self.daemon.scan_collection.get_count_total(scan_id)
1265
        self.assertEqual(count, 0)
1266
1267
    def test_get_scan_progress_xml(self):
1268
1269
        fs = FakeStream()
1270
        self.daemon.handle_command(
1271
            '<start_scan parallel="2">'
1272
            '<scanner_params />'
1273
            '<targets><target>'
1274
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1275
            '<ports>22</ports>'
1276
            '</target></targets>'
1277
            '</start_scan>',
1278
            fs,
1279
        )
1280
        self.daemon.start_queued_scans()
1281
1282
        response = fs.get_response()
1283
        scan_id = response.findtext('id')
1284
1285
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1286
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1287
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1288
1289
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1290
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1291
1292
        fs = FakeStream()
1293
        self.daemon.handle_command(
1294
            '<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id,
1295
            fs,
1296
        )
1297
        response = fs.get_response()
1298
1299
        progress = response.find('scan/progress')
1300
1301
        overall = float(progress.findtext('overall'))
1302
        self.assertEqual(int(overall), 66)
1303
1304
        count_alive = progress.findtext('count_alive')
1305
        self.assertEqual(count_alive, '1')
1306
1307
        count_dead = progress.findtext('count_dead')
1308
        self.assertEqual(count_dead, '1')
1309
1310
        current_hosts = progress.findall('host')
1311
        self.assertEqual(len(current_hosts), 2)
1312
1313
        count_excluded = progress.findtext('count_excluded')
1314
        self.assertEqual(count_excluded, '0')
1315
1316
    def test_set_get_vts_version(self):
1317
        self.daemon.set_vts_version('1234')
1318
1319
        version = self.daemon.get_vts_version()
1320
        self.assertEqual('1234', version)
1321
1322
    def test_set_get_vts_version_error(self):
1323
        self.assertRaises(TypeError, self.daemon.set_vts_version)
1324
1325
    @patch("ospd.ospd.os")
1326
    @patch("ospd.ospd.create_process")
1327
    def test_scan_exists(self, mock_create_process, _mock_os):
1328
        fp = FakeStartProcess()
1329
        mock_create_process.side_effect = fp
1330
        mock_process = fp.call_mock
1331
        mock_process.start.side_effect = fp.run
1332
        mock_process.is_alive.return_value = True
1333
        mock_process.pid = "main-scan-process"
1334
1335
        fs = FakeStream()
1336
        self.daemon.handle_command(
1337
            '<start_scan>'
1338
            '<scanner_params />'
1339
            '<targets><target>'
1340
            '<hosts>localhost</hosts>'
1341
            '<ports>22</ports>'
1342
            '</target></targets>'
1343
            '</start_scan>',
1344
            fs,
1345
        )
1346
        response = fs.get_response()
1347
        scan_id = response.findtext('id')
1348
        self.assertIsNotNone(scan_id)
1349
1350
        status = response.get('status_text')
1351
        self.assertEqual(status, 'OK')
1352
1353
        self.daemon.start_queued_scans()
1354
1355
        assert_called(mock_create_process)
1356
        assert_called(mock_process.start)
1357
1358
        self.daemon.handle_command('<stop_scan scan_id="%s" />' % scan_id, fs)
1359
1360
        fs = FakeStream()
1361
        cmd = (
1362
            '<start_scan scan_id="' + scan_id + '">'
1363
            '<scanner_params />'
1364
            '<targets><target>'
1365
            '<hosts>localhost</hosts>'
1366
            '<ports>22</ports>'
1367
            '</target></targets>'
1368
            '</start_scan>'
1369
        )
1370
1371
        self.daemon.handle_command(
1372
            cmd,
1373
            fs,
1374
        )
1375
        self.daemon.start_queued_scans()
1376
1377
        response = fs.get_response()
1378
        status = response.get('status_text')
1379
        self.assertEqual(status, 'Continue')
1380
1381
    def test_result_order(self):
1382
1383
        fs = FakeStream()
1384
        self.daemon.handle_command(
1385
            '<start_scan parallel="1">'
1386
            '<scanner_params />'
1387
            '<targets><target>'
1388
            '<hosts>a</hosts>'
1389
            '<ports>22</ports>'
1390
            '</target></targets>'
1391
            '</start_scan>',
1392
            fs,
1393
        )
1394
        self.daemon.start_queued_scans()
1395
        response = fs.get_response()
1396
1397
        scan_id = response.findtext('id')
1398
1399
        self.daemon.add_scan_log(scan_id, host='a', name='a')
1400
        self.daemon.add_scan_log(scan_id, host='c', name='c')
1401
        self.daemon.add_scan_log(scan_id, host='b', name='b')
1402
        hosts = ['a', 'c', 'b']
1403
1404
        fs = FakeStream()
1405
        self.daemon.handle_command(
1406
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1407
        )
1408
        response = fs.get_response()
1409
1410
        results = response.findall("scan/results/")
1411
1412
        for idx, res in enumerate(results):
1413
            att_dict = res.attrib
1414
            self.assertEqual(hosts[idx], att_dict['name'])
1415
1416
    def test_batch_result(self):
1417
        reslist = ResultList()
1418
        fs = FakeStream()
1419
        self.daemon.handle_command(
1420
            '<start_scan parallel="1">'
1421
            '<scanner_params />'
1422
            '<targets><target>'
1423
            '<hosts>a</hosts>'
1424
            '<ports>22</ports>'
1425
            '</target></targets>'
1426
            '</start_scan>',
1427
            fs,
1428
        )
1429
        self.daemon.start_queued_scans()
1430
        response = fs.get_response()
1431
1432
        scan_id = response.findtext('id')
1433
        reslist.add_scan_log_to_list(host='a', name='a')
1434
        reslist.add_scan_log_to_list(host='c', name='c')
1435
        reslist.add_scan_log_to_list(host='b', name='b')
1436
        self.daemon.scan_collection.add_result_list(scan_id, reslist)
1437
1438
        hosts = ['a', 'c', 'b']
1439
1440
        fs = FakeStream()
1441
        self.daemon.handle_command(
1442
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1443
        )
1444
        response = fs.get_response()
1445
1446
        results = response.findall("scan/results/")
1447
1448
        for idx, res in enumerate(results):
1449
            att_dict = res.attrib
1450
            self.assertEqual(hosts[idx], att_dict['name'])
1451
1452
    def test_is_new_scan_allowed_false(self):
1453
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1454
            'a': 1,
1455
            'b': 2,
1456
        }
1457
        self.daemon.max_scans = 1
1458
1459
        self.assertFalse(self.daemon.is_new_scan_allowed())
1460
1461
    def test_is_new_scan_allowed_true(self):
1462
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1463
            'a': 1,
1464
            'b': 2,
1465
        }
1466
        self.daemon.max_scans = 3
1467
1468
        self.assertTrue(self.daemon.is_new_scan_allowed())
1469
1470
    def test_start_queue_scan_daemon_not_init(self):
1471
        self.daemon.get_count_queued_scans = MagicMock(return_value=10)
1472
        self.daemon.initialized = False
1473
        logging.Logger.info = Mock()
1474
        self.daemon.start_queued_scans()
1475
1476
        logging.Logger.info.assert_called_with(  # pylint: disable=no-member
1477
            "Queued task can not be started because a "
1478
            "feed update is being performed."
1479
        )
1480
1481
    @patch("ospd.ospd.psutil")
1482
    def test_free_memory_true(self, mock_psutil):
1483
        self.daemon.min_free_mem_scan_queue = 1000
1484
        # 1.5 GB free
1485
        mock_psutil.virtual_memory.return_value = FakePsutil(
1486
            available=1500000000
1487
        )
1488
1489
        self.assertTrue(self.daemon.is_enough_free_memory())
1490
1491
    @patch("ospd.ospd.psutil")
1492
    def test_free_memory_false(self, mock_psutil):
1493
        self.daemon.min_free_mem_scan_queue = 2000
1494
        # 1.5 GB free
1495
        mock_psutil.virtual_memory.return_value = FakePsutil(
1496
            available=1500000000
1497
        )
1498
1499
        self.assertFalse(self.daemon.is_enough_free_memory())
1500
1501
    def test_count_queued_scans(self):
1502
        fs = FakeStream()
1503
        self.daemon.handle_command(
1504
            '<start_scan>'
1505
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1506
            '</vts>'
1507
            '<targets><target>'
1508
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1509
            '<ports>80,443</ports>'
1510
            '</target></targets>'
1511
            '</start_scan>',
1512
            fs,
1513
        )
1514
1515
        self.assertEqual(self.daemon.get_count_queued_scans(), 1)
1516
        self.daemon.start_queued_scans()
1517
        self.assertEqual(self.daemon.get_count_queued_scans(), 0)
1518
1519
    def test_ids_iterator_dict_modified(self):
1520
        self.daemon.scan_collection.scans_table = {'a': 1, 'b': 2}
1521
1522
        for _ in self.daemon.scan_collection.ids_iterator():
1523
            self.daemon.scan_collection.scans_table['c'] = 3
1524
1525
        self.assertEqual(len(self.daemon.scan_collection.scans_table), 3)
1526