Completed
Push — master ( edf821...f5f7b9 )
by Juan José
14s queued 11s
created

ospd.ospd.OSPDaemon.handle_client_stream()   D

Complexity

Conditions 12

Size

Total Lines 54
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 12
eloc 43
nop 2
dl 0
loc 54
rs 4.8
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like ospd.ospd.OSPDaemon.handle_client_stream() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
# pylint: disable=too-many-lines
20
21
""" OSP Daemon core class.
22
"""
23
24
import logging
25
import socket
26
import ssl
27
import multiprocessing
28
import time
29
import os
30
31
from typing import (
32
    List,
33
    Any,
34
    Iterator,
35
    Dict,
36
    Optional,
37
    Iterable,
38
    Tuple,
39
)
40
from xml.etree.ElementTree import Element, SubElement
41
42
import defusedxml.ElementTree as secET
43
44
from deprecated import deprecated
45
46
from ospd import __version__
47
from ospd.command import get_commands
48
from ospd.errors import OspdCommandError
49
from ospd.misc import ResultType
50
from ospd.network import resolve_hostname, target_str_to_list
51
from ospd.protocol import OspRequest, OspResponse, RequestParser
52
from ospd.scan import ScanCollection, ScanStatus
53
from ospd.server import BaseServer, Stream
54
from ospd.vtfilter import VtsFilter
55
from ospd.vts import Vts
56
from ospd.xml import (
57
    elements_as_text,
58
    get_result_xml,
59
    get_elements_from_dict,
60
)
61
62
logger = logging.getLogger(__name__)
63
64
PROTOCOL_VERSION = "1.2"
65
66
SCHEDULER_CHECK_PERIOD = 10  # in seconds
67
68
BASE_SCANNER_PARAMS = {
69
    'debug_mode': {
70
        'type': 'boolean',
71
        'name': 'Debug Mode',
72
        'default': 0,
73
        'mandatory': 0,
74
        'description': 'Whether to get extra scan debug information.',
75
    },
76
    'dry_run': {
77
        'type': 'boolean',
78
        'name': 'Dry Run',
79
        'default': 0,
80
        'mandatory': 0,
81
        'description': 'Whether to dry run scan.',
82
    },
83
}  # type: Dict
84
85
86
def _terminate_process_group(process: multiprocessing.Process) -> None:
87
    os.killpg(os.getpgid(process.pid), 15)
88
89
90
class OSPDaemon:
91
92
    """ Daemon class for OSP traffic handling.
93
94
    Every scanner wrapper should subclass it and make necessary additions and
95
    changes.
96
97
    * Add any needed parameters in __init__.
98
    * Implement check() method which verifies scanner availability and other
99
      environment related conditions.
100
    * Implement process_scan_params and exec_scan methods which are
101
      specific to handling the <start_scan> command, executing the wrapped
102
      scanner and storing the results.
103
    * exec_scan() should return 0 if host is dead or not reached, 1 if host is
104
      alive and 2 if scan error or status is unknown.
105
    * Implement other methods that assert to False such as get_scanner_name,
106
      get_scanner_version.
107
    * Use Call set_command_attributes at init time to add scanner command
108
      specific options eg. the w3af profile for w3af wrapper.
109
    """
110
111
    def __init__(
112
        self, *, customvtfilter=None, **kwargs
113
    ):  # pylint: disable=unused-argument
114
        """ Initializes the daemon's internal data. """
115
        self.scan_collection = ScanCollection()
116
        self.scan_processes = dict()
117
118
        self.daemon_info = dict()
119
        self.daemon_info['name'] = "OSPd"
120
        self.daemon_info['version'] = __version__
121
        self.daemon_info['description'] = "No description"
122
123
        self.scanner_info = dict()
124
        self.scanner_info['name'] = 'No name'
125
        self.scanner_info['version'] = 'No version'
126
        self.scanner_info['description'] = 'No description'
127
128
        self.server_version = None  # Set by the subclass.
129
130
        self.initialized = None  # Set after initialization finished
131
132
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
133
134
        self.protocol_version = PROTOCOL_VERSION
135
136
        self.commands = {}
137
138
        for command_class in get_commands():
139
            command = command_class(self)
140
            self.commands[command.get_name()] = command
141
142
        self.scanner_params = dict()
143
144
        for name, params in BASE_SCANNER_PARAMS.items():
145
            self.set_scanner_param(name, params)
146
147
        self.vts = Vts()
148
        self.vts_version = None
149
150
        if customvtfilter:
151
            self.vts_filter = customvtfilter
152
        else:
153
            self.vts_filter = VtsFilter()
154
155
    def init(self, server: BaseServer) -> None:
156
        """ Should be overridden by a subclass if the initialization is costly.
157
158
            Will be called after check.
159
        """
160
        server.start(self.handle_client_stream)
161
        self.initialized = True
162
163
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
164
        """ Sets the xml attributes of a specified command. """
165
        if self.command_exists(name):
166
            command = self.commands.get(name)
167
            command.attributes = attributes
168
169
    @deprecated(version="20.4", reason="Use set_scanner_param instead")
170
    def add_scanner_param(self, name: str, scanner_params: Dict) -> None:
171
        """ Set a scanner parameter. """
172
        self.set_scanner_param(name, scanner_params)
173
174
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
175
        """ Set a scanner parameter. """
176
177
        assert name
178
        assert scanner_params
179
180
        self.scanner_params[name] = scanner_params
181
182
    def get_scanner_params(self) -> Dict:
183
        return self.scanner_params
