Passed
Pull Request — master (#266)
by Juan José
01:20
created

ospd.ospd.OSPDaemon.get_scan_vts()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" OSP Daemon core class.
21
"""
22
23
import logging
24
import socket
25
import ssl
26
import multiprocessing
27
import time
28
import os
29
30
from typing import (
31
    List,
32
    Any,
33
    Iterator,
34
    Dict,
35
    Optional,
36
    Iterable,
37
    Tuple,
38
    Union,
39
)
40
from xml.etree.ElementTree import Element, SubElement
41
42
import defusedxml.ElementTree as secET
43
44
from deprecated import deprecated
45
46
from ospd import __version__
47
from ospd.command import get_commands
48
from ospd.errors import OspdCommandError
49
from ospd.misc import ResultType
50
from ospd.network import resolve_hostname, target_str_to_list
51
from ospd.protocol import OspRequest, OspResponse, RequestParser
52
from ospd.scan import ScanCollection, ScanStatus
53
from ospd.server import BaseServer, Stream
54
from ospd.vtfilter import VtsFilter
55
from ospd.vts import Vts
56
from ospd.xml import (
57
    elements_as_text,
58
    get_result_xml,
59
    get_progress_xml,
60
    get_elements_from_dict,
61
)
62
63
logger = logging.getLogger(__name__)
64
65
PROTOCOL_VERSION = "1.2"
66
67
SCHEDULER_CHECK_PERIOD = 10  # in seconds
68
69
BASE_SCANNER_PARAMS = {
70
    'debug_mode': {
71
        'type': 'boolean',
72
        'name': 'Debug Mode',
73
        'default': 0,
74
        'mandatory': 0,
75
        'description': 'Whether to get extra scan debug information.',
76
    },
77
    'dry_run': {
78
        'type': 'boolean',
79
        'name': 'Dry Run',
80
        'default': 0,
81
        'mandatory': 0,
82
        'description': 'Whether to dry run scan.',
83
    },
84
}  # type: Dict
85
86
PROGRESS_FINISHED = 100
87
PROGRESS_DEAD_HOST = -1
88
89
90
def _terminate_process_group(process: multiprocessing.Process) -> None:
91
    os.killpg(os.getpgid(process.pid), 15)
92
93
94
class OSPDaemon:
95
96
    """ Daemon class for OSP traffic handling.
97
98
    Every scanner wrapper should subclass it and make necessary additions and
99
    changes.
100
101
    * Add any needed parameters in __init__.
102
    * Implement check() method which verifies scanner availability and other
103
      environment related conditions.
104
    * Implement process_scan_params and exec_scan methods which are
105
      specific to handling the <start_scan> command, executing the wrapped
106
      scanner and storing the results.
107
    * Implement other methods that assert to False such as get_scanner_name,
108
      get_scanner_version.
109
    * Use Call set_command_attributes at init time to add scanner command
110
      specific options eg. the w3af profile for w3af wrapper.
111
    """
112
113
    def __init__(
114
        self,
115
        *,
116
        customvtfilter=None,
117
        storage=None,
118
        max_scans=0,
119
        check_free_memory=False,
120
        **kwargs
121
    ):  # pylint: disable=unused-argument
122
        """ Initializes the daemon's internal data. """
123
        self.scan_collection = ScanCollection()
124
        self.scan_processes = dict()
125
126
        self.daemon_info = dict()
127
        self.daemon_info['name'] = "OSPd"
128
        self.daemon_info['version'] = __version__
129
        self.daemon_info['description'] = "No description"
130
131
        self.scanner_info = dict()
132
        self.scanner_info['name'] = 'No name'
133
        self.scanner_info['version'] = 'No version'
134
        self.scanner_info['description'] = 'No description'
135
136
        self.server_version = None  # Set by the subclass.
137
138
        self.initialized = None  # Set after initialization finished
139
140
        self.max_scans = max_scans
141
        self.check_free_memory = check_free_memory
142
143
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
144
145
        self.protocol_version = PROTOCOL_VERSION
146
147
        self.commands = {}
148
149
        for command_class in get_commands():
150
            command = command_class(self)
151
            self.commands[command.get_name()] = command
152
153
        self.scanner_params = dict()
154
155
        for name, params in BASE_SCANNER_PARAMS.items():
156
            self.set_scanner_param(name, params)
157
158
        self.vts = Vts(storage)
159
        self.vts_version = None
160
161
        if customvtfilter:
162
            self.vts_filter = customvtfilter
163
        else:
164
            self.vts_filter = VtsFilter()
165
166
    def init(self, server: BaseServer) -> None:
167
        """ Should be overridden by a subclass if the initialization is costly.
168
169
            Will be called after check.
170
        """
171
        server.start(self.handle_client_stream)
172
        self.initialized = True
173
174
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
175
        """ Sets the xml attributes of a specified command. """
176
        if self.command_exists(name):
177
            command = self.commands.get(name)
178
            command.attributes = attributes
179
180
    @deprecated(version="20.8", reason="Use set_scanner_param instead")
181
    def add_scanner_param(self, name: str, scanner_params: Dict) -> None:
182
        """ Set a scanner parameter. """
183
        self.set_scanner_param(name, scanner_params)
184
185
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
186
        """ Set a scanner parameter. """
187
188
        assert name
189
        assert scanner_params
190
191
        self.scanner_params[name] = scanner_params
192
193
    def get_scanner_params(self) -> Dict:
194
        return self.scanner_params
195
196
    def add_vt(
197
        self,
198
        vt_id: str,
199
        name: str = None,
200
        vt_params: str = None,
201
        vt_refs: str = None,
202
        custom: str = None,
203
        vt_creation_time: str = None,
204
        vt_modification_time: str = None,
205
        vt_dependencies: str = None,
206
        summary: str = None,
207
        impact: str = None,
208
        affected: str = None,
209
        insight: str = None,
210
        solution: str = None,
211
        solution_t: str = None,
212
        solution_m: str = None,
213
        detection: str = None,
214
        qod_t: str = None,
215
        qod_v: str = None,
216
        severities: str = None,
217
    ) -> None:
218
        """ Add a vulnerability test information.
219
220
        IMPORTANT: The VT's Data Manager will store the vts collection.
221
        If the collection is considerably big and it will be consultated
222
        intensible during a routine, consider to do a deepcopy(), since
223
        accessing the shared memory in the data manager is very expensive.
224
        At the end of the routine, the temporal copy must be set to None
225
        and deleted.
226
        """
227
        self.vts.add(
228
            vt_id,
229
            name=name,
230
            vt_params=vt_params,
231
            vt_refs=vt_refs,
232
            custom=custom,
233
            vt_creation_time=vt_creation_time,
234
            vt_modification_time=vt_modification_time,
235
            vt_dependencies=vt_dependencies,
236
            summary=summary,
237
            impact=impact,
238
            affected=affected,
239
            insight=insight,
240
            solution=solution,
241
            solution_t=solution_t,
242
            solution_m=solution_m,
243
            detection=detection,
244
            qod_t=qod_t,
245
            qod_v=qod_v,
246
            severities=severities,
247
        )
248
249
    def set_vts_version(self, vts_version: str) -> None:
250
        """ Add into the vts dictionary an entry to identify the
251
        vts version.
252
253
        Parameters:
254
            vts_version (str): Identifies a unique vts version.
255
        """
256
        if not vts_version:
257
            raise OspdCommandError(
258
                'A vts_version parameter is required', 'set_vts_version'
259
            )
260
        self.vts_version = vts_version
261
262
    def get_vts_version(self) -> Optional[str]:
263
        """Return the vts version.
264
        """
265
        return self.vts_version
266
267
    def command_exists(self, name: str) -> bool:
268
        """ Checks if a commands exists. """
269
        return name in self.commands
270
271
    def get_scanner_name(self) -> str:
272
        """ Gives the wrapped scanner's name. """
273
        return self.scanner_info['name']
274
275
    def get_scanner_version(self) -> str:
276
        """ Gives the wrapped scanner's version. """
277
        return self.scanner_info['version']
278
279
    def get_scanner_description(self) -> str:
280
        """ Gives the wrapped scanner's description. """
281
        return self.scanner_info['description']
282
283
    def get_server_version(self) -> str:
284
        """ Gives the specific OSP server's version. """
285
        assert self.server_version
286
        return self.server_version
287
288
    def get_protocol_version(self) -> str:
289
        """ Gives the OSP's version. """
290
        return self.protocol_version
291
292
    def preprocess_scan_params(self, xml_params):
293
        """ Processes the scan parameters. """
294
        params = {}
295
296
        for param in xml_params:
297
            params[param.tag] = param.text or ''
298
299
        # Set default values.
300
        for key in self.scanner_params:
301
            if key not in params:
302
                params[key] = self.get_scanner_param_default(key)
303
                if self.get_scanner_param_type(key) == 'selection':
304
                    params[key] = params[key].split('|')[0]
305
306
        # Validate values.
307
        for key in params:
308
            param_type = self.get_scanner_param_type(key)
309
            if not param_type:
310
                continue
311
312
            if param_type in ['integer', 'boolean']:
313
                try:
314
                    params[key] = int(params[key])
315
                except ValueError:
316
                    raise OspdCommandError(
317
                        'Invalid %s value' % key, 'start_scan'
318
                    )
319
320
            if param_type == 'boolean':
321
                if params[key] not in [0, 1]:
322
                    raise OspdCommandError(
323
                        'Invalid %s value' % key, 'start_scan'
324
                    )
325
            elif param_type == 'selection':
326
                selection = self.get_scanner_param_default(key).split('|')
327
                if params[key] not in selection:
328
                    raise OspdCommandError(
329
                        'Invalid %s value' % key, 'start_scan'
330
                    )
331
            if self.get_scanner_param_mandatory(key) and params[key] == '':
332
                raise OspdCommandError(
333
                    'Mandatory %s value is missing' % key, 'start_scan'
334
                )
335
336
        return params
337
338
    def process_scan_params(self, params: Dict) -> Dict:
339
        """ This method is to be overridden by the child classes if necessary
340
        """
341
        return params
342
343
    @staticmethod
344
    @deprecated(
345
        version="20.8",
346
        reason="Please use OspRequest.process_vt_params instead.",
347
    )
348
    def process_vts_params(scanner_vts) -> Dict:
349
        return OspRequest.process_vts_params(scanner_vts)
350
351
    @staticmethod
352
    @deprecated(
353
        version="20.8",
354
        reason="Please use OspRequest.process_credential_elements instead.",
355
    )
356
    def process_credentials_elements(cred_tree) -> Dict:
357
        return OspRequest.process_credentials_elements(cred_tree)
358
359
    @staticmethod
360
    @deprecated(
361
        version="20.8",
362
        reason="Please use OspRequest.process_targets_elements instead.",
363
    )
364
    def process_targets_element(scanner_target) -> List:
365
        return OspRequest.process_target_element(scanner_target)
366
367
    def stop_scan(self, scan_id: str) -> None:
368
        scan_process = self.scan_processes.get(scan_id)
369
        if not scan_process:
370
            raise OspdCommandError(
371
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
372
            )
373
        if not scan_process.is_alive():
374
            raise OspdCommandError(
375
                'Scan already stopped or finished.', 'stop_scan'
376
            )
377
378
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
379
380
        logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
381
382
        self.stop_scan_cleanup(scan_id)
383
384
        try:
385
            scan_process.terminate()
386
        except AttributeError:
387
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
388
389
        try:
390
            _terminate_process_group(scan_process)
391
        except ProcessLookupError:
392
            logger.info(
393
                '%s: Scan already stopped %s.', scan_id, scan_process.pid
394
            )
395
396
        if scan_process.ident != os.getpid():
397
            scan_process.join(0)
398
399
        logger.info('%s: Scan stopped.', scan_id)
400
401
    @staticmethod
402
    def stop_scan_cleanup(scan_id: str):
403
        """ Should be implemented by subclass in case of a clean up before
404
        terminating is needed. """
405
406
    @staticmethod
407
    def target_is_finished(scan_id: str):
408
        """ Should be implemented by subclass in case of a check before
409
        stopping is needed. """
410
411
    def exec_scan(self, scan_id: str):
412
        """ Asserts to False. Should be implemented by subclass. """
413
        raise NotImplementedError
414
415
    def finish_scan(self, scan_id: str) -> None:
416
        """ Sets a scan as finished. """
417
        self.scan_collection.set_progress(scan_id, PROGRESS_FINISHED)
418
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
419
        logger.info("%s: Scan finished.", scan_id)
420
421
    def get_daemon_name(self) -> str:
422
        """ Gives osp daemon's name. """
423
        return self.daemon_info['name']
424
425
    def get_daemon_version(self) -> str:
426
        """ Gives osp daemon's version. """
427
        return self.daemon_info['version']
428
429
    def get_scanner_param_type(self, param: str):
430
        """ Returns type of a scanner parameter. """
431
        assert isinstance(param, str)
432
        entry = self.scanner_params.get(param)
433
        if not entry:
434
            return None
435
        return entry.get('type')
436
437
    def get_scanner_param_mandatory(self, param: str):
438
        """ Returns if a scanner parameter is mandatory. """
439
        assert isinstance(param, str)
440
        entry = self.scanner_params.get(param)
441
        if not entry:
442
            return False
443
        return entry.get('mandatory')
444
445
    def get_scanner_param_default(self, param: str):
446
        """ Returns default value of a scanner parameter. """
447
        assert isinstance(param, str)
448
        entry = self.scanner_params.get(param)
449
        if not entry:
450
            return None
451
        return entry.get('default')
452
453
    @deprecated(
454
        version="20.8",
455
        reason="Please use OspResponse.create_scanner_params_xml instead.",
456
    )
457
    def get_scanner_params_xml(self):
458
        """ Returns the OSP Daemon's scanner params in xml format. """
459
        return OspResponse.create_scanner_params_xml(self.scanner_params)
460
461
    def handle_client_stream(self, stream: Stream) -> None:
462
        """ Handles stream of data received from client. """
463
        data = b''
464
465
        request_parser = RequestParser()
466
467
        while True:
468
            try:
469
                buf = stream.read()
470
                if not buf:
471
                    break
472
473
                data += buf
474
475
                if request_parser.has_ended(buf):
476
                    break
477
            except (AttributeError, ValueError) as message:
478
                logger.error(message)
479
                return
480
            except (ssl.SSLError) as exception:
481
                logger.debug('Error: %s', exception)
482
                break
483
            except (socket.timeout) as exception:
484
                logger.debug('Request timeout: %s', exception)
485
                break
486
487
        if len(data) <= 0:
488
            logger.debug("Empty client stream")
489
            return
490
491
        response = None
492
        try:
493
            self.handle_command(data, stream)
494
        except OspdCommandError as exception:
495
            response = exception.as_xml()
496
            logger.debug('Command error: %s', exception.message)
497
        except Exception:  # pylint: disable=broad-except
498
            logger.exception('While handling client command:')
499
            exception = OspdCommandError('Fatal error', 'error')
500
            response = exception.as_xml()
501
502
        if response:
503
            stream.write(response)
504
505
        stream.close()
506
507
    def calculate_progress(self, scan_id: str) -> float:
508
        """ Calculate the total scan progress. """
509
510
        return self.scan_collection.calculate_target_progress(scan_id)
511
512
    def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None:
513
        """ Process the finished hosts before launching the scans."""
514
515
        if not finished_hosts:
516
            return
517
518
        exc_finished_hosts_list = target_str_to_list(finished_hosts)
519
        self.scan_collection.set_host_finished(scan_id, exc_finished_hosts_list)
520
521
    def start_scan(self, scan_id: str, target: Dict) -> None:
522
        """ Starts the scan with scan_id. """
523
        os.setsid()
524
525
        if target is None or not target:
526
            raise OspdCommandError('Erroneous target', 'start_scan')
527
528
        logger.info("%s: Scan started.", scan_id)
529
530
        self.process_finished_hosts(scan_id, target.get('finished_hosts'))
531
532
        try:
533
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
534
            self.exec_scan(scan_id)
535
        except Exception as e:  # pylint: disable=broad-except
536
            self.add_scan_error(
537
                scan_id,
538
                name='',
539
                host=self.get_scan_host(scan_id),
540
                value='Host process failure (%s).' % e,
541
            )
542
            logger.exception('While scanning: %s', scan_id)
543
        else:
544
            logger.info("%s: Host scan finished.", scan_id)
545
546
        if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
547
            self.finish_scan(scan_id)
548
549
    def dry_run_scan(self, scan_id: str, target: Dict) -> None:
550
        """ Dry runs a scan. """
551
552
        os.setsid()
553
554
        host = resolve_hostname(target[0])
555
        if host is None:
556
            logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))
