Completed
Push — master ( 745830...54fe4d )
by Juan José
17s queued 13s
created

tests.test_scan_and_result.ScanTestCase.setUp()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" Test module for scan runs
21
"""
22
23
import time
24
import unittest
25
26
from unittest.mock import patch, MagicMock
27
28
import xml.etree.ElementTree as ET
29
30
from defusedxml.common import EntitiesForbidden
31
32
from ospd.resultlist import ResultList
33
from ospd.errors import OspdCommandError
34
35
from .helper import DummyWrapper, assert_called, FakeStream, FakeDataManager
36
37
38
class FakeStartProcess:
39
    def __init__(self):
40
        self.run_mock = MagicMock()
41
        self.call_mock = MagicMock()
42
43
        self.func = None
44
        self.args = None
45
        self.kwargs = None
46
47
    def __call__(self, func, *, args=None, kwargs=None):
48
        self.func = func
49
        self.args = args or []
50
        self.kwargs = kwargs or {}
51
        return self.call_mock
52
53
    def run(self):
54
        self.func(*self.args, **self.kwargs)
55
        return self.run_mock
56
57
    def __repr__(self):
58
        return "<FakeProcess func={} args={} kwargs={}>".format(
59
            self.func, self.args, self.kwargs
60
        )
61
62
63
class Result(object):
64
    def __init__(self, type_, **kwargs):
65
        self.result_type = type_
66
        self.host = ''
67
        self.hostname = ''
68
        self.name = ''
69
        self.value = ''
70
        self.port = ''
71
        self.test_id = ''
72
        self.severity = ''
73
        self.qod = ''
74
        for name, value in kwargs.items():
75
            setattr(self, name, value)
76
77
78
class ScanTestCase(unittest.TestCase):
79
    def setUp(self):
80
        self.daemon = DummyWrapper([])
81
        self.daemon.scan_collection.datamanager = FakeDataManager()
82
83
    def test_get_default_scanner_params(self):
84
        fs = FakeStream()
85
86
        self.daemon.handle_command('<get_scanner_details />', fs)
87
        response = fs.get_response()
88
89
        # The status of the response must be success (i.e. 200)
90
        self.assertEqual(response.get('status'), '200')
91
        # The response root element must have the correct name
92
        self.assertEqual(response.tag, 'get_scanner_details_response')
93
        # The response must contain a 'scanner_params' element
94
        self.assertIsNotNone(response.find('scanner_params'))
95
96
    def test_get_default_help(self):
97
        fs = FakeStream()
98
99
        self.daemon.handle_command('<help />', fs)
100
        response = fs.get_response()
101
        self.assertEqual(response.get('status'), '200')
102
103
        fs = FakeStream()
104
        self.daemon.handle_command('<help format="xml" />', fs)
105
        response = fs.get_response()
106
107
        self.assertEqual(response.get('status'), '200')
108
        self.assertEqual(response.tag, 'help_response')
109
110
    def test_get_default_scanner_version(self):
111
        fs = FakeStream()
112
        self.daemon.handle_command('<get_version />', fs)
113
        response = fs.get_response()
114
115
        self.assertEqual(response.get('status'), '200')
116
        self.assertIsNotNone(response.find('protocol'))
117
118
    def test_get_vts_no_vt(self):
119
        fs = FakeStream()
120
121
        self.daemon.handle_command('<get_vts />', fs)
122
        response = fs.get_response()
123
124
        self.assertEqual(response.get('status'), '200')
125
        self.assertIsNotNone(response.find('vts'))
126
127
    def test_get_vt_xml_no_dict(self):
128
        single_vt = ('1234', None)
129
        vt = self.daemon.get_vt_xml(single_vt)
130
        self.assertFalse(vt.get('id'))
131
132
    def test_get_vts_single_vt(self):
133
        fs = FakeStream()
134
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
135
        self.daemon.handle_command('<get_vts />', fs)
136
        response = fs.get_response()
137
138
        self.assertEqual(response.get('status'), '200')
139
140
        vts = response.find('vts')
141
        self.assertIsNotNone(vts.find('vt'))
142
143
        vt = vts.find('vt')
144
        self.assertEqual(vt.get('id'), '1.2.3.4')
145
146
    def test_get_vts_still_not_init(self):
147
        fs = FakeStream()
148
        self.daemon.initialized = False
149
        self.daemon.handle_command('<get_vts />', fs)
150
        response = fs.get_response()
151
152
        self.assertEqual(response.get('status'), '400')
153
154
    def test_get_help_still_not_init(self):
155
        fs = FakeStream()
156
        self.daemon.initialized = False
157
        self.daemon.handle_command('<help/>', fs)
158
        response = fs.get_response()
159
160
        self.assertEqual(response.get('status'), '200')
161
162 View Code Duplication
    def test_get_vts_filter_positive(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
163
        self.daemon.add_vt(
164
            '1.2.3.4',
165
            'A vulnerability test',
166
            vt_params="a",
167
            vt_modification_time='19000202',
168
        )
169
        fs = FakeStream()
170
171
        self.daemon.handle_command(
172
            '<get_vts filter="modification_time&gt;19000201"></get_vts>', fs
173
        )
174
        response = fs.get_response()
175
176
        self.assertEqual(response.get('status'), '200')
177
        vts = response.find('vts')
178
179
        vt = vts.find('vt')
180
        self.assertIsNotNone(vt)
181
        self.assertEqual(vt.get('id'), '1.2.3.4')
182
183
        modification_time = response.findall('vts/vt/modification_time')
184
        self.assertEqual(
185
            '<modification_time>19000202</modification_time>',
186
            ET.tostring(modification_time[0]).decode('utf-8'),
187
        )
188
189 View Code Duplication
    def test_get_vts_filter_negative(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
190
        self.daemon.add_vt(
191
            '1.2.3.4',
192
            'A vulnerability test',
193
            vt_params="a",
194
            vt_modification_time='19000202',
195
        )
196
        fs = FakeStream()
197
        self.daemon.handle_command(
198
            '<get_vts filter="modification_time&lt;19000203"></get_vts>', fs,
199
        )
200
        response = fs.get_response()
201
202
        self.assertEqual(response.get('status'), '200')
203
204
        vts = response.find('vts')
205
206
        vt = vts.find('vt')
207
        self.assertIsNotNone(vt)
208
        self.assertEqual(vt.get('id'), '1.2.3.4')
209
210
        modification_time = response.findall('vts/vt/modification_time')
211
        self.assertEqual(
212
            '<modification_time>19000202</modification_time>',
213
            ET.tostring(modification_time[0]).decode('utf-8'),
214
        )
215
216
    def test_get_vts_bad_filter(self):
217
        fs = FakeStream()
218
        cmd = '<get_vts filter="modification_time"/>'
219
220
        self.assertRaises(OspdCommandError, self.daemon.handle_command, cmd, fs)
221
        self.assertTrue(self.daemon.vts.is_cache_available)
222
223
    def test_get_vtss_multiple_vts(self):
224
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test')
225
        self.daemon.add_vt('1.2.3.5', 'Another vulnerability test')
226
        self.daemon.add_vt('123456789', 'Yet another vulnerability test')
227
228
        fs = FakeStream()
229
230
        self.daemon.handle_command('<get_vts />', fs)
231
        response = fs.get_response()
232
        self.assertEqual(response.get('status'), '200')
233
234
        vts = response.find('vts')
235
        self.assertIsNotNone(vts.find('vt'))
236
237
    def test_get_vts_multiple_vts_with_custom(self):
238
        self.daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b')
239
        self.daemon.add_vt(
240
            '4.3.2.1', 'Another vulnerability test with custom info', custom='b'
241
        )
242
        self.daemon.add_vt(
243
            '123456789', 'Yet another vulnerability test', custom='b'
244
        )
245
        fs = FakeStream()
246
247
        self.daemon.handle_command('<get_vts />', fs)
248
        response = fs.get_response()
249
250
        custom = response.findall('vts/vt/custom')
251
252
        self.assertEqual(3, len(custom))
253
254 View Code Duplication
    def test_get_vts_vts_with_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
255
        self.daemon.add_vt(
256
            '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b"
257
        )
258
        fs = FakeStream()
259
260
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
261
        response = fs.get_response()
262
263
        # The status of the response must be success (i.e. 200)
264
        self.assertEqual(response.get('status'), '200')
265
266
        # The response root element must have the correct name
267
        self.assertEqual(response.tag, 'get_vts_response')
268
        # The response must contain a 'scanner_params' element
269
        self.assertIsNotNone(response.find('vts'))
270
271
        vt_params = response[0][0].findall('params')
272
        self.assertEqual(1, len(vt_params))
273
274
        custom = response[0][0].findall('custom')
275
        self.assertEqual(1, len(custom))
276
277
        params = response.findall('vts/vt/params/param')
278
        self.assertEqual(2, len(params))
279
280 View Code Duplication
    def test_get_vts_vts_with_refs(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
281
        self.daemon.add_vt(
282
            '1.2.3.4',
283
            'A vulnerability test',
284
            vt_params="a",
285
            custom="b",
286
            vt_refs="c",
287
        )
288
        fs = FakeStream()
289
290
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
291
        response = fs.get_response()
292
293
        # The status of the response must be success (i.e. 200)
294
        self.assertEqual(response.get('status'), '200')
295
296
        # The response root element must have the correct name
297
        self.assertEqual(response.tag, 'get_vts_response')
298
299
        # The response must contain a 'vts' element
300
        self.assertIsNotNone(response.find('vts'))
301
302
        vt_params = response[0][0].findall('params')
303
        self.assertEqual(1, len(vt_params))
304
305
        custom = response[0][0].findall('custom')
306
        self.assertEqual(1, len(custom))
307
308
        refs = response.findall('vts/vt/refs/ref')
309
        self.assertEqual(2, len(refs))
310
311
    def test_get_vts_vts_with_dependencies(self):
312
        self.daemon.add_vt(
313
            '1.2.3.4',
314
            'A vulnerability test',
315
            vt_params="a",
316
            custom="b",
317
            vt_dependencies="c",
318
        )
319
        fs = FakeStream()
320
321
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
322
323
        response = fs.get_response()
324
325
        deps = response.findall('vts/vt/dependencies/dependency')
326
        self.assertEqual(2, len(deps))
327
328
    def test_get_vts_vts_with_severities(self):
329
        self.daemon.add_vt(
330
            '1.2.3.4',
331
            'A vulnerability test',
332
            vt_params="a",
333
            custom="b",
334
            severities="c",
335
        )
336
        fs = FakeStream()
337
338
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
339
        response = fs.get_response()
340
341
        severity = response.findall('vts/vt/severities/severity')
342
        self.assertEqual(1, len(severity))
343
344
    def test_get_vts_vts_with_detection_qodt(self):
345
        self.daemon.add_vt(
346
            '1.2.3.4',
347
            'A vulnerability test',
348
            vt_params="a",
349
            custom="b",
350
            detection="c",
351
            qod_t="d",
352
        )
353
        fs = FakeStream()
354
355
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
356
        response = fs.get_response()
357
358
        detection = response.findall('vts/vt/detection')
359
        self.assertEqual(1, len(detection))
360
361
    def test_get_vts_vts_with_detection_qodv(self):
362
        self.daemon.add_vt(
363
            '1.2.3.4',
364
            'A vulnerability test',
365
            vt_params="a",
366
            custom="b",
367
            detection="c",
368
            qod_v="d",
369
        )
370
        fs = FakeStream()
371
372
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
373
        response = fs.get_response()
374
375
        detection = response.findall('vts/vt/detection')
376
        self.assertEqual(1, len(detection))
377
378
    def test_get_vts_vts_with_summary(self):
379
        self.daemon.add_vt(
380
            '1.2.3.4',
381
            'A vulnerability test',
382
            vt_params="a",
383
            custom="b",
384
            summary="c",
385
        )
386
        fs = FakeStream()
387
388
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
389
        response = fs.get_response()
390
391
        summary = response.findall('vts/vt/summary')
392
        self.assertEqual(1, len(summary))
393
394
    def test_get_vts_vts_with_impact(self):
395
        self.daemon.add_vt(
396
            '1.2.3.4',
397
            'A vulnerability test',
398
            vt_params="a",
399
            custom="b",
400
            impact="c",
401
        )
402
        fs = FakeStream()
403
404
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
405
        response = fs.get_response()
406
407
        impact = response.findall('vts/vt/impact')
408
        self.assertEqual(1, len(impact))
409
410
    def test_get_vts_vts_with_affected(self):
411
        self.daemon.add_vt(
412
            '1.2.3.4',
413
            'A vulnerability test',
414
            vt_params="a",
415
            custom="b",
416
            affected="c",
417
        )
418
        fs = FakeStream()
419
420
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
421
        response = fs.get_response()
422
423
        affect = response.findall('vts/vt/affected')
424
        self.assertEqual(1, len(affect))
425
426
    def test_get_vts_vts_with_insight(self):
427
        self.daemon.add_vt(
428
            '1.2.3.4',
429
            'A vulnerability test',
430
            vt_params="a",
431
            custom="b",
432
            insight="c",
433
        )
434
        fs = FakeStream()
435
436
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
437
        response = fs.get_response()
438
439
        insight = response.findall('vts/vt/insight')
440
        self.assertEqual(1, len(insight))
441
442
    def test_get_vts_vts_with_solution(self):
443
        self.daemon.add_vt(
444
            '1.2.3.4',
445
            'A vulnerability test',
446
            vt_params="a",
447
            custom="b",
448
            solution="c",
449
            solution_t="d",
450
            solution_m="e",
451
        )
452
        fs = FakeStream()
453
454
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
455
        response = fs.get_response()
456
457
        solution = response.findall('vts/vt/solution')
458
        self.assertEqual(1, len(solution))
459
460
    def test_get_vts_vts_with_ctime(self):
461
        self.daemon.add_vt(
462
            '1.2.3.4',
463
            'A vulnerability test',
464
            vt_params="a",
465
            vt_creation_time='01-01-1900',
466
        )
467
        fs = FakeStream()
468
469
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
470
        response = fs.get_response()
471
472
        creation_time = response.findall('vts/vt/creation_time')
473
        self.assertEqual(
474
            '<creation_time>01-01-1900</creation_time>',
475
            ET.tostring(creation_time[0]).decode('utf-8'),
476
        )
477
478
    def test_get_vts_vts_with_mtime(self):
479
        self.daemon.add_vt(
480
            '1.2.3.4',
481
            'A vulnerability test',
482
            vt_params="a",
483
            vt_modification_time='02-01-1900',
484
        )
485
        fs = FakeStream()
486
487
        self.daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>', fs)
488
        response = fs.get_response()
489
490
        modification_time = response.findall('vts/vt/modification_time')
491
        self.assertEqual(
492
            '<modification_time>02-01-1900</modification_time>',
493
            ET.tostring(modification_time[0]).decode('utf-8'),
494
        )
495
496
    def test_clean_forgotten_scans(self):
497
        fs = FakeStream()
498
499
        self.daemon.handle_command(
500
            '<start_scan target="localhost" ports="80, '
501
            '443"><scanner_params /></start_scan>',
502
            fs,
503
        )
504
        response = fs.get_response()
505
506
        scan_id = response.findtext('id')
507
508
        finished = False
509
510
        self.daemon.start_pending_scans()
511
        while not finished:
512
            fs = FakeStream()
513
            self.daemon.handle_command(
514
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
515
            )
516
            response = fs.get_response()
517
518
            scans = response.findall('scan')
519
            self.assertEqual(1, len(scans))
520
521
            scan = scans[0]
522
523
            if scan.get('end_time') != '0':
524
                finished = True
525
            else:
526
                time.sleep(0.01)
527
528
            fs = FakeStream()
529
            self.daemon.handle_command(
530
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
531
            )
532
            response = fs.get_response()
533
534
        self.assertEqual(
535
            len(list(self.daemon.scan_collection.ids_iterator())), 1
536
        )
537
538
        # Set an old end_time
539
        self.daemon.scan_collection.scans_table[scan_id]['end_time'] = 123456
540
        # Run the check
541
        self.daemon.clean_forgotten_scans()
542
        # Not removed
543
        self.assertEqual(
544
            len(list(self.daemon.scan_collection.ids_iterator())), 1
545
        )
546
547
        # Set the max time and run again
548
        self.daemon.scaninfo_store_time = 1
549
        self.daemon.clean_forgotten_scans()
550
        # Now is removed
551
        self.assertEqual(
552
            len(list(self.daemon.scan_collection.ids_iterator())), 0
553
        )
554
555
    def test_scan_with_error(self):
556
        fs = FakeStream()
557
558
        self.daemon.handle_command(
559
            '<start_scan target="localhost" ports="80, '
560
            '443"><scanner_params /></start_scan>',
561
            fs,
562
        )
563
564
        response = fs.get_response()
565
        scan_id = response.findtext('id')
566
        finished = False
567
        self.daemon.start_pending_scans()
568
        self.daemon.add_scan_error(
569
            scan_id, host='a', value='something went wrong'
570
        )
571
572
        while not finished:
573
            fs = FakeStream()
574
            self.daemon.handle_command(
575
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
576
            )
577
            response = fs.get_response()
578
579
            scans = response.findall('scan')
580
            self.assertEqual(1, len(scans))
581
582
            scan = scans[0]
583
            status = scan.get('status')
584
585
            if status == "init" or status == "running":
586
                self.assertEqual('0', scan.get('end_time'))
587
                time.sleep(0.010)
588
            else:
589
                finished = True
590
591
            fs = FakeStream()
592
593
            self.daemon.handle_command(
594
                '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
595
            )
596
            response = fs.get_response()
597
598
        self.assertEqual(
599
            response.findtext('scan/results/result'), 'something went wrong'
600
        )
601
        fs = FakeStream()
602
        self.daemon.handle_command('<delete_scan scan_id="%s" />' % scan_id, fs)
603
        response = fs.get_response()
604
605
        self.assertEqual(response.get('status'), '200')
606
607
    def test_get_scan_pop(self):
608
        fs = FakeStream()
609
610
        self.daemon.handle_command(
611
            '<start_scan target="localhost" ports="80, 443">'
612
            '<scanner_params /></start_scan>',
613
            fs,
614
        )
615
        response = fs.get_response()
616
617
        scan_id = response.findtext('id')
618
        self.daemon.add_scan_host_detail(
619
            scan_id, host='a', value='Some Host Detail'
620
        )
621
622
        time.sleep(1)
623
624
        fs = FakeStream()
625
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
626
        response = fs.get_response()
627
628
        self.assertEqual(
629
            response.findtext('scan/results/result'), 'Some Host Detail'
630
        )
631
        fs = FakeStream()
632
        self.daemon.handle_command(
633
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
634
        )
635
        response = fs.get_response()
636
637
        self.assertEqual(
638
            response.findtext('scan/results/result'), 'Some Host Detail'
639
        )
640
641
        fs = FakeStream()
642
        self.daemon.handle_command(
643
            '<get_scans scan_id="%s" details="0" pop_results="1"/>' % scan_id,
644
            fs,
645
        )
646
        response = fs.get_response()
647
648
        self.assertEqual(response.findtext('scan/results/result'), None)
649
650
    def test_get_scan_pop_max_res(self):
651
        fs = FakeStream()
652
        self.daemon.handle_command(
653
            '<start_scan target="localhost" ports="80, 443">'
654
            '<scanner_params /></start_scan>',
655
            fs,
656
        )
657
        response = fs.get_response()
658
        scan_id = response.findtext('id')
659
660
        self.daemon.add_scan_log(scan_id, host='a', name='a')
661
        self.daemon.add_scan_log(scan_id, host='c', name='c')
662
        self.daemon.add_scan_log(scan_id, host='b', name='b')
663
664
        fs = FakeStream()
665
        self.daemon.handle_command(
666
            '<get_scans scan_id="%s" pop_results="1" max_results="1"/>'
667
            % scan_id,
668
            fs,
669
        )
670
671
        response = fs.get_response()
672
673
        self.assertEqual(len(response.findall('scan/results/result')), 1)
674
675
        fs = FakeStream()
676
        self.daemon.handle_command(
677
            '<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs
678
        )
679
        response = fs.get_response()
680
        self.assertEqual(len(response.findall('scan/results/result')), 2)
681
682
    def test_billon_laughs(self):
683
        # pylint: disable=line-too-long
684
685
        lol = (
686
            '<?xml version="1.0"?>'
687
            '<!DOCTYPE lolz ['
688
            ' <!ENTITY lol "lol">'
689
            ' <!ELEMENT lolz (#PCDATA)>'
690
            ' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">'
691
            ' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">'
692
            ' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">'
693
            ' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">'
694
            ' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">'
695
            ' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">'
696
            ' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">'
697
            ' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">'
698
            ' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">'
699
            ']>'
700
        )
701
        fs = FakeStream()
702
        self.assertRaises(
703
            EntitiesForbidden, self.daemon.handle_command, lol, fs
704
        )
705
706
    def test_target_with_credentials(self):
707
        fs = FakeStream()
708
        self.daemon.handle_command(
709
            '<start_scan>'
710
            '<scanner_params /><vts><vt id="1.2.3.4" />'
711
            '</vts>'
712
            '<targets><target>'
713
            '<hosts>192.168.0.0/24</hosts><ports>22'
714
            '</ports><credentials>'
715
            '<credential type="up" service="ssh" port="22">'
716
            '<username>scanuser</username>'
717
            '<password>mypass</password>'
718
            '</credential><credential type="up" service="smb">'
719
            '<username>smbuser</username>'
720
            '<password>mypass</password></credential>'
721
            '</credentials>'
722
            '</target></targets>'
723
            '</start_scan>',
724
            fs,
725
        )
726
        response = fs.get_response()
727
728
        self.assertEqual(response.get('status'), '200')
729
730
        cred_dict = {
731
            'ssh': {
732
                'type': 'up',
733
                'password': 'mypass',
734
                'port': '22',
735
                'username': 'scanuser',
736
            },
737
            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
738
        }
739
        scan_id = response.findtext('id')
740
        response = self.daemon.get_scan_credentials(scan_id)
741
        self.assertEqual(response, cred_dict)
742
743
    def test_scan_get_target(self):
744
        fs = FakeStream()
745
        self.daemon.handle_command(
746
            '<start_scan>'
747
            '<scanner_params /><vts><vt id="1.2.3.4" />'
748
            '</vts>'
749
            '<targets><target>'
750
            '<hosts>localhosts,192.168.0.0/24</hosts>'
751
            '<ports>80,443</ports>'
752
            '</target></targets>'
753
            '</start_scan>',
754
            fs,
755
        )
756
        response = fs.get_response()
757
        scan_id = response.findtext('id')
758
759
        fs = FakeStream()
760
        self.daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id, fs)
761
        response = fs.get_response()
762
763
        scan_res = response.find('scan')
764
        self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
765
766
    def test_scan_get_target_options(self):
767
        fs = FakeStream()
768
        self.daemon.handle_command(
769
            '<start_scan>'
770
            '<scanner_params /><vts><vt id="1.2.3.4" />'
771
            '</vts>'
772
            '<targets>'
773
            '<target><hosts>192.168.0.1</hosts>'
774
            '<ports>22</ports><alive_test>0</alive_test></target>'
775
            '</targets>'
776
            '</start_scan>',
777
            fs,
778
        )
779
        response = fs.get_response()
780
781
        scan_id = response.findtext('id')
782
        time.sleep(1)
783
        target_options = self.daemon.get_scan_target_options(scan_id)
784
        self.assertEqual(target_options, {'alive_test': '0'})
785
786
    def test_progress(self):
787
788
        fs = FakeStream()
789
        self.daemon.handle_command(
790
            '<start_scan parallel="2">'
791
            '<scanner_params />'
792
            '<targets><target>'
793
            '<hosts>localhost1, localhost2</hosts>'
794
            '<ports>22</ports>'
795
            '</target></targets>'
796
            '</start_scan>',
797
            fs,
798
        )
799
        response = fs.get_response()
800
801
        scan_id = response.findtext('id')
802
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
803
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
804
805
        self.assertEqual(
806
            self.daemon.scan_collection.calculate_target_progress(scan_id), 50
807
        )
808
809
    def test_sort_host_finished(self):
810
811
        fs = FakeStream()
812
        self.daemon.handle_command(
813
            '<start_scan parallel="2">'
814
            '<scanner_params />'
815
            '<targets><target>'
816
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
817
            '<ports>22</ports>'
818
            '</target></targets>'
819
            '</start_scan>',
820
            fs,
821
        )
822
        response = fs.get_response()
823
824
        scan_id = response.findtext('id')
825
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
826
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
827
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
828
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
829
830
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
831
832
        rounded_progress = self.daemon.scan_collection.calculate_target_progress(  # pylint: disable=line-too-long)
833
            scan_id
834
        )
835
        self.assertEqual(rounded_progress, 66)
836
837
    def test_calculate_progress_without_current_hosts(self):
838
839
        fs = FakeStream()
840
        self.daemon.handle_command(
841
            '<start_scan parallel="2">'
842
            '<scanner_params />'
843
            '<targets><target>'
844
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
845
            '<ports>22</ports>'
846
            '</target></targets>'
847
            '</start_scan>',
848
            fs,
849
        )
850
        response = fs.get_response()
851
852
        scan_id = response.findtext('id')
853
        self.daemon.set_scan_host_progress(scan_id)
854
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
855
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
856
857
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
858
859
        float_progress = self.daemon.scan_collection.calculate_target_progress(
860
            scan_id
861
        )
862
        self.assertEqual(int(float_progress), 33)
863
864
        self.daemon.scan_collection.set_progress(scan_id, float_progress)
865
        progress = self.daemon.get_scan_progress(scan_id)
866
        self.assertEqual(progress, 33)
867
868
    def test_get_scan_without_scanid(self):
869
870
        fs = FakeStream()
871
        self.daemon.handle_command(
872
            '<start_scan parallel="2">'
873
            '<scanner_params />'
874
            '<targets><target>'
875
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
876
            '<ports>22</ports>'
877
            '</target></targets>'
878
            '</start_scan>',
879
            fs,
880
        )
881
882
        fs = FakeStream()
883
        self.assertRaises(
884
            OspdCommandError,
885
            self.daemon.handle_command,
886
            '<get_scans details="0" progress="1"/>',
887
            fs,
888
        )
889
890
    def test_get_scan_progress_xml(self):
891
892
        fs = FakeStream()
893
        self.daemon.handle_command(
894
            '<start_scan parallel="2">'
895
            '<scanner_params />'
896
            '<targets><target>'
897
            '<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
898
            '<ports>22</ports>'
899
            '</target></targets>'
900
            '</start_scan>',
901
            fs,
902
        )
903
        response = fs.get_response()
904
        scan_id = response.findtext('id')
905
906
        self.daemon.set_scan_host_progress(scan_id, 'localhost3', -1)
907
        self.daemon.set_scan_host_progress(scan_id, 'localhost4', 100)
908
        self.daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4'])
909
910
        self.daemon.set_scan_host_progress(scan_id, 'localhost1', 75)
911
        self.daemon.set_scan_host_progress(scan_id, 'localhost2', 25)
912
913
        fs = FakeStream()
914
        self.daemon.handle_command(
915
            '<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id, fs,
916
        )
917
        response = fs.get_response()
918
919
        progress = response.find('scan/progress')
920
921
        overall = float(progress.findtext('overall'))
922
        self.assertEqual(int(overall), 66)
923
924
        count_alive = progress.findtext('count_alive')
925
        self.assertEqual(count_alive, '1')
926
927
        count_dead = progress.findtext('count_dead')
928
        self.assertEqual(count_dead, '1')
929
930
        current_hosts = progress.findall('host')
931
        self.assertEqual(len(current_hosts), 2)
932
933
        count_excluded = progress.findtext('count_excluded')
934
        self.assertEqual(count_excluded, '0')
935
936
    def test_set_get_vts_version(self):
937
        self.daemon.set_vts_version('1234')
938
939
        version = self.daemon.get_vts_version()
940
        self.assertEqual('1234', version)
941
942
    def test_set_get_vts_version_error(self):
943
        self.assertRaises(TypeError, self.daemon.set_vts_version)
944
945
    @patch("ospd.ospd.os")
946
    @patch("ospd.ospd.create_process")
947
    def test_scan_exists(self, mock_create_process, _mock_os):
948
        fp = FakeStartProcess()
949
        mock_create_process.side_effect = fp
950
        mock_process = fp.call_mock
951
        mock_process.start.side_effect = fp.run
952
        mock_process.is_alive.return_value = True
953
        mock_process.pid = "main-scan-process"
954
955
        fs = FakeStream()
956
        self.daemon.handle_command(
957
            '<start_scan>'
958
            '<scanner_params />'
959
            '<targets><target>'
960
            '<hosts>localhost</hosts>'
961
            '<ports>22</ports>'
962
            '</target></targets>'
963
            '</start_scan>',
964
            fs,
965
        )
966
        response = fs.get_response()
967
        scan_id = response.findtext('id')
968
        self.assertIsNotNone(scan_id)
969
970
        status = response.get('status_text')
971
        self.assertEqual(status, 'OK')
972
973
        self.daemon.start_pending_scans()
974
975
        assert_called(mock_create_process)
976
        assert_called(mock_process.start)
977
978
        self.daemon.handle_command('<stop_scan scan_id="%s" />' % scan_id, fs)
979
980
        fs = FakeStream()
981
        cmd = (
982
            '<start_scan scan_id="' + scan_id + '">'
983
            '<scanner_params />'
984
            '<targets><target>'
985
            '<hosts>localhost</hosts>'
986
            '<ports>22</ports>'
987
            '</target></targets>'
988
            '</start_scan>'
989
        )
990
991
        self.daemon.handle_command(
992
            cmd, fs,
993
        )
994
        response = fs.get_response()
995
        status = response.get('status_text')
996
        self.assertEqual(status, 'Continue')
997
998
    def test_result_order(self):
999
1000
        fs = FakeStream()
1001
        self.daemon.handle_command(
1002
            '<start_scan parallel="1">'
1003
            '<scanner_params />'
1004
            '<targets><target>'
1005
            '<hosts>a</hosts>'
1006
            '<ports>22</ports>'
1007
            '</target></targets>'
1008
            '</start_scan>',
1009
            fs,
1010
        )
1011
1012
        response = fs.get_response()
1013
1014
        scan_id = response.findtext('id')
1015
1016
        self.daemon.add_scan_log(scan_id, host='a', name='a')
1017
        self.daemon.add_scan_log(scan_id, host='c', name='c')
1018
        self.daemon.add_scan_log(scan_id, host='b', name='b')
1019
        hosts = ['a', 'c', 'b']
1020
1021
        fs = FakeStream()
1022
        self.daemon.handle_command(
1023
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1024
        )
1025
        response = fs.get_response()
1026
1027
        results = response.findall("scan/results/")
1028
1029
        for idx, res in enumerate(results):
1030
            att_dict = res.attrib
1031
            self.assertEqual(hosts[idx], att_dict['name'])
1032
1033
    def test_batch_result(self):
1034
        reslist = ResultList()
1035
        fs = FakeStream()
1036
        self.daemon.handle_command(
1037
            '<start_scan parallel="1">'
1038
            '<scanner_params />'
1039
            '<targets><target>'
1040
            '<hosts>a</hosts>'
1041
            '<ports>22</ports>'
1042
            '</target></targets>'
1043
            '</start_scan>',
1044
            fs,
1045
        )
1046
1047
        response = fs.get_response()
1048
1049
        scan_id = response.findtext('id')
1050
        reslist.add_scan_log_to_list(host='a', name='a')
1051
        reslist.add_scan_log_to_list(host='c', name='c')
1052
        reslist.add_scan_log_to_list(host='b', name='b')
1053
        self.daemon.scan_collection.add_result_list(scan_id, reslist)
1054
1055
        hosts = ['a', 'c', 'b']
1056
1057
        fs = FakeStream()
1058
        self.daemon.handle_command(
1059
            '<get_scans scan_id="%s" details="1"/>' % scan_id, fs
1060
        )
1061
        response = fs.get_response()
1062
1063
        results = response.findall("scan/results/")
1064
1065
        for idx, res in enumerate(results):
1066
            att_dict = res.attrib
1067
            self.assertEqual(hosts[idx], att_dict['name'])
1068