Passed
Pull Request — master (#253)
by Juan José
01:24
created

ospd.ospd.OSPDaemon.add_scan_error()   A

Complexity

Conditions 1

Size

Total Lines 20
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 18
nop 8
dl 0
loc 20
rs 9.5
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" OSP Daemon core class.
21
"""
22
23
import logging
24
import socket
25
import ssl
26
import multiprocessing
27
import time
28
import os
29
30
from typing import (
31
    List,
32
    Any,
33
    Iterator,
34
    Dict,
35
    Optional,
36
    Iterable,
37
    Tuple,
38
    Union,
39
)
40
from xml.etree.ElementTree import Element, SubElement
41
42
import defusedxml.ElementTree as secET
43
44
from deprecated import deprecated
45
46
from ospd import __version__
47
from ospd.command import get_commands
48
from ospd.errors import OspdCommandError
49
from ospd.misc import ResultType
50
from ospd.network import resolve_hostname, target_str_to_list
51
from ospd.protocol import OspRequest, OspResponse, RequestParser
52
from ospd.scan import ScanCollection, ScanStatus
53
from ospd.server import BaseServer, Stream
54
from ospd.vtfilter import VtsFilter
55
from ospd.vts import Vts
56
from ospd.xml import (
57
    elements_as_text,
58
    get_result_xml,
59
    get_elements_from_dict,
60
)
61
62
logger = logging.getLogger(__name__)
63
64
PROTOCOL_VERSION = "1.2"
65
66
SCHEDULER_CHECK_PERIOD = 10  # in seconds
67
68
BASE_SCANNER_PARAMS = {
69
    'debug_mode': {
70
        'type': 'boolean',
71
        'name': 'Debug Mode',
72
        'default': 0,
73
        'mandatory': 0,
74
        'description': 'Whether to get extra scan debug information.',
75
    },
76
    'dry_run': {
77
        'type': 'boolean',
78
        'name': 'Dry Run',
79
        'default': 0,
80
        'mandatory': 0,
81
        'description': 'Whether to dry run scan.',
82
    },
83
}  # type: Dict
84
85
86
def _terminate_process_group(process: multiprocessing.Process) -> None:
87
    os.killpg(os.getpgid(process.pid), 15)
88
89
90
class OSPDaemon:
91
92
    """ Daemon class for OSP traffic handling.
93
94
    Every scanner wrapper should subclass it and make necessary additions and
95
    changes.
96
97
    * Add any needed parameters in __init__.
98
    * Implement check() method which verifies scanner availability and other
99
      environment related conditions.
100
    * Implement process_scan_params and exec_scan methods which are
101
      specific to handling the <start_scan> command, executing the wrapped
102
      scanner and storing the results.
103
    * Implement other methods that assert to False such as get_scanner_name,
104
      get_scanner_version.
105
    * Use Call set_command_attributes at init time to add scanner command
106
      specific options eg. the w3af profile for w3af wrapper.
107
    """
108
109
    def __init__(
110
        self,
111
        *,
112
        customvtfilter=None,
113
        storage=None,
114
        max_scans=0,
115
        check_free_memory=False,
116
        **kwargs
117
    ):  # pylint: disable=unused-argument
118
        """ Initializes the daemon's internal data. """
119
        self.scan_collection = ScanCollection()
120
        self.scan_processes = dict()
121
122
        self.daemon_info = dict()
123
        self.daemon_info['name'] = "OSPd"
124
        self.daemon_info['version'] = __version__
125
        self.daemon_info['description'] = "No description"
126
127
        self.scanner_info = dict()
128
        self.scanner_info['name'] = 'No name'
129
        self.scanner_info['version'] = 'No version'
130
        self.scanner_info['description'] = 'No description'
131
132
        self.server_version = None  # Set by the subclass.
133
134
        self.initialized = None  # Set after initialization finished
135
136
        self.max_scans = max_scans
137
        self.check_free_memory = check_free_memory
138
139
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
140
141
        self.protocol_version = PROTOCOL_VERSION
142
143
        self.commands = {}
144
145
        for command_class in get_commands():
146
            command = command_class(self)
147
            self.commands[command.get_name()] = command
148
149
        self.scanner_params = dict()
150
151
        for name, params in BASE_SCANNER_PARAMS.items():
152
            self.set_scanner_param(name, params)
153
154
        self.vts = Vts(storage)
155
        self.vts_version = None
156
157
        if customvtfilter:
158
            self.vts_filter = customvtfilter
159
        else:
160
            self.vts_filter = VtsFilter()
161
162
    def init(self, server: BaseServer) -> None:
163
        """ Should be overridden by a subclass if the initialization is costly.
164
165
            Will be called after check.
166
        """
167
        server.start(self.handle_client_stream)
168
        self.initialized = True
169
170
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
171
        """ Sets the xml attributes of a specified command. """
172
        if self.command_exists(name):
173
            command = self.commands.get(name)
174
            command.attributes = attributes
175
176
    @deprecated(version="20.4", reason="Use set_scanner_param instead")
177
    def add_scanner_param(self, name: str, scanner_params: Dict) -> None:
178
        """ Set a scanner parameter. """
179
        self.set_scanner_param(name, scanner_params)
180
181
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
182
        """ Set a scanner parameter. """
183
184
        assert name
185
        assert scanner_params
186
187
        self.scanner_params[name] = scanner_params
188
189
    def get_scanner_params(self) -> Dict:
190
        return self.scanner_params
191
192
    def add_vt(
193
        self,
194
        vt_id: str,
195
        name: str = None,
196
        vt_params: str = None,
197
        vt_refs: str = None,
198
        custom: str = None,
199
        vt_creation_time: str = None,
200
        vt_modification_time: str = None,
201
        vt_dependencies: str = None,
202
        summary: str = None,
203
        impact: str = None,
204
        affected: str = None,
205
        insight: str = None,
206
        solution: str = None,
207
        solution_t: str = None,
208
        solution_m: str = None,
209
        detection: str = None,
210
        qod_t: str = None,
211
        qod_v: str = None,
212
        severities: str = None,
213
    ) -> None:
214
        """ Add a vulnerability test information.
215
216
        IMPORTANT: The VT's Data Manager will store the vts collection.
217
        If the collection is considerably big and it will be consultated
218
        intensible during a routine, consider to do a deepcopy(), since
219
        accessing the shared memory in the data manager is very expensive.
220
        At the end of the routine, the temporal copy must be set to None
221
        and deleted.
222
        """
223
        self.vts.add(
224
            vt_id,
225
            name=name,
226
            vt_params=vt_params,
227
            vt_refs=vt_refs,
228
            custom=custom,
229
            vt_creation_time=vt_creation_time,
230
            vt_modification_time=vt_modification_time,
231
            vt_dependencies=vt_dependencies,
232
            summary=summary,
233
            impact=impact,
234
            affected=affected,
235
            insight=insight,
236
            solution=solution,
237
            solution_t=solution_t,
238
            solution_m=solution_m,
239
            detection=detection,
240
            qod_t=qod_t,
241
            qod_v=qod_v,
242
            severities=severities,
243
        )
244
245
    def set_vts_version(self, vts_version: str) -> None:
246
        """ Add into the vts dictionary an entry to identify the
247
        vts version.
248
249
        Parameters:
250
            vts_version (str): Identifies a unique vts version.
251
        """
252
        if not vts_version:
253
            raise OspdCommandError(
254
                'A vts_version parameter is required', 'set_vts_version'
255
            )
256
        self.vts_version = vts_version
257
258
    def get_vts_version(self) -> Optional[str]:
259
        """Return the vts version.
260
        """
261
        return self.vts_version
262
263
    def command_exists(self, name: str) -> bool:
264
        """ Checks if a commands exists. """
265
        return name in self.commands
266
267
    def get_scanner_name(self) -> str:
268
        """ Gives the wrapped scanner's name. """
269
        return self.scanner_info['name']
270
271
    def get_scanner_version(self) -> str:
272
        """ Gives the wrapped scanner's version. """
273
        return self.scanner_info['version']
274
275
    def get_scanner_description(self) -> str:
276
        """ Gives the wrapped scanner's description. """
277
        return self.scanner_info['description']
278
279
    def get_server_version(self) -> str:
280
        """ Gives the specific OSP server's version. """
281
        assert self.server_version
282
        return self.server_version
283
284
    def get_protocol_version(self) -> str:
285
        """ Gives the OSP's version. """
286
        return self.protocol_version
287
288
    def preprocess_scan_params(self, xml_params):
289
        """ Processes the scan parameters. """
290
        params = {}
291
292
        for param in xml_params:
293
            params[param.tag] = param.text or ''
294
295
        # Set default values.
296
        for key in self.scanner_params:
297
            if key not in params:
298
                params[key] = self.get_scanner_param_default(key)
299
                if self.get_scanner_param_type(key) == 'selection':
300
                    params[key] = params[key].split('|')[0]
301
302
        # Validate values.
303
        for key in params:
304
            param_type = self.get_scanner_param_type(key)
305
            if not param_type:
306
                continue
307
308
            if param_type in ['integer', 'boolean']:
309
                try:
310
                    params[key] = int(params[key])
311
                except ValueError:
312
                    raise OspdCommandError(
313
                        'Invalid %s value' % key, 'start_scan'
314
                    )
315
316
            if param_type == 'boolean':
317
                if params[key] not in [0, 1]:
318
                    raise OspdCommandError(
319
                        'Invalid %s value' % key, 'start_scan'
320
                    )
321
            elif param_type == 'selection':
322
                selection = self.get_scanner_param_default(key).split('|')
323
                if params[key] not in selection:
324
                    raise OspdCommandError(
325
                        'Invalid %s value' % key, 'start_scan'
326
                    )
327
            if self.get_scanner_param_mandatory(key) and params[key] == '':
328
                raise OspdCommandError(
329
                    'Mandatory %s value is missing' % key, 'start_scan'
330
                )
331
332
        return params
333
334
    def process_scan_params(self, params: Dict) -> Dict:
335
        """ This method is to be overridden by the child classes if necessary
336
        """
337
        return params
338
339
    @staticmethod
340
    @deprecated(
341
        version="20.4",
342
        reason="Please use OspRequest.process_vt_params instead.",
343
    )
344
    def process_vts_params(scanner_vts) -> Dict:
345
        return OspRequest.process_vts_params(scanner_vts)
346
347
    @staticmethod
348
    @deprecated(
349
        version="20.4",
350
        reason="Please use OspRequest.process_credential_elements instead.",
351
    )
352
    def process_credentials_elements(cred_tree) -> Dict:
353
        return OspRequest.process_credentials_elements(cred_tree)
354
355
    @staticmethod
356
    @deprecated(
357
        version="20.4",
358
        reason="Please use OspRequest.process_targets_elements instead.",
359
    )
360
    def process_targets_element(scanner_target) -> List:
361
        return OspRequest.process_target_element(scanner_target)
362
363
    def stop_scan(self, scan_id: str) -> None:
364
        scan_process = self.scan_processes.get(scan_id)
365
        if not scan_process:
366
            raise OspdCommandError(
367
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
368
            )
369
        if not scan_process.is_alive():
370
            raise OspdCommandError(
371
                'Scan already stopped or finished.', 'stop_scan'
372
            )
373
374
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
375
376
        logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
377
378
        self.stop_scan_cleanup(scan_id)
379
380
        try:
381
            scan_process.terminate()
382
        except AttributeError:
383
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
384
385
        try:
386
            _terminate_process_group(scan_process)
387
        except ProcessLookupError as e:
388
            logger.info(
389
                '%s: Scan already stopped %s.', scan_id, scan_process.pid
390
            )
391
392
        if scan_process.ident != os.getpid():
393
            scan_process.join(0)
394
395
        logger.info('%s: Scan stopped.', scan_id)
396
397
    @staticmethod
398
    def stop_scan_cleanup(scan_id: str):
399
        """ Should be implemented by subclass in case of a clean up before
400
        terminating is needed. """
401
402
    @staticmethod
403
    def target_is_finished(scan_id: str):
404
        """ Should be implemented by subclass in case of a check before
405
        stopping is needed. """
406
407
    def exec_scan(self, scan_id: str):
408
        """ Asserts to False. Should be implemented by subclass. """
409
        raise NotImplementedError
410
411
    def finish_scan(self, scan_id: str) -> None:
412
        """ Sets a scan as finished. """
413
        self.set_scan_progress(scan_id, 100)
414
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
415
        logger.info("%s: Scan finished.", scan_id)
416
417
    def get_daemon_name(self) -> str:
418
        """ Gives osp daemon's name. """
419
        return self.daemon_info['name']
420
421
    def get_daemon_version(self) -> str:
422
        """ Gives osp daemon's version. """
423
        return self.daemon_info['version']
424
425
    def get_scanner_param_type(self, param: str):
426
        """ Returns type of a scanner parameter. """
427
        assert isinstance(param, str)
428
        entry = self.scanner_params.get(param)
429
        if not entry:
430
            return None
431
        return entry.get('type')
432
433
    def get_scanner_param_mandatory(self, param: str):
434
        """ Returns if a scanner parameter is mandatory. """
435
        assert isinstance(param, str)
436
        entry = self.scanner_params.get(param)
437
        if not entry:
438
            return False
439
        return entry.get('mandatory')
440
441
    def get_scanner_param_default(self, param: str):
442
        """ Returns default value of a scanner parameter. """
443
        assert isinstance(param, str)
444
        entry = self.scanner_params.get(param)
445
        if not entry:
446
            return None
447
        return entry.get('default')
448
449
    @deprecated(
450
        version="20.4",
451
        reason="Please use OspResponse.create_scanner_params_xml instead.",
452
    )
453
    def get_scanner_params_xml(self):
454
        """ Returns the OSP Daemon's scanner params in xml format. """
455
        return OspResponse.create_scanner_params_xml(self.scanner_params)
456
457
    def handle_client_stream(self, stream: Stream) -> None:
458
        """ Handles stream of data received from client. """
459
        data = b''
460
461
        request_parser = RequestParser()
462
463
        while True:
464
            try:
465
                buf = stream.read()
466
                if not buf:
467
                    break
468
469
                data += buf
470
471
                if request_parser.has_ended(buf):
472
                    break
473
            except (AttributeError, ValueError) as message:
474
                logger.error(message)
475
                return
476
            except (ssl.SSLError) as exception:
477
                logger.debug('Error: %s', exception)
478
                break
479
            except (socket.timeout) as exception:
480
                logger.debug('Request timeout: %s', exception)
481
                break
482
483
        if len(data) <= 0:
484
            logger.debug("Empty client stream")
485
            return
486
487
        response = None
488
        try:
489
            self.handle_command(data, stream)
490
        except OspdCommandError as exception:
491
            response = exception.as_xml()
492
            logger.debug('Command error: %s', exception.message)
493
        except Exception:  # pylint: disable=broad-except
494
            logger.exception('While handling client command:')
495
            exception = OspdCommandError('Fatal error', 'error')
496
            response = exception.as_xml()
497
498
        if response:
499
            stream.write(response)
500
501
        stream.close()
502
503
    def calculate_progress(self, scan_id: str) -> float:
504
        """ Calculate the total scan progress. """
505
506
        return self.scan_collection.calculate_target_progress(scan_id)
507
508
    def process_exclude_hosts(self, scan_id: str, exclude_hosts: str) -> None:
509
        """ Process the exclude hosts before launching the scans."""
510
511
        exc_hosts_list = ''
512
        if not exclude_hosts:
513
            return
514
        exc_hosts_list = target_str_to_list(exclude_hosts)
515
        self.remove_scan_hosts_from_target_progress(scan_id, exc_hosts_list)
516
517
    def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None:
518
        """ Process the finished hosts before launching the scans.
519
        Set finished hosts as finished with 100% to calculate
520
        the scan progress."""
521
522
        exc_hosts_list = ''
523
        if not finished_hosts:
524
            return
525
526
        exc_hosts_list = target_str_to_list(finished_hosts)
527
528
        for host in exc_hosts_list:
529
            self.set_scan_host_finished(scan_id, finished_hosts=host)
530
            self.set_scan_host_progress(scan_id, host=host, progress=100)
531
532
    def start_scan(self, scan_id: str, target: Dict) -> None:
533
        """ Starts the scan with scan_id. """
534
        os.setsid()
535
536
        if target is None or not target:
537
            raise OspdCommandError('Erroneous target', 'start_scan')
538
539
        logger.info("%s: Scan started.", scan_id)
540
541
        self.process_exclude_hosts(scan_id, target.get('exclude_hosts'))
542
        self.process_finished_hosts(scan_id, target.get('finished_hosts'))
543
544
        try:
545
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
546
            self.exec_scan(scan_id)
547
        except Exception as e:  # pylint: disable=broad-except
548
            self.add_scan_error(
549
                scan_id,
550
                name='',
551
                host=self.get_scan_host(scan_id),
552
                value='Host process failure (%s).' % e,
553
            )
554
            logger.exception('While scanning: %s', scan_id)
555
        else:
556
            logger.info("%s: Host scan finished.", scan_id)
557
558
        if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
559
            self.finish_scan(scan_id)
560
561
    def dry_run_scan(self, scan_id: str, target: Dict) -> None:
562
        """ Dry runs a scan. """
563
564
        os.setsid()
565
566
        host = resolve_hostname(target[0])
567
        if host is None:
568
            logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))