557
558
        port = self.get_scan_ports(scan_id)
559
560
        logger.info("%s:%s: Dry run mode.", host, port)
561
562
        self.add_scan_log(scan_id, name='', host=host, value='Dry run result')
563
564
        self.finish_scan(scan_id)
565
566
    def handle_timeout(self, scan_id: str, host: str) -> None:
567
        """ Handles scanner reaching timeout error. """
568
        self.add_scan_error(
569
            scan_id,
570
            host=host,
571
            name="Timeout",
572
            value="{0} exec timeout.".format(self.get_scanner_name()),
573
        )
574
575
    def sort_host_finished(
576
        self, scan_id: str, finished_hosts: Union[List[str], str],
577
    ) -> None:
578
        """ Check if the finished host in the list was alive or dead
579
        and update the corresponding alive_count or dead_count. """
580
        if isinstance(finished_hosts, str):
581
            finished_hosts = [finished_hosts]
582
583
        alive_hosts = []
584
        dead_hosts = []
585
586
        current_hosts = self.scan_collection.get_current_target_progress(
587
            scan_id
588
        )
589
        for finished_host in finished_hosts:
590
            progress = current_hosts.get(finished_host)
591
            if progress == PROGRESS_FINISHED:
592
                alive_hosts.append(finished_host)