184
185
    def add_vt(
186
        self,
187
        vt_id: str,
188
        name: str = None,
189
        vt_params: str = None,
190
        vt_refs: str = None,
191
        custom: str = None,
192
        vt_creation_time: str = None,
193
        vt_modification_time: str = None,
194
        vt_dependencies: str = None,
195
        summary: str = None,
196
        impact: str = None,
197
        affected: str = None,
198
        insight: str = None,
199
        solution: str = None,
200
        solution_t: str = None,
201
        solution_m: str = None,
202
        detection: str = None,
203
        qod_t: str = None,
204
        qod_v: str = None,
205
        severities: str = None,
206
    ) -> None:
207
        """ Add a vulnerability test information.
208
209
        IMPORTANT: The VT's Data Manager will store the vts collection.
210
        If the collection is considerably big and it will be consultated
211
        intensible during a routine, consider to do a deepcopy(), since
212
        accessing the shared memory in the data manager is very expensive.
213
        At the end of the routine, the temporal copy must be set to None
214
        and deleted.
215
        """
216
        self.vts.add(
217
            vt_id,
218
            name=name,
219
            vt_params=vt_params,
220
            vt_refs=vt_refs,
221
            custom=custom,
222
            vt_creation_time=vt_creation_time,
223
            vt_modification_time=vt_modification_time,
224
            vt_dependencies=vt_dependencies,
225
            summary=summary,
226
            impact=impact,
227
            affected=affected,
228
            insight=insight,
229
            solution=solution,
230
            solution_t=solution_t,
231
            solution_m=solution_m,
232
            detection=detection,
233
            qod_t=qod_t,
234
            qod_v=qod_v,
235
            severities=severities,
236
        )
237
238
    def set_vts_version(self, vts_version: str) -> None:
239
        """ Add into the vts dictionary an entry to identify the
240
        vts version.
241
242
        Parameters:
243
            vts_version (str): Identifies a unique vts version.
244
        """
245
        if not vts_version:
246
            raise OspdCommandError(
247
                'A vts_version parameter is required', 'set_vts_version'
248
            )
249
        self.vts_version = vts_version
250
251
    def get_vts_version(self) -> Optional[str]:
252
        """Return the vts version.
253
        """
254
        return self.vts_version
255
256
    def command_exists(self, name: str) -> bool:
257
        """ Checks if a commands exists. """
258
        return name in self.commands
259
260
    def get_scanner_name(self) -> str:
261
        """ Gives the wrapped scanner's name. """
262
        return self.scanner_info['name']
263
264
    def get_scanner_version(self) -> str:
265
        """ Gives the wrapped scanner's version. """
266
        return self.scanner_info['version']
267
268
    def get_scanner_description(self) -> str:
269
        """ Gives the wrapped scanner's description. """
270
        return self.scanner_info['description']
271
272
    def get_server_version(self) -> str:
273
        """ Gives the specific OSP server's version. """
274
        assert self.server_version
275
        return self.server_version
276
277
    def get_protocol_version(self) -> str:
278
        """ Gives the OSP's version. """
279
        return self.protocol_version
280
281
    def preprocess_scan_params(self, xml_params):
282
        """ Processes the scan parameters. """
283
        params = {}
284
285
        for param in xml_params:
286
            params[param.tag] = param.text or ''
287
288
        # Set default values.
289
        for key in self.scanner_params:
290
            if key not in params:
291
                params[key] = self.get_scanner_param_default(key)
292
                if self.get_scanner_param_type(key) == 'selection':
293
                    params[key] = params[key].split('|')[0]
294
295
        # Validate values.
296
        for key in params:
297
            param_type = self.get_scanner_param_type(key)
298
            if not param_type:
299
                continue
300
301
            if param_type in ['integer', 'boolean']:
302
                try:
303
                    params[key] = int(params[key])
304
                except ValueError:
305
                    raise OspdCommandError(
306
                        'Invalid %s value' % key, 'start_scan'
307
                    )
308
309
            if param_type == 'boolean':
310
                if params[key] not in [0, 1]:
311
                    raise OspdCommandError(
312
                        'Invalid %s value' % key, 'start_scan'
313
                    )
314
            elif param_type == 'selection':
315
                selection = self.get_scanner_param_default(key).split('|')
316
                if params[key] not in selection:
317
                    raise OspdCommandError(
318
                        'Invalid %s value' % key, 'start_scan'
319
                    )
320
            if self.get_scanner_param_mandatory(key) and params[key] == '':
321
                raise OspdCommandError(
322
                    'Mandatory %s value is missing' % key, 'start_scan'
323
                )
324
325
        return params
326
327
    def process_scan_params(self, params: Dict) -> Dict:
328
        """ This method is to be overridden by the child classes if necessary
329
        """
330
        return params
331
332
    @staticmethod
333
    @deprecated(
334
        version="20.4",
335
        reason="Please use OspRequest.process_vt_params instead.",
336
    )
337
    def process_vts_params(scanner_vts) -> Dict:
338
        return OspRequest.process_vts_params(scanner_vts)
339
340
    @staticmethod
341
    @deprecated(
342
        version="20.4",
343
        reason="Please use OspRequest.process_credential_elements instead.",
344
    )
345
    def process_credentials_elements(cred_tree) -> Dict:
346
        return OspRequest.process_credentials_elements(cred_tree)
347
348
    @staticmethod
349
    @deprecated(
350
        version="20.4",
351
        reason="Please use OspRequest.process_targets_elements instead.",
352
    )
353
    def process_targets_element(scanner_target) -> List:
354
        return OspRequest.process_target_element(scanner_target)
355
356
    def stop_scan(self, scan_id: str) -> None:
357
        scan_process = self.scan_processes.get(scan_id)
358
        if not scan_process:
359
            raise OspdCommandError(
360
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
361
            )
362
        if not scan_process.is_alive():
363
            raise OspdCommandError(
364
                'Scan already stopped or finished.', 'stop_scan'
365
            )
366
367
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
368
369
        logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
370
371
        self.stop_scan_cleanup(scan_id)
372
373
        try:
374
            scan_process.terminate()
375
        except AttributeError:
376
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
377
378
        try:
379
            _terminate_process_group(scan_process)
380
        except ProcessLookupError as e:
381
            logger.info(
382
                '%s: Scan already stopped %s.', scan_id, scan_process.pid
383
            )
384
385
        if scan_process.ident != os.getpid():
386
            scan_process.join(0)
387
388
        logger.info('%s: Scan stopped.', scan_id)
389
390
    @staticmethod
391
    def stop_scan_cleanup(scan_id: str):
392
        """ Should be implemented by subclass in case of a clean up before
393
        terminating is needed. """
394
395
    @staticmethod
396
    def target_is_finished(scan_id: str):
397
        """ Should be implemented by subclass in case of a check before
398
        stopping is needed. """
399
400
    def exec_scan(self, scan_id: str):
401
        """ Asserts to False. Should be implemented by subclass. """
402
        raise NotImplementedError
403
404
    def finish_scan(self, scan_id: str) -> None:
405
        """ Sets a scan as finished. """
406
        self.set_scan_progress(scan_id, 100)
407
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
408
        logger.info("%s: Scan finished.", scan_id)
409
410
    def get_daemon_name(self) -> str:
411
        """ Gives osp daemon's name. """
412
        return self.daemon_info['name']
413
414
    def get_daemon_version(self) -> str:
415
        """ Gives osp daemon's version. """
416
        return self.daemon_info['version']
417
418
    def get_scanner_param_type(self, param: str):
419
        """ Returns type of a scanner parameter. """
420
        assert isinstance(param, str)
421
        entry = self.scanner_params.get(param)
422
        if not entry:
423
            return None
424
        return entry.get('type')
425
426
    def get_scanner_param_mandatory(self, param: str):
427
        """ Returns if a scanner parameter is mandatory. """
428
        assert isinstance(param, str)
429
        entry = self.scanner_params.get(param)
430
        if not entry:
431
            return False
432
        return entry.get('mandatory')
433
434
    def get_scanner_param_default(self, param: str):
435
        """ Returns default value of a scanner parameter. """
436
        assert isinstance(param, str)
437
        entry = self.scanner_params.get(param)
438
        if not entry:
439
            return None
440
        return entry.get('default')
441
442
    @deprecated(
443
        version="20.4",
444
        reason="Please use OspResponse.create_scanner_params_xml instead.",
445
    )
446
    def get_scanner_params_xml(self):
447
        """ Returns the OSP Daemon's scanner params in xml format. """
448
        return OspResponse.create_scanner_params_xml(self.scanner_params)
449
450
    def handle_client_stream(self, stream: Stream) -> None:
451
        """ Handles stream of data received from client. """
452
        data = b''
453
454
        request_parser = RequestParser()
455
456
        while True:
457
            try:
458
                buf = stream.read()
459
                if not buf:
460
                    break
461
462
                data += buf
463
464
                if request_parser.has_ended(buf):
465
                    break
466
            except (AttributeError, ValueError) as message:
467
                logger.error(message)
468
                return
469
            except (ssl.SSLError) as exception:
470
                logger.debug('Error: %s', exception)
471
                break
472
            except (socket.timeout) as exception:
473
                logger.debug('Request timeout: %s', exception)
474
                break
475
476
        if len(data) <= 0:
477
            logger.debug("Empty client stream")
478
            return
479
480
        if not self.initialized:
481
            exception = OspdCommandError(
482
                '%s is still starting' % self.daemon_info['name'], 'error'
483
            )
484
            response = exception.as_xml()
485
            stream.write(response)
486
            stream.close()
487
            return
488
489
        response = None
490
        try:
491
            self.handle_command(data, stream)
492
        except OspdCommandError as exception:
493
            response = exception.as_xml()
494
            logger.debug('Command error: %s', exception.message)
495
        except Exception:  # pylint: disable=broad-except
496
            logger.exception('While handling client command:')
497
            exception = OspdCommandError('Fatal error', 'error')
498
            response = exception.as_xml()
499
500
        if response:
501
            stream.write(response)
502
503
        stream.close()
504
505
    def calculate_progress(self, scan_id: str) -> float:
506
        """ Calculate the total scan progress. """
507
508
        return self.scan_collection.calculate_target_progress(scan_id)
509
510
    def process_exclude_hosts(self, scan_id: str, exclude_hosts: str) -> None:
511
        """ Process the exclude hosts before launching the scans."""
512
513
        exc_hosts_list = ''
514
        if not exclude_hosts:
515
            return
516
        exc_hosts_list = target_str_to_list(exclude_hosts)
517
        self.remove_scan_hosts_from_target_progress(scan_id, exc_hosts_list)
518
519
    def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None:
520
        """ Process the finished hosts before launching the scans.
521
        Set finished hosts as finished with 100% to calculate
522
        the scan progress."""
523
524
        exc_hosts_list = ''
525
        if not finished_hosts:
526
            return
527
528
        exc_hosts_list = target_str_to_list(finished_hosts)
529
530
        for host in exc_hosts_list:
531
            self.set_scan_host_finished(scan_id, host)
532
            self.set_scan_host_progress(scan_id, host, 100)