569
570
        port = self.get_scan_ports(scan_id)
571
572
        logger.info("%s:%s: Dry run mode.", host, port)
573
574
        self.add_scan_log(scan_id, name='', host=host, value='Dry run result')
575
576
        self.finish_scan(scan_id)
577
578
    def handle_timeout(self, scan_id: str, host: str) -> None:
579
        """ Handles scanner reaching timeout error. """
580
        self.add_scan_error(
581
            scan_id,
582
            host=host,
583
            name="Timeout",
584
            value="{0} exec timeout.".format(self.get_scanner_name()),
585
        )
586
587
    def remove_scan_hosts_from_target_progress(
588
        self, scan_id: str, exc_hosts_list: List
589
    ) -> None:
590
        """ Remove a list of hosts from the main scan progress table."""
591
        self.scan_collection.remove_hosts_from_target_progress(
592
            scan_id, exc_hosts_list
593
        )
594
595
    def set_scan_host_finished(
596
        self, scan_id: str, finished_hosts: Union[List[str], str],
597
    ) -> None:
598
        """ Add the host in a list of finished hosts """
599
        if isinstance(finished_hosts, str):
600
            finished_hosts = [finished_hosts]
601
602
        self.scan_collection.set_host_finished(scan_id, finished_hosts)
603
604
    def set_scan_progress(self, scan_id: str, progress: int) -> None:
605
        """ Sets scan_id scan's progress which is a number
606
        between 0 and 100. """
607
        self.scan_collection.set_progress(scan_id, progress)
608
609
    def set_scan_progress_batch(
610
        self, scan_id: str, host_progress: Dict[str, int]
611
    ):
612
        self.scan_collection.set_host_progress(scan_id, host_progress)
613
614
        scan_progress = self.calculate_progress(scan_id)
615
        self.set_scan_progress(scan_id, scan_progress)
616
617
    def set_scan_host_progress(
618
        self, scan_id: str, host: str = None, progress: int = None,
619
    ) -> None:
620
        """ Sets host's progress which is part of target.
621
        Each time a host progress is updated, the scan progress
622
        is updated too.
623
        """
624
        if host and progress < 0 or progress > 100:
625
            return
626
627
        host_progress = {host: progress}
628
        self.set_scan_progress_batch(scan_id, host_progress)
629
630
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
631
        """ Set the scan's status."""
632
        self.scan_collection.set_status(scan_id, status)
633
634
    def get_scan_status(self, scan_id: str) -> ScanStatus:
635
        """ Get scan_id scans's status."""
636
        return self.scan_collection.get_status(scan_id)
637
638
    def scan_exists(self, scan_id: str) -> bool:
639
        """ Checks if a scan with ID scan_id is in collection.
640
641
        @return: 1 if scan exists, 0 otherwise.
642
        """
