Completed
Push — master ( 1ed431...35dbab )
by
unknown
17s queued 13s
created

tests.test_scan_and_result.ScanTestCase.setUp()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" Test module for scan runs
21
"""
22
23
import time
24
import unittest
25
26
from unittest.mock import patch, MagicMock
27
28
import xml.etree.ElementTree as ET
29
30
from defusedxml.common import EntitiesForbidden
31
32
from ospd.resultlist import ResultList
33
from ospd.errors import OspdCommandError
34
35
from .helper import (
36
    DummyWrapper,
37
    assert_called,
38
    FakeStream,
39
    FakeDataManager,
40
    FakePsutil,
41
)
42
43
44
class FakeStartProcess:
45
    def __init__(self):
46
        self.run_mock = MagicMock()
47
        self.call_mock = MagicMock()
48
49
        self.func = None
50
        self.args = None
51
        self.kwargs = None
52
53
    def __call__(self, func, *, args=None, kwargs=None):
54
        self.func = func
55
        self.args = args or []
56
        self.kwargs = kwargs or {}
57
        return self.call_mock
58
59
    def run(self):
60
        self.func(*self.args, **self.kwargs)
61
        return self.run_mock
62
63
    def __repr__(self):
64
        return "<FakeProcess func={} args={} kwargs={}>".format(
65
            self.func, self.args, self.kwargs
66
        )
67
68
69
class Result(object):
70
    def __init__(self, type_, **kwargs):
71
        self.result_type = type_
72
        self.host = ''
73
        self.hostname = ''
74
        self.name = ''
75
        self.value = ''
76
        self.port = ''
77
        self.test_id = ''
78
        self.severity = ''
79
        self.qod = ''
80
        for name, value in kwargs.items():
81
            setattr(self, name, value)
82
83
84
class ScanTestCase(unittest.TestCase):
85
    def setUp(self):
86
        self.daemon = DummyWrapper([])
87
        self.daemon.scan_collection.datamanager = FakeDataManager()
88
        self.daemon.scan_collection.file_storage_dir = '/tmp'
89
90
    def test_get_default_scanner_params(self):
91
        fs = FakeStream()
92
93
        self.daemon.handle_command('<get_scanner_details />', fs)
94
        response = fs.get_response()
95
96
        # The status of the response must be success (i.e. 200)
97
        self.assertEqual(response.get('status'), '200')
98
        # The response root element must have the correct name
99
        self.assertEqual(response.tag, 'get_scanner_details_response')
100
        # The response must contain a 'scanner_params' element
101
        self.assertIsNotNone(response.find('scanner_params'))
102
103
    def test_get_default_help(self):
104
        fs = FakeStream()
105
106
        self.daemon.handle_command('<help />', fs)
107
        response = fs.get_response()
108
        self.assertEqual(response.get('status'), '200')
109
110
        fs = FakeStream()
111
        self.daemon.handle_command('<help format="xml" />', fs)
112
        response = fs.get_response()
113
114
        self.assertEqual(response.get('status'), '200')
115
        self.assertEqual(response.tag, 'help_response')
116
117
    def test_get_default_scanner_version(self):
118
        fs = FakeStream()
119
        self.daemon.handle_command('<get_version />', fs)
120
        response = fs.get_response()
121
122
        self.assertEqual(response.get('status'), '200')
123
        self.assertIsNotNone(response.find('protocol'))
124
125
    def test_get_vts_no_vt(self):
126
        fs = FakeStream()
127
128
        self.daemon.handle_command('<get_vts />', fs)
129
        response = fs.get_response()
130
131
        self.assertEqual(response.get('status'), '200')
132
        self.assertIsNotNone(response.find('vts'))
133
134
    def test_get_vt_xml_no_dict(self):
135
        single_vt = ('1234', None)
136
        vt = self.daemon.get_vt_xml(single_vt)
137
        self.assertFalse(vt.get('id'))
138
139
    def test_get_vts_single_vt(self):
140
        fs = FakeStream()
141
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
142
        self.daemon.handle_command('<get_vts />', fs)
143
        response = fs.get_response()
144
145
        self.assertEqual(response.get('status'), '200')
146
147
        vts = response.find('vts')
148
        self.assertIsNotNone(vts.find('vt'))
149
150
        vt = vts.find('vt')
151
        self.assertEqual(vt.get('id'), '1.2.3.4')
152
153
    def test_get_vts_still_not_init(self):
154
        fs = FakeStream()
155
        self.daemon.initialized = False
156
        self.daemon.handle_command('<get_vts />', fs)
157
        response = fs.get_response()
158
159
        self.assertEqual(response.get('status'), '400')
160
161
    def test_get_help_still_not_init(self):
162
        fs = FakeStream()
163
        self.daemon.initialized = False
164
        self.daemon.handle_command('<help/>', fs)
165
        response = fs.get_response()
166
167
        self.assertEqual(response.get('status'), '200')
168
169 View Code Duplication
    def test_get_vts_filter_positive(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
170
        self.daemon.add_vt(
171
            '1.2.3.4',
172
            'A vulnerability test',
173
            vt_params="a",
174
            vt_modification_time='19000202',
175
        )
176
        fs = FakeStream()
177
178
        self.daemon.handle_command(
179
            '<get_vts filter="modification_time&gt;19000201"></get_vts>', fs
180
        )
181
        response = fs.get_response()
182
183
        self.assertEqual(response.get('status'), '200')
184
        vts = response.find('vts')
185
186
        vt = vts.find('vt')
187
        self.assertIsNotNone(vt)
188
        self.assertEqual(vt.get('id'), '1.2.3.4')
189
190
        modification_time = response.findall('vts/vt/modification_time')
191
        self.assertEqual(
192
            '<modification_time>19000202</modification_time>',
193
            ET.tostring(modification_time[0]).decode('utf-8'),
194
        )
195
196 View Code Duplication
    def test_get_vts_filter_negative(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
197
        self.daemon.add_vt(
198
            '1.2.3.4',
199
            'A vulnerability test',
200
            vt_params="a",
201
            vt_modification_time='19000202',
202
        )
203
        fs = FakeStream()
204
        self.daemon.handle_command(
205
            '<get_vts filter="modification_time&lt;19000203"></get_vts>', fs,
206
        )
207
        response = fs.get_response()
208
209
        self.assertEqual(response.get('status'), '200')
210
211
        vts = response.find('vts')
212
213
        vt = vts.find('vt')
214
        self.assertIsNotNone(vt)
215
        self.assertEqual(vt.get('id'), '1.2.3.4')
216
217
        modification_time = response.findall('vts/vt/modification_time')
218
        self.assertEqual(
219
            '<modification_time>19000202</modification_time>',
220
            ET.tostring(modification_time[0]).decode('utf-8'),
221
        )
222
223
    def test_get_vts_bad_filter(self):
224
        fs = FakeStream()
225
        cmd = '<get_vts filter="modification_time"/>'
226
227
        self.assertRaises(OspdCommandError, self.daemon.handle_command, cmd, fs)
228
        self.assertTrue(self.daemon.vts.is_cache_available)
229
230
    def test_get_vtss_multiple_vts(self):
231
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
232
        self.daemon.add_vt('1.2.3.5', 'Another vulnerability test')
233
        self.daemon.add_vt('123456789', 'Yet another vulnerability test')
234
235
        fs = FakeStream()
236
237
        self.daemon.handle_command('<get_vts />', fs)
238
        response = fs.get_response()
239
        self.assertEqual(response.get('status'), '200')
240
241
        vts = response.find('vts')
242
        self.assertIsNotNone(vts.find('vt'))
243
244
    def test_get_vts_multiple_vts_with_custom(self):
245
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b')
246
        self.daemon.add_vt(
247
            '4.3.2.1', 'Another vulnerability test with custom info', custom='b'
248
        )
249
        self.daemon.add_vt(
250
            '123456789', 'Yet another vulnerability test', custom='b'
251
        )
252
        fs = FakeStream()
253
254
        self.daemon.handle_command('<get_vts />', fs)
255
        response = fs.get_response()
256
257
        custom = response.findall('vts/vt/custom')
258
259
        self.assertEqual(3, len(custom))
260
261 View Code Duplication
    def test_get_vts_vts_with_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
262
        self.daemon.add_vt(
263
            '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b"
264
        )
265
        fs = FakeStream()
266
267
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
268
        response = fs.get_response()
269
270
        # The status of the response must be success (i.e. 200)
271
        self.assertEqual(response.get('status'), '200')
272
273
        # The response root element must have the correct name
274
        self.assertEqual(response.tag, 'get_vts_response')
275
        # The response must contain a 'scanner_params' element
276
        self.assertIsNotNone(response.find('vts'))
277
278
        vt_params = response[0][0].findall('params')
279
        self.assertEqual(1, len(vt_params))
280
281
        custom = response[0][0].findall('custom')
282
        self.assertEqual(1, len(custom))
283
284
        params = response.findall('vts/vt/params/param')
285
        self.assertEqual(2, len(params))
286
287 View Code Duplication
    def test_get_vts_vts_with_refs(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
288
        self.daemon.add_vt(
289
            '1.2.3.4',
290
            'A vulnerability test',
291
            vt_params="a",
292
            custom="b",
293
            vt_refs="c",
294
        )
295
        fs = FakeStream()
296
297
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
298
        response = fs.get_response()
299
300
        # The status of the response must be success (i.e. 200)
301
        self.assertEqual(response.get('status'), '200')
302
303
        # The response root element must have the correct name
304
        self.assertEqual(response.tag, 'get_vts_response')
305
306
        # The response must contain a 'vts' element
307
        self.assertIsNotNone(response.find('vts'))
308
309
        vt_params = response[0][0].findall('params')
310
        self.assertEqual(1, len(vt_params))
311
312
        custom = response[0][0].findall('custom')
313
        self.assertEqual(1, len(custom))
314
315
        refs = response.findall('vts/vt/refs/ref')
316
        self.assertEqual(2, len(refs))
317
318
    def test_get_vts_vts_with_dependencies(self):
319
        self.daemon.add_vt(
320
            '1.2.3.4',
321
            'A vulnerability test',
322
            vt_params="a",
323
            custom="b",
324
            vt_dependencies="c",
325
        )
326
        fs = FakeStream()
327
328
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
329
330
        response = fs.get_response()
331
332
        deps = response.findall('vts/vt/dependencies/dependency')
333
        self.assertEqual(2, len(deps))
334
335
    def test_get_vts_vts_with_severities(self):
336
        self.daemon.add_vt(
337
            '1.2.3.4',
338
            'A vulnerability test',
339
            vt_params="a",
340
            custom="b",
341
            severities="c",
342
        )
343
        fs = FakeStream()
344
345
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
346
        response = fs.get_response()
347
348
        severity = response.findall('vts/vt/severities/severity')
349
        self.assertEqual(1, len(severity))
350
351
    def test_get_vts_vts_with_detection_qodt(self):
352
        self.daemon.add_vt(
353
            '1.2.3.4',
354
            'A vulnerability test',
355
            vt_params="a",
356
            custom="b",
357
            detection="c",
358
            qod_t="d",
359
        )
360
        fs = FakeStream()
361
362
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
363
        response = fs.get_response()
364
365
        detection = response.findall('vts/vt/detection')
366
        self.assertEqual(1, len(detection))
367
368
    def test_get_vts_vts_with_detection_qodv(self):
369
        self.daemon.add_vt(
370
            '1.2.3.4',
371
            'A vulnerability test',
372
            vt_params="a",
373
            custom="b",
374
            detection="c",
375
            qod_v="d",
376
        )
377
        fs = FakeStream()
378
379
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
380
        response = fs.get_response()
381
382
        detection = response.findall('vts/vt/detection')
383
        self.assertEqual(1, len(detection))
384
385
    def test_get_vts_vts_with_summary(self):
386
        self.daemon.add_vt(
387
            '1.2.3.4',
388
            'A vulnerability test',
389
            vt_params="a",
390
            custom="b",
391
            summary="c",
392
        )
393
        fs = FakeStream()
394
395
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
396
        response = fs.get_response()
397
398
        summary = response.findall('vts/vt/summary')
399
        self.assertEqual(1, len(summary))
400
401
    def test_get_vts_vts_with_impact(self):
402
        self.daemon.add_vt(
403
            '1.2.3.4',
404
            'A vulnerability test',
405
            vt_params="a",
406
            custom="b",
407
            impact="c",
408
        )
409
        fs = FakeStream()
410
411
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
412
        response = fs.get_response()
413
414
        impact = response.findall('vts/vt/impact')
415
        self.assertEqual(1, len(impact))
416
417
    def test_get_vts_vts_with_affected(self):
418
        self.daemon.add_vt(
419
            '1.2.3.4',
420
            'A vulnerability test',
421
            vt_params="a",
422
            custom="b",
423
            affected="c",
424
        )
425
        fs = FakeStream()
426
427
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
428
        response = fs.get_response()
429
430
        affect = response.findall('vts/vt/affected')
431
        self.assertEqual(1, len(affect))
432
433
    def test_get_vts_vts_with_insight(self):
434
        self.daemon.add_vt(
435
            '1.2.3.4',
436
            'A vulnerability test',
437
            vt_params="a",
438
            custom="b",
439
            insight="c",
440
        )
441
        fs = FakeStream()
442
443
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
444
        response = fs.get_response()
445
446
        insight = response.findall('vts/vt/insight')
447
        self.assertEqual(1, len(insight))
448
449
    def test_get_vts_vts_with_solution(self):
450
        self.daemon.add_vt(
451
            '1.2.3.4',
452
            'A vulnerability test',
453
            vt_params="a",
454
            custom="b",
455
            solution="c",
456
            solution_t="d",
457
            solution_m="e",
458
        )
459
        fs = FakeStream()
460
461
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
462
        response = fs.get_response()
463
464
        solution = response.findall('vts/vt/solution')
465
        self.assertEqual(1, len(solution))
466
467
    def test_get_vts_vts_with_ctime(self):
468
        self.daemon.add_vt(
469
            '1.2.3.4',
470
            'A vulnerability test',
471
            vt_params="a",
472
            vt_creation_time='01-01-1900',
473
        )
474
        fs = FakeStream()
475
476
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
477
        response = fs.get_response()
478
479
        creation_time = response.findall('vts/vt/creation_time')
480
        self.assertEqual(
481
            '<creation_time>01-01-1900</creation_time>',
482
            ET.tostring(creation_time[0]).decode('utf-8'),
483
        )
484
485
    def test_get_vts_vts_with_mtime(self):
486
        self.daemon.add_vt(
487
            '1.2.3.4',
488
            'A vulnerability test',
489
            vt_params="a",
490
            vt_modification_time='02-01-1900',
491
        )
492
        fs = FakeStream()
493
494
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
495
        response = fs.get_response()
496
497
        modification_time = response.findall('vts/vt/modification_time')
498
        self.assertEqual(
499
            '<modification_time>02-01-1900</modification_time>',
500
            ET.tostring(modification_time[0]).decode('utf-8'),
501
        )
502
503
    def test_clean_forgotten_scans(self):
504
        fs = FakeStream()
505
506
        self.daemon.handle_command(
507
            '<start_scan target="localhost" ports="80, '
508
            '443"><scanner_params /></start_scan>',
509
            fs,
510
        )
511
        response = fs.get_response()
512
513
        scan_id = response.findtext('id')
514
515
        finished = False
516
517
        self.daemon.start_queued_scans()
518
        while not finished:
519
            fs = FakeStream()
520
            self.daemon.handle_command(
521
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
522
            )
523
            response = fs.get_response()
524
525
            scans = response.findall('scan')
526
            self.assertEqual(1, len(scans))
527
528
            scan = scans[0]
529
530
            if scan.get('end_time') != '0':
531
                finished = True
532
            else:
533
                time.sleep(0.01)
534
535
            fs = FakeStream()
536
            self.daemon.handle_command(
537
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
538
            )
539
            response = fs.get_response()
540
541
        self.assertEqual(
542
            len(list(self.daemon.scan_collection.ids_iterator())), 1
543
        )
544
545
        # Set an old end_time
546
        self.daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456
547
        # Run the check
548
        self.daemon.clean_forgotten_scans()
549
        # Not removed
550
        self.assertEqual(
551
            len(list(self.daemon.scan_collection.ids_iterator())), 1
552
        )
553
554
        # Set the max time and run again
555
        self.daemon.scaninfo_store_time = 1
556
        self.daemon.clean_forgotten_scans()
557
        # Now is removed
558
        self.assertEqual(
559
            len(list(self.daemon.scan_collection.ids_iterator())), 0
560
        )
561
562
    def test_scan_with_error(self):
563
        fs = FakeStream()
564
565
        self.daemon.handle_command(
566
            '<start_scan target="localhost" ports="80, '
567
            '443"><scanner_params /></start_scan>',
568
            fs,
569
        )
570
571
        response = fs.get_response()
572
        scan_id = response.findtext('id')
573
        finished = False
574
        self.daemon.start_queued_scans()
575
        self.daemon.add_scan_error(
576
            scan_id, host='a', value='something went wrong'
577
        )
578
579
        while not finished:
580
            fs = FakeStream()
581
            self.daemon.handle_command(
582
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
583
            )
584
            response = fs.get_response()
585
586
            scans = response.findall('scan')
587
            self.assertEqual(1, len(scans))
588
589
            scan = scans[0]
590
            status = scan.get('status')
591
592
            if status == "init" or status == "running":
593
                self.assertEqual('0', scan.get('end_time'))
594
                time.sleep(0.010)
595
            else:
596
                finished = True
597
598
            fs = FakeStream()
599
600
            self.daemon.handle_command(
601
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
602
            )
603
            response = fs.get_response()
604
605
        self.assertEqual(
606
            response.findtext('scan/results/result'), 'something went wrong'
607
        )
608
        fs = FakeStream()
609
        self.daemon.handle_command('<delete_scan scan_id="%s" />' % scan_id, fs)
610
        response = fs.get_response()
611
612
        self.assertEqual(response.get('status'), '200')
613
614
    def test_get_scan_pop(self):
615
        fs = FakeStream()
616
617
        self.daemon.handle_command(
618
            '<start_scan target="localhost" ports="80, 443">'
619
            '<scanner_params /></start_scan>',
620
            fs,
621
        )
622
        self.daemon.start_queued_scans()
623
        response = fs.get_response()
624
625
        scan_id = response.findtext('id')
626
        self.daemon.add_scan_host_detail(
627
            scan_id, host='a', value='Some Host Detail'
628
        )
629
630
        time.sleep(1)
631
632
        fs = FakeStream()
633
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
634
        response = fs.get_response()
635
636
        self.assertEqual(
637
            response.findtext('scan/results/result'), 'Some Host Detail'
638
        )
639
        fs = FakeStream()
640
        self.daemon.handle_command(
641
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
642
        )
643
        response = fs.get_response()
644
645
        self.assertEqual(
646
            response.findtext('scan/results/result'), 'Some Host Detail'
647
        )
648
649
        fs = FakeStream()
650
        self.daemon.handle_command(
651
            '<get_scans scan_id="%s" details="0" pop_results="1"/>' % scan_id,
652
            fs,
653
        )
654
        response = fs.get_response()
655
656
        self.assertEqual(response.findtext('scan/results/result'), None)
657
658
    def test_get_scan_pop_max_res(self):
659
        fs = FakeStream()
660
        self.daemon.handle_command(
661
            '<start_scan target="localhost" ports="80, 443">'
662
            '<scanner_params /></start_scan>',
663
            fs,
664
        )
665
        self.daemon.start_queued_scans()
666
        response = fs.get_response()
667
        scan_id = response.findtext('id')
668
669
        self.daemon.add_scan_log(scan_id, host='a', name='a')
670
        self.daemon.add_scan_log(scan_id, host='c', name='c')
671
        self.daemon.add_scan_log(scan_id, host='b', name='b')
672
673
        fs = FakeStream()
674
        self.daemon.handle_command(
675
            '<get_scans scan_id="%s" pop_results="1" max_results="1"/>'
676
            % scan_id,
677
            fs,
678
        )
679
680
        response = fs.get_response()
681
682
        self.assertEqual(len(response.findall('scan/results/result')), 1)
683
684
        fs = FakeStream()
685
        self.daemon.handle_command(
686
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
687
        )
688
        response = fs.get_response()
689
        self.assertEqual(len(response.findall('scan/results/result')), 2)
690
691
    def test_billon_laughs(self):
692
        # pylint: disable=line-too-long
693
694
        lol = (
695
            '<?xml version="1.0"?>'
696
            '<!DOCTYPE lolz ['
697
            ' <!ENTITY lol "lol">'
698
            ' <!ELEMENT lolz (#PCDATA)>'
699
            ' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">'
700
            ' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">'
701
            ' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">'
702
            ' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">'
703
            ' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">'
704
            ' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">'
705
            ' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">'
706
            ' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">'
707
            ' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">'
708
            ']>'
709
        )
710
        fs = FakeStream()
711
        self.assertRaises(
712
            EntitiesForbidden, self.daemon.handle_command, lol, fs
713
        )
714
715
    def test_target_with_credentials(self):
716
        fs = FakeStream()
717
        self.daemon.handle_command(
718
            '<start_scan>'
719
            '<scanner_params /><vts><vt id="1.2.3.4" />'
720
            '</vts>'
721
            '<targets><target>'
722
            '<hosts>192.168.0.0/24</hosts><ports>22'
723
            '</ports><credentials>'
724
            '<credential type="up" service="ssh" port="22">'
725
            '<username>scanuser</username>'
726
            '<password>mypass</password>'
727
            '</credential><credential type="up" service="smb">'
728
            '<username>smbuser</username>'
729
            '<password>mypass</password></credential>'
730
            '</credentials>'
731
            '</target></targets>'
732
            '</start_scan>',
733
            fs,
734
        )
735
        self.daemon.start_queued_scans()
736
        response = fs.get_response()
737
738
        self.assertEqual(response.get('status'), '200')
739
740
        cred_dict = {
741
            'ssh': {
742
                'type': 'up',
743
                'password': 'mypass',
744
                'port': '22',
745
                'username': 'scanuser',
746
            },
747
            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
748
        }
749
        scan_id = response.findtext('id')
750
        response = self.daemon.get_scan_credentials(scan_id)
751
        self.assertEqual(response, cred_dict)
752
753
    def test_scan_get_target(self):
754
        fs = FakeStream()
755
        self.daemon.handle_command(
756
            '<start_scan>'
757
            '<scanner_params /><vts><vt id="1.2.3.4" />'
758
            '</vts>'
759
            '<targets><target>'
760
            '<hosts>localhosts,192.168.0.0/24</hosts>'
761
            '<ports>80,443</ports>'
762
            '</target></targets>'
763
            '</start_scan>',
764
            fs,
765
        )
766
        self.daemon.start_queued_scans()
767
768
        response = fs.get_response()
769
        scan_id = response.findtext('id')
770
771
        fs = FakeStream()
772
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
773
        response = fs.get_response()
774
775
        scan_res = response.find('scan')
776
        self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
777
778
    def test_scan_get_target_options(self):
779
        fs = FakeStream()
780
        self.daemon.handle_command(
781
            '<start_scan>'
782
            '<scanner_params /><vts><vt id="1.2.3.4" />'
783
            '</vts>'
784
            '<targets>'
785
            '<target><hosts>192.168.0.1</hosts>'
786
            '<ports>22</ports><alive_test>0</alive_test></target>'
787
            '</targets>'
788
            '</start_scan>',
789
            fs,
790
        )
791
        self.daemon.start_queued_scans()
792
793
        response = fs.get_response()
794
795
        scan_id = response.findtext('id')
796
        time.sleep(1)
797
        target_options = self.daemon.get_scan_target_options(scan_id)
798
        self.assertEqual(target_options, {'alive_test': '0'})
799
800
    def test_progress(self):
801
802
        fs = FakeStream()
803
        self.daemon.handle_command(
804
            '<start_scan parallel="2">'
805
            '<scanner_params />'
806
            '<targets><target>'
807
            '<hosts>localhost1, localhost2</hosts>'
808
            '<ports>22</ports>'
809
            '</target></targets>'
810
            '</start_scan>',
811
            fs,
812
        )
813
        self.daemon.start_queued_scans()
814
        response = fs.get_response()
815
816
        scan_id = response.findtext('id')
817
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
818
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
819
820
        self.assertEqual(
821
            self.daemon.scan_collection.calculate_target_progress(scan_id), 50
822
        )
823
824
    def test_sort_host_finished(self):
825
826
        fs = FakeStream()
827
        self.daemon.handle_command(
828
            '<start_scan parallel="2">'
829
            '<scanner_params />'
830
            '<targets><target>'
831
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
832
            '<ports>22</ports>'
833
            '</target></targets>'
834
            '</start_scan>',
835
            fs,
836
        )
837
        self.daemon.start_queued_scans()
838
839
        response = fs.get_response()
840
841
        scan_id = response.findtext('id')
842
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
843
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
844
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
845
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
846
847
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
848
849
        rounded_progress = self.daemon.scan_collection.calculate_target_progress(  # pylint: disable=line-too-long)
850
            scan_id
851
        )
852
        self.assertEqual(rounded_progress, 66)
853
854
    def test_calculate_progress_without_current_hosts(self):
855
856
        fs = FakeStream()
857
        self.daemon.handle_command(
858
            '<start_scan parallel="2">'
859
            '<scanner_params />'
860
            '<targets><target>'
861
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
862
            '<ports>22</ports>'
863
            '</target></targets>'
864
            '</start_scan>',
865
            fs,
866
        )
867
        self.daemon.start_queued_scans()
868
        response = fs.get_response()
869
870
        scan_id = response.findtext('id')
871
        self.daemon.set_scan_host_progress(scan_id)
872
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
873
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
874
875
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
876
877
        float_progress = self.daemon.scan_collection.calculate_target_progress(
878
            scan_id
879
        )
880
        self.assertEqual(int(float_progress), 33)
881
882
        self.daemon.scan_collection.set_progress(scan_id, float_progress)
883
        progress = self.daemon.get_scan_progress(scan_id)
884
        self.assertEqual(progress, 33)
885
886
    def test_get_scan_without_scanid(self):
887
888
        fs = FakeStream()
889
        self.daemon.handle_command(
890
            '<start_scan parallel="2">'
891
            '<scanner_params />'
892
            '<targets><target>'
893
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
894
            '<ports>22</ports>'
895
            '</target></targets>'
896
            '</start_scan>',
897
            fs,
898
        )
899
        self.daemon.start_queued_scans()
900
901
        fs = FakeStream()
902
        self.assertRaises(
903
            OspdCommandError,
904
            self.daemon.handle_command,
905
            '<get_scans details="0" progress="1"/>',
906
            fs,
907
        )
908
909
    def test_get_scan_progress_xml(self):
910
911
        fs = FakeStream()
912
        self.daemon.handle_command(
913
            '<start_scan parallel="2">'
914
            '<scanner_params />'
915
            '<targets><target>'
916
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
917
            '<ports>22</ports>'
918
            '</target></targets>'
919
            '</start_scan>',
920
            fs,
921
        )
922
        self.daemon.start_queued_scans()
923
924
        response = fs.get_response()
925
        scan_id = response.findtext('id')
926
927
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
928
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
929
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
930
931
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
932
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
933
934
        fs = FakeStream()
935
        self.daemon.handle_command(
936
            '<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id, fs,
937
        )
938
        response = fs.get_response()
939
940
        progress = response.find('scan/progress')
941
942
        overall = float(progress.findtext('overall'))
943
        self.assertEqual(int(overall), 66)
944
945
        count_alive = progress.findtext('count_alive')
946
        self.assertEqual(count_alive, '1')
947
948
        count_dead = progress.findtext('count_dead')
949
        self.assertEqual(count_dead, '1')
950
951
        current_hosts = progress.findall('host')
952
        self.assertEqual(len(current_hosts), 2)
953
954
        count_excluded = progress.findtext('count_excluded')
955
        self.assertEqual(count_excluded, '0')
956
957
    def test_set_get_vts_version(self):
958
        self.daemon.set_vts_version('1234')
959
960
        version = self.daemon.get_vts_version()
961
        self.assertEqual('1234', version)
962
963
    def test_set_get_vts_version_error(self):
964
        self.assertRaises(TypeError, self.daemon.set_vts_version)
965
966
    @patch("ospd.ospd.os")
967
    @patch("ospd.ospd.create_process")
968
    def test_scan_exists(self, mock_create_process, _mock_os):
969
        fp = FakeStartProcess()
970
        mock_create_process.side_effect = fp
971
        mock_process = fp.call_mock
972
        mock_process.start.side_effect = fp.run
973
        mock_process.is_alive.return_value = True
974
        mock_process.pid = "main-scan-process"
975
976
        fs = FakeStream()
977
        self.daemon.handle_command(
978
            '<start_scan>'
979
            '<scanner_params />'
980
            '<targets><target>'
981
            '<hosts>localhost</hosts>'
982
            '<ports>22</ports>'
983
            '</target></targets>'
984
            '</start_scan>',
985
            fs,
986
        )
987
        response = fs.get_response()
988
        scan_id = response.findtext('id')
989
        self.assertIsNotNone(scan_id)
990
991
        status = response.get('status_text')
992
        self.assertEqual(status, 'OK')
993
994
        self.daemon.start_queued_scans()
995
996
        assert_called(mock_create_process)
997
        assert_called(mock_process.start)
998
999
        self.daemon.handle_command('<stop_scan scan_id="%s" />' % scan_id, fs)
1000
1001
        fs = FakeStream()
1002
        cmd = (
1003
            '<start_scan scan_id="' + scan_id + '">'
1004
            '<scanner_params />'
1005
            '<targets><target>'
1006
            '<hosts>localhost</hosts>'
1007
            '<ports>22</ports>'
1008
            '</target></targets>'
1009
            '</start_scan>'
1010
        )
1011
1012
        self.daemon.handle_command(
1013
            cmd, fs,
1014
        )
1015
        self.daemon.start_queued_scans()
1016
1017
        response = fs.get_response()
1018
        status = response.get('status_text')
1019
        self.assertEqual(status, 'Continue')
1020
1021
    def test_result_order(self):
1022
1023
        fs = FakeStream()
1024
        self.daemon.handle_command(
1025
            '<start_scan parallel="1">'
1026
            '<scanner_params />'
1027
            '<targets><target>'
1028
            '<hosts>a</hosts>'
1029
            '<ports>22</ports>'
1030
            '</target></targets>'
1031
            '</start_scan>',
1032
            fs,
1033
        )
1034
        self.daemon.start_queued_scans()
1035
        response = fs.get_response()
1036
1037
        scan_id = response.findtext('id')
1038
1039
        self.daemon.add_scan_log(scan_id, host='a', name='a')
1040
        self.daemon.add_scan_log(scan_id, host='c', name='c')
1041
        self.daemon.add_scan_log(scan_id, host='b', name='b')
1042
        hosts = ['a', 'c', 'b']
1043
1044
        fs = FakeStream()
1045
        self.daemon.handle_command(
1046
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1047
        )
1048
        response = fs.get_response()
1049
1050
        results = response.findall("scan/results/")
1051
1052
        for idx, res in enumerate(results):
1053
            att_dict = res.attrib
1054
            self.assertEqual(hosts[idx], att_dict['name'])
1055
1056
    def test_batch_result(self):
1057
        reslist = ResultList()
1058
        fs = FakeStream()
1059
        self.daemon.handle_command(
1060
            '<start_scan parallel="1">'
1061
            '<scanner_params />'
1062
            '<targets><target>'
1063
            '<hosts>a</hosts>'
1064
            '<ports>22</ports>'
1065
            '</target></targets>'
1066
            '</start_scan>',
1067
            fs,
1068
        )
1069
        self.daemon.start_queued_scans()
1070
        response = fs.get_response()
1071
1072
        scan_id = response.findtext('id')
1073
        reslist.add_scan_log_to_list(host='a', name='a')
1074
        reslist.add_scan_log_to_list(host='c', name='c')
1075
        reslist.add_scan_log_to_list(host='b', name='b')
1076
        self.daemon.scan_collection.add_result_list(scan_id, reslist)
1077
1078
        hosts = ['a', 'c', 'b']
1079
1080
        fs = FakeStream()
1081
        self.daemon.handle_command(
1082
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1083
        )
1084
        response = fs.get_response()
1085
1086
        results = response.findall("scan/results/")
1087
1088
        for idx, res in enumerate(results):
1089
            att_dict = res.attrib
1090
            self.assertEqual(hosts[idx], att_dict['name'])
1091
1092
    def test_is_new_scan_allowed_false(self):
1093
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1094
            'a': 1,
1095
            'b': 2,
1096
        }
1097
        self.daemon.max_scans = 1
1098
1099
        self.assertFalse(self.daemon.is_new_scan_allowed())
1100
1101
    def test_is_new_scan_allowed_true(self):
1102
        self.daemon.scan_processes = {  # pylint: disable=protected-access
1103
            'a': 1,
1104
            'b': 2,
1105
        }
1106
        self.daemon.max_scans = 3
1107
1108
        self.assertTrue(self.daemon.is_new_scan_allowed())
1109
1110
    @patch("ospd.ospd.psutil")
1111
    def test_free_memory_true(self, mock_psutil):
1112
        self.daemon.min_free_mem_scan_queue = 1000
1113
        # 1.5 GB free
1114
        mock_psutil.virtual_memory.return_value = FakePsutil(free=1500000000)
1115
1116
        self.assertTrue(self.daemon.is_enough_free_memory())
1117
1118
    @patch("ospd.ospd.psutil")
1119
    def test_free_memory_false(self, mock_psutil):
1120
        self.daemon.min_free_mem_scan_queue = 2000
1121
        # 1.5 GB free
1122
        mock_psutil.virtual_memory.return_value = FakePsutil(free=1500000000)
1123
1124
        self.assertFalse(self.daemon.is_enough_free_memory())
1125
1126
    def test_count_queued_scans(self):
1127
        fs = FakeStream()
1128
        self.daemon.handle_command(
1129
            '<start_scan>'
1130
            '<scanner_params /><vts><vt id="1.2.3.4" />'
1131
            '</vts>'
1132
            '<targets><target>'
1133
            '<hosts>localhosts,192.168.0.0/24</hosts>'
1134
            '<ports>80,443</ports>'
1135
            '</target></targets>'
1136
            '</start_scan>',
1137
            fs,
1138
        )
1139
1140
        self.assertEqual(self.daemon.get_count_queued_scans(), 1)
1141
        self.daemon.start_queued_scans()
1142
        self.assertEqual(self.daemon.get_count_queued_scans(), 0)
1143