533
534
    def start_scan(self, scan_id: str, target: Dict) -> None:
535
        """ Starts the scan with scan_id. """
536
        os.setsid()
537
538
        if target is None or not target:
539
            raise OspdCommandError('Erroneous target', 'start_scan')
540
541
        logger.info("%s: Scan started.", scan_id)
542
543
        self.process_exclude_hosts(scan_id, target.get('exclude_hosts'))
544
        self.process_finished_hosts(scan_id, target.get('finished_hosts'))
545
546
        try:
547
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
548
            ret = self.exec_scan(scan_id)
549
        except Exception as e:  # pylint: disable=broad-except
550
            self.add_scan_error(
551
                scan_id,
552
                name='',
553
                host=self.get_scan_host(scan_id),
554
                value='Host process failure (%s).' % e,
555
            )
556
            logger.exception('While scanning: %s', scan_id)
557
        else:
558
            logger.info("%s: Host scan finished.", scan_id)
559
560
        if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
561
            self.finish_scan(scan_id)
562
563
    def dry_run_scan(self, scan_id: str, target: Dict) -> None:
564
        """ Dry runs a scan. """
565
566
        os.setsid()
567
568
        host = resolve_hostname(target[0])
569
        if host is None:
570
            logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))
571
572
        port = self.get_scan_ports(scan_id)
573
574
        logger.info("%s:%s: Dry run mode.", host, port)
575
576
        self.add_scan_log(scan_id, name='', host=host, value='Dry run result')
577
578
        self.finish_scan(scan_id)
579
580
    def handle_timeout(self, scan_id: str, host: str) -> None:
581
        """ Handles scanner reaching timeout error. """
582
        self.add_scan_error(
583
            scan_id,
584
            host=host,
585
            name="Timeout",
586
            value="{0} exec timeout.".format(self.get_scanner_name()),
587
        )
588
589
    def remove_scan_hosts_from_target_progress(
590
        self, scan_id: str, exc_hosts_list: List
591
    ) -> None:
592
        """ Remove a list of hosts from the main scan progress table."""
593
        self.scan_collection.remove_hosts_from_target_progress(
594
            scan_id, exc_hosts_list
595
        )
596
597
    def set_scan_host_finished(self, scan_id: str, host: str) -> None:
598
        """ Add the host in a list of finished hosts """
599
        self.scan_collection.set_host_finished(scan_id, host)
600
601
    def set_scan_progress(self, scan_id: str, progress: int) -> None:
602
        """ Sets scan_id scan's progress which is a number
603
        between 0 and 100. """
604
        self.scan_collection.set_progress(scan_id, progress)
605
606
    def set_scan_host_progress(
607
        self, scan_id: str, host: str, progress: int
608
    ) -> None:
609
        """ Sets host's progress which is part of target.
610
        Each time a host progress is updated, the scan progress
611
        is updated too.
612
        """
613
        self.scan_collection.set_host_progress(scan_id, host, progress)
614
615
        scan_progress = self.calculate_progress(scan_id)
616
        self.set_scan_progress(scan_id, scan_progress)
617
618
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
619
        """ Set the scan's status."""
620
        self.scan_collection.set_status(scan_id, status)
621
622
    def get_scan_status(self, scan_id: str) -> ScanStatus:
623
        """ Get scan_id scans's status."""
624
        return self.scan_collection.get_status(scan_id)
625
626
    def scan_exists(self, scan_id: str) -> bool:
627
        """ Checks if a scan with ID scan_id is in collection.
628
629
        @return: 1 if scan exists, 0 otherwise.
630
        """
631
        return self.scan_collection.id_exists(scan_id)
632
633
    def get_help_text(self) -> str:
634
        """ Returns the help output in plain text format."""
635
636
        txt = ''
637
        for name, info in self.commands.items():
638
            description = info.get_description()
639
            attributes = info.get_attributes()
640
            elements = info.get_elements()
641
642
            command_txt = "\t{0: <22} {1}\n".format(name, description)
643
644
            if attributes:
645
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
646
647
                for attrname, attrdesc in attributes.items():
648
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
649
                    command_txt = ''.join([command_txt, attr_txt])
650
651
            if elements:
652
                command_txt = ''.join(
653
                    [command_txt, "\t Elements:\n", elements_as_text(elements),]
654
                )
655
656
            txt += command_txt
657
658
        return txt
659
660
    @deprecated(version="20.4", reason="Use ospd.xml.elements_as_text instead.")
661
    def elements_as_text(self, elems: Dict, indent: int = 2) -> str:
662
        """ Returns the elems dictionary as formatted plain text. """
663
        return elements_as_text(elems, indent)
664
665
    def delete_scan(self, scan_id: str) -> int:
666
        """ Deletes scan_id scan from collection.
667
668
        @return: 1 if scan deleted, 0 otherwise.
669
        """
670
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
671
            return 0
672
673
        # Don't delete the scan until the process stops
674
        exitcode = None
675
        try:
676
            self.scan_processes[scan_id].join()
677
            exitcode = self.scan_processes[scan_id].exitcode
678
        except KeyError:
679
            logger.debug('Scan process for %s not found', scan_id)
680
681
        if exitcode or exitcode == 0:
682
            del self.scan_processes[scan_id]
683
684
        return self.scan_collection.delete_scan(scan_id)
685
686
    def get_scan_results_xml(
687
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
688
    ):
689
        """ Gets scan_id scan's results in XML format.
690
691
        @return: String of scan results in xml.
692
        """
693
        results = Element('results')
694
        for result in self.scan_collection.results_iterator(
695
            scan_id, pop_res, max_res
696
        ):