593
            if progress == PROGRESS_DEAD_HOST:
594
                dead_hosts.append(finished_host)
595
596
        self.scan_collection.set_host_dead(scan_id, dead_hosts)
597
598
        self.scan_collection.set_host_finished(scan_id, alive_hosts)
599
600
        self.scan_collection.remove_hosts_from_target_progress(
601
            scan_id, finished_hosts
602
        )
603
604
    def set_scan_progress_batch(
605
        self, scan_id: str, host_progress: Dict[str, int]
606
    ):
607
        self.scan_collection.set_host_progress(scan_id, host_progress)
608
609
        scan_progress = self.calculate_progress(scan_id)
610
        self.scan_collection.set_progress(scan_id, scan_progress)
611
612
    def set_scan_host_progress(
613
        self, scan_id: str, host: str = None, progress: int = None,
614
    ) -> None:
615
        """ Sets host's progress which is part of target.
616
        Each time a host progress is updated, the scan progress
617
        is updated too.
618
        """
619
        host_progress = {host: progress}
620
        self.set_scan_progress_batch(scan_id, host_progress)
621
622
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
623
        """ Set the scan's status."""
624
        self.scan_collection.set_status(scan_id, status)
625
626
    def get_scan_status(self, scan_id: str) -> ScanStatus:
627
        """ Get scan_id scans's status."""
628
        return self.scan_collection.get_status(scan_id)
629
630
    def scan_exists(self, scan_id: str) -> bool:
631
        """ Checks if a scan with ID scan_id is in collection.
632
633
        @return: 1 if scan exists, 0 otherwise.
634
        """
635
        return self.scan_collection.id_exists(scan_id)
636
637
    def get_help_text(self) -> str:
638
        """ Returns the help output in plain text format."""
639
640
        txt = ''
641
        for name, info in self.commands.items():
642
            description = info.get_description()
643
            attributes = info.get_attributes()
644
            elements = info.get_elements()
645
646
            command_txt = "\t{0: <22} {1}\n".format(name, description)
647
648
            if attributes:
649
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
650
651
                for attrname, attrdesc in attributes.items():
652
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
653
                    command_txt = ''.join([command_txt, attr_txt])
654
655
            if elements:
656
                command_txt = ''.join(
657
                    [command_txt, "\t Elements:\n", elements_as_text(elements),]
658
                )
659
660
            txt += command_txt
661
662
        return txt
663
664
    @deprecated(version="20.8", reason="Use ospd.xml.elements_as_text instead.")
665
    def elements_as_text(self, elems: Dict, indent: int = 2) -> str:
666
        """ Returns the elems dictionary as formatted plain text. """
667
        return elements_as_text(elems, indent)
668
669
    def delete_scan(self, scan_id: str) -> int:
670
        """ Deletes scan_id scan from collection.
671
672
        @return: 1 if scan deleted, 0 otherwise.
673
        """
674
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
675
            return 0
676
677
        # Don't delete the scan until the process stops
678
        exitcode = None
679
        try:
680
            self.scan_processes[scan_id].join()
681
            exitcode = self.scan_processes[scan_id].exitcode
682
        except KeyError:
683
            logger.debug('Scan process for %s not found', scan_id)
684
685
        if exitcode or exitcode == 0:
686
            del self.scan_processes[scan_id]
687
688
        return self.scan_collection.delete_scan(scan_id)
689
690
    def get_scan_results_xml(
691
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
692
    ):
693
        """ Gets scan_id scan's results in XML format.
694
695
        @return: String of scan results in xml.
696
        """
697
        results = Element('results')
698
        for result in self.scan_collection.results_iterator(
699
            scan_id, pop_res, max_res
700
        ):
701
            results.append(get_result_xml(result))
702
703
        logger.debug('Returning %d results', len(results))
704
        return results
705
706
    def _get_scan_progress_xml(self, scan_id: str):
707
        """ Gets scan_id scan's progress in XML format.
708
709
        @return: String of scan progress in xml.
710
        """
711
        current_progress = dict()
712
713
        current_progress[
714
            'current_hosts'
715
        ] = self.scan_collection.get_current_target_progress(scan_id)
716
        current_progress['overall'] = self.scan_collection.get_progress(scan_id)
717
        current_progress['count_alive'] = self.scan_collection.get_count_alive(
718
            scan_id
719
        )
