Passed
Pull Request — master (#131)
by Juan José
02:17
created

ScanTestCase.test_scan_multi_target()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 17
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2015-2018 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
# pylint: disable=too-many-lines
20
21
""" Test module for scan runs
22
"""
23
24
import time
25
import unittest
26
27
from unittest.mock import patch
28
29
import xml.etree.ElementTree as ET
30
import defusedxml.lxml as secET
31
32
from defusedxml.common import EntitiesForbidden
33
34
from ospd.ospd import OSPDaemon
35
from ospd.errors import OspdCommandError
36
37
38
class Result(object):
39
    def __init__(self, type_, **kwargs):
40
        self.result_type = type_
41
        self.host = ''
42
        self.hostname = ''
43
        self.name = ''
44
        self.value = ''
45
        self.port = ''
46
        self.test_id = ''
47
        self.severity = ''
48
        self.qod = ''
49
        for name, value in kwargs.items():
50
            setattr(self, name, value)
51
52
53
class DummyWrapper(OSPDaemon):
54
    def __init__(self, results, checkresult=True):
55
        super().__init__()
56
        self.checkresult = checkresult
57
        self.results = results
58
59
    def check(self):
60
        return self.checkresult
61
62
    @staticmethod
63
    def get_custom_vt_as_xml_str(vt_id, custom):
64
        return '<custom><mytest>static test</mytest></custom>'
65
66
    @staticmethod
67
    def get_params_vt_as_xml_str(vt_id, vt_params):
68
        return (
69
            '<params><param id="abc" type="string">'
70
            '<name>ABC</name><description>Test ABC</description>'
71
            '<default>yes</default></param>'
72
            '<param id="def" type="string">'
73
            '<name>DEF</name><description>Test DEF</description>'
74
            '<default>no</default></param></params>'
75
        )
76
77
    @staticmethod
78
    def get_refs_vt_as_xml_str(vt_id, vt_refs):
79
        response = (
80
            '<refs><ref type="cve" id="CVE-2010-4480"/>'
81
            '<ref type="url" id="http://example.com"/></refs>'
82
        )
83
        return response
84
85
    @staticmethod
86
    def get_dependencies_vt_as_xml_str(vt_id, vt_dependencies):
87
        response = (
88
            '<dependencies>'
89
            '<dependency vt_id="1.3.6.1.4.1.25623.1.0.50282" />'
90
            '<dependency vt_id="1.3.6.1.4.1.25623.1.0.50283" />'
91
            '</dependencies>'
92
        )
93
94
        return response
95
96
    @staticmethod
97
    def get_severities_vt_as_xml_str(vt_id, severities):
98
        response = (
99
            '<severities><severity cvss_base="5.0" cvss_'
100
            'type="cvss_base_v2">AV:N/AC:L/Au:N/C:N/I:N/'
101
            'A:P</severity></severities>'
102
        )
103
104
        return response
105
106
    @staticmethod
107
    def get_detection_vt_as_xml_str(
108
        vt_id, detection=None, qod_type=None, qod=None
109
    ):
110
        response = '<detection qod_type="package">some detection</detection>'
111
112
        return response
113
114
    @staticmethod
115
    def get_summary_vt_as_xml_str(vt_id, summary):
116
        response = '<summary>Some summary</summary>'
117
118
        return response
119
120
    @staticmethod
121
    def get_affected_vt_as_xml_str(vt_id, affected):
122
        response = '<affected>Some affected</affected>'
123
124
        return response
125
126
    @staticmethod
127
    def get_impact_vt_as_xml_str(vt_id, impact):
128
        response = '<impact>Some impact</impact>'
129
130
        return response
131
132
    @staticmethod
133
    def get_insight_vt_as_xml_str(vt_id, insight):
134
        response = '<insight>Some insight</insight>'
135
136
        return response
137
138
    @staticmethod
139
    def get_solution_vt_as_xml_str(vt_id, solution, solution_type=None):
140
        response = '<solution>Some solution</solution>'
141
142
        return response
143
144
    @staticmethod
145
    def get_creation_time_vt_as_xml_str(
146
        vt_id, creation_time
147
    ):  # pylint: disable=arguments-differ
148
        response = '<creation_time>%s</creation_time>' % creation_time
149
150
        return response
151
152
    @staticmethod
153
    def get_modification_time_vt_as_xml_str(
154
        vt_id, modification_time
155
    ):  # pylint: disable=arguments-differ
156
        response = (
157
            '<modification_time>%s</modification_time>' % modification_time
158
        )
159
160
        return response
161
162
    def exec_scan(self, scan_id, target):
163
        time.sleep(0.01)
164
        for res in self.results:
165
            if res.result_type == 'log':
166
                self.add_scan_log(
167
                    scan_id,
168
                    res.host or target,
169
                    res.hostname,
170
                    res.name,
171
                    res.value,
172
                    res.port,
173
                )
174
            if res.result_type == 'error':
175
                self.add_scan_error(
176
                    scan_id,
177
                    res.host or target,
178
                    res.hostname,
179
                    res.name,
180
                    res.value,
181
                    res.port,
182
                )
183
            elif res.result_type == 'host-detail':
184
                self.add_scan_host_detail(
185
                    scan_id,
186
                    res.host or target,
187
                    res.hostname,
188
                    res.name,
189
                    res.value,
190
                )
191
            elif res.result_type == 'alarm':
192
                self.add_scan_alarm(
193
                    scan_id,
194
                    res.host or target,
195
                    res.hostname,
196
                    res.name,
197
                    res.value,
198
                    res.port,
199
                    res.test_id,
200
                    res.severity,
201
                    res.qod,
202
                )
203
            else:
204
                raise ValueError(res.result_type)
205
206
207
class ScanTestCase(unittest.TestCase):
208
    def test_get_default_scanner_params(self):
209
        daemon = DummyWrapper([])
210
        response = secET.fromstring(
211
            daemon.handle_command('<get_scanner_details />')
212
        )
213
214
        # The status of the response must be success (i.e. 200)
215
        self.assertEqual(response.get('status'), '200')
216
        # The response root element must have the correct name
217
        self.assertEqual(response.tag, 'get_scanner_details_response')
218
        # The response must contain a 'scanner_params' element
219
        self.assertIsNotNone(response.find('scanner_params'))
220
221
    def test_get_default_help(self):
222
        daemon = DummyWrapper([])
223
        response = secET.fromstring(daemon.handle_command('<help />'))
224
225
        self.assertEqual(response.get('status'), '200')
226
227
        response = secET.fromstring(
228
            daemon.handle_command('<help format="xml" />')
229
        )
230
231
        self.assertEqual(response.get('status'), '200')
232
        self.assertEqual(response.tag, 'help_response')
233
234
    @patch('ospd.ospd.subprocess')
235
    def test_get_performance(self, mock_subproc):
236
        daemon = DummyWrapper([])
237
        mock_subproc.check_output.return_value("foo")
238
        response = secET.fromstring(
239
            daemon.handle_command(
240
                '<get_performance start="0" end="0" titles="mem"/>')
241
        )
242
243
        self.assertEqual(response.get('status'), '200')
244
        self.assertEqual(response.tag, 'get_performance_response')
245
246
    def test_get_performance_fail_int(self):
247
        daemon = DummyWrapper([])
248
        cmd = secET.fromstring(
249
            '<get_performance start="a" end="0" titles="mem"/>')
250
251
        self.assertRaises(
252
            OspdCommandError, daemon.handle_get_performance, cmd
253
        )
254
255
    def test_get_performance_fail_regex(self):
256
        daemon = DummyWrapper([])
257
        cmd = secET.fromstring(
258
            '<get_performance start="0" end="0" titles="mem|bar"/>')
259
260
        self.assertRaises(
261
            OspdCommandError, daemon.handle_get_performance, cmd
262
        )
263
264
    def test_get_performance_fail_cmd(self):
265
        daemon = DummyWrapper([])
266
        cmd = secET.fromstring(
267
            '<get_performance start="0" end="0" titles="mem1"/>'
268
        )
269
        self.assertRaises(
270
            OspdCommandError, daemon.handle_get_performance, cmd
271
        )
272
273
    def test_get_default_scanner_version(self):
274
        daemon = DummyWrapper([])
275
        response = secET.fromstring(daemon.handle_command('<get_version />'))
276
277
        self.assertEqual(response.get('status'), '200')
278
        self.assertIsNotNone(response.find('protocol'))
279
280
    def test_get_vts_no_vt(self):
281
        daemon = DummyWrapper([])
282
        response = secET.fromstring(daemon.handle_command('<get_vts />'))
283
284
        self.assertEqual(response.get('status'), '200')
285
        self.assertIsNotNone(response.find('vts'))
286
287
    def test_get_vts_single_vt(self):
288
        daemon = DummyWrapper([])
289
        daemon.add_vt('1.2.3.4', 'A vulnerability test')
290
        response = secET.fromstring(daemon.handle_command('<get_vts />'))
291
292
        self.assertEqual(response.get('status'), '200')
293
294
        vts = response.find('vts')
295
        self.assertIsNotNone(vts.find('vt'))
296
297
        vt = vts.find('vt')
298
        self.assertEqual(vt.get('id'), '1.2.3.4')
299
300 View Code Duplication
    def test_get_vts_filter_positive(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
301
        daemon = DummyWrapper([])
302
        daemon.add_vt(
303
            '1.2.3.4',
304
            'A vulnerability test',
305
            vt_params="a",
306
            vt_modification_time='19000202',
307
        )
308
309
        response = secET.fromstring(
310
            daemon.handle_command(
311
                '<get_vts filter="modification_time&gt;19000201"></get_vts>'
312
            )
313
        )
314
315
        self.assertEqual(response.get('status'), '200')
316
        vts = response.find('vts')
317
318
        vt = vts.find('vt')
319
        self.assertIsNotNone(vt)
320
        self.assertEqual(vt.get('id'), '1.2.3.4')
321
322
        modification_time = response.findall('vts/vt/modification_time')
323
        self.assertEqual(
324
            '<modification_time>19000202</modification_time>',
325
            ET.tostring(modification_time[0]).decode('utf-8'),
326
        )
327
328 View Code Duplication
    def test_get_vts_filter_negative(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
329
        daemon = DummyWrapper([])
330
        daemon.add_vt(
331
            '1.2.3.4',
332
            'A vulnerability test',
333
            vt_params="a",
334
            vt_modification_time='19000202',
335
        )
336
337
        response = secET.fromstring(
338
            daemon.handle_command(
339
                '<get_vts filter="modification_time&lt;19000203"></get_vts>'
340
            )
341
        )
342
        self.assertEqual(response.get('status'), '200')
343
344
        vts = response.find('vts')
345
346
        vt = vts.find('vt')
347
        self.assertIsNotNone(vt)
348
        self.assertEqual(vt.get('id'), '1.2.3.4')
349
350
        modification_time = response.findall('vts/vt/modification_time')
351
        self.assertEqual(
352
            '<modification_time>19000202</modification_time>',
353
            ET.tostring(modification_time[0]).decode('utf-8'),
354
        )
355
356
    def test_get_vtss_multiple_vts(self):
357
        daemon = DummyWrapper([])
358
        daemon.add_vt('1.2.3.4', 'A vulnerability test')
359
        daemon.add_vt('1.2.3.5', 'Another vulnerability test')
360
        daemon.add_vt('123456789', 'Yet another vulnerability test')
361
362
        response = secET.fromstring(daemon.handle_command('<get_vts />'))
363
364
        self.assertEqual(response.get('status'), '200')
365
366
        vts = response.find('vts')
367
        self.assertIsNotNone(vts.find('vt'))
368
369
    def test_get_vts_multiple_vts_with_custom(self):
370
        daemon = DummyWrapper([])
371
        daemon.add_vt('1.2.3.4', 'A vulnerability test', custom='b')
372
        daemon.add_vt(
373
            '4.3.2.1', 'Another vulnerability test with custom info', custom='b'
374
        )
375
        daemon.add_vt('123456789', 'Yet another vulnerability test', custom='b')
376
377
        response = secET.fromstring(daemon.handle_command('<get_vts />'))
378
        custom = response.findall('vts/vt/custom')
379
380
        self.assertEqual(3, len(custom))
381
382 View Code Duplication
    def test_get_vts_vts_with_params(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
383
        daemon = DummyWrapper([])
384
        daemon.add_vt(
385
            '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b"
386
        )
387
388
        response = secET.fromstring(
389
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
390
        )
391
        # The status of the response must be success (i.e. 200)
392
        self.assertEqual(response.get('status'), '200')
393
394
        # The response root element must have the correct name
395
        self.assertEqual(response.tag, 'get_vts_response')
396
        # The response must contain a 'scanner_params' element
397
        self.assertIsNotNone(response.find('vts'))
398
399
        vt_params = response[0][0].findall('params')
400
        self.assertEqual(1, len(vt_params))
401
402
        custom = response[0][0].findall('custom')
403
        self.assertEqual(1, len(custom))
404
405
        params = response.findall('vts/vt/params/param')
406
        self.assertEqual(2, len(params))
407
408 View Code Duplication
    def test_get_vts_vts_with_refs(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
409
        daemon = DummyWrapper([])
410
        daemon.add_vt(
411
            '1.2.3.4',
412
            'A vulnerability test',
413
            vt_params="a",
414
            custom="b",
415
            vt_refs="c",
416
        )
417
418
        response = secET.fromstring(
419
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
420
        )
421
        # The status of the response must be success (i.e. 200)
422
        self.assertEqual(response.get('status'), '200')
423
424
        # The response root element must have the correct name
425
        self.assertEqual(response.tag, 'get_vts_response')
426
427
        # The response must contain a 'vts' element
428
        self.assertIsNotNone(response.find('vts'))
429
430
        vt_params = response[0][0].findall('params')
431
        self.assertEqual(1, len(vt_params))
432
433
        custom = response[0][0].findall('custom')
434
        self.assertEqual(1, len(custom))
435
436
        refs = response.findall('vts/vt/refs/ref')
437
        self.assertEqual(2, len(refs))
438
439
    def test_get_vts_vts_with_dependencies(self):
440
        daemon = DummyWrapper([])
441
        daemon.add_vt(
442
            '1.2.3.4',
443
            'A vulnerability test',
444
            vt_params="a",
445
            custom="b",
446
            vt_dependencies="c",
447
        )
448
449
        response = secET.fromstring(
450
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
451
        )
452
453
        deps = response.findall('vts/vt/dependencies/dependency')
454
        self.assertEqual(2, len(deps))
455
456
    def test_get_vts_vts_with_severities(self):
457
        daemon = DummyWrapper([])
458
        daemon.add_vt(
459
            '1.2.3.4',
460
            'A vulnerability test',
461
            vt_params="a",
462
            custom="b",
463
            severities="c",
464
        )
465
466
        response = secET.fromstring(
467
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
468
        )
469
470
        severity = response.findall('vts/vt/severities/severity')
471
        self.assertEqual(1, len(severity))
472
473
    def test_get_vts_vts_with_detection_qodt(self):
474
        daemon = DummyWrapper([])
475
        daemon.add_vt(
476
            '1.2.3.4',
477
            'A vulnerability test',
478
            vt_params="a",
479
            custom="b",
480
            detection="c",
481
            qod_t="d",
482
        )
483
484
        response = secET.fromstring(
485
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
486
        )
487
488
        detection = response.findall('vts/vt/detection')
489
        self.assertEqual(1, len(detection))
490
491
    def test_get_vts_vts_with_detection_qodv(self):
492
        daemon = DummyWrapper([])
493
        daemon.add_vt(
494
            '1.2.3.4',
495
            'A vulnerability test',
496
            vt_params="a",
497
            custom="b",
498
            detection="c",
499
            qod_v="d",
500
        )
501
502
        response = secET.fromstring(
503
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
504
        )
505
506
        detection = response.findall('vts/vt/detection')
507
        self.assertEqual(1, len(detection))
508
509
    def test_get_vts_vts_with_summary(self):
510
        daemon = DummyWrapper([])
511
        daemon.add_vt(
512
            '1.2.3.4',
513
            'A vulnerability test',
514
            vt_params="a",
515
            custom="b",
516
            summary="c",
517
        )
518
519
        response = secET.fromstring(
520
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
521
        )
522
523
        summary = response.findall('vts/vt/summary')
524
        self.assertEqual(1, len(summary))
525
526
    def test_get_vts_vts_with_impact(self):
527
        daemon = DummyWrapper([])
528
        daemon.add_vt(
529
            '1.2.3.4',
530
            'A vulnerability test',
531
            vt_params="a",
532
            custom="b",
533
            impact="c",
534
        )
535
536
        response = secET.fromstring(
537
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
538
        )
539
540
        impact = response.findall('vts/vt/impact')
541
        self.assertEqual(1, len(impact))
542
543
    def test_get_vts_vts_with_affected(self):
544
        daemon = DummyWrapper([])
545
        daemon.add_vt(
546
            '1.2.3.4',
547
            'A vulnerability test',
548
            vt_params="a",
549
            custom="b",
550
            affected="c",
551
        )
552
553
        response = secET.fromstring(
554
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
555
        )
556
557
        affect = response.findall('vts/vt/affected')
558
        self.assertEqual(1, len(affect))
559
560
    def test_get_vts_vts_with_insight(self):
561
        daemon = DummyWrapper([])
562
        daemon.add_vt(
563
            '1.2.3.4',
564
            'A vulnerability test',
565
            vt_params="a",
566
            custom="b",
567
            insight="c",
568
        )
569
570
        response = secET.fromstring(
571
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
572
        )
573
574
        insight = response.findall('vts/vt/insight')
575
        self.assertEqual(1, len(insight))
576
577
    def test_get_vts_vts_with_solution(self):
578
        daemon = DummyWrapper([])
579
        daemon.add_vt(
580
            '1.2.3.4',
581
            'A vulnerability test',
582
            vt_params="a",
583
            custom="b",
584
            solution="c",
585
            solution_t="d",
586
        )
587
588
        response = secET.fromstring(
589
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
590
        )
591
592
        solution = response.findall('vts/vt/solution')
593
        self.assertEqual(1, len(solution))
594
595
    def test_get_vts_vts_with_ctime(self):
596
        daemon = DummyWrapper([])
597
        daemon.add_vt(
598
            '1.2.3.4',
599
            'A vulnerability test',
600
            vt_params="a",
601
            vt_creation_time='01-01-1900',
602
        )
603
604
        response = secET.fromstring(
605
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
606
        )
607
608
        creation_time = response.findall('vts/vt/creation_time')
609
        self.assertEqual(
610
            '<creation_time>01-01-1900</creation_time>',
611
            ET.tostring(creation_time[0]).decode('utf-8'),
612
        )
613
614
    def test_get_vts_vts_with_mtime(self):
615
        daemon = DummyWrapper([])
616
        daemon.add_vt(
617
            '1.2.3.4',
618
            'A vulnerability test',
619
            vt_params="a",
620
            vt_modification_time='02-01-1900',
621
        )
622
623
        response = secET.fromstring(
624
            daemon.handle_command('<get_vts vt_id="1.2.3.4"></get_vts>')
625
        )
626
627
        modification_time = response.findall('vts/vt/modification_time')
628
        self.assertEqual(
629
            '<modification_time>02-01-1900</modification_time>',
630
            ET.tostring(modification_time[0]).decode('utf-8'),
631
        )
632
633
    def test_scan_with_error(self):
634
        daemon = DummyWrapper([Result('error', value='something went wrong')])
635
636
        response = secET.fromstring(
637
            daemon.handle_command(
638
                '<start_scan target="localhost" ports="80, '
639
                '443"><scanner_params /></start_scan>'
640
            )
641
        )
642
        scan_id = response.findtext('id')
643
644
        finished = False
645
        while not finished:
646
            response = secET.fromstring(
647
                daemon.handle_command(
648
                    '<get_scans scan_id="%s" details="1"/>' % scan_id
649
                )
650
            )
651
            scans = response.findall('scan')
652
            self.assertEqual(1, len(scans))
653
654
            scan = scans[0]
655
            status = scan.get('status')
656
657
            if status == "init" or status == "running":
658
                self.assertEqual('0', scan.get('end_time'))
659
                time.sleep(0.010)
660
            else:
661
                finished = True
662
663
        response = secET.fromstring(
664
            daemon.handle_command(
665
                '<get_scans scan_id="%s" details="1"/>' % scan_id
666
            )
667
        )
668
669
        self.assertEqual(
670
            response.findtext('scan/results/result'), 'something went wrong'
671
        )
672
673
        response = secET.fromstring(
674
            daemon.handle_command('<delete_scan scan_id="%s" />' % scan_id)
675
        )
676
677
        self.assertEqual(response.get('status'), '200')
678
679
    def test_get_scan_pop(self):
680
        daemon = DummyWrapper([Result('host-detail', value='Some Host Detail')])
681
682
        response = secET.fromstring(
683
            daemon.handle_command(
684
                '<start_scan target="localhost" ports="80, 443">'
685
                '<scanner_params /></start_scan>'
686
            )
687
        )
688
689
        scan_id = response.findtext('id')
690
        time.sleep(1)
691
692
        response = secET.fromstring(
693
            daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id)
694
        )
695
        self.assertEqual(
696
            response.findtext('scan/results/result'), 'Some Host Detail'
697
        )
698
699
        response = secET.fromstring(
700
            daemon.handle_command(
701
                '<get_scans scan_id="%s" pop_results="1"/>' % scan_id
702
            )
703
        )
704
        self.assertEqual(
705
            response.findtext('scan/results/result'), 'Some Host Detail'
706
        )
707
708
        response = secET.fromstring(
709
            daemon.handle_command('<get_scans details="0" pop_results="1"/>')
710
        )
711
        self.assertEqual(response.findtext('scan/results/result'), None)
712
713
    def test_stop_scan(self):
714
        daemon = DummyWrapper([])
715
        response = secET.fromstring(
716
            daemon.handle_command(
717
                '<start_scan '
718
                'target="localhost" ports="80, 443">'
719
                '<scanner_params /></start_scan>'
720
            )
721
        )
722
        scan_id = response.findtext('id')
723
724
        # Depending on the sistem this test can end with a race condition
725
        # because the scanner is already stopped when the <stop_scan>
726
        # command is run.
727
        time.sleep(3)
728
729
        cmd = secET.fromstring('<stop_scan scan_id="%s" />' % scan_id)
730
        self.assertRaises(
731
            OspdCommandError, daemon.handle_stop_scan_command, cmd
732
        )
733
734
        cmd = secET.fromstring('<stop_scan />')
735
        self.assertRaises(
736
            OspdCommandError, daemon.handle_stop_scan_command, cmd
737
        )
738
739 View Code Duplication
    def test_scan_with_vts(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
740
        daemon = DummyWrapper([])
741
        cmd = secET.fromstring(
742
            '<start_scan '
743
            'target="localhost" ports="80, 443">'
744
            '<scanner_params /><vt_selection />'
745
            '</start_scan>'
746
        )
747
748
        with self.assertRaises(OspdCommandError):
749
            daemon.handle_start_scan_command(cmd)
750
751
        # With one vt, without params
752
        response = secET.fromstring(
753
            daemon.handle_command(
754
                '<start_scan '
755
                'target="localhost" ports="80, 443">'
756
                '<scanner_params /><vt_selection>'
757
                '<vt_single id="1.2.3.4" />'
758
                '</vt_selection></start_scan>'
759
            )
760
        )
761
        scan_id = response.findtext('id')
762
        time.sleep(0.01)
763
764
        self.assertEqual(
765
            daemon.get_scan_vts(scan_id), {'1.2.3.4': {}, 'vt_groups': []}
766
        )
767
        self.assertNotEqual(daemon.get_scan_vts(scan_id), {'1.2.3.6': {}})
768
769
        # With out vtS
770
        response = secET.fromstring(
771
            daemon.handle_command(
772
                '<start_scan '
773
                'target="localhost" ports="80, 443">'
774
                '<scanner_params /></start_scan>'
775
            )
776
        )
777
778
        scan_id = response.findtext('id')
779
        time.sleep(0.01)
780
        self.assertEqual(daemon.get_scan_vts(scan_id), {})
781
782 View Code Duplication
    def test_scan_with_vts_and_param(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
783
        daemon = DummyWrapper([])
784
785
        # Raise because no vt_param id attribute
786
        cmd = secET.fromstring(
787
            '<start_scan '
788
            'target="localhost" ports="80, 443">'
789
            '<scanner_params /><vt_selection><vt_si'
790
            'ngle id="1234"><vt_value>200</vt_value>'
791
            '</vt_single></vt_selection></start_scan>'
792
        )
793
794
        with self.assertRaises(OspdCommandError):
795
            daemon.handle_start_scan_command(cmd)
796
797
        # No error
798
        response = secET.fromstring(
799
            daemon.handle_command(
800
                '<start_scan '
801
                'target="localhost" ports="80, 443">'
802
                '<scanner_params /><vt_selection><vt'
803
                '_single id="1234"><vt_value id="ABC">200'
804
                '</vt_value></vt_single></vt_selection>'
805
                '</start_scan>'
806
            )
807
        )
808
        scan_id = response.findtext('id')
809
        time.sleep(0.01)
810
        self.assertEqual(
811
            daemon.get_scan_vts(scan_id),
812
            {'1234': {'ABC': '200'}, 'vt_groups': []},
813
        )
814
815
        # Raise because no vtgroup filter attribute
816
        cmd = secET.fromstring(
817
            '<start_scan '
818
            'target="localhost" ports="80, 443">'
819
            '<scanner_params /><vt_selection><vt_group/>'
820
            '</vt_selection></start_scan>'
821
        )
822
        self.assertRaises(
823
            OspdCommandError, daemon.handle_start_scan_command, cmd
824
        )
825
826
        # No error
827
        response = secET.fromstring(
828
            daemon.handle_command(
829
                '<start_scan '
830
                'target="localhost" ports="80, 443">'
831
                '<scanner_params /><vt_selection>'
832
                '<vt_group filter="a"/>'
833
                '</vt_selection></start_scan>'
834
            )
835
        )
836
        scan_id = response.findtext('id')
837
        time.sleep(0.01)
838
        self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']})
839
840
    def test_billon_laughs(self):
841
        # pylint: disable=line-too-long
842
        daemon = DummyWrapper([])
843
        lol = (
844
            '<?xml version="1.0"?>'
845
            '<!DOCTYPE lolz ['
846
            ' <!ENTITY lol "lol">'
847
            ' <!ELEMENT lolz (#PCDATA)>'
848
            ' <!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">'
849
            ' <!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">'
850
            ' <!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">'
851
            ' <!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">'
852
            ' <!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">'
853
            ' <!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">'
854
            ' <!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">'
855
            ' <!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">'
856
            ' <!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">'
857
            ']>'
858
        )
859
        self.assertRaises(EntitiesForbidden, daemon.handle_command, lol)
860
861
    def test_scan_multi_target(self):
862
        daemon = DummyWrapper([])
863
        response = secET.fromstring(
864
            daemon.handle_command(
865
                '<start_scan>'
866
                '<scanner_params /><vts><vt id="1.2.3.4" />'
867
                '</vts>'
868
                '<targets><target>'
869
                '<hosts>localhosts</hosts>'
870
                '<ports>80,443</ports>'
871
                '</target>'
872
                '<target><hosts>192.168.0.0/24</hosts>'
873
                '<ports>22</ports></target></targets>'
874
                '</start_scan>'
875
            )
876
        )
877
        self.assertEqual(response.get('status'), '200')
878
879
    def test_multi_target_with_credentials(self):
880
        daemon = DummyWrapper([])
881
        response = secET.fromstring(
882
            daemon.handle_command(
883
                '<start_scan>'
884
                '<scanner_params /><vts><vt id="1.2.3.4" />'
885
                '</vts>'
886
                '<targets><target><hosts>localhosts</hosts>'
887
                '<ports>80,443</ports></target><target>'
888
                '<hosts>192.168.0.0/24</hosts><ports>22'
889
                '</ports><credentials>'
890
                '<credential type="up" service="ssh" port="22">'
891
                '<username>scanuser</username>'
892
                '<password>mypass</password>'
893
                '</credential><credential type="up" service="smb">'
894
                '<username>smbuser</username>'
895
                '<password>mypass</password></credential>'
896
                '</credentials>'
897
                '</target></targets>'
898
                '</start_scan>'
899
            )
900
        )
901
902
        self.assertEqual(response.get('status'), '200')
903
904
        cred_dict = {
905
            'ssh': {
906
                'type': 'up',
907
                'password': 'mypass',
908
                'port': '22',
909
                'username': 'scanuser',
910
            },
911
            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
912
        }
913
        scan_id = response.findtext('id')
914
        response = daemon.get_scan_credentials(scan_id, "192.168.0.0/24")
915
        self.assertEqual(response, cred_dict)
916
917
    def test_scan_get_target(self):
918
        daemon = DummyWrapper([])
919
        response = secET.fromstring(
920
            daemon.handle_command(
921
                '<start_scan>'
922
                '<scanner_params /><vts><vt id="1.2.3.4" />'
923
                '</vts>'
924
                '<targets><target>'
925
                '<hosts>localhosts</hosts>'
926
                '<ports>80,443</ports>'
927
                '</target>'
928
                '<target><hosts>192.168.0.0/24</hosts>'
929
                '<ports>22</ports></target></targets>'
930
                '</start_scan>'
931
            )
932
        )
933
        scan_id = response.findtext('id')
934
        response = secET.fromstring(
935
            daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id)
936
        )
937
        scan_res = response.find('scan')
938
        self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
939
940
    def test_scan_get_exclude_hosts(self):
941
        daemon = DummyWrapper([])
942
        response = secET.fromstring(
943
            daemon.handle_command(
944
                '<start_scan>'
945
                '<scanner_params /><vts><vt id="1.2.3.4" />'
946
                '</vts>'
947
                '<targets><target>'
948
                '<hosts>192.168.10.20-25</hosts>'
949
                '<ports>80,443</ports>'
950
                '<exclude_hosts>192.168.10.23-24'
951
                '</exclude_hosts>'
952
                '</target>'
953
                '<target><hosts>192.168.0.0/24</hosts>'
954
                '<ports>22</ports></target>'
955
                '</targets>'
956
                '</start_scan>'
957
            )
958
        )
959
        scan_id = response.findtext('id')
960
        time.sleep(1)
961
        finished = daemon.get_scan_finished_hosts(scan_id)
962
        self.assertEqual(finished, ['192.168.10.23', '192.168.10.24'])
963
964
    def test_scan_multi_target_parallel_with_error(self):
965
        daemon = DummyWrapper([])
966
        cmd = secET.fromstring(
967
            '<start_scan parallel="100a">'
968
            '<scanner_params />'
969
            '<targets><target>'
970
            '<hosts>localhosts</hosts>'
971
            '<ports>22</ports>'
972
            '</target></targets>'
973
            '</start_scan>'
974
        )
975
        time.sleep(1)
976
        self.assertRaises(
977
            OspdCommandError, daemon.handle_start_scan_command, cmd
978
        )
979
980
    def test_scan_multi_target_parallel_100(self):
981
        daemon = DummyWrapper([])
982
        response = secET.fromstring(
983
            daemon.handle_command(
984
                '<start_scan parallel="100">'
985
                '<scanner_params />'
986
                '<targets><target>'
987
                '<hosts>localhosts</hosts>'
988
                '<ports>22</ports>'
989
                '</target></targets>'
990
                '</start_scan>'
991
            )
992
        )
993
        time.sleep(1)
994
        self.assertEqual(response.get('status'), '200')
995
996
    def test_progress(self):
997
        daemon = DummyWrapper([])
998
999
        response = secET.fromstring(
1000
            daemon.handle_command(
1001
                '<start_scan parallel="2">'
1002
                '<scanner_params />'
1003
                '<targets><target>'
1004
                '<hosts>localhost1</hosts>'
1005
                '<ports>22</ports>'
1006
                '</target><target>'
1007
                '<hosts>localhost2</hosts>'
1008
                '<ports>22</ports>'
1009
                '</target></targets>'
1010
                '</start_scan>'
1011
            )
1012
        )
1013
1014
        scan_id = response.findtext('id')
1015
1016
        daemon.set_scan_target_progress(scan_id, 'localhost1', 'localhost1', 75)
1017
        daemon.set_scan_target_progress(scan_id, 'localhost2', 'localhost2', 25)
1018
1019
        self.assertEqual(daemon.calculate_progress(scan_id), 50)
1020
1021
    def test_set_get_vts_version(self):
1022
        daemon = DummyWrapper([])
1023
        daemon.set_vts_version('1234')
1024
1025
        version = daemon.get_vts_version()
1026
        self.assertEqual('1234', version)
1027
1028
    def test_set_get_vts_version_error(self):
1029
        daemon = DummyWrapper([])
1030
        self.assertRaises(TypeError, daemon.set_vts_version)
1031
1032
    def test_resume_task(self):
1033
        daemon = DummyWrapper(
1034
            [
1035
                Result(
1036
                    'host-detail', host='localhost', value='Some Host Detail'
1037
                ),
1038
                Result(
1039
                    'host-detail', host='localhost', value='Some Host Detail2'
1040
                ),
1041
            ]
1042
        )
1043
1044
        response = secET.fromstring(
1045
            daemon.handle_command(
1046
                '<start_scan parallel="2">'
1047
                '<scanner_params />'
1048
                '<targets><target>'
1049
                '<hosts>localhost</hosts>'
1050
                '<ports>22</ports>'
1051
                '</target></targets>'
1052
                '</start_scan>'
1053
            )
1054
        )
1055
        scan_id = response.findtext('id')
1056
1057
        time.sleep(3)
1058
        cmd = secET.fromstring('<stop_scan scan_id="%s" />' % scan_id)
1059
1060
        with self.assertRaises(OspdCommandError):
1061
            daemon.handle_stop_scan_command(cmd)
1062
1063
        response = secET.fromstring(
1064
            daemon.handle_command(
1065
                '<get_scans scan_id="%s" details="1"/>' % scan_id
1066
            )
1067
        )
1068
1069
        result = response.findall('scan/results/result')
1070
        self.assertEqual(len(result), 2)
1071
1072
        # Resume the task
1073
        cmd = (
1074
            '<start_scan scan_id="%s" target="localhost" ports="80, 443">'
1075
            '<scanner_params /></start_scan>' % scan_id
1076
        )
1077
        response = secET.fromstring(daemon.handle_command(cmd))
1078
1079
        # Check unfinished host
1080
        self.assertEqual(response.findtext('id'), scan_id)
1081
        self.assertEqual(
1082
            daemon.get_scan_unfinished_hosts(scan_id), ['localhost']
1083
        )
1084
1085
        # Finished the host and check unfinished again.
1086
        daemon.set_scan_host_finished(scan_id, "localhost", "localhost")
1087
        self.assertEqual(daemon.get_scan_unfinished_hosts(scan_id), [])
1088
1089
        # Check finished hosts
1090
        self.assertEqual(
1091
            daemon.scan_collection.get_hosts_finished(scan_id), ['localhost']
1092
        )
1093
1094
        # Check if the result was removed.
1095
        response = secET.fromstring(
1096
            daemon.handle_command(
1097
                '<get_scans scan_id="%s" details="1"/>' % scan_id
1098
            )
1099
        )
1100
        result = response.findall('scan/results/result')
1101
        self.assertEqual(len(result), 0)
1102
1103
    def test_result_order (self):
1104
        daemon = DummyWrapper([])
1105
        response = secET.fromstring(
1106
            daemon.handle_command(
1107
                '<start_scan parallel="1">'
1108
                '<scanner_params />'
1109
                '<targets><target>'
1110
                '<hosts>a</hosts>'
1111
                '<ports>22</ports>'
1112
                '</target></targets>'
1113
                '</start_scan>'
1114
            )
1115
        )
1116
1117
        scan_id = response.findtext('id')
1118
1119
        daemon.add_scan_log(scan_id, host='a', name='a')
1120
        daemon.add_scan_log(scan_id, host='c', name='c')
1121
        daemon.add_scan_log(scan_id, host='b', name='b')
1122
        hosts = ['a','c','b']
1123
        response = secET.fromstring(
1124
            daemon.handle_command('<get_scans details="1"/>'
1125
            )
1126
        )
1127
        results = response.findall("scan/results/")
1128
1129
        for idx, res in enumerate(results):
1130
            att_dict = res.attrib
1131
            self.assertEqual(hosts[idx], att_dict['name'])
1132