697
            results.append(get_result_xml(result))
698
699
        logger.debug('Returning %d results', len(results))
700
        return results
701
702
    @deprecated(
703
        version="20.4",
704
        reason="Please use ospd.xml.get_elements_from_dict instead.",
705
    )
706
    def get_xml_str(self, data: Dict) -> List:
707
        """ Creates a string in XML Format using the provided data structure.
708
709
        @param: Dictionary of xml tags and their elements.
710
711
        @return: String of data in xml format.
712
        """
713
        return get_elements_from_dict(data)
714
715
    def get_scan_xml(
716
        self,
717
        scan_id: str,
718
        detailed: bool = True,
719
        pop_res: bool = False,
720
        max_res: int = 0,
721
    ):
722
        """ Gets scan in XML format.
723
724
        @return: String of scan in XML format.
725
        """
726
        if not scan_id:
727
            return Element('scan')
728
729
        target = self.get_scan_host(scan_id)
730
        progress = self.get_scan_progress(scan_id)
731
        status = self.get_scan_status(scan_id)
732
        start_time = self.get_scan_start_time(scan_id)
733
        end_time = self.get_scan_end_time(scan_id)
734
        response = Element('scan')
735
        for name, value in [
736
            ('id', scan_id),
737
            ('target', target),
738
            ('progress', progress),
739
            ('status', status.name.lower()),
740
            ('start_time', start_time),
741
            ('end_time', end_time),
742
        ]:
743
            response.set(name, str(value))
744
        if detailed:
745
            response.append(
746
                self.get_scan_results_xml(scan_id, pop_res, max_res)
747
            )
748
        return response
749
750
    @staticmethod
751
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
752
        vt_id: str, custom: Dict
753
    ) -> str:
754
        """ Create a string representation of the XML object from the
755
        custom data object.
756
        This needs to be implemented by each ospd wrapper, in case
757
        custom elements for VTs are used.
758
759
        The custom XML object which is returned will be embedded
760
        into a <custom></custom> element.
761
762
        @return: XML object as string for custom data.
763
        """
764
        return ''
765
766
    @staticmethod
767
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
768
        vt_id: str, vt_params
769
    ) -> str:
770
        """ Create a string representation of the XML object from the
771
        vt_params data object.
772
        This needs to be implemented by each ospd wrapper, in case
773
        vt_params elements for VTs are used.
774
775
        The params XML object which is returned will be embedded
776
        into a <params></params> element.
777
778
        @return: XML object as string for vt parameters data.
779
        """
780
        return ''
781
782
    @staticmethod
783
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
784
        vt_id: str, vt_refs
785
    ) -> str:
786
        """ Create a string representation of the XML object from the
787
        refs data object.
788
        This needs to be implemented by each ospd wrapper, in case
789
        refs elements for VTs are used.
790
791
        The refs XML object which is returned will be embedded
792
        into a <refs></refs> element.
793
794
        @return: XML object as string for vt references data.
795
        """
796
        return ''
797
798
    @staticmethod
799
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
800
        vt_id: str, vt_dependencies
801
    ) -> str:
802
        """ Create a string representation of the XML object from the
803
        vt_dependencies data object.
804
        This needs to be implemented by each ospd wrapper, in case
805
        vt_dependencies elements for VTs are used.
806
807
        The vt_dependencies XML object which is returned will be embedded
808
        into a <dependencies></dependencies> element.
809
810
        @return: XML object as string for vt dependencies data.
811
        """
812
        return ''
813
814
    @staticmethod
815
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
816
        vt_id: str, vt_creation_time
817
    ) -> str:
818
        """ Create a string representation of the XML object from the
819
        vt_creation_time data object.
820
        This needs to be implemented by each ospd wrapper, in case
821
        vt_creation_time elements for VTs are used.
822
823
        The vt_creation_time XML object which is returned will be embedded
824
        into a <vt_creation_time></vt_creation_time> element.
825
826
        @return: XML object as string for vt creation time data.
827
        """
828
        return ''
829
830
    @staticmethod
831
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
832
        vt_id: str, vt_modification_time
833
    ) -> str:
834
        """ Create a string representation of the XML object from the
835
        vt_modification_time data object.
836
        This needs to be implemented by each ospd wrapper, in case
837
        vt_modification_time elements for VTs are used.
838
839
        The vt_modification_time XML object which is returned will be embedded
840
        into a <vt_modification_time></vt_modification_time> element.
841
842
        @return: XML object as string for vt references data.
843
        """
844
        return ''
845
846
    @staticmethod
847
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
848
        vt_id: str, summary
849
    ) -> str:
850
        """ Create a string representation of the XML object from the
851
        summary data object.
852
        This needs to be implemented by each ospd wrapper, in case
853
        summary elements for VTs are used.
854
855
        The summary XML object which is returned will be embedded
856
        into a <summary></summary> element.
857
858
        @return: XML object as string for summary data.
859
        """
860
        return ''
861
862
    @staticmethod
863
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
864
        vt_id: str, impact
865
    ) -> str:
866
        """ Create a string representation of the XML object from the
867
        impact data object.
868
        This needs to be implemented by each ospd wrapper, in case
869
        impact elements for VTs are used.
870
871
        The impact XML object which is returned will be embedded
872
        into a <impact></impact> element.
873
874
        @return: XML object as string for impact data.
875
        """
876
        return ''
877
878
    @staticmethod
879
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
880
        vt_id: str, affected
881
    ) -> str:
882
        """ Create a string representation of the XML object from the
883
        affected data object.
884
        This needs to be implemented by each ospd wrapper, in case
885
        affected elements for VTs are used.
886
887
        The affected XML object which is returned will be embedded
888
        into a <affected></affected> element.
889
890
        @return: XML object as string for affected data.
891
        """
892
        return ''
893
894
    @staticmethod
895
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
896
        vt_id: str, insight
897
    ) -> str:
898
        """ Create a string representation of the XML object from the
899
        insight data object.
900
        This needs to be implemented by each ospd wrapper, in case
901
        insight elements for VTs are used.
902
903
        The insight XML object which is returned will be embedded
904
        into a <insight></insight> element.
905
906
        @return: XML object as string for insight data.
907
        """
908
        return ''
909
910
    @staticmethod
911
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
912
        vt_id: str, solution, solution_type=None, solution_method=None
913
    ) -> str:
914
        """ Create a string representation of the XML object from the
915
        solution data object.
916
        This needs to be implemented by each ospd wrapper, in case
917
        solution elements for VTs are used.
918
919
        The solution XML object which is returned will be embedded
920
        into a <solution></solution> element.
921
922
        @return: XML object as string for solution data.
923
        """
924
        return ''
925
926
    @staticmethod
927
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
928
        vt_id: str, detection=None, qod_type=None, qod=None
929
    ) -> str:
930
        """ Create a string representation of the XML object from the
931
        detection data object.
932
        This needs to be implemented by each ospd wrapper, in case
933
        detection elements for VTs are used.
934
935
        The detection XML object which is returned is an element with
936
        tag <detection></detection> element
937
938
        @return: XML object as string for detection data.
939
        """
940
        return ''
941
942
    @staticmethod
943
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
944
        vt_id: str, severities
945
    ) -> str:
946
        """ Create a string representation of the XML object from the
947
        severities data object.
948
        This needs to be implemented by each ospd wrapper, in case
949
        severities elements for VTs are used.
950
951
        The severities XML objects which are returned will be embedded
952
        into a <severities></severities> element.
953
954
        @return: XML object as string for severities data.
955
        """
956
        return ''
957
958
    def get_vt_iterator(self) -> Iterator[Tuple[str, Dict]]:
959
        """ Return iterator object for getting elements
960
        from the VTs dictionary. """
961
        return self.vts.items()
962
963
    def get_vt_xml(self, single_vt: Tuple[str, Dict]) -> Element:
964
        """ Gets a single vulnerability test information in XML format.
965
966
        @return: String of single vulnerability test information in XML format.
967
        """
968
        if not single_vt:
969
            return Element('vt')
970
971
        vt_id, vt = single_vt
972
973
        name = vt.get('name')
974
        vt_xml = Element('vt')
975
        vt_xml.set('id', vt_id)
976
977
        for name, value in [('name', name)]:
978
            elem = SubElement(vt_xml, name)
979
            elem.text = str(value)
980
981
        if vt.get('vt_params'):
982
            params_xml_str = self.get_params_vt_as_xml_str(
983
                vt_id, vt.get('vt_params')
984
            )
985
            vt_xml.append(secET.fromstring(params_xml_str))
986
987
        if vt.get('vt_refs'):
988
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
989
            vt_xml.append(secET.fromstring(refs_xml_str))
990
991
        if vt.get('vt_dependencies'):
992
            dependencies = self.get_dependencies_vt_as_xml_str(
993
                vt_id, vt.get('vt_dependencies')
994
            )
995
            vt_xml.append(secET.fromstring(dependencies))
996
997
        if vt.get('creation_time'):
998
            vt_ctime = self.get_creation_time_vt_as_xml_str(
999
                vt_id, vt.get('creation_time')
1000
            )
1001
            vt_xml.append(secET.fromstring(vt_ctime))
1002
1003
        if vt.get('modification_time'):
1004
            vt_mtime = self.get_modification_time_vt_as_xml_str(
1005
                vt_id, vt.get('modification_time')
1006
            )
1007
            vt_xml.append(secET.fromstring(vt_mtime))
1008
1009
        if vt.get('summary'):
1010
            summary_xml_str = self.get_summary_vt_as_xml_str(
1011
                vt_id, vt.get('summary')
1012
            )
1013
            vt_xml.append(secET.fromstring(summary_xml_str))
1014
1015
        if vt.get('impact'):
1016
            impact_xml_str = self.get_impact_vt_as_xml_str(
1017
                vt_id, vt.get('impact')
1018
            )
1019
            vt_xml.append(secET.fromstring(impact_xml_str))
1020
1021
        if vt.get('affected'):
1022
            affected_xml_str = self.get_affected_vt_as_xml_str(
1023
                vt_id, vt.get('affected')
1024
            )
1025
            vt_xml.append(secET.fromstring(affected_xml_str))
1026
1027
        if vt.get('insight'):
1028
            insight_xml_str = self.get_insight_vt_as_xml_str(
1029
                vt_id, vt.get('insight')
1030
            )
1031
            vt_xml.append(secET.fromstring(insight_xml_str))
1032
1033
        if vt.get('solution'):
1034
            solution_xml_str = self.get_solution_vt_as_xml_str(
1035
                vt_id,
1036
                vt.get('solution'),
1037
                vt.get('solution_type'),
1038
                vt.get('solution_method'),
1039
            )
1040
            vt_xml.append(secET.fromstring(solution_xml_str))
1041
1042
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1043
            detection_xml_str = self.get_detection_vt_as_xml_str(
1044
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1045
            )
1046
            vt_xml.append(secET.fromstring(detection_xml_str))
1047
1048
        if vt.get('severities'):