720
        current_progress['count_dead'] = self.scan_collection.get_count_dead(
721
            scan_id
722
        )
723
        current_progress[
724
            'count_excluded'
725
        ] = self.scan_collection.simplify_exclude_host_count(scan_id)
726
        current_progress['count_total'] = self.scan_collection.get_host_count(
727
            scan_id
728
        )
729
730
        return get_progress_xml(current_progress)
731
732
    @deprecated(
733
        version="20.8",
734
        reason="Please use ospd.xml.get_elements_from_dict instead.",
735
    )
736
    def get_xml_str(self, data: Dict) -> List:
737
        """ Creates a string in XML Format using the provided data structure.
738
739
        @param: Dictionary of xml tags and their elements.
740
741
        @return: String of data in xml format.
742
        """
743
        return get_elements_from_dict(data)
744
745
    def get_scan_xml(
746
        self,
747
        scan_id: str,
748
        detailed: bool = True,
749
        pop_res: bool = False,
750
        max_res: int = 0,
751
        progress: bool = False,
752
    ):
753
        """ Gets scan in XML format.
754
755
        @return: String of scan in XML format.
756
        """
757
        if not scan_id:
758
            return Element('scan')
759
760
        target = self.get_scan_host(scan_id)
761
        progress = self.get_scan_progress(scan_id)
762
        status = self.get_scan_status(scan_id)
763
        start_time = self.get_scan_start_time(scan_id)
764
        end_time = self.get_scan_end_time(scan_id)
765
        response = Element('scan')
766
        for name, value in [
767
            ('id', scan_id),
768
            ('target', target),
769
            ('progress', progress),
770
            ('status', status.name.lower()),
771
            ('start_time', start_time),
772
            ('end_time', end_time),
773
        ]:
774
            response.set(name, str(value))
775
        if detailed:
776
            response.append(
777
                self.get_scan_results_xml(scan_id, pop_res, max_res)
778
            )
779
        if progress:
780
            response.append(self._get_scan_progress_xml(scan_id))
781
782
        return response
783
784
    @staticmethod
785
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
786
        vt_id: str, custom: Dict
787
    ) -> str:
788
        """ Create a string representation of the XML object from the
789
        custom data object.
790
        This needs to be implemented by each ospd wrapper, in case
791
        custom elements for VTs are used.
792
793
        The custom XML object which is returned will be embedded
794
        into a <custom></custom> element.
795
796
        @return: XML object as string for custom data.
797
        """
798
        return ''
799
800
    @staticmethod
801
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
802
        vt_id: str, vt_params
803
    ) -> str:
804
        """ Create a string representation of the XML object from the
805
        vt_params data object.
806
        This needs to be implemented by each ospd wrapper, in case
807
        vt_params elements for VTs are used.
808
809
        The params XML object which is returned will be embedded
810
        into a <params></params> element.
811
812
        @return: XML object as string for vt parameters data.
813
        """
814
        return ''
815
816
    @staticmethod
817
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
818
        vt_id: str, vt_refs
819
    ) -> str:
820
        """ Create a string representation of the XML object from the
821
        refs data object.
822
        This needs to be implemented by each ospd wrapper, in case
823
        refs elements for VTs are used.
824
825
        The refs XML object which is returned will be embedded
826
        into a <refs></refs> element.
827
828
        @return: XML object as string for vt references data.
829
        """
830
        return ''
831
832
    @staticmethod
833
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
834
        vt_id: str, vt_dependencies
835
    ) -> str:
836
        """ Create a string representation of the XML object from the
837
        vt_dependencies data object.
838
        This needs to be implemented by each ospd wrapper, in case
839
        vt_dependencies elements for VTs are used.
840
841
        The vt_dependencies XML object which is returned will be embedded
842
        into a <dependencies></dependencies> element.
843
844
        @return: XML object as string for vt dependencies data.
845
        """
846
        return ''
847
848
    @staticmethod
849
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
850
        vt_id: str, vt_creation_time
851
    ) -> str:
852
        """ Create a string representation of the XML object from the
853
        vt_creation_time data object.
854
        This needs to be implemented by each ospd wrapper, in case
855
        vt_creation_time elements for VTs are used.
856
857
        The vt_creation_time XML object which is returned will be embedded
858
        into a <vt_creation_time></vt_creation_time> element.
859
860
        @return: XML object as string for vt creation time data.
861
        """
862
        return ''
863
864
    @staticmethod
865
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
866
        vt_id: str, vt_modification_time
867
    ) -> str:
868
        """ Create a string representation of the XML object from the
869
        vt_modification_time data object.
870
        This needs to be implemented by each ospd wrapper, in case
871
        vt_modification_time elements for VTs are used.
872
873
        The vt_modification_time XML object which is returned will be embedded
874
        into a <vt_modification_time></vt_modification_time> element.
875
876
        @return: XML object as string for vt references data.
877
        """
878
        return ''
879
880
    @staticmethod
881
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
882
        vt_id: str, summary
883
    ) -> str:
884
        """ Create a string representation of the XML object from the
885
        summary data object.
886
        This needs to be implemented by each ospd wrapper, in case
887
        summary elements for VTs are used.
888
889
        The summary XML object which is returned will be embedded
890
        into a <summary></summary> element.
891
892
        @return: XML object as string for summary data.
893
        """
894
        return ''
895
896
    @staticmethod
897
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
898
        vt_id: str, impact
899
    ) -> str:
900
        """ Create a string representation of the XML object from the
901
        impact data object.
902
        This needs to be implemented by each ospd wrapper, in case
903
        impact elements for VTs are used.
904
905
        The impact XML object which is returned will be embedded
906
        into a <impact></impact> element.
907
908
        @return: XML object as string for impact data.
909
        """
