Passed
Pull Request — master (#237)
by Juan José
01:35
created

ospd.ospd.OSPDaemon.handle_client_stream()   C

Complexity

Conditions 11

Size

Total Lines 45
Code Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 36
nop 2
dl 0
loc 45
rs 5.4
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like ospd.ospd.OSPDaemon.handle_client_stream() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
# pylint: disable=too-many-lines
20
21
""" OSP Daemon core class.
22
"""
23
24
import logging
25
import socket
26
import ssl
27
import multiprocessing
28
import time
29
import os
30
31
from typing import (
32
    List,
33
    Any,
34
    Iterator,
35
    Dict,
36
    Optional,
37
    Iterable,
38
    Tuple,
39
)
40
from xml.etree.ElementTree import Element, SubElement
41
42
import defusedxml.ElementTree as secET
43
44
from deprecated import deprecated
45
46
from ospd import __version__
47
from ospd.command import get_commands
48
from ospd.errors import OspdCommandError
49
from ospd.misc import ResultType
50
from ospd.network import resolve_hostname, target_str_to_list
51
from ospd.protocol import OspRequest, OspResponse, RequestParser
52
from ospd.scan import ScanCollection, ScanStatus
53
from ospd.server import BaseServer, Stream
54
from ospd.vtfilter import VtsFilter
55
from ospd.vts import Vts
56
from ospd.xml import (
57
    elements_as_text,
58
    get_result_xml,
59
    get_elements_from_dict,
60
)
61
62
logger = logging.getLogger(__name__)
63
64
PROTOCOL_VERSION = "1.2"
65
66
SCHEDULER_CHECK_PERIOD = 10  # in seconds
67
68
BASE_SCANNER_PARAMS = {
69
    'debug_mode': {
70
        'type': 'boolean',
71
        'name': 'Debug Mode',
72
        'default': 0,
73
        'mandatory': 0,
74
        'description': 'Whether to get extra scan debug information.',
75
    },
76
    'dry_run': {
77
        'type': 'boolean',
78
        'name': 'Dry Run',
79
        'default': 0,
80
        'mandatory': 0,
81
        'description': 'Whether to dry run scan.',
82
    },
83
}  # type: Dict
84
85
86
def _terminate_process_group(process: multiprocessing.Process) -> None:
87
    os.killpg(os.getpgid(process.pid), 15)
88
89
90
class OSPDaemon:
91
92
    """ Daemon class for OSP traffic handling.
93
94
    Every scanner wrapper should subclass it and make necessary additions and
95
    changes.
96
97
    * Add any needed parameters in __init__.
98
    * Implement check() method which verifies scanner availability and other
99
      environment related conditions.
100
    * Implement process_scan_params and exec_scan methods which are
101
      specific to handling the <start_scan> command, executing the wrapped
102
      scanner and storing the results.
103
    * exec_scan() should return 0 if host is dead or not reached, 1 if host is
104
      alive and 2 if scan error or status is unknown.
105
    * Implement other methods that assert to False such as get_scanner_name,
106
      get_scanner_version.
107
    * Use Call set_command_attributes at init time to add scanner command
108
      specific options eg. the w3af profile for w3af wrapper.
109
    """
110
111
    def __init__(
112
        self, *, customvtfilter=None, **kwargs
113
    ):  # pylint: disable=unused-argument
114
        """ Initializes the daemon's internal data. """
115
        self.scan_collection = ScanCollection()
116
        self.scan_processes = dict()
117
118
        self.daemon_info = dict()
119
        self.daemon_info['name'] = "OSPd"
120
        self.daemon_info['version'] = __version__
121
        self.daemon_info['description'] = "No description"
122
123
        self.scanner_info = dict()
124
        self.scanner_info['name'] = 'No name'
125
        self.scanner_info['version'] = 'No version'
126
        self.scanner_info['description'] = 'No description'
127
128
        self.server_version = None  # Set by the subclass.
129
130
        self.initialized = None  # Set after initialization finished
131
132
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
133
134
        self.protocol_version = PROTOCOL_VERSION
135
136
        self.commands = {}
137
138
        for command_class in get_commands():
139
            command = command_class(self)
140
            self.commands[command.get_name()] = command
141
142
        self.scanner_params = dict()
143
144
        for name, params in BASE_SCANNER_PARAMS.items():
145
            self.set_scanner_param(name, params)
146
147
        self.vts = Vts()
148
        self.vts_version = None
149
150
        if customvtfilter:
151
            self.vts_filter = customvtfilter
152
        else:
153
            self.vts_filter = VtsFilter()
154
155
    def init(self, server: BaseServer) -> None:
156
        """ Should be overridden by a subclass if the initialization is costly.
157
158
            Will be called after check.
159
        """
160
        server.start(self.handle_client_stream)
161
        self.initialized = True
162
163
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
164
        """ Sets the xml attributes of a specified command. """
165
        if self.command_exists(name):
166
            command = self.commands.get(name)
167
            command.attributes = attributes
168
169
    @deprecated(version="20.4", reason="Use set_scanner_param instead")
170
    def add_scanner_param(self, name: str, scanner_params: Dict) -> None:
171
        """ Set a scanner parameter. """
172
        self.set_scanner_param(name, scanner_params)
173
174
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
175
        """ Set a scanner parameter. """
176
177
        assert name
178
        assert scanner_params
179
180
        self.scanner_params[name] = scanner_params
181
182
    def get_scanner_params(self) -> Dict:
183
        return self.scanner_params
184
185
    def add_vt(
186
        self,
187
        vt_id: str,
188
        name: str = None,
189
        vt_params: str = None,
190
        vt_refs: str = None,
191
        custom: str = None,
192
        vt_creation_time: str = None,
193
        vt_modification_time: str = None,
194
        vt_dependencies: str = None,
195
        summary: str = None,
196
        impact: str = None,
197
        affected: str = None,
198
        insight: str = None,
199
        solution: str = None,
200
        solution_t: str = None,
201
        solution_m: str = None,
202
        detection: str = None,
203
        qod_t: str = None,
204
        qod_v: str = None,
205
        severities: str = None,
206
    ) -> None:
207
        """ Add a vulnerability test information.
208
209
        IMPORTANT: The VT's Data Manager will store the vts collection.
210
        If the collection is considerably big and it will be consultated
211
        intensible during a routine, consider to do a deepcopy(), since
212
        accessing the shared memory in the data manager is very expensive.
213
        At the end of the routine, the temporal copy must be set to None
214
        and deleted.
215
        """
216
        self.vts.add(
217
            vt_id,
218
            name=name,
219
            vt_params=vt_params,
220
            vt_refs=vt_refs,
221
            custom=custom,
222
            vt_creation_time=vt_creation_time,
223
            vt_modification_time=vt_modification_time,
224
            vt_dependencies=vt_dependencies,
225
            summary=summary,
226
            impact=impact,
227
            affected=affected,
228
            insight=insight,
229
            solution=solution,
230
            solution_t=solution_t,
231
            solution_m=solution_m,
232
            detection=detection,
233
            qod_t=qod_t,
234
            qod_v=qod_v,
235
            severities=severities,
236
        )
237
238
    def set_vts_version(self, vts_version: str) -> None:
239
        """ Add into the vts dictionary an entry to identify the
240
        vts version.
241
242
        Parameters:
243
            vts_version (str): Identifies a unique vts version.
244
        """
245
        if not vts_version:
246
            raise OspdCommandError(
247
                'A vts_version parameter is required', 'set_vts_version'
248
            )
249
        self.vts_version = vts_version
250
251
    def get_vts_version(self) -> Optional[str]:
252
        """Return the vts version.
253
        """
254
        return self.vts_version
255
256
    def command_exists(self, name: str) -> bool:
257
        """ Checks if a commands exists. """
258
        return name in self.commands
259
260
    def get_scanner_name(self) -> str:
261
        """ Gives the wrapped scanner's name. """
262
        return self.scanner_info['name']
263
264
    def get_scanner_version(self) -> str:
265
        """ Gives the wrapped scanner's version. """
266
        return self.scanner_info['version']
267
268
    def get_scanner_description(self) -> str:
269
        """ Gives the wrapped scanner's description. """
270
        return self.scanner_info['description']
271
272
    def get_server_version(self) -> str:
273
        """ Gives the specific OSP server's version. """
274
        assert self.server_version
275
        return self.server_version
276
277
    def get_protocol_version(self) -> str:
278
        """ Gives the OSP's version. """
279
        return self.protocol_version
280
281
    def preprocess_scan_params(self, xml_params):
282
        """ Processes the scan parameters. """
283
        params = {}
284
285
        for param in xml_params:
286
            params[param.tag] = param.text or ''
287
288
        # Set default values.
289
        for key in self.scanner_params:
290
            if key not in params:
291
                params[key] = self.get_scanner_param_default(key)
292
                if self.get_scanner_param_type(key) == 'selection':
293
                    params[key] = params[key].split('|')[0]
294
295
        # Validate values.
296
        for key in params:
297
            param_type = self.get_scanner_param_type(key)
298
            if not param_type:
299
                continue
300
301
            if param_type in ['integer', 'boolean']:
302
                try:
303
                    params[key] = int(params[key])
304
                except ValueError:
305
                    raise OspdCommandError(
306
                        'Invalid %s value' % key, 'start_scan'
307
                    )
308
309
            if param_type == 'boolean':
310
                if params[key] not in [0, 1]:
311
                    raise OspdCommandError(
312
                        'Invalid %s value' % key, 'start_scan'
313
                    )
314
            elif param_type == 'selection':
315
                selection = self.get_scanner_param_default(key).split('|')
316
                if params[key] not in selection:
317
                    raise OspdCommandError(
318
                        'Invalid %s value' % key, 'start_scan'
319
                    )
320
            if self.get_scanner_param_mandatory(key) and params[key] == '':
321
                raise OspdCommandError(
322
                    'Mandatory %s value is missing' % key, 'start_scan'
323
                )
324
325
        return params
326
327
    def process_scan_params(self, params: Dict) -> Dict:
328
        """ This method is to be overridden by the child classes if necessary
329
        """
330
        return params
331
332
    @staticmethod
333
    @deprecated(
334
        version="20.4",
335
        reason="Please use OspRequest.process_vt_params instead.",
336
    )
337
    def process_vts_params(scanner_vts) -> Dict:
338
        return OspRequest.process_vts_params(scanner_vts)
339
340
    @staticmethod
341
    @deprecated(
342
        version="20.4",
343
        reason="Please use OspRequest.process_credential_elements instead.",
344
    )
345
    def process_credentials_elements(cred_tree) -> Dict:
346
        return OspRequest.process_credentials_elements(cred_tree)
347
348
    @staticmethod
349
    @deprecated(
350
        version="20.4",
351
        reason="Please use OspRequest.process_targets_elements instead.",
352
    )
353
    def process_targets_element(scanner_target) -> List:
354
        return OspRequest.process_target_element(scanner_target)
355
356
    def stop_scan(self, scan_id: str) -> None:
357
        scan_process = self.scan_processes.get(scan_id)
358
        if not scan_process:
359
            raise OspdCommandError(
360
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
361
            )
362
        if not scan_process.is_alive():
363
            raise OspdCommandError(
364
                'Scan already stopped or finished.', 'stop_scan'
365
            )
366
367
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
368
369
        logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
370
371
        self.stop_scan_cleanup(scan_id)
372
373
        try:
374
            scan_process.terminate()
375
        except AttributeError:
376
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
377
378
        try:
379
            _terminate_process_group(scan_process)
380
        except ProcessLookupError as e:
381
            logger.info(
382
                '%s: Scan already stopped %s.', scan_id, scan_process.pid
383
            )
384
385
        if scan_process.ident != os.getpid():
386
            scan_process.join(0)
387
388
        logger.info('%s: Scan stopped.', scan_id)
389
390
    @staticmethod
391
    def stop_scan_cleanup(scan_id: str):
392
        """ Should be implemented by subclass in case of a clean up before
393
        terminating is needed. """
394
395
    @staticmethod
396
    def target_is_finished(scan_id: str):
397
        """ Should be implemented by subclass in case of a check before
398
        stopping is needed. """
399
400
    def exec_scan(self, scan_id: str):
401
        """ Asserts to False. Should be implemented by subclass. """
402
        raise NotImplementedError
403
404
    def finish_scan(self, scan_id: str) -> None:
405
        """ Sets a scan as finished. """
406
        self.set_scan_progress(scan_id, 100)
407
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
408
        logger.info("%s: Scan finished.", scan_id)
409
410
    def get_daemon_name(self) -> str:
411
        """ Gives osp daemon's name. """
412
        return self.daemon_info['name']
413
414
    def get_daemon_version(self) -> str:
415
        """ Gives osp daemon's version. """
416
        return self.daemon_info['version']
417
418
    def get_scanner_param_type(self, param: str):
419
        """ Returns type of a scanner parameter. """
420
        assert isinstance(param, str)
421
        entry = self.scanner_params.get(param)
422
        if not entry:
423
            return None
424
        return entry.get('type')
425
426
    def get_scanner_param_mandatory(self, param: str):
427
        """ Returns if a scanner parameter is mandatory. """
428
        assert isinstance(param, str)
429
        entry = self.scanner_params.get(param)
430
        if not entry:
431
            return False
432
        return entry.get('mandatory')
433
434
    def get_scanner_param_default(self, param: str):
435
        """ Returns default value of a scanner parameter. """
436
        assert isinstance(param, str)
437
        entry = self.scanner_params.get(param)
438
        if not entry:
439
            return None
440
        return entry.get('default')
441
442
    @deprecated(
443
        version="20.4",
444
        reason="Please use OspResponse.create_scanner_params_xml instead.",
445
    )
446
    def get_scanner_params_xml(self):
447
        """ Returns the OSP Daemon's scanner params in xml format. """
448
        return OspResponse.create_scanner_params_xml(self.scanner_params)
449
450
    def handle_client_stream(self, stream: Stream) -> None:
451
        """ Handles stream of data received from client. """
452
        data = b''
453
454
        request_parser = RequestParser()
455
456
        while True:
457
            try:
458
                buf = stream.read()
459
                if not buf:
460
                    break
461
462
                data += buf
463
464
                if request_parser.has_ended(buf):
465
                    break
466
            except (AttributeError, ValueError) as message:
467
                logger.error(message)
468
                return
469
            except (ssl.SSLError) as exception:
470
                logger.debug('Error: %s', exception)
471
                break
472
            except (socket.timeout) as exception:
473
                logger.debug('Request timeout: %s', exception)
474
                break
475
476
        if len(data) <= 0:
477
            logger.debug("Empty client stream")
478
            return
479
480
        response = None
481
        try:
482
            self.handle_command(data, stream)
483
        except OspdCommandError as exception:
484
            response = exception.as_xml()
485
            logger.debug('Command error: %s', exception.message)
486
        except Exception:  # pylint: disable=broad-except
487
            logger.exception('While handling client command:')
488
            exception = OspdCommandError('Fatal error', 'error')
489
            response = exception.as_xml()
490
491
        if response:
492
            stream.write(response)
493
494
        stream.close()
495
496
    def calculate_progress(self, scan_id: str) -> float:
497
        """ Calculate the total scan progress. """
498
499
        return self.scan_collection.calculate_target_progress(scan_id)
500
501
    def process_exclude_hosts(self, scan_id: str, exclude_hosts: str) -> None:
502
        """ Process the exclude hosts before launching the scans."""
503
504
        exc_hosts_list = ''
505
        if not exclude_hosts:
506
            return
507
        exc_hosts_list = target_str_to_list(exclude_hosts)
508
        self.remove_scan_hosts_from_target_progress(scan_id, exc_hosts_list)
509
510
    def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None:
511
        """ Process the finished hosts before launching the scans.
512
        Set finished hosts as finished with 100% to calculate
513
        the scan progress."""
514
515
        exc_hosts_list = ''
516
        if not finished_hosts:
517
            return
518
519
        exc_hosts_list = target_str_to_list(finished_hosts)
520
521
        for host in exc_hosts_list:
522
            self.set_scan_host_finished(scan_id, host)
523
            self.set_scan_host_progress(scan_id, host, 100)
524
525
    def start_scan(self, scan_id: str, target: Dict) -> None:
526
        """ Starts the scan with scan_id. """
527
        os.setsid()
528
529
        if target is None or not target:
530
            raise OspdCommandError('Erroneous target', 'start_scan')
531
532
        logger.info("%s: Scan started.", scan_id)
533
534
        self.process_exclude_hosts(scan_id, target.get('exclude_hosts'))
535
        self.process_finished_hosts(scan_id, target.get('finished_hosts'))
536
537
        try:
538
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
539
            ret = self.exec_scan(scan_id)
540
        except Exception as e:  # pylint: disable=broad-except
541
            self.add_scan_error(
542
                scan_id,
543
                name='',
544
                host=self.get_scan_host(scan_id),
545
                value='Host process failure (%s).' % e,
546
            )
547
            logger.exception('While scanning: %s', scan_id)
548
        else:
549
            logger.info("%s: Host scan finished.", scan_id)
550
551
        if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
552
            self.finish_scan(scan_id)
553
554
    def dry_run_scan(self, scan_id: str, target: Dict) -> None:
555
        """ Dry runs a scan. """
556
557
        os.setsid()
558
559
        host = resolve_hostname(target[0])
560
        if host is None:
561
            logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))