643
        return self.scan_collection.id_exists(scan_id)
644
645
    def get_help_text(self) -> str:
646
        """ Returns the help output in plain text format."""
647
648
        txt = ''
649
        for name, info in self.commands.items():
650
            description = info.get_description()
651
            attributes = info.get_attributes()
652
            elements = info.get_elements()
653
654
            command_txt = "\t{0: <22} {1}\n".format(name, description)
655
656
            if attributes:
657
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
658
659
                for attrname, attrdesc in attributes.items():
660
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
661
                    command_txt = ''.join([command_txt, attr_txt])
662
663
            if elements:
664
                command_txt = ''.join(
665
                    [command_txt, "\t Elements:\n", elements_as_text(elements),]
666
                )
667
668
            txt += command_txt
669
670
        return txt
671
672
    @deprecated(version="20.4", reason="Use ospd.xml.elements_as_text instead.")
673
    def elements_as_text(self, elems: Dict, indent: int = 2) -> str:
674
        """ Returns the elems dictionary as formatted plain text. """
675
        return elements_as_text(elems, indent)
676
677
    def delete_scan(self, scan_id: str) -> int:
678
        """ Deletes scan_id scan from collection.
679
680
        @return: 1 if scan deleted, 0 otherwise.
681
        """
682
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
683
            return 0
684
685
        # Don't delete the scan until the process stops
686
        exitcode = None
687
        try:
688
            self.scan_processes[scan_id].join()
689
            exitcode = self.scan_processes[scan_id].exitcode
690
        except KeyError:
691
            logger.debug('Scan process for %s not found', scan_id)
692
693
        if exitcode or exitcode == 0:
694
            del self.scan_processes[scan_id]
695
696
        return self.scan_collection.delete_scan(scan_id)
697
698
    def get_scan_results_xml(
699
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
700
    ):
701
        """ Gets scan_id scan's results in XML format.
702
703
        @return: String of scan results in xml.
704
        """
705
        results = Element('results')
706
        for result in self.scan_collection.results_iterator(
707
            scan_id, pop_res, max_res
708
        ):
709
            results.append(get_result_xml(result))
710
711
        logger.debug('Returning %d results', len(results))
712
        return results
