Completed
Push — master ( dd3ea0...651e0a )
by Björn
20s queued 12s
created

ScanTestCase.test_get_scan_host_progress()   A

Complexity

Conditions 1

Size

Total Lines 19
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 11
nop 1
dl 0
loc 19
rs 9.85
c 0
b 0
f 0
1
# Copyright (C) 2014-2021 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" Test module for scan runs
21
"""
22
23
import time
24
import unittest
25
26
from unittest.mock import patch, MagicMock, Mock
27
28
import logging
29
import xml.etree.ElementTree as ET
30
31
from defusedxml.common import EntitiesForbidden
32
33
from ospd.resultlist import ResultList
34
from ospd.errors import OspdCommandError
35
from ospd.scan import ScanStatus
36
37
from .helper import (
38
    DummyWrapper,
39
    assert_called,
40
    FakeStream,
41
    FakeDataManager,
42
    FakePsutil,
43
)
44
45
46
class FakeStartProcess:
47
    def __init__(self):
48
        self.run_mock = MagicMock()
49
        self.call_mock = MagicMock()
50
51
        self.func = None
52
        self.args = None
53
        self.kwargs = None
54
55
    def __call__(self, func, *, args=None, kwargs=None):
56
        self.func = func
57
        self.args = args or []
58
        self.kwargs = kwargs or {}
59
        return self.call_mock
60
61
    def run(self):
62
        self.func(*self.args, **self.kwargs)
63
        return self.run_mock
64
65
    def __repr__(self):
66
        return "<FakeProcess func={} args={} kwargs={}>".format(
67
            self.func, self.args, self.kwargs
68
        )
69
70
71
class Result(object):
72
    def __init__(self, type_, **kwargs):
73
        self.result_type = type_
74
        self.host = ''
75
        self.hostname = ''
76
        self.name = ''
77
        self.value = ''
78
        self.port = ''
79
        self.test_id = ''
80
        self.severity = ''
81
        self.qod = ''
82
        self.uri = ''
83
        for name, value in kwargs.items():
84
            setattr(self, name, value)
85
86
87
class ScanTestCase(unittest.TestCase):
88
    def setUp(self):
89
        self.daemon = DummyWrapper([])
90
        self.daemon.scan_collection.datamanager = FakeDataManager()
91
        self.daemon.scan_collection.file_storage_dir = '/tmp'
92
93
    def test_get_default_scanner_params(self):
94
        fs = FakeStream()
95
96
        self.daemon.handle_command('<get_scanner_details />', fs)
97
        response = fs.get_response()
98
99
        # The status of the response must be success (i.e. 200)
100
        self.assertEqual(response.get('status'), '200')
101
        # The response root element must have the correct name
102
        self.assertEqual(response.tag, 'get_scanner_details_response')
103
        # The response must contain a 'scanner_params' element
104
        self.assertIsNotNone(response.find('scanner_params'))
105
106
    def test_get_default_help(self):
107
        fs = FakeStream()
108
109
        self.daemon.handle_command('<help />', fs)
110
        response = fs.get_response()
111
        self.assertEqual(response.get('status'), '200')
112
113
        fs = FakeStream()
114
        self.daemon.handle_command('<help format="xml" />', fs)
115
        response = fs.get_response()
116
117
        self.assertEqual(response.get('status'), '200')
118
        self.assertEqual(response.tag, 'help_response')
119
120
    def test_get_default_scanner_version(self):
121
        fs = FakeStream()
122
        self.daemon.handle_command('<get_version />', fs)
123
        response = fs.get_response()
124
125
        self.assertEqual(response.get('status'), '200')
126
        self.assertIsNotNone(response.find('protocol'))
127
128
    def test_get_vts_no_vt(self):
129
        fs = FakeStream()
130
131
        self.daemon.handle_command('<get_vts />', fs)
132
        response = fs.get_response()
133
134
        self.assertEqual(response.get('status'), '200')
135
        self.assertIsNotNone(response.find('vts'))
136
137
    def test_get_vt_xml_no_dict(self):
138
        single_vt = ('1234', None)
139
        vt = self.daemon.get_vt_xml(single_vt)
140
        self.assertFalse(vt.get('id'))
141
142
    def test_get_vts_single_vt(self):
143
        fs = FakeStream()
144
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
145
        self.daemon.handle_command('<get_vts />', fs)
146
        response = fs.get_response()
147
148
        self.assertEqual(response.get('status'), '200')
149
150
        vts = response.find('vts')
151
        self.assertIsNotNone(vts.find('vt'))
152
153
        vt = vts.find('vt')
154
        self.assertEqual(vt.get('id'), '1.2.3.4')
155
156
    def test_get_vts_version(self):
157
        fs = FakeStream()
158
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
159
        self.daemon.set_vts_version('today')
160
        self.daemon.handle_command('<get_vts />', fs)
161
        response = fs.get_response()
162
163
        self.assertEqual(response.get('status'), '200')
164
165
        vts_version = response.find('vts').attrib['vts_version']
166
        self.assertEqual(vts_version, self.daemon.get_vts_version())
167
168
        vts = response.find('vts')
169
        self.assertIsNotNone(vts.find('vt'))
170
171
        vt = vts.find('vt')
172
        self.assertEqual(vt.get('id'), '1.2.3.4')
173
174
    def test_get_vts_version_only(self):
175
        fs = FakeStream()
176
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
177
        self.daemon.set_vts_version('today')
178
        self.daemon.handle_command('<get_vts version_only="1"/>', fs)
179
        response = fs.get_response()
180
181
        self.assertEqual(response.get('status'), '200')
182
183
        vts_version = response.find('vts').attrib['vts_version']
184
        self.assertEqual(vts_version, self.daemon.get_vts_version())
185
186
        vts = response.find('vts')
187
        self.assertIsNone(vts.find('vt'))
188
189
    def test_get_vts_still_not_init(self):
190
        fs = FakeStream()
191
        self.daemon.initialized = False
192
        self.daemon.handle_command('<get_vts />', fs)
193
        response = fs.get_response()
194
195
        self.assertEqual(response.get('status'), '400')
196
197
    def test_get_help_still_not_init(self):
198
        fs = FakeStream()
199
        self.daemon.initialized = False
200
        self.daemon.handle_command('<help/>', fs)
201
        response = fs.get_response()
202
203
        self.assertEqual(response.get('status'), '200')
204
205 View Code Duplication
    def test_get_vts_filter_positive(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
206
        self.daemon.add_vt(
207
            '1.2.3.4',
208
            'A vulnerability test',
209
            vt_params="a",
210
            vt_modification_time='19000202',
211
        )
212
        fs = FakeStream()
213
214
        self.daemon.handle_command(
215
            '<get_vts filter="modification_time&gt;19000201"></get_vts>', fs
216
        )
217
        response = fs.get_response()
218
219
        self.assertEqual(response.get('status'), '200')
220
        vts = response.find('vts')
221
222
        vt = vts.find('vt')
223
        self.assertIsNotNone(vt)
224
        self.assertEqual(vt.get('id'), '1.2.3.4')
225
226
        modification_time = response.findall('vts/vt/modification_time')
227
        self.assertEqual(
228
            '<modification_time>19000202</modification_time>',
229
            ET.tostring(modification_time[0]).decode('utf-8'),
230
        )
231
232 View Code Duplication
    def test_get_vts_filter_negative(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
233
        self.daemon.add_vt(
234
            '1.2.3.4',
235
            'A vulnerability test',
236
            vt_params="a",
237
            vt_modification_time='19000202',
238
        )
239
        fs = FakeStream()
240
        self.daemon.handle_command(
241
            '<get_vts filter="modification_time&lt;19000203"></get_vts>',
242
            fs,
243
        )
244
        response = fs.get_response()
245
246
        self.assertEqual(response.get('status'), '200')
247
248
        vts = response.find('vts')
249
250
        vt = vts.find('vt')
251
        self.assertIsNotNone(vt)
252
        self.assertEqual(vt.get('id'), '1.2.3.4')
253
254
        modification_time = response.findall('vts/vt/modification_time')
255
        self.assertEqual(
256
            '<modification_time>19000202</modification_time>',
257
            ET.tostring(modification_time[0]).decode('utf-8'),
258
        )
259
260
    def test_get_vts_bad_filter(self):
261
        fs = FakeStream()
262
        cmd = '<get_vts filter="modification_time"/>'
263
264
        self.assertRaises(OspdCommandError, self.daemon.handle_command, cmd, fs)
265
        self.assertTrue(self.daemon.vts.is_cache_available)
266
267
    def test_get_vtss_multiple_vts(self):
268
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
269
        self.daemon.add_vt('1.2.3.5', 'Another vulnerability test')
270
        self.daemon.add_vt('123456789', 'Yet another vulnerability test')
271
272
        fs = FakeStream()
273
274
        self.daemon.handle_command('<get_vts />', fs)
275
        response = fs.get_response()
276
        self.assertEqual(response.get('status'), '200')
277
278
        vts = response.find('vts')
279
        self.assertIsNotNone(vts.find('vt'))
280
281
    def test_get_vts_multiple_vts_with_custom(self):
282
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b')
283
        self.daemon.add_vt(
284
            '4.3.2.1', 'Another vulnerability test with custom info', custom='b'
285
        )
286
        self.daemon.add_vt(
287
            '123456789', 'Yet another vulnerability test', custom='b'
288
        )
289
        fs = FakeStream()
290
291
        self.daemon.handle_command('<get_vts />', fs)
292
        response = fs.get_response()
293
294
        custom = response.findall('vts/vt/custom')
295
296
        self.assertEqual(3, len(custom))
297
298 View Code Duplication
    def test_get_vts_vts_with_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
299
        self.daemon.add_vt(
300
            '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b"
301
        )
302
        fs = FakeStream()
303
304
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
305
        response = fs.get_response()
306
307
        # The status of the response must be success (i.e. 200)
308
        self.assertEqual(response.get('status'), '200')
309
310
        # The response root element must have the correct name
311
        self.assertEqual(response.tag, 'get_vts_response')
312
        # The response must contain a 'scanner_params' element
313
        self.assertIsNotNone(response.find('vts'))
314
315
        vt_params = response[0][0].findall('params')
316
        self.assertEqual(1, len(vt_params))
317
318
        custom = response[0][0].findall('custom')
319
        self.assertEqual(1, len(custom))
320
321
        params = response.findall('vts/vt/params/param')
322
        self.assertEqual(2, len(params))
323
324 View Code Duplication
    def test_get_vts_vts_with_refs(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
325
        self.daemon.add_vt(
326
            '1.2.3.4',
327
            'A vulnerability test',
328
            vt_params="a",
329
            custom="b",
330
            vt_refs="c",
331
        )
332
        fs = FakeStream()
333
334
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
335
        response = fs.get_response()
336
337
        # The status of the response must be success (i.e. 200)
338
        self.assertEqual(response.get('status'), '200')
339
340
        # The response root element must have the correct name
341
        self.assertEqual(response.tag, 'get_vts_response')
342
343
        # The response must contain a 'vts' element
344
        self.assertIsNotNone(response.find('vts'))
345
346
        vt_params = response[0][0].findall('params')
347
        self.assertEqual(1, len(vt_params))
348
349
        custom = response[0][0].findall('custom')
350
        self.assertEqual(1, len(custom))
351
352
        refs = response.findall('vts/vt/refs/ref')
353
        self.assertEqual(2, len(refs))
354
355
    def test_get_vts_vts_with_dependencies(self):
356
        self.daemon.add_vt(
357
            '1.2.3.4',
358
            'A vulnerability test',
359
            vt_params="a",
360
            custom="b",
361
            vt_dependencies="c",
362
        )
363
        fs = FakeStream()
364
365
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
366
367
        response = fs.get_response()
368
369
        deps = response.findall('vts/vt/dependencies/dependency')
370
        self.assertEqual(2, len(deps))
371
372
    def test_get_vts_vts_with_severities(self):
373
        self.daemon.add_vt(
374
            '1.2.3.4',
375
            'A vulnerability test',
376
            vt_params="a",
377
            custom="b",
378
            severities="c",
379
        )
380
        fs = FakeStream()
381
382
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
383
        response = fs.get_response()
384
385
        severity = response.findall('vts/vt/severities/severity')
386
        self.assertEqual(1, len(severity))
387
388
    def test_get_vts_vts_with_detection_qodt(self):
389
        self.daemon.add_vt(
390
            '1.2.3.4',
391
            'A vulnerability test',
392
            vt_params="a",
393
            custom="b",
394
            detection="c",
395
            qod_t="d",
396
        )
397
        fs = FakeStream()
398
399
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
400
        response = fs.get_response()
401
402
        detection = response.findall('vts/vt/detection')
403
        self.assertEqual(1, len(detection))
404
405
    def test_get_vts_vts_with_detection_qodv(self):
406
        self.daemon.add_vt(
407
            '1.2.3.4',
408
            'A vulnerability test',
409
            vt_params="a",
410
            custom="b",
411
            detection="c",
412
            qod_v="d",
413
        )
414
        fs = FakeStream()
415
416
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
417
        response = fs.get_response()
418
419
        detection = response.findall('vts/vt/detection')
420
        self.assertEqual(1, len(detection))
421
422
    def test_get_vts_vts_with_summary(self):
423
        self.daemon.add_vt(
424
            '1.2.3.4',
425
            'A vulnerability test',
426
            vt_params="a",
427
            custom="b",
428
            summary="c",
429
        )
430
        fs = FakeStream()
431
432
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
433
        response = fs.get_response()
434
435
        summary = response.findall('vts/vt/summary')
436
        self.assertEqual(1, len(summary))
437
438
    def test_get_vts_vts_with_impact(self):
439
        self.daemon.add_vt(
440
            '1.2.3.4',
441
            'A vulnerability test',
442
            vt_params="a",
443
            custom="b",
444
            impact="c",
445
        )
446
        fs = FakeStream()
447
448
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
449
        response = fs.get_response()
450
451
        impact = response.findall('vts/vt/impact')
452
        self.assertEqual(1, len(impact))
453
454
    def test_get_vts_vts_with_affected(self):
455
        self.daemon.add_vt(
456
            '1.2.3.4',
457
            'A vulnerability test',
458
            vt_params="a",
459
            custom="b",
460
            affected="c",
461
        )
462
        fs = FakeStream()
463
464
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
465
        response = fs.get_response()
466
467
        affect = response.findall('vts/vt/affected')
468
        self.assertEqual(1, len(affect))
469
470
    def test_get_vts_vts_with_insight(self):
471
        self.daemon.add_vt(
472
            '1.2.3.4',
473
            'A vulnerability test',
474
            vt_params="a",
475
            custom="b",
476
            insight="c",
477
        )
478
        fs = FakeStream()
479
480
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
481
        response = fs.get_response()
482
483
        insight = response.findall('vts/vt/insight')
484
        self.assertEqual(1, len(insight))
485
486
    def test_get_vts_vts_with_solution(self):
487
        self.daemon.add_vt(
488
            '1.2.3.4',
489
            'A vulnerability test',
490
            vt_params="a",
491
            custom="b",
492
            solution="c",
493
            solution_t="d",
494
            solution_m="e",
495
        )
496
        fs = FakeStream()
497
498
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
499
        response = fs.get_response()
500
501
        solution = response.findall('vts/vt/solution')
502
        self.assertEqual(1, len(solution))
503
504
    def test_get_vts_vts_with_ctime(self):
505
        self.daemon.add_vt(
506
            '1.2.3.4',
507
            'A vulnerability test',
508
            vt_params="a",
509
            vt_creation_time='01-01-1900',
510
        )
511
        fs = FakeStream()
512
513
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
514
        response = fs.get_response()
515
516
        creation_time = response.findall('vts/vt/creation_time')
517
        self.assertEqual(
518
            '<creation_time>01-01-1900</creation_time>',
519
            ET.tostring(creation_time[0]).decode('utf-8'),
520
        )
521
522
    def test_get_vts_vts_with_mtime(self):
523
        self.daemon.add_vt(
524
            '1.2.3.4',
525
            'A vulnerability test',
526
            vt_params="a",
527
            vt_modification_time='02-01-1900',
528
        )
529
        fs = FakeStream()
530
531
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
532
        response = fs.get_response()
533
534
        modification_time = response.findall('vts/vt/modification_time')
535
        self.assertEqual(
536
            '<modification_time>02-01-1900</modification_time>',
537
            ET.tostring(modification_time[0]).decode('utf-8'),
538
        )
539
540
    def test_clean_forgotten_scans(self):
541
        fs = FakeStream()
542
543
        self.daemon.handle_command(
544
            '<start_scan target="localhost" ports="80, '
545
            '443"><scanner_params /></start_scan>',
546
            fs,
547
        )
548
        response = fs.get_response()
549
550
        scan_id = response.findtext('id')
551
552
        finished = False
553
554
        self.daemon.start_queued_scans()
555
        while not finished:
556
            fs = FakeStream()
557
            self.daemon.handle_command(
558
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
559
            )
560
            response = fs.get_response()
561
562
            scans = response.findall('scan')
563
            self.assertEqual(1, len(scans))
564
565
            scan = scans[0]
566
567
            if scan.get('end_time') != '0':
568
                finished = True
569
            else:
570
                time.sleep(0.01)
571
572
            fs = FakeStream()
573
            self.daemon.handle_command(
574
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
575
            )
576
            response = fs.get_response()
577
578
        self.assertEqual(
579
            len(list(self.daemon.scan_collection.ids_iterator())), 1
580
        )
581
582
        # Set an old end_time
583
        self.daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456
584
        # Run the check
585
        self.daemon.clean_forgotten_scans()
586
        # Not removed
587
        self.assertEqual(
588
            len(list(self.daemon.scan_collection.ids_iterator())), 1
589
        )
590
591
        # Set the max time and run again
592
        self.daemon.scaninfo_store_time = 1
593
        self.daemon.clean_forgotten_scans()
594
        # Now is removed
595
        self.assertEqual(
596
            len(list(self.daemon.scan_collection.ids_iterator())), 0
597
        )
598
599
    def test_scan_with_error(self):
600
        fs = FakeStream()
601
602
        self.daemon.handle_command(
603
            '<start_scan target="localhost" ports="80, '
604
            '443"><scanner_params /></start_scan>',
605
            fs,
606
        )
607
608
        response = fs.get_response()
609
        scan_id = response.findtext('id')
610
        finished = False
611
        self.daemon.start_queued_scans()
612
        self.daemon.add_scan_error(
613
            scan_id, host='a', value='something went wrong'
614
        )
615
616
        while not finished:
617
            fs = FakeStream()
618
            self.daemon.handle_command(
619
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
620
            )
621
            response = fs.get_response()
622
623
            scans = response.findall('scan')
624
            self.assertEqual(1, len(scans))
625
626
            scan = scans[0]
627
            status = scan.get('status')
628
629
            if status == "init" or status == "running":
630
                self.assertEqual('0', scan.get('end_time'))
631
                time.sleep(0.010)
632
            else:
633
                finished = True
634
635
            fs = FakeStream()
636
637
            self.daemon.handle_command(
638
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
639
            )
640
            response = fs.get_response()
641
642
        self.assertEqual(
643
            response.findtext('scan/results/result'), 'something went wrong'
644
        )
645
        fs = FakeStream()
646
        self.daemon.handle_command('<delete_scan scan_id="%s" />' % scan_id, fs)
647
        response = fs.get_response()
648
649
        self.assertEqual(response.get('status'), '200')
650
651
    def test_get_scan_pop(self):
652
        fs = FakeStream()
653
654
        self.daemon.handle_command(
655
            '<start_scan target="localhost" ports="80, 443">'
656
            '<scanner_params /></start_scan>',
657
            fs,
658
        )
659
        self.daemon.start_queued_scans()
660
        response = fs.get_response()
661
662
        scan_id = response.findtext('id')
663
        self.daemon.add_scan_host_detail(
664
            scan_id, host='a', value='Some Host Detail'
665
        )
666
667
        time.sleep(1)
668
669
        fs = FakeStream()
670
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
671
        response = fs.get_response()
672
673
        self.assertEqual(
674
            response.findtext('scan/results/result'), 'Some Host Detail'
675
        )
676
        fs = FakeStream()
677
        self.daemon.handle_command(
678
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
679
        )
680
        response = fs.get_response()
681
682
        self.assertEqual(
683
            response.findtext('scan/results/result'), 'Some Host Detail'
684
        )
685
686
        fs = FakeStream()
687
        self.daemon.handle_command(
688
            '<get_scans scan_id="%s" details="0" pop_results="1"/>' % scan_id,
689
            fs,
690
        )
691
        response = fs.get_response()
692
693
        self.assertEqual(response.findtext('scan/results/result'), None)
694
695
    def test_get_scan_pop_max_res(self):
696
        fs = FakeStream()
697
        self.daemon.handle_command(
698
            '<start_scan target="localhost" ports="80, 443">'
699
            '<scanner_params /></start_scan>',
700
            fs,
701
        )
702
        self.daemon.start_queued_scans()
703
        response = fs.get_response()
704
        scan_id = response.findtext('id')
705
706
        self.daemon.add_scan_log(scan_id, host='a', name='a')
707
        self.daemon.add_scan_log(scan_id, host='c', name='c')
708
        self.daemon.add_scan_log(scan_id, host='b', name='b')
709
710
        fs = FakeStream()
711
        self.daemon.handle_command(
712
            '<get_scans scan_id="%s" pop_results="1" max_results="1"/>'
713
            % scan_id,
714
            fs,
715
        )
716
717
        response = fs.get_response()
718
719
        self.assertEqual(len(response.findall('scan/results/result')), 1)
720
721
        fs = FakeStream()
722
        self.daemon.handle_command(
723
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
724
        )
725
        response = fs.get_response()
726
        self.assertEqual(len(response.findall('scan/results/result')), 2)
727
728 View Code Duplication
    def test_get_scan_results_clean(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
729
        fs = FakeStream()
730
        self.daemon.handle_command(
731
            '<start_scan target="localhost" ports="80, 443">'
732
            '<scanner_params /></start_scan>',
733
            fs,
734
        )
735
        self.daemon.start_queued_scans()
736
        response = fs.get_response()
737
        scan_id = response.findtext('id')
738
739
        self.daemon.add_scan_log(scan_id, host='a', name='a')
740
        self.daemon.add_scan_log(scan_id, host='c', name='c')
741
        self.daemon.add_scan_log(scan_id, host='b', name='b')
742
743
        fs = FakeStream()
744
        self.daemon.handle_command(
745
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
746
            fs,
747
        )
748
749
        res_len = len(
750
            self.daemon.scan_collection.scans_table[scan_id]['results']
751
        )
752
        self.assertEqual(res_len, 0)
753
754
        res_len = len(
755
            self.daemon.scan_collection.scans_table[scan_id]['temp_results']
756
        )
757
        self.assertEqual(res_len, 0)
758
759 View Code Duplication
    def test_get_scan_results_restore(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
760
        fs = FakeStream()
761
        self.daemon.handle_command(
762
            '<start_scan target="localhost" ports="80, 443">'
763
            '<scanner_params /></start_scan>',
764
            fs,
765
        )
766
        self.daemon.start_queued_scans()
767
        response = fs.get_response()
768
        scan_id = response.findtext('id')
769
770
        self.daemon.add_scan_log(scan_id, host='a', name='a')
771
        self.daemon.add_scan_log(scan_id, host='c', name='c')
772
        self.daemon.add_scan_log(scan_id, host='b', name='b')
773
774
        fs = FakeStream(return_value=False)
775
        self.daemon.handle_command(
776
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
777
            fs,
778
        )
779
780
        res_len = len(
781
            self.daemon.scan_collection.scans_table[scan_id]['results']
782
        )
783
        self.assertEqual(res_len, 3)
784
785
        res_len = len(
786
            self.daemon.scan_collection.scans_table[scan_id]['temp_results']
787
        )
788
        self.assertEqual(res_len, 0)
789
790
    def test_billon_laughs(self):
791
        # pylint: disable=line-too-long
792
793
        lol = (
794
            '<?xml version="1.0"?>'
795
            '<!DOCTYPE lolz ['
796
            ' <!ENTITY lol "lol">'
797
            ' <!ELEMENT lolz (#PCDATA)>'
798
            ' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">'
799
            ' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">'
800
            ' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">'
801
            ' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">'
802
            ' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">'
803
            ' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">'
804
            ' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">'
805
            ' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">'
806
            ' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">'
807
            ']>'
808
        )
809
        fs = FakeStream()
810
        self.assertRaises(
811
            EntitiesForbidden, self.daemon.handle_command, lol, fs
812
        )
813
814
    def test_target_with_credentials(self):
815
        fs = FakeStream()
816
        self.daemon.handle_command(
817
            '<start_scan>'
818
            '<scanner_params /><vts><vt id="1.2.3.4" />'
819
            '</vts>'
820
            '<targets><target>'
821
            '<hosts>192.168.0.0/24</hosts><ports>22'
822
            '</ports><credentials>'
823
            '<credential type="up" service="ssh" port="22">'
824
            '<username>scanuser</username>'
825
            '<password>mypass</password>'
826
            '</credential><credential type="up" service="smb">'
827
            '<username>smbuser</username>'
828
            '<password>mypass</password></credential>'
829
            '</credentials>'
830
            '</target></targets>'
831
            '</start_scan>',
832
            fs,
833
        )
834
        self.daemon.start_queued_scans()
835
        response = fs.get_response()
836
837
        self.assertEqual(response.get('status'), '200')
838
839
        cred_dict = {
840
            'ssh': {
841
                'type': 'up',
842
                'password': 'mypass',
843
                'port': '22',
844
                'username': 'scanuser',
845
            },
846
            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
847
        }
848
        scan_id = response.findtext('id')
849
        response = self.daemon.get_scan_credentials(scan_id)
850
        self.assertEqual(response, cred_dict)
851
852
    def test_target_with_credential_empty_community(self):
853
        fs = FakeStream()
854
        self.daemon.handle_command(
855
            '<start_scan>'
856
            '<scanner_params /><vts><vt id="1.2.3.4" />'
857
            '</vts>'
858
            '<targets><target>'
859
            '<hosts>192.168.0.0/24</hosts><ports>22'
860
            '</ports><credentials>'
861
            '<credential type="up" service="snmp">'
862
            '<community></community></credential>'
863
            '</credentials>'
864
            '</target></targets>'
865
            '</start_scan>',
866
            fs,
867
        )
868
        self.daemon.start_queued_scans()
869
        response = fs.get_response()
870
871
        self.assertEqual(response.get('status'), '200')
872
873
        cred_dict = {
874
            'snmp': {'type': 'up', 'community': ''},
875
        }
876
        scan_id = response.findtext('id')
877
        response = self.daemon.get_scan_credentials(scan_id)
878
        self.assertEqual(response, cred_dict)
879
880
    def test_scan_get_target(self):
881
        fs = FakeStream()
882
        self.daemon.handle_command(
883
            '<start_scan>'
884
            '<scanner_params /><vts><vt id="1.2.3.4" />'
885
            '</vts>'
886
            '<targets><target>'
887
            '<hosts>localhosts,192.168.0.0/24</hosts>'
888
            '<ports>80,443</ports>'
889
            '</target></targets>'
890
            '</start_scan>',
891
            fs,
892
        )
893
        self.daemon.start_queued_scans()
894
895
        response = fs.get_response()
896
        scan_id = response.findtext('id')
897
898
        fs = FakeStream()
899
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
900
        response = fs.get_response()
901
902
        scan_res = response.find('scan')
903
        self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
904
905
    def test_scan_get_target_options(self):
906
        fs = FakeStream()
907
        self.daemon.handle_command(
908
            '<start_scan>'
909
            '<scanner_params /><vts><vt id="1.2.3.4" />'
910
            '</vts>'
911
            '<targets>'
912
            '<target><hosts>192.168.0.1</hosts>'
913
            '<ports>22</ports><alive_test>0</alive_test></target>'
914
            '</targets>'
915
            '</start_scan>',
916
            fs,
917
        )
918
        self.daemon.start_queued_scans()
919
920
        response = fs.get_response()
921
922
        scan_id = response.findtext('id')
923
        time.sleep(1)
924
        target_options = self.daemon.get_scan_target_options(scan_id)
925
        self.assertEqual(target_options, {'alive_test': '0'})
926
927
    def test_scan_get_target_options_alive_test_methods(self):
928
        fs = FakeStream()
929
        self.daemon.handle_command(
930
            '<start_scan>'
931
            '<scanner_params /><vts><vt id="1.2.3.4" />'
932
            '</vts>'
933
            '<targets>'
934
            '<target><hosts>192.168.0.1</hosts>'
935
            '<ports>22</ports>'
936
            '<alive_test_methods>'
937
            '<icmp>1</icmp>'
938
            '<tcp_syn>1</tcp_syn>'
939
            '<tcp_ack>1</tcp_ack>'
940
            '<arp>1</arp>'
941
            '<consider_alive>1</consider_alive>'
942
            '</alive_test_methods>'
943
            '</target>'
944
            '</targets>'
945
            '</start_scan>',
946
            fs,
947
        )
948
        self.daemon.start_queued_scans()
949
950
        response = fs.get_response()
951
952
        scan_id = response.findtext('id')
953
        time.sleep(1)
954
        target_options = self.daemon.get_scan_target_options(scan_id)
955
        self.assertEqual(
956
            target_options,
957
            {
958
                'alive_test_methods': '1',
959
                'icmp': '1',
960
                'tcp_syn': '1',
961
                'tcp_ack': '1',
962
                'arp': '1',
963
                'consider_alive': '1',
964
            },
965
        )
966
967
    def test_scan_get_target_options_alive_test_methods_dont_add_empty_or_missing(  # pylint: disable=line-too-long
968
        self,
969
    ):
970
        fs = FakeStream()
971
        self.daemon.handle_command(
972
            '<start_scan>'
973
            '<scanner_params /><vts><vt id="1.2.3.4" />'
974
            '</vts>'
975
            '<targets>'
976
            '<target><hosts>192.168.0.1</hosts>'
977
            '<ports>22</ports>'
978
            '<alive_test_methods>'
979
            '<icmp>1</icmp>'
980
            '<arp></arp>'
981
            '<consider_alive></consider_alive>'
982
            '</alive_test_methods>'
983
            '</target>'
984
            '</targets>'
985
            '</start_scan>',
986
            fs,
987
        )
988
        self.daemon.start_queued_scans()
989
990
        response = fs.get_response()
991
992
        scan_id = response.findtext('id')
993
        time.sleep(1)
994
        target_options = self.daemon.get_scan_target_options(scan_id)
995
        self.assertEqual(
996
            target_options,
997
            {
998
                'alive_test_methods': '1',
999
                'icmp': '1',
1000
            },
1001
        )
1002
1003
    def test_progress(self):
1004
1005
        fs = FakeStream()
1006
        self.daemon.handle_command(
1007
            '<start_scan parallel="2">'
1008
            '<scanner_params />'
1009
            '<targets><target>'
1010
            '<hosts>localhost1, localhost2</hosts>'
1011
            '<ports>22</ports>'
1012
            '</target></targets>'
1013
            '</start_scan>',
1014
            fs,
1015
        )
1016
        self.daemon.start_queued_scans()
1017
        response = fs.get_response()
1018
1019
        scan_id = response.findtext('id')
1020
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1021
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1022
1023
        self.assertEqual(
1024
            self.daemon.scan_collection.calculate_target_progress(scan_id), 50
1025
        )
1026
1027
    def test_progress_all_host_dead(self):
1028
1029
        fs = FakeStream()
1030
        self.daemon.handle_command(
1031
            '<start_scan parallel="2">'
1032
            '<scanner_params />'
1033
            '<targets><target>'
1034
            '<hosts>localhost1, localhost2</hosts>'
1035
            '<ports>22</ports>'
1036
            '</target></targets>'
1037
            '</start_scan>',
1038
            fs,
1039
        )
1040
        self.daemon.start_queued_scans()
1041
        response = fs.get_response()
1042
1043
        scan_id = response.findtext('id')
1044
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', -1)
1045
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', -1)
1046
1047
        self.daemon.sort_host_finished(scan_id, ['localhost1', 'localhost2'])
1048
        self.assertEqual(
1049
            self.daemon.scan_collection.calculate_target_progress(scan_id), 100
1050
        )
1051
1052
    @patch('ospd.ospd.os')
1053
    def test_interrupted_scan(self, mock_os):
1054
        mock_os.setsid.return_value = None
1055
        fs = FakeStream()
1056
        self.daemon.handle_command(
1057
            '<start_scan parallel="2">'
1058
            '<scanner_params />'
1059
            '<targets><target>'
1060
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1061
            '<ports>22</ports>'
1062
            '</target></targets>'
1063
            '</start_scan>',
1064
            fs,
1065
        )
1066
        self.daemon.start_queued_scans()
1067
1068
        response = fs.get_response()
1069
        scan_id = response.findtext('id')
1070
1071
        self.daemon.exec_scan = Mock(return_value=None)
1072
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 5)
1073
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 14)
1074
        while self.daemon.get_scan_status(scan_id) == ScanStatus.INIT:
1075
            fs = FakeStream()
1076
            self.daemon.handle_command(
1077
                '<get_scans scan_id="%s" details="0" progress="0"/>' % scan_id,
1078
                fs,
1079
            )
1080
        response = fs.get_response()
1081
        status = response.find('scan').attrib['status']
1082
1083
        self.assertEqual(status, ScanStatus.INTERRUPTED.name.lower())
1084
1085
    def test_sort_host_finished(self):
1086
1087
        fs = FakeStream()
1088
        self.daemon.handle_command(
1089
            '<start_scan parallel="2">'
1090
            '<scanner_params />'
1091
            '<targets><target>'
1092
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1093
            '<ports>22</ports>'
1094
            '</target></targets>'
1095
            '</start_scan>',
1096
            fs,
1097
        )
1098
        self.daemon.start_queued_scans()
1099
1100
        response = fs.get_response()
1101
1102
        scan_id = response.findtext('id')
1103
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1104
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1105
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1106
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1107
1108
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1109
1110
        rounded_progress = self.daemon.scan_collection.calculate_target_progress(  # pylint: disable=line-too-long)
1111
            scan_id
1112
        )
1113
        self.assertEqual(rounded_progress, 66)
1114
1115
    def test_set_status_interrupted(self):
1116
        fs = FakeStream()
1117
        self.daemon.handle_command(
1118
            '<start_scan parallel="2">'
1119
            '<scanner_params />'
1120
            '<targets><target>'
1121
            '<hosts>localhost1</hosts>'
1122
            '<ports>22</ports>'
1123
            '</target></targets>'
1124
            '</start_scan>',
1125
            fs,
1126
        )
1127
        self.daemon.start_queued_scans()
1128
        response = fs.get_response()
1129
        scan_id = response.findtext('id')
1130
1131
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1132
        self.assertEqual(end_time, 0)
1133
1134
        self.daemon.interrupt_scan(scan_id)
1135
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1136
        self.assertNotEqual(end_time, 0)
1137
1138
    def test_set_status_stopped(self):
1139
        fs = FakeStream()
1140
        self.daemon.handle_command(
1141
            '<start_scan parallel="2">'
1142
            '<scanner_params />'
1143
            '<targets><target>'
1144
            '<hosts>localhost1</hosts>'
1145
            '<ports>22</ports>'
1146
            '</target></targets>'
1147
            '</start_scan>',
1148
            fs,
1149
        )
1150
        self.daemon.start_queued_scans()
1151
        response = fs.get_response()
1152
        scan_id = response.findtext('id')
1153
1154
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1155
        self.assertEqual(end_time, 0)
1156
1157
        self.daemon.set_scan_status(scan_id, ScanStatus.STOPPED)
1158
        end_time = self.daemon.scan_collection.get_end_time(scan_id)
1159
        self.assertNotEqual(end_time, 0)
1160
1161
    def test_calculate_progress_without_current_hosts(self):
1162
1163
        fs = FakeStream()
1164
        self.daemon.handle_command(
1165
            '<start_scan parallel="2">'
1166
            '<scanner_params />'
1167
            '<targets><target>'
1168
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1169
            '<ports>22</ports>'
1170
            '</target></targets>'
1171
            '</start_scan>',
1172
            fs,
1173
        )
1174
        self.daemon.start_queued_scans()
1175
        response = fs.get_response()
1176
1177
        scan_id = response.findtext('id')
1178
        self.daemon.set_scan_host_progress(scan_id)
1179
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1180
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1181
1182
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1183
1184
        float_progress = self.daemon.scan_collection.calculate_target_progress(
1185
            scan_id
1186
        )
1187
        self.assertEqual(int(float_progress), 33)
1188
1189
        self.daemon.scan_collection.set_progress(scan_id, float_progress)
1190
        progress = self.daemon.get_scan_progress(scan_id)
1191
        self.assertEqual(progress, 33)
1192
1193
    def test_get_scan_host_progress(self):
1194
        fs = FakeStream()
1195
        self.daemon.handle_command(
1196
            '<start_scan parallel="2">'
1197
            '<scanner_params />'
1198
            '<targets><target>'
1199
            '<hosts>localhost</hosts>'
1200
            '<ports>22</ports>'
1201
            '</target></targets>'
1202
            '</start_scan>',
1203
            fs,
1204
        )
1205
        self.daemon.start_queued_scans()
1206
        response = fs.get_response()
1207
1208
        scan_id = response.findtext('id')
1209
        self.daemon.set_scan_host_progress(scan_id, 'localhost', 45)
1210
        self.assertEqual(
1211
            self.daemon.get_scan_host_progress(scan_id, 'localhost'), 45
1212
        )
1213
1214
    def test_get_scan_without_scanid(self):
1215
1216
        fs = FakeStream()
1217
        self.daemon.handle_command(
1218
            '<start_scan parallel="2">'
1219
            '<scanner_params />'
1220
            '<targets><target>'
1221
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1222
            '<ports>22</ports>'
1223
            '</target></targets>'
1224
            '</start_scan>',
1225
            fs,
1226
        )
1227
        self.daemon.start_queued_scans()
1228
1229
        fs = FakeStream()
1230
        self.assertRaises(
1231
            OspdCommandError,
1232
            self.daemon.handle_command,
1233
            '<get_scans details="0" progress="1"/>',
1234
            fs,
1235
        )
1236
1237
    def test_set_scan_total_hosts(self):
1238
1239
        fs = FakeStream()
1240
        self.daemon.handle_command(
1241
            '<start_scan parallel="2">'
1242
            '<scanner_params />'
1243
            '<targets><target>'
1244
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1245
            '<ports>22</ports>'
1246
            '</target></targets>'
1247
            '</start_scan>',
1248
            fs,
1249
        )
1250
        self.daemon.start_queued_scans()
1251
1252
        response = fs.get_response()
1253
        scan_id = response.findtext('id')
1254
1255
        count = self.daemon.scan_collection.get_count_total(scan_id)
1256
        self.assertEqual(count, 4)
1257
1258
        self.daemon.set_scan_total_hosts(scan_id, 3)
1259
        count = self.daemon.scan_collection.get_count_total(scan_id)
1260
        self.assertEqual(count, 3)
1261
1262
    def test_set_scan_total_hosts_zero(self):
1263
1264
        fs = FakeStream()
1265
        self.daemon.handle_command(
1266
            '<start_scan parallel="2">'
1267
            '<scanner_params />'
1268
            '<targets><target>'
1269
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1270
            '<ports>22</ports>'
1271
            '</target></targets>'
1272
            '</start_scan>',
1273
            fs,
1274
        )
1275
        self.daemon.start_queued_scans()
1276
1277
        response = fs.get_response()
1278
        scan_id = response.findtext('id')
1279
1280
        # Default calculated by ospd with the hosts in the target
1281
        count = self.daemon.scan_collection.get_count_total(scan_id)
1282
        self.assertEqual(count, 4)
1283
1284
        # Set to 0 (all hosts unresolved, dead, invalid target) via
1285
        # the server. This one has priority and must be still 0 and
1286
        # never overwritten with the calculation from host list
1287
        self.daemon.set_scan_total_hosts(scan_id, 0)
1288
        count = self.daemon.scan_collection.get_count_total(scan_id)
1289
        self.assertEqual(count, 0)
1290
1291
    def test_set_scan_total_hosts_invalid_target(self):
1292
1293
        fs = FakeStream()
1294
        self.daemon.handle_command(
1295
            '<start_scan parallel="2">'
1296
            '<scanner_params />'
1297
            '<targets><target>'
1298
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1299
            '<ports>22</ports>'
1300
            '</target></targets>'
1301
            '</start_scan>',
1302
            fs,
1303
        )
1304
        self.daemon.start_queued_scans()
1305
1306
        response = fs.get_response()
1307
        scan_id = response.findtext('id')
1308
1309
        count = self.daemon.scan_collection.get_count_total(scan_id)
1310
        self.assertEqual(count, 4)
1311
1312
        # The total host is set by the server as -1, because invalid target
1313
        self.daemon.set_scan_total_hosts(scan_id, -1)
1314
        count = self.daemon.scan_collection.get_count_total(scan_id)
1315
        self.assertEqual(count, 0)
1316
1317
    def test_get_scan_progress_xml(self):
1318
1319
        fs = FakeStream()
1320
        self.daemon.handle_command(
1321
            '<start_scan parallel="2">'
1322
            '<scanner_params />'
1323
            '<targets><target>'
1324
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
1325
            '<ports>22</ports>'
1326
            '</target></targets>'
1327
            '</start_scan>',
1328
            fs,
1329
        )
1330
        self.daemon.start_queued_scans()
1331
1332
        response = fs.get_response()
1333
        scan_id = response.findtext('id')
1334
1335
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
1336
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
1337
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
1338
1339
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
1340
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
1341
1342
        fs = FakeStream()
1343
        self.daemon.handle_command(
1344
            '<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id,
1345
            fs,
1346
        )
1347
        response = fs.get_response()
1348
1349
        progress = response.find('scan/progress')
1350
1351
        overall = float(progress.findtext('overall'))
1352
        self.assertEqual(int(overall), 66)
1353
1354
        count_alive = progress.findtext('count_alive')
1355
        self.assertEqual(count_alive, '1')
1356
1357
        count_dead = progress.findtext('count_dead')
1358
        self.assertEqual(count_dead, '1')
1359
1360
        current_hosts = progress.findall('host')
1361
        self.assertEqual(len(current_hosts), 2)
1362
1363
        count_excluded = progress.findtext('count_excluded')
1364
        self.assertEqual(count_excluded, '0')
1365
1366
    def test_set_get_vts_version(self):
1367
        self.daemon.set_vts_version('1234')
1368
1369
        version = self.daemon.get_vts_version()
1370
        self.assertEqual('1234', version)
1371
1372
    def test_set_get_vts_version_error(self):
1373
        self.assertRaises(TypeError, self.daemon.set_vts_version)
1374
1375
    @patch("ospd.ospd.os")
1376
    @patch("ospd.ospd.create_process")
1377
    def test_scan_exists(self, mock_create_process, _mock_os):
1378
        fp = FakeStartProcess()
1379
        mock_create_process.side_effect = fp
1380
        mock_process = fp.call_mock
1381
        mock_process.start.side_effect = fp.run
1382
        mock_process.is_alive.return_value = True
1383
        mock_process.pid = "main-scan-process"
1384
1385
        fs = FakeStream()
1386
        self.daemon.handle_command(
1387
            '<start_scan>'
1388
            '<scanner_params />'
1389
            '<targets><target>'
1390
            '<hosts>localhost</hosts>'
1391
            '<ports>22</ports>'
1392
            '</target></targets>'
1393
            '</start_scan>',
1394
            fs,
1395
        )
1396
        response = fs.get_response()
1397
        scan_id = response.findtext('id')
1398
        self.assertIsNotNone(scan_id)
1399
1400
        status = response.get('status_text')
1401
        self.assertEqual(status, 'OK')
1402
1403
        self.daemon.start_queued_scans()
1404
1405
        assert_called(mock_create_process)
1406
        assert_called(mock_process.start)
1407
1408
        self.daemon.handle_command('<stop_scan scan_id="%s" />' % scan_id, fs)
1409
1410
        fs = FakeStream()
1411
        cmd = (
1412
            '<start_scan scan_id="' + scan_id + '">'
1413
            '<scanner_params />'
1414
            '<targets><target>'
1415
            '<hosts>localhost</hosts>'
1416
            '<ports>22</ports>'
1417
            '</target></targets>'
1418
            '</start_scan>'
1419
        )
1420
1421
        self.daemon.handle_command(
1422
            cmd,
1423
            fs,
1424
        )
1425
        self.daemon.start_queued_scans()
1426
1427
        response = fs.get_response()
1428
        status = response.get('status_text')
1429
        self.assertEqual(status, 'Continue')
1430
1431
    def test_result_order(self):
1432
1433
        fs = FakeStream()
1434
        self.daemon.handle_command(
1435
            '<start_scan parallel="1">'
1436
            '<scanner_params />'
1437
            '<targets><target>'
1438
            '<hosts>a</hosts>'
1439
            '<ports>22</ports>'
1440
            '</target></targets>'
1441
            '</start_scan>',
1442
            fs,
1443
        )
1444
        self.daemon.start_queued_scans()
1445
        response = fs.get_response()
1446
1447
        scan_id = response.findtext('id')
1448
1449
        self.daemon.add_scan_log(scan_id, host='a', name='a')
1450
        self.daemon.add_scan_log(scan_id, host='c', name='c')
1451
        self.daemon.add_scan_log(scan_id, host='b', name='b')
1452
        hosts = ['a', 'c', 'b']
1453
1454
        fs = FakeStream()
1455
        self.daemon.handle_command(
1456
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1457
        )
1458
        response = fs.get_response()
1459
1460
        results = response.findall("scan/results/")
1461
1462
        for idx, res in enumerate(results):
1463
            att_dict = res.attrib
1464
            self.assertEqual(hosts[idx], att_dict['name'])
1465
1466
    def test_batch_result(self):
1467
        reslist = ResultList()
1468
        fs = FakeStream()
1469
        self.daemon.handle_command(
1470
            '<start_scan parallel="1">'
1471
            '<scanner_params />'
1472
            '<targets><target>'
1473
            '<hosts>a</hosts>'
1474
            '<ports>22</ports>'
1475
            '</target></targets>'
1476
            '</start_scan>',
1477
            fs,
1478
        )
1479
        self.daemon.start_queued_scans()
1480
        response = fs.get_response()
1481
1482
        scan_id = response.findtext('id')
1483
        reslist.add_scan_log_to_list(host='a', name='a')
1484
        reslist.add_scan_log_to_list(host='c', name='c')
1485
        reslist.add_scan_log_to_list(host='b', name='b')
1486
        self.daemon.scan_collection.add_result_list(scan_id, reslist)
1487
1488
        hosts = ['a', 'c', 'b']
1489
1490
        fs = FakeStream()
1491
        self.daemon.handle_command(
1492
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1493
        )
1494
        response = fs.get_response()
1495
1496
        results = response.findall("scan/results/")
1497
1498
        for idx, res in enumerate(results):
1499
            att_dict = res.attrib
1500
            self.assertEqual(hosts[idx], att_dict['name'])
1501
1502
    def test_is_new_scan_allowed_false(self):
1503
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1504
            'a': 1,
1505
            'b': 2,
1506
        }
1507
        self.daemon.max_scans = 1
1508
1509
        self.assertFalse(self.daemon.is_new_scan_allowed())
1510
1511
    def test_is_new_scan_allowed_true(self):
1512
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1513
            'a': 1,
1514
            'b': 2,
1515
        }
1516
        self.daemon.max_scans = 3
1517
1518
        self.assertTrue(self.daemon.is_new_scan_allowed())
1519
1520
    def test_start_queue_scan_daemon_not_init(self):
1521
        self.daemon.get_count_queued_scans = MagicMock(return_value=10)
1522
        self.daemon.initialized = False
1523
        logging.Logger.info = Mock()
1524
        self.daemon.start_queued_scans()
1525
1526
        logging.Logger.info.assert_called_with(  # pylint: disable=no-member
1527
            "Queued task can not be started because a "
1528
            "feed update is being performed."
1529
        )
1530
1531
    @patch("ospd.ospd.psutil")
1532
    def test_free_memory_true(self, mock_psutil):
1533
        self.daemon.min_free_mem_scan_queue = 1000
1534
        # 1.5 GB free
1535
        mock_psutil.virtual_memory.return_value = FakePsutil(
1536
            available=1500000000
1537
        )
1538
1539
        self.assertTrue(self.daemon.is_enough_free_memory())
1540
1541
    @patch("ospd.ospd.psutil")
1542
    def test_free_memory_false(self, mock_psutil):
1543
        self.daemon.min_free_mem_scan_queue = 2000
1544
        # 1.5 GB free
1545
        mock_psutil.virtual_memory.return_value = FakePsutil(
1546
            available=1500000000
1547
        )
1548
1549
        self.assertFalse(self.daemon.is_enough_free_memory())
1550
1551
    def test_count_queued_scans(self):
1552
        fs = FakeStream()
1553
        self.daemon.handle_command(
1554
            '<start_scan>'
1555
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1556
            '</vts>'
1557
            '<targets><target>'
1558
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1559
            '<ports>80,443</ports>'
1560
            '</target></targets>'
1561
            '</start_scan>',
1562
            fs,
1563
        )
1564
1565
        self.assertEqual(self.daemon.get_count_queued_scans(), 1)
1566
        self.daemon.start_queued_scans()
1567
        self.assertEqual(self.daemon.get_count_queued_scans(), 0)
1568
1569
    def test_ids_iterator_dict_modified(self):
1570
        self.daemon.scan_collection.scans_table = {'a': 1, 'b': 2}
1571
1572
        for _ in self.daemon.scan_collection.ids_iterator():
1573
            self.daemon.scan_collection.scans_table['c'] = 3
1574
1575
        self.assertEqual(len(self.daemon.scan_collection.scans_table), 3)
1576