562
563
        port = self.get_scan_ports(scan_id)
564
565
        logger.info("%s:%s: Dry run mode.", host, port)
566
567
        self.add_scan_log(scan_id, name='', host=host, value='Dry run result')
568
569
        self.finish_scan(scan_id)
570
571
    def handle_timeout(self, scan_id: str, host: str) -> None:
572
        """ Handles scanner reaching timeout error. """
573
        self.add_scan_error(
574
            scan_id,
575
            host=host,
576
            name="Timeout",
577
            value="{0} exec timeout.".format(self.get_scanner_name()),
578
        )
579
580
    def remove_scan_hosts_from_target_progress(
581
        self, scan_id: str, exc_hosts_list: List
582
    ) -> None:
583
        """ Remove a list of hosts from the main scan progress table."""
584
        self.scan_collection.remove_hosts_from_target_progress(
585
            scan_id, exc_hosts_list
586
        )
587
588
    def set_scan_host_finished(self, scan_id: str, host: str) -> None:
589
        """ Add the host in a list of finished hosts """
590
        self.scan_collection.set_host_finished(scan_id, host)
591
592
    def set_scan_progress(self, scan_id: str, progress: int) -> None:
593
        """ Sets scan_id scan's progress which is a number
594
        between 0 and 100. """
595
        self.scan_collection.set_progress(scan_id, progress)