713
714
    @deprecated(
715
        version="20.4",
716
        reason="Please use ospd.xml.get_elements_from_dict instead.",
717
    )
718
    def get_xml_str(self, data: Dict) -> List:
719
        """ Creates a string in XML Format using the provided data structure.
720
721
        @param: Dictionary of xml tags and their elements.
722
723
        @return: String of data in xml format.
724
        """
725
        return get_elements_from_dict(data)
726
727
    def get_scan_xml(
728
        self,
729
        scan_id: str,
730
        detailed: bool = True,
731
        pop_res: bool = False,
732
        max_res: int = 0,
733
    ):
734
        """ Gets scan in XML format.
735
736
        @return: String of scan in XML format.
737
        """
738
        if not scan_id:
739
            return Element('scan')
740
741
        target = self.get_scan_host(scan_id)
742
        progress = self.get_scan_progress(scan_id)
743
        status = self.get_scan_status(scan_id)
744
        start_time = self.get_scan_start_time(scan_id)
745
        end_time = self.get_scan_end_time(scan_id)
746
        response = Element('scan')
747
        for name, value in [
748
            ('id', scan_id),
749
            ('target', target),
750
            ('progress', progress),
751
            ('status', status.name.lower()),
752
            ('start_time', start_time),
753
            ('end_time', end_time),
754
        ]:
755
            response.set(name, str(value))
756
        if detailed:
757
            response.append(
758
                self.get_scan_results_xml(scan_id, pop_res, max_res)
759
            )
760
        return response
761
762
    @staticmethod
763
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
764
        vt_id: str, custom: Dict
765
    ) -> str:
766
        """ Create a string representation of the XML object from the
767
        custom data object.
768
        This needs to be implemented by each ospd wrapper, in case
769
        custom elements for VTs are used.
770
771
        The custom XML object which is returned will be embedded
772
        into a <custom></custom> element.
773
774
        @return: XML object as string for custom data.
775
        """
776
        return ''
777
778
    @staticmethod
779
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
780
        vt_id: str, vt_params
781
    ) -> str:
782
        """ Create a string representation of the XML object from the
783
        vt_params data object.
784
        This needs to be implemented by each ospd wrapper, in case
785
        vt_params elements for VTs are used.
786
787
        The params XML object which is returned will be embedded
788
        into a <params></params> element.
789
790
        @return: XML object as string for vt parameters data.
791
        """
792
        return ''
793
794
    @staticmethod
795
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
796
        vt_id: str, vt_refs
797
    ) -> str:
798
        """ Create a string representation of the XML object from the
799
        refs data object.
800
        This needs to be implemented by each ospd wrapper, in case
801
        refs elements for VTs are used.
802
803
        The refs XML object which is returned will be embedded
804
        into a <refs></refs> element.
805
806
        @return: XML object as string for vt references data.
807
        """
808
        return ''
809
810
    @staticmethod
811
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
812
        vt_id: str, vt_dependencies
813
    ) -> str:
814
        """ Create a string representation of the XML object from the
815
        vt_dependencies data object.
816
        This needs to be implemented by each ospd wrapper, in case
817
        vt_dependencies elements for VTs are used.
818
819
        The vt_dependencies XML object which is returned will be embedded
820
        into a <dependencies></dependencies> element.
821
822
        @return: XML object as string for vt dependencies data.
823
        """
824
        return ''
825
826
    @staticmethod
827
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
828
        vt_id: str, vt_creation_time
829
    ) -> str:
830
        """ Create a string representation of the XML object from the
831
        vt_creation_time data object.
832
        This needs to be implemented by each ospd wrapper, in case
833
        vt_creation_time elements for VTs are used.
834
835
        The vt_creation_time XML object which is returned will be embedded
836
        into a <vt_creation_time></vt_creation_time> element.
837
838
        @return: XML object as string for vt creation time data.
839
        """
840
        return ''
841
842
    @staticmethod
843
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
844
        vt_id: str, vt_modification_time
845
    ) -> str:
846
        """ Create a string representation of the XML object from the
847
        vt_modification_time data object.
848
        This needs to be implemented by each ospd wrapper, in case
849
        vt_modification_time elements for VTs are used.
850
851
        The vt_modification_time XML object which is returned will be embedded
852
        into a <vt_modification_time></vt_modification_time> element.
853
854
        @return: XML object as string for vt references data.
855
        """
856
        return ''
857
858
    @staticmethod
859
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
860
        vt_id: str, summary
861
    ) -> str:
862
        """ Create a string representation of the XML object from the
863
        summary data object.
864
        This needs to be implemented by each ospd wrapper, in case
865
        summary elements for VTs are used.
866
867
        The summary XML object which is returned will be embedded
868
        into a <summary></summary> element.
869
870
        @return: XML object as string for summary data.
871
        """
872
        return ''
873
874
    @staticmethod
875
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
876
        vt_id: str, impact
877
    ) -> str:
878
        """ Create a string representation of the XML object from the
879
        impact data object.
880
        This needs to be implemented by each ospd wrapper, in case
881
        impact elements for VTs are used.
882
883
        The impact XML object which is returned will be embedded
884
        into a <impact></impact> element.
885
886
        @return: XML object as string for impact data.
887
        """
888
        return ''
889
890
    @staticmethod
891
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
892
        vt_id: str, affected
893
    ) -> str:
894
        """ Create a string representation of the XML object from the
895
        affected data object.
896
        This needs to be implemented by each ospd wrapper, in case
897
        affected elements for VTs are used.
898
899
        The affected XML object which is returned will be embedded
900
        into a <affected></affected> element.
901
902
        @return: XML object as string for affected data.
903
        """
904
        return ''