910
        return ''
911
912
    @staticmethod
913
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
914
        vt_id: str, affected
915
    ) -> str:
916
        """ Create a string representation of the XML object from the
917
        affected data object.
918
        This needs to be implemented by each ospd wrapper, in case
919
        affected elements for VTs are used.
920
921
        The affected XML object which is returned will be embedded
922
        into a <affected></affected> element.
923
924
        @return: XML object as string for affected data.
925
        """
926
        return ''
927
928
    @staticmethod
929
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
930
        vt_id: str, insight
931
    ) -> str:
932
        """ Create a string representation of the XML object from the
933
        insight data object.
934
        This needs to be implemented by each ospd wrapper, in case
935
        insight elements for VTs are used.
936
937
        The insight XML object which is returned will be embedded
938
        into a <insight></insight> element.
939
940
        @return: XML object as string for insight data.
941
        """
942
        return ''
943
944
    @staticmethod
945
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
946
        vt_id: str, solution, solution_type=None, solution_method=None
947
    ) -> str:
948
        """ Create a string representation of the XML object from the
949
        solution data object.
950
        This needs to be implemented by each ospd wrapper, in case
951
        solution elements for VTs are used.
952
953
        The solution XML object which is returned will be embedded
954
        into a <solution></solution> element.
955
956
        @return: XML object as string for solution data.
957
        """
958
        return ''
959
960
    @staticmethod
961
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
962
        vt_id: str, detection=None, qod_type=None, qod=None
963
    ) -> str:
964
        """ Create a string representation of the XML object from the
965
        detection data object.
966
        This needs to be implemented by each ospd wrapper, in case
967
        detection elements for VTs are used.
968
969
        The detection XML object which is returned is an element with
970
        tag <detection></detection> element
971
972
        @return: XML object as string for detection data.
973
        """
974
        return ''
975
976
    @staticmethod
977
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
978
        vt_id: str, severities
979
    ) -> str:
980
        """ Create a string representation of the XML object from the
981
        severities data object.
982
        This needs to be implemented by each ospd wrapper, in case
983
        severities elements for VTs are used.
984
985
        The severities XML objects which are returned will be embedded
986
        into a <severities></severities> element.
987
988
        @return: XML object as string for severities data.
989
        """
990
        return ''
991
992
    def get_vt_iterator(  # pylint: disable=unused-argument
993
        self, vt_selection: List[str] = None, details: bool = True
994
    ) -> Iterator[Tuple[str, Dict]]:
995
        """ Return iterator object for getting elements
996
        from the VTs dictionary. """
997
        return self.vts.items()
998
999
    def get_vt_xml(self, single_vt: Tuple[str, Dict]) -> Element:
1000
        """ Gets a single vulnerability test information in XML format.
1001
1002
        @return: String of single vulnerability test information in XML format.
1003
        """
1004
        if not single_vt or single_vt[1] is None:
1005
            return Element('vt')
1006
1007
        vt_id, vt = single_vt
1008
1009
        name = vt.get('name')
1010
        vt_xml = Element('vt')
1011
        vt_xml.set('id', vt_id)
1012
1013
        for name, value in [('name', name)]:
1014
            elem = SubElement(vt_xml, name)
1015
            elem.text = str(value)
1016
1017
        if vt.get('vt_params'):
1018
            params_xml_str = self.get_params_vt_as_xml_str(
1019
                vt_id, vt.get('vt_params')
1020
            )
1021
            vt_xml.append(secET.fromstring(params_xml_str))
1022
1023
        if vt.get('vt_refs'):
1024
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
1025
            vt_xml.append(secET.fromstring(refs_xml_str))
1026
1027
        if vt.get('vt_dependencies'):
1028
            dependencies = self.get_dependencies_vt_as_xml_str(
1029
                vt_id, vt.get('vt_dependencies')
1030
            )
1031
            vt_xml.append(secET.fromstring(dependencies))
1032
1033
        if vt.get('creation_time'):
1034
            vt_ctime = self.get_creation_time_vt_as_xml_str(
1035
                vt_id, vt.get('creation_time')
1036
            )
1037
            vt_xml.append(secET.fromstring(vt_ctime))
1038
1039
        if vt.get('modification_time'):
1040
            vt_mtime = self.get_modification_time_vt_as_xml_str(
1041
                vt_id, vt.get('modification_time')
1042
            )
1043
            vt_xml.append(secET.fromstring(vt_mtime))
1044
1045
        if vt.get('summary'):
1046
            summary_xml_str = self.get_summary_vt_as_xml_str(
1047
                vt_id, vt.get('summary')
1048
            )
1049
            vt_xml.append(secET.fromstring(summary_xml_str))
1050
1051
        if vt.get('impact'):
1052
            impact_xml_str = self.get_impact_vt_as_xml_str(
1053
                vt_id, vt.get('impact')
1054
            )
1055
            vt_xml.append(secET.fromstring(impact_xml_str))
1056
1057
        if vt.get('affected'):
1058
            affected_xml_str = self.get_affected_vt_as_xml_str(
1059
                vt_id, vt.get('affected')
1060
            )
1061
            vt_xml.append(secET.fromstring(affected_xml_str))
1062
1063
        if vt.get('insight'):
1064
            insight_xml_str = self.get_insight_vt_as_xml_str(
1065
                vt_id, vt.get('insight')
1066
            )
1067
            vt_xml.append(secET.fromstring(insight_xml_str))
1068
1069
        if vt.get('solution'):