596
597
    def set_scan_host_progress(
598
        self, scan_id: str, host: str, progress: int
599
    ) -> None:
600
        """ Sets host's progress which is part of target.
601
        Each time a host progress is updated, the scan progress
602
        is updated too.
603
        """
604
        self.scan_collection.set_host_progress(scan_id, host, progress)
605
606
        scan_progress = self.calculate_progress(scan_id)
607
        self.set_scan_progress(scan_id, scan_progress)
608
609
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
610
        """ Set the scan's status."""
611
        self.scan_collection.set_status(scan_id, status)
612
613
    def get_scan_status(self, scan_id: str) -> ScanStatus:
614
        """ Get scan_id scans's status."""
615
        return self.scan_collection.get_status(scan_id)
616
617
    def scan_exists(self, scan_id: str) -> bool:
618
        """ Checks if a scan with ID scan_id is in collection.
619
620
        @return: 1 if scan exists, 0 otherwise.
621
        """
622
        return self.scan_collection.id_exists(scan_id)
623
624
    def get_help_text(self) -> str:
625
        """ Returns the help output in plain text format."""
626
627
        txt = ''
628
        for name, info in self.commands.items():
629
            description = info.get_description()
630
            attributes = info.get_attributes()
631
            elements = info.get_elements()
632
633
            command_txt = "\t{0: <22} {1}\n".format(name, description)