905
906
    @staticmethod
907
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
908
        vt_id: str, insight
909
    ) -> str:
910
        """ Create a string representation of the XML object from the
911
        insight data object.
912
        This needs to be implemented by each ospd wrapper, in case
913
        insight elements for VTs are used.
914
915
        The insight XML object which is returned will be embedded
916
        into a <insight></insight> element.
917
918
        @return: XML object as string for insight data.
919
        """
920
        return ''
921
922
    @staticmethod
923
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
924
        vt_id: str, solution, solution_type=None, solution_method=None
925
    ) -> str:
926
        """ Create a string representation of the XML object from the
927
        solution data object.
928
        This needs to be implemented by each ospd wrapper, in case
929
        solution elements for VTs are used.
930
931
        The solution XML object which is returned will be embedded
932
        into a <solution></solution> element.
933
934
        @return: XML object as string for solution data.
935
        """
936
        return ''
937
938
    @staticmethod
939
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
940
        vt_id: str, detection=None, qod_type=None, qod=None
941
    ) -> str:
942
        """ Create a string representation of the XML object from the
943
        detection data object.
944
        This needs to be implemented by each ospd wrapper, in case
945
        detection elements for VTs are used.
946
947
        The detection XML object which is returned is an element with
948
        tag <detection></detection> element
949
950
        @return: XML object as string for detection data.
951
        """
952
        return ''
953
954
    @staticmethod
955
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
956
        vt_id: str, severities
957
    ) -> str:
958
        """ Create a string representation of the XML object from the
959
        severities data object.
960
        This needs to be implemented by each ospd wrapper, in case
961
        severities elements for VTs are used.
962
963
        The severities XML objects which are returned will be embedded
964
        into a <severities></severities> element.
965
966
        @return: XML object as string for severities data.
967
        """
968
        return ''
969
970
    def get_vt_iterator(  # pylint: disable=unused-argument
971
        self, vt_selection: List[str] = None, details: bool = True
972
    ) -> Iterator[Tuple[str, Dict]]:
973
        """ Return iterator object for getting elements
974
        from the VTs dictionary. """
975
        return self.vts.items()
976
977
    def get_vt_xml(self, single_vt: Tuple[str, Dict]) -> Element:
978
        """ Gets a single vulnerability test information in XML format.
979
980
        @return: String of single vulnerability test information in XML format.
981
        """
982
        if not single_vt:
983
            return Element('vt')
984
985
        vt_id, vt = single_vt
986
987
        name = vt.get('name')
988
        vt_xml = Element('vt')
989
        vt_xml.set('id', vt_id)
990
991
        for name, value in [('name', name)]:
992
            elem = SubElement(vt_xml, name)
993
            elem.text = str(value)
994
995
        if vt.get('vt_params'):
996
            params_xml_str = self.get_params_vt_as_xml_str(
997
                vt_id, vt.get('vt_params')
998
            )
999
            vt_xml.append(secET.fromstring(params_xml_str))
1000
1001
        if vt.get('vt_refs'):
1002
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
1003
            vt_xml.append(secET.fromstring(refs_xml_str))
1004
1005
        if vt.get('vt_dependencies'):
1006
            dependencies = self.get_dependencies_vt_as_xml_str(
1007
                vt_id, vt.get('vt_dependencies')
1008
            )
1009
            vt_xml.append(secET.fromstring(dependencies))
1010
1011
        if vt.get('creation_time'):
1012
            vt_ctime = self.get_creation_time_vt_as_xml_str(
1013
                vt_id, vt.get('creation_time')
1014
            )
1015
            vt_xml.append(secET.fromstring(vt_ctime))
1016
1017
        if vt.get('modification_time'):
1018
            vt_mtime = self.get_modification_time_vt_as_xml_str(
1019
                vt_id, vt.get('modification_time')
1020
            )
1021
            vt_xml.append(secET.fromstring(vt_mtime))
1022
1023
        if vt.get('summary'):
1024
            summary_xml_str = self.get_summary_vt_as_xml_str(
1025
                vt_id, vt.get('summary')
1026
            )
1027
            vt_xml.append(secET.fromstring(summary_xml_str))
1028
1029
        if vt.get('impact'):
1030
            impact_xml_str = self.get_impact_vt_as_xml_str(
1031
                vt_id, vt.get('impact')
1032
            )
1033
            vt_xml.append(secET.fromstring(impact_xml_str))
1034
1035
        if vt.get('affected'):
1036
            affected_xml_str = self.get_affected_vt_as_xml_str(
1037
                vt_id, vt.get('affected')
1038
            )
1039
            vt_xml.append(secET.fromstring(affected_xml_str))
1040
1041
        if vt.get('insight'):
1042
            insight_xml_str = self.get_insight_vt_as_xml_str(
1043
                vt_id, vt.get('insight')
1044
            )
1045
            vt_xml.append(secET.fromstring(insight_xml_str))
1046
1047
        if vt.get('solution'):
1048
            solution_xml_str = self.get_solution_vt_as_xml_str(
1049
                vt_id,
1050
                vt.get('solution'),
1051
                vt.get('solution_type'),
1052
                vt.get('solution_method'),
1053
            )
1054
            vt_xml.append(secET.fromstring(solution_xml_str))
1055
1056
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1057
            detection_xml_str = self.get_detection_vt_as_xml_str(
1058
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1059
            )
1060
            vt_xml.append(secET.fromstring(detection_xml_str))
1061
1062
        if vt.get('severities'):
1063
            severities_xml_str = self.get_severities_vt_as_xml_str(
1064
                vt_id, vt.get('severities')
1065
            )
1066
            vt_xml.append(secET.fromstring(severities_xml_str))