1049
            severities_xml_str = self.get_severities_vt_as_xml_str(
1050
                vt_id, vt.get('severities')
1051
            )
1052
            vt_xml.append(secET.fromstring(severities_xml_str))
1053
1054
        if vt.get('custom'):
1055
            custom_xml_str = self.get_custom_vt_as_xml_str(
1056
                vt_id, vt.get('custom')
1057
            )
1058
            vt_xml.append(secET.fromstring(custom_xml_str))
1059
1060
        return vt_xml
1061
1062
    def get_vts_selection_list(
1063
        self, vt_id: str = None, filtered_vts: Dict = None
1064
    ) -> Iterable[str]:
1065
        """
1066
        Get list of VT's OID.
1067
        If vt_id is specified, the collection will contain only this vt, if
1068
        found.
1069
        If no vt_id is specified or filtered_vts is None (default), the
1070
        collection will contain all vts. Otherwise those vts passed
1071
        in filtered_vts or vt_id are returned. In case of both vt_id and
1072
        filtered_vts are given, filtered_vts has priority.
1073
1074
        Arguments:
1075
            vt_id (vt_id, optional): ID of the vt to get.
1076
            filtered_vts (list, optional): Filtered VTs collection.
1077
1078
        Return:
1079
            List of selected VT's OID.
1080
        """
1081
        vts_xml = []
1082
        if not self.vts:
1083
            return vts_xml
1084
1085
        # No match for the filter
1086
        if filtered_vts is not None and len(filtered_vts) == 0:
1087
            return vts_xml
1088
1089
        if filtered_vts:
1090
            vts_list = filtered_vts
1091
        elif vt_id:
1092
            vts_list = [vt_id]
1093
        else:
1094
            vts_list = self.vts.keys()
1095
1096
        return vts_list
1097
1098
    def handle_command(self, data: bytes, stream: Stream) -> None:
1099
        """ Handles an osp command in a string.
1100
        """
1101
        try:
1102
            tree = secET.fromstring(data)
1103
        except secET.ParseError:
1104
            logger.debug("Erroneous client input: %s", data)
1105
            raise OspdCommandError('Invalid data')
1106
1107
        command_name = tree.tag
1108
1109
        logger.debug('Handling %s command request.', command_name)
1110
1111
        command = self.commands.get(command_name, None)
1112
        if not command and command_name != "authenticate":
1113
            raise OspdCommandError('Bogus command name')
1114
1115
        response = command.handle_xml(tree)
1116
1117
        if isinstance(response, bytes):
1118
            stream.write(response)
1119
        else:
1120
            for data in response:
1121
                stream.write(data)
1122
1123
    def check(self):
1124
        """ Asserts to False. Should be implemented by subclass. """
1125
        raise NotImplementedError
1126
1127
    def run(self) -> None:
1128
        """ Starts the Daemon, handling commands until interrupted.
1129
        """
1130
1131
        try:
1132
            while True:
1133
                time.sleep(SCHEDULER_CHECK_PERIOD)
1134
                self.scheduler()
1135
                self.clean_forgotten_scans()
1136
                self.wait_for_children()
1137
        except KeyboardInterrupt:
1138
            logger.info("Received Ctrl-C shutting-down ...")
1139
1140
    def scheduler(self):
1141
        """ Should be implemented by subclass in case of need
1142
        to run tasks periodically. """
1143
1144
    def wait_for_children(self):
1145
        """ Join the zombie process to releases resources."""
1146
        for scan_id in self.scan_processes:
1147
            self.scan_processes[scan_id].join(0)
1148
1149
    def create_scan(
1150
        self,
1151
        scan_id: str,
1152
        targets: Dict,
1153
        options: Optional[Dict],
1154
        vt_selection: Dict,
1155
    ) -> Optional[str]:
1156
        """ Creates a new scan.
1157
1158
        @target: Target to scan.
1159
        @options: Miscellaneous scan options.
1160
1161
        @return: New scan's ID. None if the scan_id already exists and the
1162
                 scan status is RUNNING or FINISHED.
1163
        """
1164
        status = None
1165
        scan_exists = self.scan_exists(scan_id)
1166
        if scan_id and scan_exists:
1167
            status = self.get_scan_status(scan_id)
1168
1169
        if scan_exists and status == ScanStatus.STOPPED:
1170
            logger.info("Scan %s exists. Resuming scan.", scan_id)
1171
        elif scan_exists and (
1172
            status == ScanStatus.RUNNING or status == ScanStatus.FINISHED
1173
        ):
1174
            logger.info(
1175
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1176
            )
1177
            return
1178
        return self.scan_collection.create_scan(
1179
            scan_id, targets, options, vt_selection
1180
        )
1181
1182
    def get_scan_options(self, scan_id: str) -> str:
1183
        """ Gives a scan's list of options. """
1184
        return self.scan_collection.get_options(scan_id)
1185
1186
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1187
        """ Sets a scan's option to a provided value. """
1188
        return self.scan_collection.set_option(scan_id, name, value)
1189
1190
    def clean_forgotten_scans(self) -> None:
1191
        """ Check for old stopped or finished scans which have not been
1192
        deleted and delete them if the are older than the set value."""
1193
1194
        if not self.scaninfo_store_time:
1195
            return
1196
1197
        for scan_id in list(self.scan_collection.ids_iterator()):
1198
            end_time = int(self.get_scan_end_time(scan_id))
1199
            scan_status = self.get_scan_status(scan_id)
1200
1201
            if (
1202
                scan_status == ScanStatus.STOPPED
1203
                or scan_status == ScanStatus.FINISHED
1204
            ) and end_time:
1205
                stored_time = int(time.time()) - end_time
1206
                if stored_time > self.scaninfo_store_time * 3600:
1207
                    logger.debug(
1208
                        'Scan %s is older than %d hours and seems have been '
1209
                        'forgotten. Scan info will be deleted from the '
1210
                        'scan table',
1211
                        scan_id,
1212
                        self.scaninfo_store_time,
1213
                    )
1214
                    self.delete_scan(scan_id)
1215
1216
    def check_scan_process(self, scan_id: str) -> None:
1217
        """ Check the scan's process, and terminate the scan if not alive. """
1218
        scan_process = self.scan_processes[scan_id]
1219
        progress = self.get_scan_progress(scan_id)
1220
1221
        if progress < 100 and not scan_process.is_alive():
1222
            if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
1223
                self.set_scan_status(scan_id, ScanStatus.STOPPED)
1224
                self.add_scan_error(
1225
                    scan_id, name="", host="", value="Scan process failure."
1226
                )
1227
1228
                logger.info("%s: Scan stopped with errors.", scan_id)
1229
1230
        elif progress == 100:
1231
            scan_process.join(0)
1232
1233
    def get_scan_progress(self, scan_id: str):
1234
        """ Gives a scan's current progress value. """
1235
        return self.scan_collection.get_progress(scan_id)
1236
1237
    def get_scan_host(self, scan_id: str) -> str:
1238
        """ Gives a scan's target. """
1239
        return self.scan_collection.get_host_list(scan_id)
1240
1241
    def get_scan_ports(self, scan_id: str) -> str:
1242
        """ Gives a scan's ports list. """
1243
        return self.scan_collection.get_ports(scan_id)
1244
1245
    def get_scan_exclude_hosts(self, scan_id: str):
1246
        """ Gives a scan's exclude host list. If a target is passed gives
1247
        the exclude host list for the given target. """
1248
        return self.scan_collection.get_exclude_hosts(scan_id)
1249
1250
    def get_scan_credentials(self, scan_id: str) -> Dict:
1251
        """ Gives a scan's credential list. If a target is passed gives
1252
        the credential list for the given target. """
1253
        return self.scan_collection.get_credentials(scan_id)
1254
1255
    def get_scan_target_options(self, scan_id: str) -> Dict:
1256
        """ Gives a scan's target option dict. If a target is passed gives
1257
        the credential list for the given target. """
1258
        return self.scan_collection.get_target_options(scan_id)
1259
1260
    def get_scan_vts(self, scan_id: str) -> Dict:
1261
        """ Gives a scan's vts. """
1262
        return self.scan_collection.get_vts(scan_id)
1263
1264
    def get_scan_unfinished_hosts(self, scan_id: str) -> List:
1265
        """ Get a list of unfinished hosts."""
1266
        return self.scan_collection.get_hosts_unfinished(scan_id)
1267
1268
    def get_scan_finished_hosts(self, scan_id: str) -> List:
1269
        """ Get a list of unfinished hosts."""
1270
        return self.scan_collection.get_hosts_finished(scan_id)
1271
1272
    def get_scan_start_time(self, scan_id: str) -> str:
1273
        """ Gives a scan's start time. """
1274
        return self.scan_collection.get_start_time(scan_id)
1275
1276
    def get_scan_end_time(self, scan_id: str) -> str:
1277
        """ Gives a scan's end time. """
1278
        return self.scan_collection.get_end_time(scan_id)
1279
1280
    def add_scan_log(
1281
        self,
1282
        scan_id: str,
1283
        host: str = '',
1284
        hostname: str = '',
1285
        name: str = '',
1286
        value: str = '',
1287
        port: str = '',
1288
        test_id: str = '',
1289
        qod: str = '',
1290
    ):
1291
        """ Adds a log result to scan_id scan. """
1292
        self.scan_collection.add_result(
1293
            scan_id,
1294
            ResultType.LOG,
1295
            host,
1296
            hostname,
1297
            name,
1298
            value,
1299
            port,
1300
            test_id,
1301
            '0.0',
1302
            qod,
1303
        )
1304
1305
    def add_scan_error(
1306
        self,
1307
        scan_id: str,
1308
        host: str = '',
1309
        hostname: str = '',
1310
        name: str = '',
1311
        value: str = '',
1312
        port: str = '',
1313
    ) -> None:
1314
        """ Adds an error result to scan_id scan. """
1315
        self.scan_collection.add_result(
1316
            scan_id, ResultType.ERROR, host, hostname, name, value, port
1317
        )
1318
1319
    def add_scan_host_detail(
1320
        self,
1321
        scan_id: str,
1322
        host: str = '',
1323
        hostname: str = '',
1324
        name: str = '',
1325
        value: str = '',
1326
    ) -> None:
1327
        """ Adds a host detail result to scan_id scan. """
1328
        self.scan_collection.add_result(
1329
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value
1330
        )
1331
1332
    def add_scan_alarm(
1333
        self,
1334
        scan_id: str,
1335
        host: str = '',
1336
        hostname: str = '',
1337
        name: str = '',
1338
        value: str = '',
1339
        port: str = '',
1340
        test_id: str = '',
1341
        severity: str = '',
1342
        qod: str = '',
1343
    ):
1344
        """ Adds an alarm result to scan_id scan. """
1345
        self.scan_collection.add_result(
1346
            scan_id,
1347
            ResultType.ALARM,
1348
            host,
1349
            hostname,
1350
            name,
1351
            value,
1352
            port,
1353
            test_id,
1354
            severity,
1355
            qod,
1356
        )
1357