634
635
            if attributes:
636
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
637
638
                for attrname, attrdesc in attributes.items():
639
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
640
                    command_txt = ''.join([command_txt, attr_txt])
641
642
            if elements:
643
                command_txt = ''.join(
644
                    [command_txt, "\t Elements:\n", elements_as_text(elements),]
645
                )
646
647
            txt += command_txt
648
649
        return txt
650
651
    @deprecated(version="20.4", reason="Use ospd.xml.elements_as_text instead.")
652
    def elements_as_text(self, elems: Dict, indent: int = 2) -> str:
653
        """ Returns the elems dictionary as formatted plain text. """
654
        return elements_as_text(elems, indent)
655
656
    def delete_scan(self, scan_id: str) -> int:
657
        """ Deletes scan_id scan from collection.
658
659
        @return: 1 if scan deleted, 0 otherwise.
660
        """
661
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
662
            return 0
663
664
        # Don't delete the scan until the process stops
665
        exitcode = None
666
        try:
667
            self.scan_processes[scan_id].join()
668
            exitcode = self.scan_processes[scan_id].exitcode
669
        except KeyError:
670
            logger.debug('Scan process for %s not found', scan_id)
671
672
        if exitcode or exitcode == 0:
673
            del self.scan_processes[scan_id]
674
675
        return self.scan_collection.delete_scan(scan_id)
676
677
    def get_scan_results_xml(
678
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
679
    ):
680
        """ Gets scan_id scan's results in XML format.
681
682
        @return: String of scan results in xml.
683
        """
684
        results = Element('results')
685
        for result in self.scan_collection.results_iterator(
686
            scan_id, pop_res, max_res
687
        ):
688
            results.append(get_result_xml(result))
689
690
        logger.debug('Returning %d results', len(results))
691
        return results
692
693
    @deprecated(
694
        version="20.4",
695
        reason="Please use ospd.xml.get_elements_from_dict instead.",
696
    )
697
    def get_xml_str(self, data: Dict) -> List:
698
        """ Creates a string in XML Format using the provided data structure.
699
700
        @param: Dictionary of xml tags and their elements.
701
702
        @return: String of data in xml format.
703
        """
704
        return get_elements_from_dict(data)
705
706
    def get_scan_xml(
707
        self,
708
        scan_id: str,
709
        detailed: bool = True,
710
        pop_res: bool = False,
711
        max_res: int = 0,
712
    ):
713
        """ Gets scan in XML format.
714
715
        @return: String of scan in XML format.
716
        """
717
        if not scan_id:
718
            return Element('scan')
719
720
        target = self.get_scan_host(scan_id)
721
        progress = self.get_scan_progress(scan_id)
722
        status = self.get_scan_status(scan_id)
723
        start_time = self.get_scan_start_time(scan_id)
724
        end_time = self.get_scan_end_time(scan_id)
725
        response = Element('scan')
726
        for name, value in [
727
            ('id', scan_id),
728
            ('target', target),
729
            ('progress', progress),
730
            ('status', status.name.lower()),
731
            ('start_time', start_time),
732
            ('end_time', end_time),
733
        ]:
734
            response.set(name, str(value))
735
        if detailed:
736
            response.append(
737
                self.get_scan_results_xml(scan_id, pop_res, max_res)
738
            )
739
        return response
740
741
    @staticmethod
742
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
743
        vt_id: str, custom: Dict
744
    ) -> str:
745
        """ Create a string representation of the XML object from the
746
        custom data object.
747
        This needs to be implemented by each ospd wrapper, in case
748
        custom elements for VTs are used.
749
750
        The custom XML object which is returned will be embedded
751
        into a <custom></custom> element.
752
753
        @return: XML object as string for custom data.
754
        """
755
        return ''
756
757
    @staticmethod
758
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
759
        vt_id: str, vt_params
760
    ) -> str:
761
        """ Create a string representation of the XML object from the
762
        vt_params data object.
763
        This needs to be implemented by each ospd wrapper, in case
764
        vt_params elements for VTs are used.
765
766
        The params XML object which is returned will be embedded
767
        into a <params></params> element.
768
769
        @return: XML object as string for vt parameters data.
770
        """
771
        return ''
772
773
    @staticmethod
774
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
775
        vt_id: str, vt_refs
776
    ) -> str:
777
        """ Create a string representation of the XML object from the
778
        refs data object.
779
        This needs to be implemented by each ospd wrapper, in case
780
        refs elements for VTs are used.
781
782
        The refs XML object which is returned will be embedded
783
        into a <refs></refs> element.
784
785
        @return: XML object as string for vt references data.
786
        """
787
        return ''
788
789
    @staticmethod
790
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
791
        vt_id: str, vt_dependencies
792
    ) -> str:
793
        """ Create a string representation of the XML object from the
794
        vt_dependencies data object.
795
        This needs to be implemented by each ospd wrapper, in case
796
        vt_dependencies elements for VTs are used.
797
798
        The vt_dependencies XML object which is returned will be embedded
799
        into a <dependencies></dependencies> element.
800
801
        @return: XML object as string for vt dependencies data.
802
        """
803
        return ''
804
805
    @staticmethod
806
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
807
        vt_id: str, vt_creation_time
808
    ) -> str:
809
        """ Create a string representation of the XML object from the
810
        vt_creation_time data object.
811
        This needs to be implemented by each ospd wrapper, in case
812
        vt_creation_time elements for VTs are used.
813
814
        The vt_creation_time XML object which is returned will be embedded
815
        into a <vt_creation_time></vt_creation_time> element.
816
817
        @return: XML object as string for vt creation time data.
818
        """
819
        return ''
820
821
    @staticmethod
822
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
823
        vt_id: str, vt_modification_time
824
    ) -> str:
825
        """ Create a string representation of the XML object from the
826
        vt_modification_time data object.
827
        This needs to be implemented by each ospd wrapper, in case
828
        vt_modification_time elements for VTs are used.
829
830
        The vt_modification_time XML object which is returned will be embedded
831
        into a <vt_modification_time></vt_modification_time> element.
832
833
        @return: XML object as string for vt references data.
834
        """
835
        return ''
836
837
    @staticmethod
838
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
839
        vt_id: str, summary
840
    ) -> str:
841
        """ Create a string representation of the XML object from the
842
        summary data object.
843
        This needs to be implemented by each ospd wrapper, in case
844
        summary elements for VTs are used.
845
846
        The summary XML object which is returned will be embedded
847
        into a <summary></summary> element.
848
849
        @return: XML object as string for summary data.
850
        """
851
        return ''
852
853
    @staticmethod
854
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
855
        vt_id: str, impact
856
    ) -> str:
857
        """ Create a string representation of the XML object from the
858
        impact data object.
859
        This needs to be implemented by each ospd wrapper, in case
860
        impact elements for VTs are used.
861
862
        The impact XML object which is returned will be embedded
863
        into a <impact></impact> element.
864
865
        @return: XML object as string for impact data.
866
        """
867
        return ''
868
869
    @staticmethod
870
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
871
        vt_id: str, affected
872
    ) -> str:
873
        """ Create a string representation of the XML object from the
874
        affected data object.
875
        This needs to be implemented by each ospd wrapper, in case
876
        affected elements for VTs are used.
877
878
        The affected XML object which is returned will be embedded
879
        into a <affected></affected> element.
880
881
        @return: XML object as string for affected data.
882
        """
883
        return ''
884
885
    @staticmethod
886
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
887
        vt_id: str, insight
888
    ) -> str:
889
        """ Create a string representation of the XML object from the
890
        insight data object.
891
        This needs to be implemented by each ospd wrapper, in case
892
        insight elements for VTs are used.
893
894
        The insight XML object which is returned will be embedded
895
        into a <insight></insight> element.
896
897
        @return: XML object as string for insight data.
898
        """
899
        return ''
900
901
    @staticmethod
902
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
903
        vt_id: str, solution, solution_type=None, solution_method=None
904
    ) -> str:
905
        """ Create a string representation of the XML object from the
906
        solution data object.
907
        This needs to be implemented by each ospd wrapper, in case
908
        solution elements for VTs are used.
909
910
        The solution XML object which is returned will be embedded
911
        into a <solution></solution> element.
912
913
        @return: XML object as string for solution data.
914
        """
915
        return ''
916
917
    @staticmethod
918
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
919
        vt_id: str, detection=None, qod_type=None, qod=None
920
    ) -> str:
921
        """ Create a string representation of the XML object from the
922
        detection data object.
923
        This needs to be implemented by each ospd wrapper, in case
924
        detection elements for VTs are used.
925
926
        The detection XML object which is returned is an element with
927
        tag <detection></detection> element
928
929
        @return: XML object as string for detection data.
930
        """
931
        return ''
932
933
    @staticmethod
934
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
935
        vt_id: str, severities
936
    ) -> str:
937
        """ Create a string representation of the XML object from the
938
        severities data object.
939
        This needs to be implemented by each ospd wrapper, in case
940
        severities elements for VTs are used.
941
942
        The severities XML objects which are returned will be embedded
943
        into a <severities></severities> element.
944
945
        @return: XML object as string for severities data.
946
        """
947
        return ''
948
949
    def get_vt_iterator(
950
        self, vt_selection: List[str] = None, details: bool = True
951
    ) -> Iterator[Tuple[str, Dict]]:
952
        """ Return iterator object for getting elements
953
        from the VTs dictionary. """
954
        return self.vts.items()
955
956
    def get_vt_xml(self, single_vt: Tuple[str, Dict]) -> Element:
957
        """ Gets a single vulnerability test information in XML format.
958
959
        @return: String of single vulnerability test information in XML format.
960
        """
961
        if not single_vt:
962
            return Element('vt')
963
964
        vt_id, vt = single_vt
965
966
        name = vt.get('name')
967
        vt_xml = Element('vt')
968
        vt_xml.set('id', vt_id)
969
970
        for name, value in [('name', name)]:
971
            elem = SubElement(vt_xml, name)
972
            elem.text = str(value)
973
974
        if vt.get('vt_params'):
975
            params_xml_str = self.get_params_vt_as_xml_str(
976
                vt_id, vt.get('vt_params')
977
            )
978
            vt_xml.append(secET.fromstring(params_xml_str))
979
980
        if vt.get('vt_refs'):
981
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
982
            vt_xml.append(secET.fromstring(refs_xml_str))
983
984
        if vt.get('vt_dependencies'):
985
            dependencies = self.get_dependencies_vt_as_xml_str(
986
                vt_id, vt.get('vt_dependencies')
987
            )
988
            vt_xml.append(secET.fromstring(dependencies))
989
990
        if vt.get('creation_time'):
991
            vt_ctime = self.get_creation_time_vt_as_xml_str(
992
                vt_id, vt.get('creation_time')
993
            )
994
            vt_xml.append(secET.fromstring(vt_ctime))
995
996
        if vt.get('modification_time'):
997
            vt_mtime = self.get_modification_time_vt_as_xml_str(
998
                vt_id, vt.get('modification_time')
999
            )
1000
            vt_xml.append(secET.fromstring(vt_mtime))
1001
1002
        if vt.get('summary'):
1003
            summary_xml_str = self.get_summary_vt_as_xml_str(
1004
                vt_id, vt.get('summary')
1005
            )
1006
            vt_xml.append(secET.fromstring(summary_xml_str))
1007
1008
        if vt.get('impact'):
1009
            impact_xml_str = self.get_impact_vt_as_xml_str(
1010
                vt_id, vt.get('impact')
1011
            )
1012
            vt_xml.append(secET.fromstring(impact_xml_str))
1013
1014
        if vt.get('affected'):
1015
            affected_xml_str = self.get_affected_vt_as_xml_str(
1016
                vt_id, vt.get('affected')
1017
            )
1018
            vt_xml.append(secET.fromstring(affected_xml_str))
1019
1020
        if vt.get('insight'):
1021
            insight_xml_str = self.get_insight_vt_as_xml_str(
1022
                vt_id, vt.get('insight')
1023
            )
1024
            vt_xml.append(secET.fromstring(insight_xml_str))
1025
1026
        if vt.get('solution'):
1027
            solution_xml_str = self.get_solution_vt_as_xml_str(
1028
                vt_id,
1029
                vt.get('solution'),
1030
                vt.get('solution_type'),
1031
                vt.get('solution_method'),
1032
            )
1033
            vt_xml.append(secET.fromstring(solution_xml_str))
1034
1035
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1036
            detection_xml_str = self.get_detection_vt_as_xml_str(
1037
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1038
            )
1039
            vt_xml.append(secET.fromstring(detection_xml_str))
1040
1041
        if vt.get('severities'):
1042
            severities_xml_str = self.get_severities_vt_as_xml_str(
1043
                vt_id, vt.get('severities')
1044
            )
1045
            vt_xml.append(secET.fromstring(severities_xml_str))
1046
1047
        if vt.get('custom'):
1048
            custom_xml_str = self.get_custom_vt_as_xml_str(
1049
                vt_id, vt.get('custom')
1050
            )