1067
1068
        if vt.get('custom'):
1069
            custom_xml_str = self.get_custom_vt_as_xml_str(
1070
                vt_id, vt.get('custom')
1071
            )
1072
            vt_xml.append(secET.fromstring(custom_xml_str))
1073
1074
        return vt_xml
1075
1076
    def get_vts_selection_list(
1077
        self, vt_id: str = None, filtered_vts: Dict = None
1078
    ) -> Iterable[str]:
1079
        """
1080
        Get list of VT's OID.
1081
        If vt_id is specified, the collection will contain only this vt, if
1082
        found.
1083
        If no vt_id is specified or filtered_vts is None (default), the
1084
        collection will contain all vts. Otherwise those vts passed
1085
        in filtered_vts or vt_id are returned. In case of both vt_id and
1086
        filtered_vts are given, filtered_vts has priority.
1087
1088
        Arguments:
1089
            vt_id (vt_id, optional): ID of the vt to get.
1090
            filtered_vts (list, optional): Filtered VTs collection.
1091
1092
        Return:
1093
            List of selected VT's OID.
1094
        """
1095
        vts_xml = []
1096
1097
        # No match for the filter
1098
        if filtered_vts is not None and len(filtered_vts) == 0:
1099
            return vts_xml
1100
1101
        if filtered_vts:
1102
            vts_list = filtered_vts
1103
        elif vt_id:
1104
            vts_list = [vt_id]
1105
        else:
1106
            vts_list = self.vts.keys()
1107
1108
        return vts_list
1109
1110
    def handle_command(self, data: bytes, stream: Stream) -> None:
1111
        """ Handles an osp command in a string.
1112
        """
1113
        try:
1114
            tree = secET.fromstring(data)
1115
        except secET.ParseError:
1116
            logger.debug("Erroneous client input: %s", data)
1117
            raise OspdCommandError('Invalid data')
1118
1119
        command_name = tree.tag
1120
1121
        logger.debug('Handling %s command request.', command_name)
1122
1123
        command = self.commands.get(command_name, None)
1124
        if not command and command_name != "authenticate":
1125
            raise OspdCommandError('Bogus command name')
1126
1127
        if not self.initialized and command.must_be_initialized:
1128
            exception = OspdCommandError(
1129
                '%s is still starting' % self.daemon_info['name'], 'error'
1130
            )
1131
            response = exception.as_xml()
1132
            stream.write(response)
1133
            return
1134
1135
        response = command.handle_xml(tree)
1136
1137
        if isinstance(response, bytes):
1138
            stream.write(response)
1139
        else:
1140
            for data in response:
1141
                stream.write(data)
1142
1143
    def check(self):
1144
        """ Asserts to False. Should be implemented by subclass. """
1145
        raise NotImplementedError
1146
1147
    def run(self) -> None:
1148
        """ Starts the Daemon, handling commands until interrupted.
1149
        """
1150
1151
        try:
1152
            while True:
1153
                time.sleep(SCHEDULER_CHECK_PERIOD)
1154
                self.scheduler()
1155
                self.clean_forgotten_scans()
1156
                self.wait_for_children()
1157
        except KeyboardInterrupt:
1158
            logger.info("Received Ctrl-C shutting-down ...")
1159
1160
    def scheduler(self):
1161
        """ Should be implemented by subclass in case of need
1162
        to run tasks periodically. """
1163
1164
    def wait_for_children(self):
1165
        """ Join the zombie process to releases resources."""
1166
        for scan_id in self.scan_processes:
1167
            self.scan_processes[scan_id].join(0)
1168
1169
    def create_scan(
1170
        self,
1171
        scan_id: str,
1172
        targets: Dict,
1173
        options: Optional[Dict],
1174
        vt_selection: Dict,
1175
    ) -> Optional[str]:
1176
        """ Creates a new scan.
1177
1178
        @target: Target to scan.
1179
        @options: Miscellaneous scan options.
1180
1181
        @return: New scan's ID. None if the scan_id already exists and the
1182
                 scan status is RUNNING or FINISHED.
1183
        """
1184
        status = None
1185
        scan_exists = self.scan_exists(scan_id)
1186
        if scan_id and scan_exists:
1187
            status = self.get_scan_status(scan_id)
1188
1189
        if scan_exists and status == ScanStatus.STOPPED:
1190
            logger.info("Scan %s exists. Resuming scan.", scan_id)
1191
        elif scan_exists and (
1192
            status == ScanStatus.RUNNING or status == ScanStatus.FINISHED
1193
        ):
1194
            logger.info(
1195
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1196
            )
1197
            return
1198
        return self.scan_collection.create_scan(
1199
            scan_id, targets, options, vt_selection
1200
        )
1201
1202
    def get_scan_options(self, scan_id: str) -> str:
1203
        """ Gives a scan's list of options. """
1204
        return self.scan_collection.get_options(scan_id)
1205
1206
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1207
        """ Sets a scan's option to a provided value. """
1208
        return self.scan_collection.set_option(scan_id, name, value)
1209
1210
    def clean_forgotten_scans(self) -> None:
1211
        """ Check for old stopped or finished scans which have not been
1212
        deleted and delete them if the are older than the set value."""
1213
1214
        if not self.scaninfo_store_time:
1215
            return
1216
1217
        for scan_id in list(self.scan_collection.ids_iterator()):
1218
            end_time = int(self.get_scan_end_time(scan_id))
1219
            scan_status = self.get_scan_status(scan_id)
1220
1221
            if (
1222
                scan_status == ScanStatus.STOPPED
1223
                or scan_status == ScanStatus.FINISHED
1224
            ) and end_time:
1225
                stored_time = int(time.time()) - end_time
