Passed
Push — master ( eefdf0...e0b344 )
by Juan José
01:39 queued 11s
created

ScanTestCase.test_wait_between_scan_allow()   A

Complexity

Conditions 1

Size

Total Lines 29
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 12
nop 2
dl 0
loc 29
rs 9.8
c 0
b 0
f 0
1
# Copyright (C) 2014-2021 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" Test module for scan runs
21
"""
22
23
import time
24
import unittest
25
26
from unittest.mock import patch, MagicMock, Mock
27
28
import logging
29
import xml.etree.ElementTree as ET
30
31
from defusedxml.common import EntitiesForbidden
32
33
from ospd.resultlist import ResultList
34
from ospd.errors import OspdCommandError
35
from ospd.scan import ScanStatus
36
37
from .helper import (
38
    DummyWrapper,
39
    assert_called,
40
    FakeStream,
41
    FakeDataManager,
42
    FakePsutil,
43
)
44
45
46
class FakeStartProcess:
47
    def __init__(self):
48
        self.run_mock = MagicMock()
49
        self.call_mock = MagicMock()
50
51
        self.func = None
52
        self.args = None
53
        self.kwargs = None
54
55
    def __call__(self, func, *, args=None, kwargs=None):
56
        self.func = func
57
        self.args = args or []
58
        self.kwargs = kwargs or {}
59
        return self.call_mock
60
61
    def run(self):
62
        self.func(*self.args, **self.kwargs)
63
        return self.run_mock
64
65
    def __repr__(self):
66
        return "<FakeProcess func={} args={} kwargs={}>".format(
67
            self.func, self.args, self.kwargs
68
        )
69
70
71
class Result(object):
72
    def __init__(self, type_, **kwargs):
73
        self.result_type = type_
74
        self.host = ''
75
        self.hostname = ''
76
        self.name = ''
77
        self.value = ''
78
        self.port = ''
79
        self.test_id = ''
80
        self.severity = ''
81
        self.qod = ''
82
        self.uri = ''
83
        for name, value in kwargs.items():
84
            setattr(self, name, value)
85
86
87
class ScanTestCase(unittest.TestCase):
88
    def setUp(self):
89
        self.daemon = DummyWrapper([])
90
        self.daemon.scan_collection.datamanager = FakeDataManager()
91
        self.daemon.scan_collection.file_storage_dir = '/tmp'
92
93
    def test_get_default_scanner_params(self):
94
        fs = FakeStream()
95
96
        self.daemon.handle_command('<get_scanner_details />', fs)
97
        response = fs.get_response()
98
99
        # The status of the response must be success (i.e. 200)
100
        self.assertEqual(response.get('status'), '200')
101
        # The response root element must have the correct name
102
        self.assertEqual(response.tag, 'get_scanner_details_response')
103
        # The response must contain a 'scanner_params' element
104
        self.assertIsNotNone(response.find('scanner_params'))
105
106
    def test_get_default_help(self):
107
        fs = FakeStream()
108
109
        self.daemon.handle_command('<help />', fs)
110
        response = fs.get_response()
111
        self.assertEqual(response.get('status'), '200')
112
113
        fs = FakeStream()
114
        self.daemon.handle_command('<help format="xml" />', fs)
115
        response = fs.get_response()
116
117
        self.assertEqual(response.get('status'), '200')
118
        self.assertEqual(response.tag, 'help_response')
119
120
    def test_get_default_scanner_version(self):
121
        fs = FakeStream()
122
        self.daemon.handle_command('<get_version />', fs)
123
        response = fs.get_response()
124
125
        self.assertEqual(response.get('status'), '200')
126
        self.assertIsNotNone(response.find('protocol'))
127
128
    def test_get_vts_no_vt(self):
129
        fs = FakeStream()
130
131
        self.daemon.handle_command('<get_vts />', fs)
132
        response = fs.get_response()
133
134
        self.assertEqual(response.get('status'), '200')
135
        self.assertIsNotNone(response.find('vts'))
136
137
    def test_get_vt_xml_no_dict(self):
138
        single_vt = ('1234', None)
139
        vt = self.daemon.get_vt_xml(single_vt)
140
        self.assertFalse(vt.get('id'))
141
142
    def test_get_vts_single_vt(self):
143
        fs = FakeStream()
144
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
145
        self.daemon.handle_command('<get_vts />', fs)
146
        response = fs.get_response()
147
148
        self.assertEqual(response.get('status'), '200')
149
150
        vts = response.find('vts')
151
        self.assertIsNotNone(vts.find('vt'))
152
153
        vt = vts.find('vt')
154
        self.assertEqual(vt.get('id'), '1.2.3.4')
155
156
    def test_get_vts_version(self):
157
        fs = FakeStream()
158
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
159
        self.daemon.set_vts_version('today')
160
        self.daemon.handle_command('<get_vts />', fs)
161
        response = fs.get_response()
162
163
        self.assertEqual(response.get('status'), '200')
164
165
        vts_version = response.find('vts').attrib['vts_version']
166
        self.assertEqual(vts_version, self.daemon.get_vts_version())
167
168
        vts = response.find('vts')
169
        self.assertIsNotNone(vts.find('vt'))
170
171
        vt = vts.find('vt')
172
        self.assertEqual(vt.get('id'), '1.2.3.4')
173
174
    def test_get_vts_version_only(self):
175
        fs = FakeStream()
176
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
177
        self.daemon.set_vts_version('today')
178
        self.daemon.handle_command('<get_vts version_only="1"/>', fs)
179
        response = fs.get_response()
180
181
        self.assertEqual(response.get('status'), '200')
182
183
        vts_version = response.find('vts').attrib['vts_version']
184
        self.assertEqual(vts_version, self.daemon.get_vts_version())
185
186
        vts = response.find('vts')
187
        self.assertIsNone(vts.find('vt'))
188
189
    def test_get_vts_still_not_init(self):
190
        fs = FakeStream()
191
        self.daemon.initialized = False
192
        self.daemon.handle_command('<get_vts />', fs)
193
        response = fs.get_response()
194
195
        self.assertEqual(response.get('status'), '400')
196
197
    def test_get_help_still_not_init(self):
198
        fs = FakeStream()
199
        self.daemon.initialized = False
200
        self.daemon.handle_command('<help/>', fs)
201
        response = fs.get_response()
202
203
        self.assertEqual(response.get('status'), '200')
204
205 View Code Duplication
    def test_get_vts_filter_positive(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
206
        self.daemon.add_vt(
207
            '1.2.3.4',
208
            'A vulnerability test',
209
            vt_params="a",
210
            vt_modification_time='19000202',
211
        )
212
        fs = FakeStream()
213
214
        self.daemon.handle_command(
215
            '<get_vts filter="modification_time&gt;19000201"></get_vts>', fs
216
        )
217
        response = fs.get_response()
218
219
        self.assertEqual(response.get('status'), '200')
220
        vts = response.find('vts')
221
222
        vt = vts.find('vt')
223
        self.assertIsNotNone(vt)
224
        self.assertEqual(vt.get('id'), '1.2.3.4')
225
226
        modification_time = response.findall('vts/vt/modification_time')
227
        self.assertEqual(
228
            '<modification_time>19000202</modification_time>',
229
            ET.tostring(modification_time[0]).decode('utf-8'),
230
        )
231
232 View Code Duplication
    def test_get_vts_filter_negative(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
233
        self.daemon.add_vt(
234
            '1.2.3.4',
235
            'A vulnerability test',
236
            vt_params="a",
237
            vt_modification_time='19000202',
238
        )
239
        fs = FakeStream()
240
        self.daemon.handle_command(
241
            '<get_vts filter="modification_time&lt;19000203"></get_vts>',
242
            fs,
243
        )
244
        response = fs.get_response()
245
246
        self.assertEqual(response.get('status'), '200')
247
248
        vts = response.find('vts')
249
250
        vt = vts.find('vt')
251
        self.assertIsNotNone(vt)
252
        self.assertEqual(vt.get('id'), '1.2.3.4')
253
254
        modification_time = response.findall('vts/vt/modification_time')
255
        self.assertEqual(
256
            '<modification_time>19000202</modification_time>',
257
            ET.tostring(modification_time[0]).decode('utf-8'),
258
        )
259
260
    def test_get_vts_bad_filter(self):
261
        fs = FakeStream()
262
        cmd = '<get_vts filter="modification_time"/>'
263
264
        self.assertRaises(OspdCommandError, self.daemon.handle_command, cmd, fs)
265
        self.assertTrue(self.daemon.vts.is_cache_available)
266
267
    def test_get_vtss_multiple_vts(self):
268
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
269
        self.daemon.add_vt('1.2.3.5', 'Another vulnerability test')
270
        self.daemon.add_vt('123456789', 'Yet another vulnerability test')
271
272
        fs = FakeStream()
273
274
        self.daemon.handle_command('<get_vts />', fs)
275
        response = fs.get_response()
276
        self.assertEqual(response.get('status'), '200')
277
278
        vts = response.find('vts')
279
        self.assertIsNotNone(vts.find('vt'))
280
281
    def test_get_vts_multiple_vts_with_custom(self):
282
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b')
283
        self.daemon.add_vt(
284
            '4.3.2.1', 'Another vulnerability test with custom info', custom='b'
285
        )
286
        self.daemon.add_vt(
287
            '123456789', 'Yet another vulnerability test', custom='b'
288
        )
289
        fs = FakeStream()
290
291
        self.daemon.handle_command('<get_vts />', fs)
292
        response = fs.get_response()
293
294
        custom = response.findall('vts/vt/custom')
295
296
        self.assertEqual(3, len(custom))
297
298 View Code Duplication
    def test_get_vts_vts_with_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
299
        self.daemon.add_vt(
300
            '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b"
301
        )
302
        fs = FakeStream()
303
304
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
305
        response = fs.get_response()
306
307
        # The status of the response must be success (i.e. 200)
308
        self.assertEqual(response.get('status'), '200')
309
310
        # The response root element must have the correct name
311
        self.assertEqual(response.tag, 'get_vts_response')
312
        # The response must contain a 'scanner_params' element
313
        self.assertIsNotNone(response.find('vts'))
314
315
        vt_params = response[0][0].findall('params')
316
        self.assertEqual(1, len(vt_params))
317
318
        custom = response[0][0].findall('custom')
319
        self.assertEqual(1, len(custom))
320
321
        params = response.findall('vts/vt/params/param')
322
        self.assertEqual(2, len(params))
323
324 View Code Duplication
    def test_get_vts_vts_with_refs(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
325
        self.daemon.add_vt(
326
            '1.2.3.4',
327
            'A vulnerability test',
328
            vt_params="a",
329
            custom="b",
330
            vt_refs="c",
331
        )
332
        fs = FakeStream()
333
334
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
335
        response = fs.get_response()
336
337
        # The status of the response must be success (i.e. 200)
338
        self.assertEqual(response.get('status'), '200')
339
340
        # The response root element must have the correct name
341
        self.assertEqual(response.tag, 'get_vts_response')
342
343
        # The response must contain a 'vts' element
344
        self.assertIsNotNone(response.find('vts'))
345
346
        vt_params = response[0][0].findall('params')
347
        self.assertEqual(1, len(vt_params))
348
349
        custom = response[0][0].findall('custom')
350
        self.assertEqual(1, len(custom))
351
352
        refs = response.findall('vts/vt/refs/ref')
353
        self.assertEqual(2, len(refs))
354
355
    def test_get_vts_vts_with_dependencies(self):
356
        self.daemon.add_vt(
357
            '1.2.3.4',
358
            'A vulnerability test',
359
            vt_params="a",
360
            custom="b",
361
            vt_dependencies="c",
362
        )
363
        fs = FakeStream()
364
365
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
366
367
        response = fs.get_response()
368
369
        deps = response.findall('vts/vt/dependencies/dependency')
370
        self.assertEqual(2, len(deps))
371
372
    def test_get_vts_vts_with_severities(self):
373
        self.daemon.add_vt(
374
            '1.2.3.4',
375
            'A vulnerability test',
376
            vt_params="a",
377
            custom="b",
378
            severities="c",
379
        )
380
        fs = FakeStream()
381
382
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
383
        response = fs.get_response()
384
385
        severity = response.findall('vts/vt/severities/severity')
386
        self.assertEqual(1, len(severity))
387
388
    def test_get_vts_vts_with_detection_qodt(self):
389
        self.daemon.add_vt(
390
            '1.2.3.4',
391
            'A vulnerability test',
392
            vt_params="a",
393
            custom="b",
394
            detection="c",
395
            qod_t="d",
396
        )
397
        fs = FakeStream()
398
399
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
400
        response = fs.get_response()
401
402
        detection = response.findall('vts/vt/detection')
403
        self.assertEqual(1, len(detection))
404
405
    def test_get_vts_vts_with_detection_qodv(self):
406
        self.daemon.add_vt(
407
            '1.2.3.4',
408
            'A vulnerability test',
409
            vt_params="a",
410
            custom="b",
411
            detection="c",
412
            qod_v="d",
413
        )
414
        fs = FakeStream()
415
416
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
417
        response = fs.get_response()
418
419
        detection = response.findall('vts/vt/detection')
420
        self.assertEqual(1, len(detection))
421
422
    def test_get_vts_vts_with_summary(self):
423
        self.daemon.add_vt(
424
            '1.2.3.4',
425
            'A vulnerability test',
426
            vt_params="a",
427
            custom="b",
428
            summary="c",
429
        )
430
        fs = FakeStream()
431
432
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
433
        response = fs.get_response()
434
435
        summary = response.findall('vts/vt/summary')
436
        self.assertEqual(1, len(summary))
437
438
    def test_get_vts_vts_with_impact(self):
439
        self.daemon.add_vt(
440
            '1.2.3.4',
441
            'A vulnerability test',
442
            vt_params="a",
443
            custom="b",
444
            impact="c",
445
        )
446
        fs = FakeStream()
447
448
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
449
        response = fs.get_response()
450
451
        impact = response.findall('vts/vt/impact')
452
        self.assertEqual(1, len(impact))
453
454
    def test_get_vts_vts_with_affected(self):
455
        self.daemon.add_vt(
456
            '1.2.3.4',
457
            'A vulnerability test',
458
            vt_params="a",
459
            custom="b",
460
            affected="c",
461
        )
462
        fs = FakeStream()
463
464
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
465
        response = fs.get_response()
466
467
        affect = response.findall('vts/vt/affected')
468
        self.assertEqual(1, len(affect))
469
470
    def test_get_vts_vts_with_insight(self):
471
        self.daemon.add_vt(
472
            '1.2.3.4',
473
            'A vulnerability test',
474
            vt_params="a",
475
            custom="b",
476
            insight="c",
477
        )
478
        fs = FakeStream()
479
480
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
481
        response = fs.get_response()
482
483
        insight = response.findall('vts/vt/insight')
484
        self.assertEqual(1, len(insight))
485
486
    def test_get_vts_vts_with_solution(self):
487
        self.daemon.add_vt(
488
            '1.2.3.4',
489
            'A vulnerability test',
490
            vt_params="a",
491
            custom="b",
492
            solution="c",
493
            solution_t="d",
494
            solution_m="e",
495
        )
496
        fs = FakeStream()
497
498
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
499
        response = fs.get_response()
500
501
        solution = response.findall('vts/vt/solution')
502
        self.assertEqual(1, len(solution))
503
504
    def test_get_vts_vts_with_ctime(self):
505
        self.daemon.add_vt(
506
            '1.2.3.4',
507
            'A vulnerability test',
508
            vt_params="a",
509
            vt_creation_time='01-01-1900',
510
        )
511
        fs = FakeStream()
512
513
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
514
        response = fs.get_response()
515
516
        creation_time = response.findall('vts/vt/creation_time')
517
        self.assertEqual(
518
            '<creation_time>01-01-1900</creation_time>',
519
            ET.tostring(creation_time[0]).decode('utf-8'),
520
        )
521
522
    def test_get_vts_vts_with_mtime(self):
523
        self.daemon.add_vt(
524
            '1.2.3.4',
525
            'A vulnerability test',
526
            vt_params="a",
527
            vt_modification_time='02-01-1900',
528
        )
529
        fs = FakeStream()
530
531
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
532
        response = fs.get_response()
533
534
        modification_time = response.findall('vts/vt/modification_time')
535
        self.assertEqual(
536
            '<modification_time>02-01-1900</modification_time>',
537
            ET.tostring(modification_time[0]).decode('utf-8'),
538
        )
539
540
    def test_clean_forgotten_scans(self):
541
        fs = FakeStream()
542
543
        self.daemon.handle_command(
544
            '<start_scan target="localhost" ports="80, '
545
            '443"><scanner_params /></start_scan>',
546
            fs,
547
        )
548
        response = fs.get_response()
549
550
        scan_id = response.findtext('id')
551
552
        finished = False
553
554
        self.daemon.start_queued_scans()
555
        while not finished:
556
            fs = FakeStream()
557
            self.daemon.handle_command(
558
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
559
            )
560
            response = fs.get_response()
561
562
            scans = response.findall('scan')
563
            self.assertEqual(1, len(scans))
564
565
            scan = scans[0]
566
567
            if scan.get('end_time') != '0':
568
                finished = True
569
            else:
570
                time.sleep(0.01)
571
572
            fs = FakeStream()
573
            self.daemon.handle_command(
574
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
575
            )
576
            response = fs.get_response()
577
578
        self.assertEqual(
579
            len(list(self.daemon.scan_collection.ids_iterator())), 1
580
        )
581
582
        # Set an old end_time
583
        self.daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456
584
        # Run the check
585
        self.daemon.clean_forgotten_scans()
586
        # Not removed
587
        self.assertEqual(
588
            len(list(self.daemon.scan_collection.ids_iterator())), 1
589
        )
590
591
        # Set the max time and run again
592
        self.daemon.scaninfo_store_time = 1
593
        self.daemon.clean_forgotten_scans()
594
        # Now is removed
595
        self.assertEqual(
596
            len(list(self.daemon.scan_collection.ids_iterator())), 0
597
        )
598
599
    def test_scan_with_error(self):
600
        fs = FakeStream()
601
602
        self.daemon.handle_command(
603
            '<start_scan target="localhost" ports="80, '
604
            '443"><scanner_params /></start_scan>',
605
            fs,
606
        )
607
608
        response = fs.get_response()
609
        scan_id = response.findtext('id')
610
        finished = False
611
        self.daemon.start_queued_scans()
612
        self.daemon.add_scan_error(
613
            scan_id, host='a', value='something went wrong'
614
        )
615
616
        while not finished:
617
            fs = FakeStream()
618
            self.daemon.handle_command(
619
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
620
            )
621
            response = fs.get_response()
622
623
            scans = response.findall('scan')
624
            self.assertEqual(1, len(scans))
625
626
            scan = scans[0]
627
            status = scan.get('status')
628
629
            if status == "init" or status == "running":
630
                self.assertEqual('0', scan.get('end_time'))
631
                time.sleep(0.010)
632
            else:
633
                finished = True
634
635
            fs = FakeStream()
636
637
            self.daemon.handle_command(
638
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
639
            )
640
            response = fs.get_response()
641
642
        self.assertEqual(
643
            response.findtext('scan/results/result'), 'something went wrong'
644
        )
645
        fs = FakeStream()
646
        self.daemon.handle_command('<delete_scan scan_id="%s" />' % scan_id, fs)
647
        response = fs.get_response()
648
649
        self.assertEqual(response.get('status'), '200')
650
651
    def test_get_scan_pop(self):
652
        fs = FakeStream()
653
654
        self.daemon.handle_command(
655
            '<start_scan target="localhost" ports="80, 443">'
656
            '<scanner_params /></start_scan>',
657
            fs,
658
        )
659
        self.daemon.start_queued_scans()
660
        response = fs.get_response()
661
662
        scan_id = response.findtext('id')
663
        self.daemon.add_scan_host_detail(
664
            scan_id, host='a', value='Some Host Detail'
665
        )
666
667
        time.sleep(1)
668
669
        fs = FakeStream()
670
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
671
        response = fs.get_response()
672
673
        self.assertEqual(
674
            response.findtext('scan/results/result'), 'Some Host Detail'
675
        )
676
        fs = FakeStream()
677
        self.daemon.handle_command(
678
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
679
        )
680
        response = fs.get_response()
681
682
        self.assertEqual(
683
            response.findtext('scan/results/result'), 'Some Host Detail'
684
        )
685
686
        fs = FakeStream()
687
        self.daemon.handle_command(
688
            '<get_scans scan_id="%s" details="0" pop_results="1"/>' % scan_id,
689
            fs,
690
        )
691
        response = fs.get_response()
692
693
        self.assertEqual(response.findtext('scan/results/result'), None)
694
695
    def test_get_scan_pop_max_res(self):
696
        fs = FakeStream()
697
        self.daemon.handle_command(
698
            '<start_scan target="localhost" ports="80, 443">'
699
            '<scanner_params /></start_scan>',
700
            fs,
701
        )
702
        self.daemon.start_queued_scans()
703
        response = fs.get_response()
704
        scan_id = response.findtext('id')
705
706
        self.daemon.add_scan_log(scan_id, host='a', name='a')
707
        self.daemon.add_scan_log(scan_id, host='c', name='c')
708
        self.daemon.add_scan_log(scan_id, host='b', name='b')
709
710
        fs = FakeStream()
711
        self.daemon.handle_command(
712
            '<get_scans scan_id="%s" pop_results="1" max_results="1"/>'
713
            % scan_id,
714
            fs,
715
        )
716
717
        response = fs.get_response()
718
719
        self.assertEqual(len(response.findall('scan/results/result')), 1)
720
721
        fs = FakeStream()
722
        self.daemon.handle_command(
723
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
724
        )
725
        response = fs.get_response()
726
        self.assertEqual(len(response.findall('scan/results/result')), 2)
727
728 View Code Duplication
    def test_get_scan_results_clean(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
729
        fs = FakeStream()
730
        self.daemon.handle_command(
731
            '<start_scan target="localhost" ports="80, 443">'
732
            '<scanner_params /></start_scan>',
733
            fs,
734
        )
735
        self.daemon.start_queued_scans()
736
        response = fs.get_response()
737
        scan_id = response.findtext('id')
738
739
        self.daemon.add_scan_log(scan_id, host='a', name='a')
740
        self.daemon.add_scan_log(scan_id, host='c', name='c')
741
        self.daemon.add_scan_log(scan_id, host='b', name='b')
742
743
        fs = FakeStream()
744
        self.daemon.handle_command(
745
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
746
            fs,
747
        )
748
749
        res_len = len(
750
            self.daemon.scan_collection.scans_table[scan_id]['results']
751
        )
752
        self.assertEqual(res_len, 0)
753
754
        res_len = len(
755
            self.daemon.scan_collection.scans_table[scan_id]['temp_results']
756
        )
757
        self.assertEqual(res_len, 0)
758
759 View Code Duplication
    def test_get_scan_results_restore(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
760
        fs = FakeStream()
761
        self.daemon.handle_command(
762
            '<start_scan target="localhost" ports="80, 443">'
763
            '<scanner_params /></start_scan>',
764
            fs,
765
        )
766
        self.daemon.start_queued_scans()
767
        response = fs.get_response()
768
        scan_id = response.findtext('id')
769
770
        self.daemon.add_scan_log(scan_id, host='a', name='a')
771
        self.daemon.add_scan_log(scan_id, host='c', name='c')
772
        self.daemon.add_scan_log(scan_id, host='b', name='b')
773
774
        fs = FakeStream(return_value=False)
775
        self.daemon.handle_command(
776
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
777
            fs,
778
        )
779
780
        res_len = len(
781
            self.daemon.scan_collection.scans_table[scan_id]['results']
782
        )
783
        self.assertEqual(res_len, 3)
784
785
        res_len = len(
786
            self.daemon.scan_collection.scans_table[scan_id]['temp_results']
787
        )
788
        self.assertEqual(res_len, 0)
789
790
    def test_billon_laughs(self):
791
        # pylint: disable=line-too-long
792
793
        lol = (
794
            '<?xml version="1.0"?>'
795
            '<!DOCTYPE lolz ['
796
            ' <!ENTITY lol "lol">'
797
            ' <!ELEMENT lolz (#PCDATA)>'
798
            ' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">'
799
            ' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">'
800
            ' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">'
801
            ' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">'
802
            ' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">'
803
            ' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">'
804
            ' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">'
805
            ' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">'
806
            ' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">'
807
            ']>'
808
        )
809
        fs = FakeStream()
810
        self.assertRaises(
811
            EntitiesForbidden, self.daemon.handle_command, lol, fs
812
        )
813
814
    def test_target_with_credentials(self):
815
        fs = FakeStream()
816
        self.daemon.handle_command(
817
            '<start_scan>'
818
            '<scanner_params /><vts><vt id="1.2.3.4" />'
819
            '</vts>'
820
            '<targets><target>'
821
            '<hosts>192.168.0.0/24</hosts><ports>22'
822
            '</ports><credentials>'
823
            '<credential type="up" service="ssh" port="22">'
824
            '<username>scanuser</username>'
825
            '<password>mypass</password>'
826
            '</credential><credential type="up" service="smb">'
827
            '<username>smbuser</username>'
828
            '<password>mypass</password></credential>'
829
            '</credentials>'
830
            '</target></targets>'
831
            '</start_scan>',
832
            fs,
833
        )
834
        self.daemon.start_queued_scans()
835
        response = fs.get_response()
836
837
        self.assertEqual(response.get('status'), '200')
838
839
        cred_dict = {
840
            'ssh': {
841
                'type': 'up',
842
                'password': 'mypass',
843
                'port': '22',
844
                'username': 'scanuser',
845
            },
846
            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
847
        }
848
        scan_id = response.findtext('id')
849
        response = self.daemon.get_scan_credentials(scan_id)
850
        self.assertEqual(response, cred_dict)
851
852
    def test_target_with_credential_empty_community(self):
853
        fs = FakeStream()
854
        self.daemon.handle_command(
855
            '<start_scan>'
856
            '<scanner_params /><vts><vt id="1.2.3.4" />'
857
            '</vts>'
858
            '<targets><target>'
859
            '<hosts>192.168.0.0/24</hosts><ports>22'
860
            '</ports><credentials>'
861
            '<credential type="up" service="snmp">'
862
            '<community></community></credential>'
863
            '</credentials>'
864
            '</target></targets>'
865
            '</start_scan>',
866
            fs,
867
        )
868
        self.daemon.start_queued_scans()
869
        response = fs.get_response()
870
871
        self.assertEqual(response.get('status'), '200')
872
873
        cred_dict = {
874
            'snmp': {'type': 'up', 'community': ''},
875
        }
876
        scan_id = response.findtext('id')
877
        response = self.daemon.get_scan_credentials(scan_id)
878
        self.assertEqual(response, cred_dict)
879
880
    def test_scan_get_target(self):
881
        fs = FakeStream()
882
        self.daemon.handle_command(
883
            '<start_scan>'
884
            '<scanner_params /><vts><vt id="1.2.3.4" />'
885
            '</vts>'
886
            '<targets><target>'
887
            '<hosts>localhosts,192.168.0.0/24</hosts>'
888
            '<ports>80,443</ports>'
889
            '</target></targets>'
890
            '</start_scan>',
891
            fs,
892
        )
893
        self.daemon.start_queued_scans()
894
895
        response = fs.get_response()
896
        scan_id = response.findtext('id')
897
898
        fs = FakeStream()
899
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
900
        response = fs.get_response()
901
902
        scan_res = response.find('scan')
903
        self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
904
905
    def test_scan_get_target_options(self):
906
        fs = FakeStream()
907
        self.daemon.handle_command(
908
            '<start_scan>'
909
            '<scanner_params /><vts><vt id="1.2.3.4" />'
910
            '</vts>'
911
            '<targets>'
912
            '<target><hosts>192.168.0.1</hosts>'
913
            '<ports>22</ports><alive_test>0</alive_test></target>'
914
            '</targets>'
915
            '</start_scan>',
916
            fs,
917
        )
918
        self.daemon.start_queued_scans()
919
920
        response = fs.get_response()
921
922
        scan_id = response.findtext('id')
923
        time.sleep(1)
924
        target_options = self.daemon.get_scan_target_options(scan_id)
925
        self.assertEqual(target_options, {'alive_test': '0'})
926
927
    def test_scan_get_target_options_alive_test_methods(self):
928
        fs = FakeStream()
929
        self.daemon.handle_command(
930
            '<start_scan>'
931
            '<scanner_params /><vts><vt id="1.2.3.4" />'
932
            '</vts>'
933
            '<targets>'
934
            '<target><hosts>192.168.0.1</hosts>'
935
            '<ports>22</ports>'
936
            '<alive_test_methods>'
937
            '<icmp>1</icmp>'
938
            '<tcp_syn>1</tcp_syn>'
939
            '<tcp_ack>1</tcp_ack>'
940
            '<arp>1</arp>'
941
            '<consider_alive>1</consider_alive>'
942
            '</alive_test_methods>'
943
            '</target>'
944
            '</targets>'
945
            '</start_scan>',
946
            fs,
947
        )
948
        self.daemon.start_queued_scans()
949
950
        response = fs.get_response()
951
952
        scan_id = response.findtext('id')
953
        time.sleep(1)
954
        target_options = self.daemon.get_scan_target_options(scan_id)
955
        self.assertEqual(
956
            target_options,
957
            {
958
                'alive_test_methods': '1',
959
                'icmp': '1',
960
                'tcp_syn': '1',
961
                'tcp_ack': '1',
962
                'arp': '1',
963
                'consider_alive': '1',
964
            },
965
        )
966
967
    def test_scan_get_target_options_alive_test_methods_dont_add_empty_or_missing(  # pylint: disable=line-too-long
968
        self,
969
    ):
970
        fs = FakeStream()
971
        self.daemon.handle_command(
972
            '<start_scan>'
973
            '<scanner_params /><vts><vt id="1.2.3.4" />'
974
            '</vts>'
975
            '<targets>'
976
            '<target><hosts>192.168.0.1</hosts>'
977
            '<ports>22</ports>'
978
            '<alive_test_methods>'
979
            '<icmp>1</icmp>'
980
            '<arp></arp>'
981
            '<consider_alive></consider_alive>'
982
            '</alive_test_methods>'
983
            '</target>'
984
            '</targets>'
985
            '</start_scan>',
986
            fs,
987
        )
988
        self.daemon.start_queued_scans()
989
990
        response = fs.get_response()
991
992
        scan_id = response.findtext('id')
993
        time.sleep(1)
994
        target_options = self.daemon.get_scan_target_options(scan_id)
995
        self.assertEqual(
996
            target_options,
997
            {
998
                'alive_test_methods': '1',
999
                'icmp': '1',
1000
            },
1001
        )
1002
1003
    def test_progress(self):
1004
1005
        fs = FakeStream()
1006
        self.daemon.handle_command(
1007
            '<start_scan parallel="2">'
1008
            '<scanner_params />'
1009
            '<targets><target>'
1010
            '<hosts>localhost1, localhost2</hosts>'
1011
            '<ports>22</ports>'
1012
            '</target></targets>'
1013
            '</start_scan>',
1014
            fs,
1015
        )
1016
        self.daemon.start_queued_scans()
1017
        response = fs.get_response()
1018
1019
        scan_id = response.findtext('id')
1020
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1021
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1022
1023
        self.assertEqual(
1024
            self.daemon.scan_collection.calculate_target_progress(scan_id), 50
1025
        )
1026
1027
    def test_progress_all_host_dead(self):
1028
1029
        fs = FakeStream()
1030
        self.daemon.handle_command(
1031
            '<start_scan parallel="2">'
1032
            '<scanner_params />'
1033
            '<targets><target>'
1034
            '<hosts>localhost1, localhost2</hosts>'
1035
            '<ports>22</ports>'
1036
            '</target></targets>'
1037
            '</start_scan>',
1038
            fs,
1039
        )
1040
        self.daemon.start_queued_scans()
1041
        response = fs.get_response()
1042
1043
        scan_id = response.findtext('id')
1044
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', -1)
1045
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', -1)
1046
1047
        self.daemon.sort_host_finished(scan_id, ['localhost1', 'localhost2'])
1048
        self.assertEqual(
1049
            self.daemon.scan_collection.calculate_target_progress(scan_id), 100
1050
        )
1051
1052
    @patch('ospd.ospd.os')
1053
    def test_interrupted_scan(self, mock_os):
1054
        mock_os.setsid.return_value = None
1055
        fs = FakeStream()
1056
        self.daemon.handle_command(
1057
            '<start_scan parallel="2">'
1058
            '<scanner_params />'
1059
            '<targets><target>'
1060
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1061
            '<ports>22</ports>'
1062
            '</target></targets>'
1063
            '</start_scan>',
1064
            fs,
1065
        )
1066
        self.daemon.start_queued_scans()
1067
1068
        response = fs.get_response()
1069
        scan_id = response.findtext('id')
1070
1071
        self.daemon.exec_scan = Mock(return_value=None)
1072
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 5)
1073
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 14)
1074
        while self.daemon.get_scan_status(scan_id) == ScanStatus.INIT:
1075
            fs = FakeStream()
1076
            self.daemon.handle_command(
1077
                '<get_scans scan_id="%s" details="0" progress="0"/>' % scan_id,
1078
                fs,
1079
            )
1080
        response = fs.get_response()
1081
        status = response.find('scan').attrib['status']
1082
1083
        self.assertEqual(status, ScanStatus.INTERRUPTED.name.lower())
1084
1085
    def test_sort_host_finished(self):
1086
1087
        fs = FakeStream()
1088
        self.daemon.handle_command(
1089
            '<start_scan parallel="2">'
1090
            '<scanner_params />'
1091
            '<targets><target>'
1092
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1093
            '<ports>22</ports>'
1094
            '</target></targets>'
1095
            '</start_scan>',
1096
            fs,
1097
        )
1098
        self.daemon.start_queued_scans()
1099
1100
        response = fs.get_response()
1101
1102
        scan_id = response.findtext('id')
1103
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1104
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1105
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1106
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1107
1108
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1109
1110
        rounded_progress = self.daemon.scan_collection.calculate_target_progress(  # pylint: disable=line-too-long)
1111
            scan_id
1112
        )
1113
        self.assertEqual(rounded_progress, 66)
1114
1115
    def test_set_status_interrupted(self):
1116
        fs = FakeStream()
1117
        self.daemon.handle_command(
1118
            '<start_scan parallel="2">'
1119
            '<scanner_params />'
1120
            '<targets><target>'
1121
            '<hosts>localhost1</hosts>'
1122
            '<ports>22</ports>'
1123
            '</target></targets>'
1124
            '</start_scan>',
1125
            fs,
1126
        )
1127
        self.daemon.start_queued_scans()
1128
        response = fs.get_response()
1129
        scan_id = response.findtext('id')
1130
1131
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1132
        self.assertEqual(end_time, 0)
1133
1134
        self.daemon.interrupt_scan(scan_id)
1135
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1136
        self.assertNotEqual(end_time, 0)
1137
1138
    def test_set_status_stopped(self):
1139
        fs = FakeStream()
1140
        self.daemon.handle_command(
1141
            '<start_scan parallel="2">'
1142
            '<scanner_params />'
1143
            '<targets><target>'
1144
            '<hosts>localhost1</hosts>'
1145
            '<ports>22</ports>'
1146
            '</target></targets>'
1147
            '</start_scan>',
1148
            fs,
1149
        )
1150
        self.daemon.start_queued_scans()
1151
        response = fs.get_response()
1152
        scan_id = response.findtext('id')
1153
1154
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1155
        self.assertEqual(end_time, 0)
1156
1157
        self.daemon.set_scan_status(scan_id, ScanStatus.STOPPED)
1158
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1159
        self.assertNotEqual(end_time, 0)
1160
1161
    def test_calculate_progress_without_current_hosts(self):
1162
1163
        fs = FakeStream()
1164
        self.daemon.handle_command(
1165
            '<start_scan parallel="2">'
1166
            '<scanner_params />'
1167
            '<targets><target>'
1168
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1169
            '<ports>22</ports>'
1170
            '</target></targets>'
1171
            '</start_scan>',
1172
            fs,
1173
        )
1174
        self.daemon.start_queued_scans()
1175
        response = fs.get_response()
1176
1177
        scan_id = response.findtext('id')
1178
        self.daemon.set_scan_host_progress(scan_id)
1179
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1180
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1181
1182
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1183
1184
        float_progress = self.daemon.scan_collection.calculate_target_progress(
1185
            scan_id
1186
        )
1187
        self.assertEqual(int(float_progress), 33)
1188
1189
        self.daemon.scan_collection.set_progress(scan_id, float_progress)
1190
        progress = self.daemon.get_scan_progress(scan_id)
1191
        self.assertEqual(progress, 33)
1192
1193
    def test_get_scan_host_progress(self):
1194
        fs = FakeStream()
1195
        self.daemon.handle_command(
1196
            '<start_scan parallel="2">'
1197
            '<scanner_params />'
1198
            '<targets><target>'
1199
            '<hosts>localhost</hosts>'
1200
            '<ports>22</ports>'
1201
            '</target></targets>'
1202
            '</start_scan>',
1203
            fs,
1204
        )
1205
        self.daemon.start_queued_scans()
1206
        response = fs.get_response()
1207
1208
        scan_id = response.findtext('id')
1209
        self.daemon.set_scan_host_progress(scan_id, 'localhost', 45)
1210
        self.assertEqual(
1211
            self.daemon.get_scan_host_progress(scan_id, 'localhost'), 45
1212
        )
1213
1214
    def test_get_scan_without_scanid(self):
1215
1216
        fs = FakeStream()
1217
        self.daemon.handle_command(
1218
            '<start_scan parallel="2">'
1219
            '<scanner_params />'
1220
            '<targets><target>'
1221
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1222
            '<ports>22</ports>'
1223
            '</target></targets>'
1224
            '</start_scan>',
1225
            fs,
1226
        )
1227
        self.daemon.start_queued_scans()
1228
1229
        fs = FakeStream()
1230
        self.assertRaises(
1231
            OspdCommandError,
1232
            self.daemon.handle_command,
1233
            '<get_scans details="0" progress="1"/>',
1234
            fs,
1235
        )
1236
1237
    def test_set_scan_total_hosts(self):
1238
1239
        fs = FakeStream()
1240
        self.daemon.handle_command(
1241
            '<start_scan parallel="2">'
1242
            '<scanner_params />'
1243
            '<targets><target>'
1244
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1245
            '<ports>22</ports>'
1246
            '</target></targets>'
1247
            '</start_scan>',
1248
            fs,
1249
        )
1250
        self.daemon.start_queued_scans()
1251
1252
        response = fs.get_response()
1253
        scan_id = response.findtext('id')
1254
1255
        count = self.daemon.scan_collection.get_count_total(scan_id)
1256
        self.assertEqual(count, 4)
1257
1258
        self.daemon.set_scan_total_hosts(scan_id, 3)
1259
        count = self.daemon.scan_collection.get_count_total(scan_id)
1260
        self.assertEqual(count, 3)
1261
1262
    def test_set_scan_total_hosts_zero(self):
1263
1264
        fs = FakeStream()
1265
        self.daemon.handle_command(
1266
            '<start_scan parallel="2">'
1267
            '<scanner_params />'
1268
            '<targets><target>'
1269
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1270
            '<ports>22</ports>'
1271
            '</target></targets>'
1272
            '</start_scan>',
1273
            fs,
1274
        )
1275
        self.daemon.start_queued_scans()
1276
1277
        response = fs.get_response()
1278
        scan_id = response.findtext('id')
1279
1280
        # Default calculated by ospd with the hosts in the target
1281
        count = self.daemon.scan_collection.get_count_total(scan_id)
1282
        self.assertEqual(count, 4)
1283
1284
        # Set to 0 (all hosts unresolved, dead, invalid target) via
1285
        # the server. This one has priority and must be still 0 and
1286
        # never overwritten with the calculation from host list
1287
        self.daemon.set_scan_total_hosts(scan_id, 0)
1288
        count = self.daemon.scan_collection.get_count_total(scan_id)
1289
        self.assertEqual(count, 0)
1290
1291
    def test_set_scan_total_hosts_invalid_target(self):
1292
1293
        fs = FakeStream()
1294
        self.daemon.handle_command(
1295
            '<start_scan parallel="2">'
1296
            '<scanner_params />'
1297
            '<targets><target>'
1298
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1299
            '<ports>22</ports>'
1300
            '</target></targets>'
1301
            '</start_scan>',
1302
            fs,
1303
        )
1304
        self.daemon.start_queued_scans()
1305
1306
        response = fs.get_response()
1307
        scan_id = response.findtext('id')
1308
1309
        count = self.daemon.scan_collection.get_count_total(scan_id)
1310
        self.assertEqual(count, 4)
1311
1312
        # The total host is set by the server as -1, because invalid target
1313
        self.daemon.set_scan_total_hosts(scan_id, -1)
1314
        count = self.daemon.scan_collection.get_count_total(scan_id)
1315
        self.assertEqual(count, 0)
1316
1317
    def test_scan_invalid_excluded_hosts(self):
1318
1319
        logging.Logger.warning = Mock()
1320
        fs = FakeStream()
1321
        self.daemon.handle_command(
1322
            '<start_scan parallel="2">'
1323
            '<scanner_params />'
1324
            '<targets><target>'
1325
            '<hosts>192.168.0.0/24</hosts>'
1326
            '<exclude_hosts>192.168.0.1-192.168.0.200,10.0.0.0/24'
1327
            '</exclude_hosts>'
1328
            '<ports>22</ports>'
1329
            '</target></targets>'
1330
            '</start_scan>',
1331
            fs,
1332
        )
1333
        self.daemon.start_queued_scans()
1334
1335
        response = fs.get_response()
1336
        scan_id = response.findtext('id')
1337
1338
        # Count only the excluded hosts present in the original target.
1339
        count = self.daemon.scan_collection.get_simplified_exclude_host_count(
1340
            scan_id
1341
        )
1342
        self.assertEqual(count, 200)
1343
1344
        logging.Logger.warning.assert_called_with(  # pylint: disable=no-member
1345
            "Please check the excluded host list. It contains hosts "
1346
            "which do not belong to the target. This warning can be ignored if "
1347
            "this was done on purpose (e.g. to exclude specific hostname)."
1348
        )
1349
1350
    def test_get_scan_progress_xml(self):
1351
1352
        fs = FakeStream()
1353
        self.daemon.handle_command(
1354
            '<start_scan parallel="2">'
1355
            '<scanner_params />'
1356
            '<targets><target>'
1357
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1358
            '<ports>22</ports>'
1359
            '</target></targets>'
1360
            '</start_scan>',
1361
            fs,
1362
        )
1363
        self.daemon.start_queued_scans()
1364
1365
        response = fs.get_response()
1366
        scan_id = response.findtext('id')
1367
1368
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1369
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1370
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1371
1372
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1373
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1374
1375
        fs = FakeStream()
1376
        self.daemon.handle_command(
1377
            '<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id,
1378
            fs,
1379
        )
1380
        response = fs.get_response()
1381
1382
        progress = response.find('scan/progress')
1383
1384
        overall = float(progress.findtext('overall'))
1385
        self.assertEqual(int(overall), 66)
1386
1387
        count_alive = progress.findtext('count_alive')
1388
        self.assertEqual(count_alive, '1')
1389
1390
        count_dead = progress.findtext('count_dead')
1391
        self.assertEqual(count_dead, '1')
1392
1393
        current_hosts = progress.findall('host')
1394
        self.assertEqual(len(current_hosts), 2)
1395
1396
        count_excluded = progress.findtext('count_excluded')
1397
        self.assertEqual(count_excluded, '0')
1398
1399
    def test_set_get_vts_version(self):
1400
        self.daemon.set_vts_version('1234')
1401
1402
        version = self.daemon.get_vts_version()
1403
        self.assertEqual('1234', version)
1404
1405
    def test_set_get_vts_version_error(self):
1406
        self.assertRaises(TypeError, self.daemon.set_vts_version)
1407
1408
    @patch("ospd.ospd.os")
1409
    @patch("ospd.ospd.create_process")
1410
    def test_scan_exists(self, mock_create_process, _mock_os):
1411
        fp = FakeStartProcess()
1412
        mock_create_process.side_effect = fp
1413
        mock_process = fp.call_mock
1414
        mock_process.start.side_effect = fp.run
1415
        mock_process.is_alive.return_value = True
1416
        mock_process.pid = "main-scan-process"
1417
1418
        fs = FakeStream()
1419
        self.daemon.handle_command(
1420
            '<start_scan>'
1421
            '<scanner_params />'
1422
            '<targets><target>'
1423
            '<hosts>localhost</hosts>'
1424
            '<ports>22</ports>'
1425
            '</target></targets>'
1426
            '</start_scan>',
1427
            fs,
1428
        )
1429
        response = fs.get_response()
1430
        scan_id = response.findtext('id')
1431
        self.assertIsNotNone(scan_id)
1432
1433
        status = response.get('status_text')
1434
        self.assertEqual(status, 'OK')
1435
1436
        self.daemon.start_queued_scans()
1437
1438
        assert_called(mock_create_process)
1439
        assert_called(mock_process.start)
1440
1441
        self.daemon.handle_command('<stop_scan scan_id="%s" />' % scan_id, fs)
1442
1443
        fs = FakeStream()
1444
        cmd = (
1445
            '<start_scan scan_id="' + scan_id + '">'
1446
            '<scanner_params />'
1447
            '<targets><target>'
1448
            '<hosts>localhost</hosts>'
1449
            '<ports>22</ports>'
1450
            '</target></targets>'
1451
            '</start_scan>'
1452
        )
1453
1454
        self.daemon.handle_command(
1455
            cmd,
1456
            fs,
1457
        )
1458
        self.daemon.start_queued_scans()
1459
1460
        response = fs.get_response()
1461
        status = response.get('status_text')
1462
        self.assertEqual(status, 'Continue')
1463
1464
    def test_result_order(self):
1465
1466
        fs = FakeStream()
1467
        self.daemon.handle_command(
1468
            '<start_scan parallel="1">'
1469
            '<scanner_params />'
1470
            '<targets><target>'
1471
            '<hosts>a</hosts>'
1472
            '<ports>22</ports>'
1473
            '</target></targets>'
1474
            '</start_scan>',
1475
            fs,
1476
        )
1477
        self.daemon.start_queued_scans()
1478
        response = fs.get_response()
1479
1480
        scan_id = response.findtext('id')
1481
1482
        self.daemon.add_scan_log(scan_id, host='a', name='a')
1483
        self.daemon.add_scan_log(scan_id, host='c', name='c')
1484
        self.daemon.add_scan_log(scan_id, host='b', name='b')
1485
        hosts = ['a', 'c', 'b']
1486
1487
        fs = FakeStream()
1488
        self.daemon.handle_command(
1489
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1490
        )
1491
        response = fs.get_response()
1492
1493
        results = response.findall("scan/results/")
1494
1495
        for idx, res in enumerate(results):
1496
            att_dict = res.attrib
1497
            self.assertEqual(hosts[idx], att_dict['name'])
1498
1499
    def test_batch_result(self):
1500
        reslist = ResultList()
1501
        fs = FakeStream()
1502
        self.daemon.handle_command(
1503
            '<start_scan parallel="1">'
1504
            '<scanner_params />'
1505
            '<targets><target>'
1506
            '<hosts>a</hosts>'
1507
            '<ports>22</ports>'
1508
            '</target></targets>'
1509
            '</start_scan>',
1510
            fs,
1511
        )
1512
        self.daemon.start_queued_scans()
1513
        response = fs.get_response()
1514
1515
        scan_id = response.findtext('id')
1516
        reslist.add_scan_log_to_list(host='a', name='a')
1517
        reslist.add_scan_log_to_list(host='c', name='c')
1518
        reslist.add_scan_log_to_list(host='b', name='b')
1519
        self.daemon.scan_collection.add_result_list(scan_id, reslist)
1520
1521
        hosts = ['a', 'c', 'b']
1522
1523
        fs = FakeStream()
1524
        self.daemon.handle_command(
1525
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1526
        )
1527
        response = fs.get_response()
1528
1529
        results = response.findall("scan/results/")
1530
1531
        for idx, res in enumerate(results):
1532
            att_dict = res.attrib
1533
            self.assertEqual(hosts[idx], att_dict['name'])
1534
1535
    def test_is_new_scan_allowed_false(self):
1536
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1537
            'a': 1,
1538
            'b': 2,
1539
        }
1540
        self.daemon.max_scans = 1
1541
1542
        self.assertFalse(self.daemon.is_new_scan_allowed())
1543
1544
    def test_is_new_scan_allowed_true(self):
1545
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1546
            'a': 1,
1547
            'b': 2,
1548
        }
1549
        self.daemon.max_scans = 3
1550
1551
        self.assertTrue(self.daemon.is_new_scan_allowed())
1552
1553
    def test_start_queue_scan_daemon_not_init(self):
1554
        self.daemon.get_count_queued_scans = MagicMock(return_value=10)
1555
        self.daemon.initialized = False
1556
        logging.Logger.info = Mock()
1557
        self.daemon.start_queued_scans()
1558
1559
        logging.Logger.info.assert_called_with(  # pylint: disable=no-member
1560
            "Queued task can not be started because a "
1561
            "feed update is being performed."
1562
        )
1563
1564
    @patch("ospd.ospd.psutil")
1565
    def test_free_memory_true(self, mock_psutil):
1566
        self.daemon.min_free_mem_scan_queue = 1000
1567
        # 1.5 GB free
1568
        mock_psutil.virtual_memory.return_value = FakePsutil(
1569
            available=1500000000
1570
        )
1571
1572
        self.assertTrue(self.daemon.is_enough_free_memory())
1573
1574
    @patch("ospd.ospd.psutil")
1575
    def test_wait_between_scan_no_scans(self, mock_psutil):
1576
        # Enable option
1577
        self.daemon.min_free_mem_scan_queue = 1000
1578
        # 1.5 GB free
1579
        mock_psutil.virtual_memory.return_value = FakePsutil(
1580
            available=1500000000
1581
        )
1582
        # Not enough time between scans, but no running scan
1583
        self.daemon.last_scan_start_time = time.time() - 20
1584
1585
        self.assertTrue(self.daemon.is_enough_free_memory())
1586
1587
    @patch("ospd.ospd.psutil")
1588
    def test_wait_between_scan_run_scans_not_allow(self, mock_psutil):
1589
        # Enable option
1590
        self.daemon.min_free_mem_scan_queue = 1000
1591
        # 1.5 GB free
1592
        mock_psutil.virtual_memory.return_value = FakePsutil(
1593
            available=1500000000
1594
        )
1595
1596
        fs = FakeStream()
1597
        self.daemon.handle_command(
1598
            '<start_scan>'
1599
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1600
            '</vts>'
1601
            '<targets><target>'
1602
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1603
            '<ports>80,443</ports>'
1604
            '</target></targets>'
1605
            '</start_scan>',
1606
            fs,
1607
        )
1608
1609
        # There is a running scan
1610
        self.daemon.start_queued_scans()
1611
1612
        # Not enough time between scans
1613
        self.daemon.last_scan_start_time = time.time() - 20
1614
1615
        self.assertFalse(self.daemon.is_enough_free_memory())
1616
1617
    @patch("ospd.ospd.psutil")
1618
    def test_wait_between_scan_allow(self, mock_psutil):
1619
        # Enable option
1620
        self.daemon.min_free_mem_scan_queue = 1000
1621
        # 1.5 GB free
1622
        mock_psutil.virtual_memory.return_value = FakePsutil(
1623
            available=1500000000
1624
        )
1625
1626
        fs = FakeStream()
1627
        self.daemon.handle_command(
1628
            '<start_scan>'
1629
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1630
            '</vts>'
1631
            '<targets><target>'
1632
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1633
            '<ports>80,443</ports>'
1634
            '</target></targets>'
1635
            '</start_scan>',
1636
            fs,
1637
        )
1638
1639
        # There is a running scan, enough memory and enough time
1640
        # in between
1641
        self.daemon.start_queued_scans()
1642
1643
        self.daemon.last_scan_start_time = time.time() - 65
1644
1645
        self.assertTrue(self.daemon.is_enough_free_memory())
1646
1647
    @patch("ospd.ospd.psutil")
1648
    def test_free_memory_false(self, mock_psutil):
1649
        self.daemon.min_free_mem_scan_queue = 2000
1650
        # 1.5 GB free
1651
        mock_psutil.virtual_memory.return_value = FakePsutil(
1652
            available=1500000000
1653
        )
1654
1655
        self.assertFalse(self.daemon.is_enough_free_memory())
1656
1657
    def test_count_queued_scans(self):
1658
        fs = FakeStream()
1659
        self.daemon.handle_command(
1660
            '<start_scan>'
1661
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1662
            '</vts>'
1663
            '<targets><target>'
1664
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1665
            '<ports>80,443</ports>'
1666
            '</target></targets>'
1667
            '</start_scan>',
1668
            fs,
1669
        )
1670
1671
        self.assertEqual(self.daemon.get_count_queued_scans(), 1)
1672
        self.daemon.start_queued_scans()
1673
        self.assertEqual(self.daemon.get_count_queued_scans(), 0)
1674
1675
    def test_count_running_scans(self):
1676
        fs = FakeStream()
1677
        self.daemon.handle_command(
1678
            '<start_scan>'
1679
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1680
            '</vts>'
1681
            '<targets><target>'
1682
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1683
            '<ports>80,443</ports>'
1684
            '</target></targets>'
1685
            '</start_scan>',
1686
            fs,
1687
        )
1688
1689
        self.assertEqual(self.daemon.get_count_running_scans(), 0)
1690
        self.daemon.start_queued_scans()
1691
        self.assertEqual(self.daemon.get_count_running_scans(), 1)
1692
1693
    def test_ids_iterator_dict_modified(self):
1694
        self.daemon.scan_collection.scans_table = {'a': 1, 'b': 2}
1695
1696
        for _ in self.daemon.scan_collection.ids_iterator():
1697
            self.daemon.scan_collection.scans_table['c'] = 3
1698
1699
        self.assertEqual(len(self.daemon.scan_collection.scans_table), 3)
1700