1070
            solution_xml_str = self.get_solution_vt_as_xml_str(
1071
                vt_id,
1072
                vt.get('solution'),
1073
                vt.get('solution_type'),
1074
                vt.get('solution_method'),
1075
            )
1076
            vt_xml.append(secET.fromstring(solution_xml_str))
1077
1078
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1079
            detection_xml_str = self.get_detection_vt_as_xml_str(
1080
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1081
            )
1082
            vt_xml.append(secET.fromstring(detection_xml_str))
1083
1084
        if vt.get('severities'):
1085
            severities_xml_str = self.get_severities_vt_as_xml_str(
1086
                vt_id, vt.get('severities')
1087
            )
1088
            vt_xml.append(secET.fromstring(severities_xml_str))
1089
1090
        if vt.get('custom'):
1091
            custom_xml_str = self.get_custom_vt_as_xml_str(
1092
                vt_id, vt.get('custom')
1093
            )
1094
            vt_xml.append(secET.fromstring(custom_xml_str))
1095
1096
        return vt_xml
1097
1098
    def get_vts_selection_list(
1099
        self, vt_id: str = None, filtered_vts: Dict = None
1100
    ) -> Iterable[str]:
1101
        """
1102
        Get list of VT's OID.
1103
        If vt_id is specified, the collection will contain only this vt, if
1104
        found.
1105
        If no vt_id is specified or filtered_vts is None (default), the
1106
        collection will contain all vts. Otherwise those vts passed
1107
        in filtered_vts or vt_id are returned. In case of both vt_id and
1108
        filtered_vts are given, filtered_vts has priority.
1109
1110
        Arguments:
1111
            vt_id (vt_id, optional): ID of the vt to get.
1112
            filtered_vts (list, optional): Filtered VTs collection.
1113
1114
        Return:
1115
            List of selected VT's OID.
1116
        """
1117
        vts_xml = []
1118
1119
        # No match for the filter
1120
        if filtered_vts is not None and len(filtered_vts) == 0:
1121
            return vts_xml
1122
1123
        if filtered_vts:
1124
            vts_list = filtered_vts
1125
        elif vt_id:
1126
            vts_list = [vt_id]
1127
        else:
1128
            vts_list = self.vts.keys()
1129
1130
        return vts_list
1131
1132
    def handle_command(self, data: bytes, stream: Stream) -> None:
1133
        """ Handles an osp command in a string.
1134
        """
1135
        try:
1136
            tree = secET.fromstring(data)
1137
        except secET.ParseError:
1138
            logger.debug("Erroneous client input: %s", data)
1139
            raise OspdCommandError('Invalid data')
1140
1141
        command_name = tree.tag
1142
1143
        logger.debug('Handling %s command request.', command_name)
1144
1145
        command = self.commands.get(command_name, None)
1146
        if not command and command_name != "authenticate":
1147
            raise OspdCommandError('Bogus command name')
1148
1149
        if not self.initialized and command.must_be_initialized:
1150
            exception = OspdCommandError(
1151
                '%s is still starting' % self.daemon_info['name'], 'error'
1152
            )
1153
            response = exception.as_xml()
1154
            stream.write(response)
1155
            return
1156
1157
        response = command.handle_xml(tree)
1158
1159
        if isinstance(response, bytes):
1160
            stream.write(response)
1161
        else:
1162
            for data in response:
1163
                stream.write(data)
1164
1165
    def check(self):
1166
        """ Asserts to False. Should be implemented by subclass. """
1167
        raise NotImplementedError
1168
1169
    def run(self) -> None:
1170
        """ Starts the Daemon, handling commands until interrupted.
1171
        """
1172
1173
        try:
1174
            while True:
1175
                time.sleep(SCHEDULER_CHECK_PERIOD)
1176
                self.scheduler()
1177
                self.clean_forgotten_scans()
1178
                self.wait_for_children()
1179
        except KeyboardInterrupt:
1180
            logger.info("Received Ctrl-C shutting-down ...")
1181
1182
    def scheduler(self):
1183
        """ Should be implemented by subclass in case of need
1184
        to run tasks periodically. """
1185
1186
    def wait_for_children(self):
1187
        """ Join the zombie process to releases resources."""
1188
        for scan_id in self.scan_processes:
1189
            self.scan_processes[scan_id].join(0)
1190
1191
    def create_scan(
1192
        self,
1193
        scan_id: str,
1194
        targets: Dict,
1195
        options: Optional[Dict],
1196
        vt_selection: Dict,
1197
    ) -> Optional[str]:
1198
        """ Creates a new scan.
1199
1200
        @target: Target to scan.
1201
        @options: Miscellaneous scan options.
1202
1203
        @return: New scan's ID. None if the scan_id already exists.
1204
        """
1205
        status = None
1206
        scan_exists = self.scan_exists(scan_id)
1207
        if scan_id and scan_exists:
1208
            status = self.get_scan_status(scan_id)
1209
            logger.info(
1210
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1211
            )
1212
            return
1213
1214
        return self.scan_collection.create_scan(
1215
            scan_id, targets, options, vt_selection
1216
        )
1217
1218
    def get_scan_options(self, scan_id: str) -> str:
1219
        """ Gives a scan's list of options. """
1220
        return self.scan_collection.get_options(scan_id)
1221
1222
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1223
        """ Sets a scan's option to a provided value. """
1224
        return self.scan_collection.set_option(scan_id, name, value)
1225
1226
    def clean_forgotten_scans(self) -> None:
1227
        """ Check for old stopped or finished scans which have not been
1228
        deleted and delete them if the are older than the set value."""
1229
1230
        if not self.scaninfo_store_time:
1231
            return
1232
1233
        for scan_id in list(self.scan_collection.ids_iterator()):