1226
                if stored_time > self.scaninfo_store_time * 3600:
1227
                    logger.debug(
1228
                        'Scan %s is older than %d hours and seems have been '
1229
                        'forgotten. Scan info will be deleted from the '
1230
                        'scan table',
1231
                        scan_id,
1232
                        self.scaninfo_store_time,
1233
                    )
1234
                    self.delete_scan(scan_id)
1235
1236
    def check_scan_process(self, scan_id: str) -> None:
1237
        """ Check the scan's process, and terminate the scan if not alive. """
1238
        scan_process = self.scan_processes[scan_id]
1239
        progress = self.get_scan_progress(scan_id)
1240
1241
        if progress < 100 and not scan_process.is_alive():
1242
            if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
1243
                self.set_scan_status(scan_id, ScanStatus.STOPPED)
1244
                self.add_scan_error(
1245
                    scan_id, name="", host="", value="Scan process failure."
1246
                )
1247
1248
                logger.info("%s: Scan stopped with errors.", scan_id)
1249
1250
        elif progress == 100:
1251
            scan_process.join(0)
1252
1253
    def get_scan_progress(self, scan_id: str):
1254
        """ Gives a scan's current progress value. """
1255
        return self.scan_collection.get_progress(scan_id)
1256
1257
    def get_scan_host(self, scan_id: str) -> str:
1258
        """ Gives a scan's target. """
1259
        return self.scan_collection.get_host_list(scan_id)
1260
1261
    def get_scan_ports(self, scan_id: str) -> str:
1262
        """ Gives a scan's ports list. """
1263
        return self.scan_collection.get_ports(scan_id)
1264
1265
    def get_scan_exclude_hosts(self, scan_id: str):
1266
        """ Gives a scan's exclude host list. If a target is passed gives
1267
        the exclude host list for the given target. """
1268
        return self.scan_collection.get_exclude_hosts(scan_id)
1269
1270
    def get_scan_credentials(self, scan_id: str) -> Dict:
1271
        """ Gives a scan's credential list. If a target is passed gives
1272
        the credential list for the given target. """
1273
        return self.scan_collection.get_credentials(scan_id)
1274
1275
    def get_scan_target_options(self, scan_id: str) -> Dict:
1276
        """ Gives a scan's target option dict. If a target is passed gives
1277
        the credential list for the given target. """
1278
        return self.scan_collection.get_target_options(scan_id)
1279
1280
    def get_scan_vts(self, scan_id: str) -> Dict:
1281
        """ Gives a scan's vts. """
1282
        return self.scan_collection.get_vts(scan_id)
1283
1284
    def get_scan_unfinished_hosts(self, scan_id: str) -> List:
1285
        """ Get a list of unfinished hosts."""
1286
        return self.scan_collection.get_hosts_unfinished(scan_id)
1287
1288
    def get_scan_finished_hosts(self, scan_id: str) -> List:
1289
        """ Get a list of unfinished hosts."""
1290
        return self.scan_collection.get_hosts_finished(scan_id)
1291
1292
    def get_scan_start_time(self, scan_id: str) -> str:
1293
        """ Gives a scan's start time. """
1294
        return self.scan_collection.get_start_time(scan_id)
1295
1296
    def get_scan_end_time(self, scan_id: str) -> str:
1297
        """ Gives a scan's end time. """
1298
        return self.scan_collection.get_end_time(scan_id)
1299
1300
    def add_scan_log(
1301
        self,
1302
        scan_id: str,
1303
        host: str = '',
1304
        hostname: str = '',
1305
        name: str = '',
1306
        value: str = '',
1307
        port: str = '',
1308
        test_id: str = '',
1309
        qod: str = '',
1310
    ) -> None:
1311
        """ Adds a log result to scan_id scan. """
1312
1313
        self.scan_collection.add_result(
1314
            scan_id,
1315
            ResultType.LOG,
1316
            host,
1317
            hostname,
1318
            name,
1319
            value,
1320
            port,
1321
            test_id,
1322
            '0.0',
1323
            qod,
1324
        )
1325
1326
    def add_scan_error(
1327
        self,
1328
        scan_id: str,
1329
        host: str = '',
1330
        hostname: str = '',
1331
        name: str = '',
1332
        value: str = '',
1333
        port: str = '',
1334
        test_id='',
1335
    ) -> None:
1336
        """ Adds an error result to scan_id scan. """
1337
        self.scan_collection.add_result(
1338
            scan_id,
1339
            ResultType.ERROR,
1340
            host,
1341
            hostname,
1342
            name,
1343
            value,
1344
            port,
1345
            test_id,
1346
        )
1347
1348
    def add_scan_host_detail(
1349
        self,
1350
        scan_id: str,
1351
        host: str = '',
1352
        hostname: str = '',
1353
        name: str = '',
1354
        value: str = '',
1355
    ) -> None:
1356
        """ Adds a host detail result to scan_id scan. """
1357
        self.scan_collection.add_result(
1358
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value
1359
        )
1360
1361
    def add_scan_alarm(
1362
        self,
1363
        scan_id: str,
1364
        host: str = '',
1365
        hostname: str = '',
1366
        name: str = '',
1367
        value: str = '',
1368
        port: str = '',
1369
        test_id: str = '',
1370
        severity: str = '',
1371
        qod: str = '',
1372
    ) -> None:
1373
        """ Adds an alarm result to scan_id scan. """
1374
        self.scan_collection.add_result(
1375
            scan_id,
1376
            ResultType.ALARM,
1377
            host,
1378
            hostname,
1379
            name,
1380
            value,
1381
            port,
1382
            test_id,
1383
            severity,
1384
            qod,
1385
        )
1386