1051
            vt_xml.append(secET.fromstring(custom_xml_str))
1052
1053
        return vt_xml
1054
1055
    def get_vts_selection_list(
1056
        self, vt_id: str = None, filtered_vts: Dict = None
1057
    ) -> Iterable[str]:
1058
        """
1059
        Get list of VT's OID.
1060
        If vt_id is specified, the collection will contain only this vt, if
1061
        found.
1062
        If no vt_id is specified or filtered_vts is None (default), the
1063
        collection will contain all vts. Otherwise those vts passed
1064
        in filtered_vts or vt_id are returned. In case of both vt_id and
1065
        filtered_vts are given, filtered_vts has priority.
1066
1067
        Arguments:
1068
            vt_id (vt_id, optional): ID of the vt to get.
1069
            filtered_vts (list, optional): Filtered VTs collection.
1070
1071
        Return:
1072
            List of selected VT's OID.
1073
        """
1074
        vts_xml = []
1075
        if not self.vts:
1076
            return vts_xml
1077
1078
        # No match for the filter
1079
        if filtered_vts is not None and len(filtered_vts) == 0:
1080
            return vts_xml
1081
1082
        if filtered_vts:
1083
            vts_list = filtered_vts
1084
        elif vt_id:
1085
            vts_list = [vt_id]
1086
        else:
1087
            vts_list = self.vts.keys()
1088
1089
        return vts_list
1090
1091
    def handle_command(self, data: bytes, stream: Stream) -> None:
1092
        """ Handles an osp command in a string.
1093
        """
1094
        try:
1095
            tree = secET.fromstring(data)
1096
        except secET.ParseError:
1097
            logger.debug("Erroneous client input: %s", data)
1098
            raise OspdCommandError('Invalid data')
1099
1100
        command_name = tree.tag
1101
1102
        logger.debug('Handling %s command request.', command_name)
1103
1104
        command = self.commands.get(command_name, None)
1105
        if not command and command_name != "authenticate":
1106
            raise OspdCommandError('Bogus command name')
1107
1108
        if not self.initialized and command.must_be_initialized:
1109
            exception = OspdCommandError(
1110
                '%s is still starting' % self.daemon_info['name'], 'error'
1111
            )
1112
            response = exception.as_xml()
1113
            stream.write(response)
1114
            return
1115
1116
        response = command.handle_xml(tree)
1117
1118
        if isinstance(response, bytes):
1119
            stream.write(response)
1120
        else:
1121
            for data in response:
1122
                stream.write(data)
1123
1124
    def check(self):
1125
        """ Asserts to False. Should be implemented by subclass. """
1126
        raise NotImplementedError
1127
1128
    def run(self) -> None:
1129
        """ Starts the Daemon, handling commands until interrupted.
1130
        """
1131
1132
        try:
1133
            while True:
1134
                time.sleep(SCHEDULER_CHECK_PERIOD)
1135
                self.scheduler()
1136
                self.clean_forgotten_scans()
1137
                self.wait_for_children()
1138
        except KeyboardInterrupt:
1139
            logger.info("Received Ctrl-C shutting-down ...")
1140
1141
    def scheduler(self):
1142
        """ Should be implemented by subclass in case of need
1143
        to run tasks periodically. """
1144
1145
    def wait_for_children(self):
1146
        """ Join the zombie process to releases resources."""
1147
        for scan_id in self.scan_processes:
1148
            self.scan_processes[scan_id].join(0)
1149
1150
    def create_scan(
1151
        self,
1152
        scan_id: str,
1153
        targets: Dict,
1154
        options: Optional[Dict],
1155
        vt_selection: Dict,
1156
    ) -> Optional[str]:
1157
        """ Creates a new scan.
1158
1159
        @target: Target to scan.
1160
        @options: Miscellaneous scan options.
1161
1162
        @return: New scan's ID. None if the scan_id already exists and the
1163
                 scan status is RUNNING or FINISHED.
1164
        """
1165
        status = None
1166
        scan_exists = self.scan_exists(scan_id)
1167
        if scan_id and scan_exists:
1168
            status = self.get_scan_status(scan_id)
1169
1170
        if scan_exists and status == ScanStatus.STOPPED:
1171
            logger.info("Scan %s exists. Resuming scan.", scan_id)
1172
        elif scan_exists and (
1173
            status == ScanStatus.RUNNING or status == ScanStatus.FINISHED
1174
        ):
1175
            logger.info(
1176
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1177
            )
1178
            return
1179
        return self.scan_collection.create_scan(
1180
            scan_id, targets, options, vt_selection
1181
        )
1182
1183
    def get_scan_options(self, scan_id: str) -> str:
1184
        """ Gives a scan's list of options. """
1185
        return self.scan_collection.get_options(scan_id)
1186
1187
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1188
        """ Sets a scan's option to a provided value. """
1189
        return self.scan_collection.set_option(scan_id, name, value)
1190
1191
    def clean_forgotten_scans(self) -> None:
1192
        """ Check for old stopped or finished scans which have not been
1193
        deleted and delete them if the are older than the set value."""
1194
1195
        if not self.scaninfo_store_time:
1196
            return
1197
1198
        for scan_id in list(self.scan_collection.ids_iterator()):
1199
            end_time = int(self.get_scan_end_time(scan_id))
1200
            scan_status = self.get_scan_status(scan_id)
1201
1202
            if (
1203
                scan_status == ScanStatus.STOPPED
1204
                or scan_status == ScanStatus.FINISHED
1205
            ) and end_time:
1206
                stored_time = int(time.time()) - end_time