1234
            end_time = int(self.get_scan_end_time(scan_id))
1235
            scan_status = self.get_scan_status(scan_id)
1236
1237
            if (
1238
                scan_status == ScanStatus.STOPPED
1239
                or scan_status == ScanStatus.FINISHED
1240
            ) and end_time:
1241
                stored_time = int(time.time()) - end_time
1242
                if stored_time > self.scaninfo_store_time * 3600:
1243
                    logger.debug(
1244
                        'Scan %s is older than %d hours and seems have been '
1245
                        'forgotten. Scan info will be deleted from the '
1246
                        'scan table',
1247
                        scan_id,
1248
                        self.scaninfo_store_time,
1249
                    )
1250
                    self.delete_scan(scan_id)
1251
1252
    def check_scan_process(self, scan_id: str) -> None:
1253
        """ Check the scan's process, and terminate the scan if not alive. """
1254
        scan_process = self.scan_processes[scan_id]
1255
        progress = self.get_scan_progress(scan_id)
1256
1257
        if progress < PROGRESS_FINISHED and not scan_process.is_alive():
1258
            if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
1259
                self.set_scan_status(scan_id, ScanStatus.STOPPED)
1260
                self.add_scan_error(
1261
                    scan_id, name="", host="", value="Scan process failure."
1262
                )
1263
1264
                logger.info("%s: Scan stopped with errors.", scan_id)
1265
1266
        elif progress == PROGRESS_FINISHED:
1267
            scan_process.join(0)
1268
1269
    def get_scan_progress(self, scan_id: str):
1270
        """ Gives a scan's current progress value. """
1271
        return self.scan_collection.get_progress(scan_id)
1272
1273
    def get_scan_host(self, scan_id: str) -> str:
1274
        """ Gives a scan's target. """
1275
        return self.scan_collection.get_host_list(scan_id)
1276
1277
    def get_scan_ports(self, scan_id: str) -> str:
1278
        """ Gives a scan's ports list. """
1279
        return self.scan_collection.get_ports(scan_id)
1280
1281
    def get_scan_exclude_hosts(self, scan_id: str):
1282
        """ Gives a scan's exclude host list. If a target is passed gives
1283
        the exclude host list for the given target. """
1284
        return self.scan_collection.get_exclude_hosts(scan_id)
1285
1286
    def get_scan_credentials(self, scan_id: str) -> Dict:
1287
        """ Gives a scan's credential list. If a target is passed gives
1288
        the credential list for the given target. """
1289
        return self.scan_collection.get_credentials(scan_id)
1290
1291
    def get_scan_target_options(self, scan_id: str) -> Dict:
1292
        """ Gives a scan's target option dict. If a target is passed gives
1293
        the credential list for the given target. """
1294
        return self.scan_collection.get_target_options(scan_id)
1295
1296
    def get_scan_vts(self, scan_id: str) -> Dict:
1297
        """ Gives a scan's vts. """
1298
        return self.scan_collection.get_vts(scan_id)
1299
1300
    def get_scan_start_time(self, scan_id: str) -> str:
1301
        """ Gives a scan's start time. """
1302
        return self.scan_collection.get_start_time(scan_id)
1303
1304
    def get_scan_end_time(self, scan_id: str) -> str:
1305
        """ Gives a scan's end time. """
1306
        return self.scan_collection.get_end_time(scan_id)
1307
1308
    def add_scan_log(
1309
        self,
1310
        scan_id: str,
1311
        host: str = '',
1312
        hostname: str = '',
1313
        name: str = '',
1314
        value: str = '',
1315
        port: str = '',
1316
        test_id: str = '',
1317
        qod: str = '',
1318
    ) -> None:
1319
        """ Adds a log result to scan_id scan. """
1320
1321
        self.scan_collection.add_result(
1322
            scan_id,
1323
            ResultType.LOG,
1324
            host,
1325
            hostname,
1326
            name,
1327
            value,
1328
            port,
1329
            test_id,
1330
            '0.0',
1331
            qod,
1332
        )
1333
1334
    def add_scan_error(
1335
        self,
1336
        scan_id: str,
1337
        host: str = '',
1338
        hostname: str = '',
1339
        name: str = '',
1340
        value: str = '',
1341
        port: str = '',
1342
        test_id='',
1343
    ) -> None:
1344
        """ Adds an error result to scan_id scan. """
1345
        self.scan_collection.add_result(
1346
            scan_id,
1347
            ResultType.ERROR,
1348
            host,
1349
            hostname,
1350
            name,
1351
            value,
1352
            port,
1353
            test_id,
1354
        )
1355
1356
    def add_scan_host_detail(
1357
        self,
1358
        scan_id: str,
1359
        host: str = '',
1360
        hostname: str = '',
1361
        name: str = '',
1362
        value: str = '',
1363
    ) -> None:
1364
        """ Adds a host detail result to scan_id scan. """
1365
        self.scan_collection.add_result(
1366
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value
1367
        )
1368
1369
    def add_scan_alarm(
1370
        self,
1371
        scan_id: str,
1372
        host: str = '',
1373
        hostname: str = '',
1374
        name: str = '',
1375
        value: str = '',
1376
        port: str = '',
1377
        test_id: str = '',
1378
        severity: str = '',
1379
        qod: str = '',
1380
    ) -> None:
1381
        """ Adds an alarm result to scan_id scan. """
1382
        self.scan_collection.add_result(
1383
            scan_id,
1384
            ResultType.ALARM,
1385
            host,
1386
            hostname,
1387
            name,
1388
            value,
1389
            port,
1390
            test_id,
1391
            severity,
1392
            qod,
1393
        )
1394