1207
                if stored_time > self.scaninfo_store_time * 3600:
1208
                    logger.debug(
1209
                        'Scan %s is older than %d hours and seems have been '
1210
                        'forgotten. Scan info will be deleted from the '
1211
                        'scan table',
1212
                        scan_id,
1213
                        self.scaninfo_store_time,
1214
                    )
1215
                    self.delete_scan(scan_id)
1216
1217
    def check_scan_process(self, scan_id: str) -> None:
1218
        """ Check the scan's process, and terminate the scan if not alive. """
1219
        scan_process = self.scan_processes[scan_id]
1220
        progress = self.get_scan_progress(scan_id)
1221
1222
        if progress < 100 and not scan_process.is_alive():
1223
            if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
1224
                self.set_scan_status(scan_id, ScanStatus.STOPPED)
1225
                self.add_scan_error(
1226
                    scan_id, name="", host="", value="Scan process failure."
1227
                )
1228
1229
                logger.info("%s: Scan stopped with errors.", scan_id)
1230
1231
        elif progress == 100:
1232
            scan_process.join(0)
1233
1234
    def get_scan_progress(self, scan_id: str):
1235
        """ Gives a scan's current progress value. """
1236
        return self.scan_collection.get_progress(scan_id)
1237
1238
    def get_scan_host(self, scan_id: str) -> str:
1239
        """ Gives a scan's target. """
1240
        return self.scan_collection.get_host_list(scan_id)
1241
1242
    def get_scan_ports(self, scan_id: str) -> str:
1243
        """ Gives a scan's ports list. """
1244
        return self.scan_collection.get_ports(scan_id)
1245
1246
    def get_scan_exclude_hosts(self, scan_id: str):
1247
        """ Gives a scan's exclude host list. If a target is passed gives
1248
        the exclude host list for the given target. """
1249
        return self.scan_collection.get_exclude_hosts(scan_id)
1250
1251
    def get_scan_credentials(self, scan_id: str) -> Dict:
1252
        """ Gives a scan's credential list. If a target is passed gives
1253
        the credential list for the given target. """
1254
        return self.scan_collection.get_credentials(scan_id)
1255
1256
    def get_scan_target_options(self, scan_id: str) -> Dict:
1257
        """ Gives a scan's target option dict. If a target is passed gives
1258
        the credential list for the given target. """
1259
        return self.scan_collection.get_target_options(scan_id)
1260
1261
    def get_scan_vts(self, scan_id: str) -> Dict:
1262
        """ Gives a scan's vts. """
1263
        return self.scan_collection.get_vts(scan_id)
1264
1265
    def get_scan_unfinished_hosts(self, scan_id: str) -> List:
1266
        """ Get a list of unfinished hosts."""
1267
        return self.scan_collection.get_hosts_unfinished(scan_id)
1268
1269
    def get_scan_finished_hosts(self, scan_id: str) -> List:
1270
        """ Get a list of unfinished hosts."""
1271
        return self.scan_collection.get_hosts_finished(scan_id)
1272
1273
    def get_scan_start_time(self, scan_id: str) -> str:
1274
        """ Gives a scan's start time. """
1275
        return self.scan_collection.get_start_time(scan_id)
1276
1277
    def get_scan_end_time(self, scan_id: str) -> str:
1278
        """ Gives a scan's end time. """
1279
        return self.scan_collection.get_end_time(scan_id)
1280
1281
    def add_scan_log(
1282
        self,
1283
        scan_id: str,
1284
        host: str = '',
1285
        hostname: str = '',
1286
        name: str = '',
1287
        value: str = '',
1288
        port: str = '',
1289
        test_id: str = '',
1290
        qod: str = '',
1291
    ):
1292
        """ Adds a log result to scan_id scan. """
1293
        self.scan_collection.add_result(
1294
            scan_id,
1295
            ResultType.LOG,
1296
            host,
1297
            hostname,
1298
            name,
1299
            value,
1300
            port,
1301
            test_id,
1302
            '0.0',
1303
            qod,
1304
        )
1305
1306
    def add_scan_error(
1307
        self,
1308
        scan_id: str,
1309
        host: str = '',
1310
        hostname: str = '',
1311
        name: str = '',
1312
        value: str = '',
1313
        port: str = '',
1314
        test_id='',
1315
    ) -> None:
1316
        """ Adds an error result to scan_id scan. """
1317
        self.scan_collection.add_result(
1318
            scan_id,
1319
            ResultType.ERROR,
1320
            host,
1321
            hostname,
1322
            name,
1323
            value,
1324
            port,
1325
            test_id,
1326
        )
1327
1328
    def add_scan_host_detail(
1329
        self,
1330
        scan_id: str,
1331
        host: str = '',
1332
        hostname: str = '',
1333
        name: str = '',
1334
        value: str = '',
1335
    ) -> None:
1336
        """ Adds a host detail result to scan_id scan. """
1337
        self.scan_collection.add_result(
1338
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value
1339
        )
1340
1341
    def add_scan_alarm(
1342
        self,
1343
        scan_id: str,
1344
        host: str = '',
1345
        hostname: str = '',
1346
        name: str = '',
1347
        value: str = '',
1348
        port: str = '',
1349
        test_id: str = '',
1350
        severity: str = '',
1351
        qod: str = '',
1352
    ):
1353
        """ Adds an alarm result to scan_id scan. """
1354
        self.scan_collection.add_result(
1355
            scan_id,
1356
            ResultType.ALARM,
1357
            host,
1358
            hostname,
1359
            name,
1360
            value,
1361
            port,
1362
            test_id,
1363
            severity,
1364
            qod,
1365
        )
1366