Passed
Pull Request — master (#211)
by Juan José
01:24
created

ospd.ospd.OSPDaemon.get_scan_target_progress()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
# pylint: disable=too-many-lines
20
21
""" OSP Daemon core class.
22
"""
23
24
import logging
25
import socket
26
import ssl
27
import multiprocessing
28
import time
29
import os
30
31
from typing import List, Any, Dict, Optional, Iterable
32
from xml.etree.ElementTree import Element, SubElement
33
34
import defusedxml.ElementTree as secET
35
36
from deprecated import deprecated
37
38
from ospd import __version__
39
from ospd.command import get_commands
40
from ospd.errors import OspdCommandError
41
from ospd.misc import ResultType
42
from ospd.network import resolve_hostname, target_str_to_list
43
from ospd.protocol import OspRequest, OspResponse
44
from ospd.scan import ScanCollection, ScanStatus
45
from ospd.server import BaseServer
46
from ospd.vtfilter import VtsFilter
47
from ospd.vts import Vts
48
from ospd.xml import (
49
    elements_as_text,
50
    get_result_xml,
51
    get_elements_from_dict,
52
)
53
54
logger = logging.getLogger(__name__)
55
56
PROTOCOL_VERSION = "1.2"
57
58
SCHEDULER_CHECK_PERIOD = 5  # in seconds
59
60
BASE_SCANNER_PARAMS = {
61
    'debug_mode': {
62
        'type': 'boolean',
63
        'name': 'Debug Mode',
64
        'default': 0,
65
        'mandatory': 0,
66
        'description': 'Whether to get extra scan debug information.',
67
    },
68
    'dry_run': {
69
        'type': 'boolean',
70
        'name': 'Dry Run',
71
        'default': 0,
72
        'mandatory': 0,
73
        'description': 'Whether to dry run scan.',
74
    },
75
}  # type: Dict
76
77
78
def _terminate_process_group(process: multiprocessing.Process) -> None:
79
    os.killpg(os.getpgid(process.pid), 15)
80
81
82
class OSPDaemon:
83
84
    """ Daemon class for OSP traffic handling.
85
86
    Every scanner wrapper should subclass it and make necessary additions and
87
    changes.
88
89
    * Add any needed parameters in __init__.
90
    * Implement check() method which verifies scanner availability and other
91
      environment related conditions.
92
    * Implement process_scan_params and exec_scan methods which are
93
      specific to handling the <start_scan> command, executing the wrapped
94
      scanner and storing the results.
95
    * exec_scan() should return 0 if host is dead or not reached, 1 if host is
96
      alive and 2 if scan error or status is unknown.
97
    * Implement other methods that assert to False such as get_scanner_name,
98
      get_scanner_version.
99
    * Use Call set_command_attributes at init time to add scanner command
100
      specific options eg. the w3af profile for w3af wrapper.
101
    """
102
103
    def __init__(
104
        self, *, customvtfilter=None, **kwargs
105
    ):  # pylint: disable=unused-argument
106
        """ Initializes the daemon's internal data. """
107
        self.scan_collection = ScanCollection()
108
        self.scan_processes = dict()
109
110
        self.daemon_info = dict()
111
        self.daemon_info['name'] = "OSPd"
112
        self.daemon_info['version'] = __version__
113
        self.daemon_info['description'] = "No description"
114
115
        self.scanner_info = dict()
116
        self.scanner_info['name'] = 'No name'
117
        self.scanner_info['version'] = 'No version'
118
        self.scanner_info['description'] = 'No description'
119
120
        self.server_version = None  # Set by the subclass.
121
122
        self.initialized = None  # Set after initialization finished
123
124
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
125
126
        self.protocol_version = PROTOCOL_VERSION
127
128
        self.commands = {}
129
130
        for command_class in get_commands():
131
            command = command_class(self)
132
            self.commands[command.get_name()] = command
133
134
        self.scanner_params = dict()
135
136
        for name, params in BASE_SCANNER_PARAMS.items():
137
            self.set_scanner_param(name, params)
138
139
        self.vts = Vts()
140
        self.vts_version = None
141
142
        if customvtfilter:
143
            self.vts_filter = customvtfilter
144
        else:
145
            self.vts_filter = VtsFilter()
146
147
    def init(self, server: BaseServer) -> None:
148
        """ Should be overridden by a subclass if the initialization is costly.
149
150
            Will be called after check.
151
        """
152
        server.start(self.handle_client_stream)
153
        self.initialized = True
154
155
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
156
        """ Sets the xml attributes of a specified command. """
157
        if self.command_exists(name):
158
            command = self.commands.get(name)
159
            command.attributes = attributes
160
161
    @deprecated(version="20.4", reason="Use set_scanner_param instead")
162
    def add_scanner_param(self, name: str, scanner_params: Dict) -> None:
163
        """ Set a scanner parameter. """
164
        self.set_scanner_param(name, scanner_params)
165
166
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
167
        """ Set a scanner parameter. """
168
169
        assert name
170
        assert scanner_params
171
172
        self.scanner_params[name] = scanner_params
173
174
    def get_scanner_params(self) -> Dict:
175
        return self.scanner_params
176
177
    def add_vt(
178
        self,
179
        vt_id: str,
180
        name: str = None,
181
        vt_params: str = None,
182
        vt_refs: str = None,
183
        custom: str = None,
184
        vt_creation_time: str = None,
185
        vt_modification_time: str = None,
186
        vt_dependencies: str = None,
187
        summary: str = None,
188
        impact: str = None,
189
        affected: str = None,
190
        insight: str = None,
191
        solution: str = None,
192
        solution_t: str = None,
193
        solution_m: str = None,
194
        detection: str = None,
195
        qod_t: str = None,
196
        qod_v: str = None,
197
        severities: str = None,
198
    ) -> None:
199
        """ Add a vulnerability test information.
200
201
        IMPORTANT: The VT's Data Manager will store the vts collection.
202
        If the collection is considerably big and it will be consultated
203
        intensible during a routine, consider to do a deepcopy(), since
204
        accessing the shared memory in the data manager is very expensive.
205
        At the end of the routine, the temporal copy must be set to None
206
        and deleted.
207
        """
208
        self.vts.add(
209
            vt_id,
210
            name=name,
211
            vt_params=vt_params,
212
            vt_refs=vt_refs,
213
            custom=custom,
214
            vt_creation_time=vt_creation_time,
215
            vt_modification_time=vt_modification_time,
216
            vt_dependencies=vt_dependencies,
217
            summary=summary,
218
            impact=impact,
219
            affected=affected,
220
            insight=insight,
221
            solution=solution,
222
            solution_t=solution_t,
223
            solution_m=solution_m,
224
            detection=detection,
225
            qod_t=qod_t,
226
            qod_v=qod_v,
227
            severities=severities,
228
        )
229
230
    def set_vts_version(self, vts_version: str) -> None:
231
        """ Add into the vts dictionary an entry to identify the
232
        vts version.
233
234
        Parameters:
235
            vts_version (str): Identifies a unique vts version.
236
        """
237
        if not vts_version:
238
            raise OspdCommandError(
239
                'A vts_version parameter is required', 'set_vts_version'
240
            )
241
        self.vts_version = vts_version
242
243
    def get_vts_version(self) -> Optional[str]:
244
        """Return the vts version.
245
        """
246
        return self.vts_version
247
248
    def command_exists(self, name: str) -> bool:
249
        """ Checks if a commands exists. """
250
        return name in self.commands
251
252
    def get_scanner_name(self) -> str:
253
        """ Gives the wrapped scanner's name. """
254
        return self.scanner_info['name']
255
256
    def get_scanner_version(self) -> str:
257
        """ Gives the wrapped scanner's version. """
258
        return self.scanner_info['version']
259
260
    def get_scanner_description(self) -> str:
261
        """ Gives the wrapped scanner's description. """
262
        return self.scanner_info['description']
263
264
    def get_server_version(self) -> str:
265
        """ Gives the specific OSP server's version. """
266
        assert self.server_version
267
        return self.server_version
268
269
    def get_protocol_version(self) -> str:
270
        """ Gives the OSP's version. """
271
        return self.protocol_version
272
273
    def preprocess_scan_params(self, xml_params):
274
        """ Processes the scan parameters. """
275
        params = {}
276
277
        for param in xml_params:
278
            params[param.tag] = param.text or ''
279
280
        # Set default values.
281
        for key in self.scanner_params:
282
            if key not in params:
283
                params[key] = self.get_scanner_param_default(key)
284
                if self.get_scanner_param_type(key) == 'selection':
285
                    params[key] = params[key].split('|')[0]
286
287
        # Validate values.
288
        for key in params:
289
            param_type = self.get_scanner_param_type(key)
290
            if not param_type:
291
                continue
292
293
            if param_type in ['integer', 'boolean']:
294
                try:
295
                    params[key] = int(params[key])
296
                except ValueError:
297
                    raise OspdCommandError(
298
                        'Invalid %s value' % key, 'start_scan'
299
                    )
300
301
            if param_type == 'boolean':
302
                if params[key] not in [0, 1]:
303
                    raise OspdCommandError(
304
                        'Invalid %s value' % key, 'start_scan'
305
                    )
306
            elif param_type == 'selection':
307
                selection = self.get_scanner_param_default(key).split('|')
308
                if params[key] not in selection:
309
                    raise OspdCommandError(
310
                        'Invalid %s value' % key, 'start_scan'
311
                    )
312
            if self.get_scanner_param_mandatory(key) and params[key] == '':
313
                raise OspdCommandError(
314
                    'Mandatory %s value is missing' % key, 'start_scan'
315
                )
316
317
        return params
318
319
    def process_scan_params(self, params: Dict) -> Dict:
320
        """ This method is to be overridden by the child classes if necessary
321
        """
322
        return params
323
324
    @staticmethod
325
    @deprecated(
326
        version="20.4",
327
        reason="Please use OspRequest.process_vt_params instead.",
328
    )
329
    def process_vts_params(scanner_vts) -> Dict:
330
        return OspRequest.process_vts_params(scanner_vts)
331
332
    @staticmethod
333
    @deprecated(
334
        version="20.4",
335
        reason="Please use OspRequest.process_credential_elements instead.",
336
    )
337
    def process_credentials_elements(cred_tree) -> Dict:
338
        return OspRequest.process_credentials_elements(cred_tree)
339
340
    @staticmethod
341
    @deprecated(
342
        version="20.4",
343
        reason="Please use OspRequest.process_targets_elements instead.",
344
    )
345
    def process_targets_element(scanner_target) -> List:
346
        return OspRequest.process_targets_element(scanner_target)
347
348
    def stop_scan(self, scan_id: str) -> None:
349
        scan_process = self.scan_processes.get(scan_id)
350
        if not scan_process:
351
            raise OspdCommandError(
352
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
353
            )
354
        if not scan_process.is_alive():
355
            raise OspdCommandError(
356
                'Scan already stopped or finished.', 'stop_scan'
357
            )
358
359
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
360
361
        logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
362
363
        self.stop_scan_cleanup(scan_id)
364
365
        try:
366
            scan_process.terminate()
367
        except AttributeError:
368
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
369
370
        try:
371
            _terminate_process_group(scan_process)
372
        except ProcessLookupError as e:
373
            logger.info(
374
                '%s: Scan already stopped %s.', scan_id, scan_process.pid
375
            )
376
377
        if scan_process.ident != os.getpid():
378
            scan_process.join(0)
379
380
        logger.info('%s: Scan stopped.', scan_id)
381
382
    @staticmethod
383
    def stop_scan_cleanup(scan_id: str):
384
        """ Should be implemented by subclass in case of a clean up before
385
        terminating is needed. """
386
387
    @staticmethod
388
    def target_is_finished(scan_id: str):
389
        """ Should be implemented by subclass in case of a check before
390
        stopping is needed. """
391
392
    def exec_scan(self, scan_id: str):
393
        """ Asserts to False. Should be implemented by subclass. """
394
        raise NotImplementedError
395
396
    def finish_scan(self, scan_id: str) -> None:
397
        """ Sets a scan as finished. """
398
        self.set_scan_progress(scan_id, 100)
399
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
400
        logger.info("%s: Scan finished.", scan_id)
401
402
    def get_daemon_name(self) -> str:
403
        """ Gives osp daemon's name. """
404
        return self.daemon_info['name']
405
406
    def get_daemon_version(self) -> str:
407
        """ Gives osp daemon's version. """
408
        return self.daemon_info['version']
409
410
    def get_scanner_param_type(self, param: str):
411
        """ Returns type of a scanner parameter. """
412
        assert isinstance(param, str)
413
        entry = self.scanner_params.get(param)
414
        if not entry:
415
            return None
416
        return entry.get('type')
417
418
    def get_scanner_param_mandatory(self, param: str):
419
        """ Returns if a scanner parameter is mandatory. """
420
        assert isinstance(param, str)
421
        entry = self.scanner_params.get(param)
422
        if not entry:
423
            return False
424
        return entry.get('mandatory')
425
426
    def get_scanner_param_default(self, param: str):
427
        """ Returns default value of a scanner parameter. """
428
        assert isinstance(param, str)
429
        entry = self.scanner_params.get(param)
430
        if not entry:
431
            return None
432
        return entry.get('default')
433
434
    @deprecated(
435
        version="20.4",
436
        reason="Please use OspResponse.create_scanner_params_xml instead.",
437
    )
438
    def get_scanner_params_xml(self):
439
        """ Returns the OSP Daemon's scanner params in xml format. """
440
        return OspResponse.create_scanner_params_xml(self.scanner_params)
441
442
    def handle_client_stream(self, stream) -> None:
443
        """ Handles stream of data received from client. """
444
445
        data = b''
446
447
        while True:
448
            try:
449
                buf = stream.read()
450
                if not buf:
451
                    break
452
453
                data += buf
454
            except (AttributeError, ValueError) as message:
455
                logger.error(message)
456
                return
457
            except (ssl.SSLError) as exception:
458
                logger.debug('Error: %s', exception)
459
                break
460
            except (socket.timeout) as exception:
461
                break
462
463
        if len(data) <= 0:
464
            logger.debug("Empty client stream")
465
            return
466
467
        if not self.initialized:
468
            exception = OspdCommandError(
469
                '%s is still starting' % self.daemon_info['name'], 'error'
470
            )
471
            response = exception.as_xml()
472
            stream.write(response)
473
            stream.close()
474
            return
475
476
        response = None
477
        try:
478
            self.handle_command(data, stream)
479
        except OspdCommandError as exception:
480
            response = exception.as_xml()
481
            logger.debug('Command error: %s', exception.message)
482
        except Exception:  # pylint: disable=broad-except
483
            logger.exception('While handling client command:')
484
            exception = OspdCommandError('Fatal error', 'error')
485
            response = exception.as_xml()
486
487
        if response:
488
            stream.write(response)
489
        stream.close()
490
491
    def calculate_progress(self, scan_id: str) -> float:
492
        """ Calculate the total scan progress. """
493
494
        return self.scan_collection.calculate_target_progress(scan_id)
495
496
    def process_exclude_hosts(self, scan_id: str, exclude_hosts: str) -> None:
497
        """ Process the exclude hosts before launching the scans."""
498
499
        exc_hosts_list = ''
500
        if not exclude_hosts:
501
            return
502
        exc_hosts_list = target_str_to_list(exclude_hosts)
503
        self.remove_scan_hosts_from_target_progress(scan_id, exc_hosts_list)
504
505
    def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None:
506
        """ Process the finished hosts before launching the scans.
507
        Set finished hosts as finished with 100% to calculate
508
        the scan progress."""
509
510
        exc_hosts_list = ''
511
        if not finished_hosts:
512
            return
513
514
        exc_hosts_list = target_str_to_list(finished_hosts)
515
516
        for host in exc_hosts_list:
517
            self.set_scan_host_finished(scan_id, host)
518
            self.set_scan_host_progress(scan_id, host, 100)
519
520
    def start_scan(self, scan_id: str, target: Dict) -> None:
521
        """ Starts the scan with scan_id. """
522
        os.setsid()
523
524
        if target is None or not target:
525
            raise OspdCommandError('Erroneous target', 'start_scan')
526
527
        logger.info("%s: Scan started.", scan_id)
528
529
        self.process_exclude_hosts(scan_id, target.get('exclude_hosts'))
530
        self.process_finished_hosts(scan_id, target.get('finished_hosts'))
531
532
        try:
533
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
534
            ret = self.exec_scan(scan_id)
535
        except Exception as e:  # pylint: disable=broad-except
536
            self.add_scan_error(
537
                scan_id,
538
                name='',
539
                host=self.get_scan_host(scan_id),
540
                value='Host process failure (%s).' % e,
541
            )
542
            logger.exception('While scanning: %s', scan_id)
543
        else:
544
            logger.info("%s: Host scan finished.", scan_id)
545
546
        if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
547
            self.finish_scan(scan_id)
548
549
    def dry_run_scan(self, scan_id: str, target: Dict) -> None:
550
        """ Dry runs a scan. """
551
552
        os.setsid()
553
554
        host = resolve_hostname(target[0])
555
        if host is None:
556
            logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))
557
558
        port = self.get_scan_ports(scan_id)
559
560
        logger.info("%s:%s: Dry run mode.", host, port)
561
562
        self.add_scan_log(scan_id, name='', host=host, value='Dry run result')
563
564
        self.finish_scan(scan_id)
565
566
    def handle_timeout(self, scan_id: str, host: str) -> None:
567
        """ Handles scanner reaching timeout error. """
568
        self.add_scan_error(
569
            scan_id,
570
            host=host,
571
            name="Timeout",
572
            value="{0} exec timeout.".format(self.get_scanner_name()),
573
        )
574
575
    def remove_scan_hosts_from_target_progress(
576
        self, scan_id: str, exc_hosts_list: List
577
    ) -> None:
578
        """ Remove a list of hosts from the main scan progress table."""
579
        self.scan_collection.remove_hosts_from_target_progress(
580
            scan_id, exc_hosts_list
581
        )
582
583
    def set_scan_host_finished(self, scan_id: str, host: str) -> None:
584
        """ Add the host in a list of finished hosts """
585
        self.scan_collection.set_host_finished(scan_id, host)
586
587
    def set_scan_progress(self, scan_id: str, progress: int) -> None:
588
        """ Sets scan_id scan's progress which is a number
589
        between 0 and 100. """
590
        self.scan_collection.set_progress(scan_id, progress)
591
592
    def set_scan_host_progress(
593
        self, scan_id: str, host: str, progress: int
594
    ) -> None:
595
        """ Sets host's progress which is part of target.
596
        Each time a host progress is updated, the scan progress
597
        is updated too.
598
        """
599
        self.scan_collection.set_host_progress(scan_id, host, progress)
600
601
        scan_progress = self.calculate_progress(scan_id)
602
        self.set_scan_progress(scan_id, scan_progress)
603
604
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
605
        """ Set the scan's status."""
606
        self.scan_collection.set_status(scan_id, status)
607
608
    def get_scan_status(self, scan_id: str) -> ScanStatus:
609
        """ Get scan_id scans's status."""
610
        return self.scan_collection.get_status(scan_id)
611
612
    def scan_exists(self, scan_id: str) -> bool:
613
        """ Checks if a scan with ID scan_id is in collection.
614
615
        @return: 1 if scan exists, 0 otherwise.
616
        """
617
        return self.scan_collection.id_exists(scan_id)
618
619
    def get_help_text(self) -> str:
620
        """ Returns the help output in plain text format."""
621
622
        txt = ''
623
        for name, info in self.commands.items():
624
            description = info.get_description()
625
            attributes = info.get_attributes()
626
            elements = info.get_elements()
627
628
            command_txt = "\t{0: <22} {1}\n".format(name, description)
629
630
            if attributes:
631
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
632
633
                for attrname, attrdesc in attributes.items():
634
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
635
                    command_txt = ''.join([command_txt, attr_txt])
636
637
            if elements:
638
                command_txt = ''.join(
639
                    [command_txt, "\t Elements:\n", elements_as_text(elements),]
640
                )
641
642
            txt += command_txt
643
644
        return txt
645
646
    @deprecated(version="20.4", reason="Use ospd.xml.elements_as_text instead.")
647
    def elements_as_text(self, elems: Dict, indent: int = 2) -> str:
648
        """ Returns the elems dictionary as formatted plain text. """
649
        return elements_as_text(elems, indent)
650
651
    def delete_scan(self, scan_id: str) -> int:
652
        """ Deletes scan_id scan from collection.
653
654
        @return: 1 if scan deleted, 0 otherwise.
655
        """
656
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
657
            return 0
658
659
        # Don't delete the scan until the process stops
660
        exitcode = None
661
        try:
662
            self.scan_processes[scan_id].join()
663
            exitcode = self.scan_processes[scan_id].exitcode
664
        except KeyError:
665
            logger.debug('Scan process for %s not found', scan_id)
666
667
        if exitcode or exitcode == 0:
668
            del self.scan_processes[scan_id]
669
670
        return self.scan_collection.delete_scan(scan_id)
671
672
    def get_scan_results_xml(
673
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
674
    ):
675
        """ Gets scan_id scan's results in XML format.
676
677
        @return: String of scan results in xml.
678
        """
679
        results = Element('results')
680
        for result in self.scan_collection.results_iterator(
681
            scan_id, pop_res, max_res
682
        ):
683
            results.append(get_result_xml(result))
684
685
        logger.debug('Returning %d results', len(results))
686
        return results
687
688
    @deprecated(
689
        version="20.4",
690
        reason="Please use ospd.xml.get_elements_from_dict instead.",
691
    )
692
    def get_xml_str(self, data: Dict) -> List:
693
        """ Creates a string in XML Format using the provided data structure.
694
695
        @param: Dictionary of xml tags and their elements.
696
697
        @return: String of data in xml format.
698
        """
699
        return get_elements_from_dict(data)
700
701
    def get_scan_xml(
702
        self,
703
        scan_id: str,
704
        detailed: bool = True,
705
        pop_res: bool = False,
706
        max_res: int = 0,
707
    ):
708
        """ Gets scan in XML format.
709
710
        @return: String of scan in XML format.
711
        """
712
        if not scan_id:
713
            return Element('scan')
714
715
        target = self.get_scan_host(scan_id)
716
        progress = self.get_scan_progress(scan_id)
717
        status = self.get_scan_status(scan_id)
718
        start_time = self.get_scan_start_time(scan_id)
719
        end_time = self.get_scan_end_time(scan_id)
720
        response = Element('scan')
721
        for name, value in [
722
            ('id', scan_id),
723
            ('target', target),
724
            ('progress', progress),
725
            ('status', status.name.lower()),
726
            ('start_time', start_time),
727
            ('end_time', end_time),
728
        ]:
729
            response.set(name, str(value))
730
        if detailed:
731
            response.append(
732
                self.get_scan_results_xml(scan_id, pop_res, max_res)
733
            )
734
        return response
735
736
    @staticmethod
737
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
738
        vt_id: str, custom: Dict
739
    ) -> str:
740
        """ Create a string representation of the XML object from the
741
        custom data object.
742
        This needs to be implemented by each ospd wrapper, in case
743
        custom elements for VTs are used.
744
745
        The custom XML object which is returned will be embedded
746
        into a <custom></custom> element.
747
748
        @return: XML object as string for custom data.
749
        """
750
        return ''
751
752
    @staticmethod
753
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
754
        vt_id: str, vt_params
755
    ) -> str:
756
        """ Create a string representation of the XML object from the
757
        vt_params data object.
758
        This needs to be implemented by each ospd wrapper, in case
759
        vt_params elements for VTs are used.
760
761
        The params XML object which is returned will be embedded
762
        into a <params></params> element.
763
764
        @return: XML object as string for vt parameters data.
765
        """
766
        return ''
767
768
    @staticmethod
769
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
770
        vt_id: str, vt_refs
771
    ) -> str:
772
        """ Create a string representation of the XML object from the
773
        refs data object.
774
        This needs to be implemented by each ospd wrapper, in case
775
        refs elements for VTs are used.
776
777
        The refs XML object which is returned will be embedded
778
        into a <refs></refs> element.
779
780
        @return: XML object as string for vt references data.
781
        """
782
        return ''
783
784
    @staticmethod
785
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
786
        vt_id: str, vt_dependencies
787
    ) -> str:
788
        """ Create a string representation of the XML object from the
789
        vt_dependencies data object.
790
        This needs to be implemented by each ospd wrapper, in case
791
        vt_dependencies elements for VTs are used.
792
793
        The vt_dependencies XML object which is returned will be embedded
794
        into a <dependencies></dependencies> element.
795
796
        @return: XML object as string for vt dependencies data.
797
        """
798
        return ''
799
800
    @staticmethod
801
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
802
        vt_id: str, vt_creation_time
803
    ) -> str:
804
        """ Create a string representation of the XML object from the
805
        vt_creation_time data object.
806
        This needs to be implemented by each ospd wrapper, in case
807
        vt_creation_time elements for VTs are used.
808
809
        The vt_creation_time XML object which is returned will be embedded
810
        into a <vt_creation_time></vt_creation_time> element.
811
812
        @return: XML object as string for vt creation time data.
813
        """
814
        return ''
815
816
    @staticmethod
817
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
818
        vt_id: str, vt_modification_time
819
    ) -> str:
820
        """ Create a string representation of the XML object from the
821
        vt_modification_time data object.
822
        This needs to be implemented by each ospd wrapper, in case
823
        vt_modification_time elements for VTs are used.
824
825
        The vt_modification_time XML object which is returned will be embedded
826
        into a <vt_modification_time></vt_modification_time> element.
827
828
        @return: XML object as string for vt references data.
829
        """
830
        return ''
831
832
    @staticmethod
833
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
834
        vt_id: str, summary
835
    ) -> str:
836
        """ Create a string representation of the XML object from the
837
        summary data object.
838
        This needs to be implemented by each ospd wrapper, in case
839
        summary elements for VTs are used.
840
841
        The summary XML object which is returned will be embedded
842
        into a <summary></summary> element.
843
844
        @return: XML object as string for summary data.
845
        """
846
        return ''
847
848
    @staticmethod
849
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
850
        vt_id: str, impact
851
    ) -> str:
852
        """ Create a string representation of the XML object from the
853
        impact data object.
854
        This needs to be implemented by each ospd wrapper, in case
855
        impact elements for VTs are used.
856
857
        The impact XML object which is returned will be embedded
858
        into a <impact></impact> element.
859
860
        @return: XML object as string for impact data.
861
        """
862
        return ''
863
864
    @staticmethod
865
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
866
        vt_id: str, affected
867
    ) -> str:
868
        """ Create a string representation of the XML object from the
869
        affected data object.
870
        This needs to be implemented by each ospd wrapper, in case
871
        affected elements for VTs are used.
872
873
        The affected XML object which is returned will be embedded
874
        into a <affected></affected> element.
875
876
        @return: XML object as string for affected data.
877
        """
878
        return ''
879
880
    @staticmethod
881
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
882
        vt_id: str, insight
883
    ) -> str:
884
        """ Create a string representation of the XML object from the
885
        insight data object.
886
        This needs to be implemented by each ospd wrapper, in case
887
        insight elements for VTs are used.
888
889
        The insight XML object which is returned will be embedded
890
        into a <insight></insight> element.
891
892
        @return: XML object as string for insight data.
893
        """
894
        return ''
895
896
    @staticmethod
897
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
898
        vt_id: str, solution, solution_type=None, solution_method=None
899
    ) -> str:
900
        """ Create a string representation of the XML object from the
901
        solution data object.
902
        This needs to be implemented by each ospd wrapper, in case
903
        solution elements for VTs are used.
904
905
        The solution XML object which is returned will be embedded
906
        into a <solution></solution> element.
907
908
        @return: XML object as string for solution data.
909
        """
910
        return ''
911
912
    @staticmethod
913
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
914
        vt_id: str, detection=None, qod_type=None, qod=None
915
    ) -> str:
916
        """ Create a string representation of the XML object from the
917
        detection data object.
918
        This needs to be implemented by each ospd wrapper, in case
919
        detection elements for VTs are used.
920
921
        The detection XML object which is returned is an element with
922
        tag <detection></detection> element
923
924
        @return: XML object as string for detection data.
925
        """
926
        return ''
927
928
    @staticmethod
929
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
930
        vt_id: str, severities
931
    ) -> str:
932
        """ Create a string representation of the XML object from the
933
        severities data object.
934
        This needs to be implemented by each ospd wrapper, in case
935
        severities elements for VTs are used.
936
937
        The severities XML objects which are returned will be embedded
938
        into a <severities></severities> element.
939
940
        @return: XML object as string for severities data.
941
        """
942
        return ''
943
944
    def get_vt_xml(self, vt_id: str):
945
        """ Gets a single vulnerability test information in XML format.
946
947
        @return: String of single vulnerability test information in XML format.
948
        """
949
        if not vt_id:
950
            return Element('vt')
951
952
        vt = self.vts.get(vt_id)
953
954
        name = vt.get('name')
955
        vt_xml = Element('vt')
956
        vt_xml.set('id', vt_id)
957
958
        for name, value in [('name', name)]:
959
            elem = SubElement(vt_xml, name)
960
            elem.text = str(value)
961
962
        if vt.get('vt_params'):
963
            params_xml_str = self.get_params_vt_as_xml_str(
964
                vt_id, vt.get('vt_params')
965
            )
966
            vt_xml.append(secET.fromstring(params_xml_str))
967
968
        if vt.get('vt_refs'):
969
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
970
            vt_xml.append(secET.fromstring(refs_xml_str))
971
972
        if vt.get('vt_dependencies'):
973
            dependencies = self.get_dependencies_vt_as_xml_str(
974
                vt_id, vt.get('vt_dependencies')
975
            )
976
            vt_xml.append(secET.fromstring(dependencies))
977
978
        if vt.get('creation_time'):
979
            vt_ctime = self.get_creation_time_vt_as_xml_str(
980
                vt_id, vt.get('creation_time')
981
            )
982
            vt_xml.append(secET.fromstring(vt_ctime))
983
984
        if vt.get('modification_time'):
985
            vt_mtime = self.get_modification_time_vt_as_xml_str(
986
                vt_id, vt.get('modification_time')
987
            )
988
            vt_xml.append(secET.fromstring(vt_mtime))
989
990
        if vt.get('summary'):
991
            summary_xml_str = self.get_summary_vt_as_xml_str(
992
                vt_id, vt.get('summary')
993
            )
994
            vt_xml.append(secET.fromstring(summary_xml_str))
995
996
        if vt.get('impact'):
997
            impact_xml_str = self.get_impact_vt_as_xml_str(
998
                vt_id, vt.get('impact')
999
            )
1000
            vt_xml.append(secET.fromstring(impact_xml_str))
1001
1002
        if vt.get('affected'):
1003
            affected_xml_str = self.get_affected_vt_as_xml_str(
1004
                vt_id, vt.get('affected')
1005
            )
1006
            vt_xml.append(secET.fromstring(affected_xml_str))
1007
1008
        if vt.get('insight'):
1009
            insight_xml_str = self.get_insight_vt_as_xml_str(
1010
                vt_id, vt.get('insight')
1011
            )
1012
            vt_xml.append(secET.fromstring(insight_xml_str))
1013
1014
        if vt.get('solution'):
1015
            solution_xml_str = self.get_solution_vt_as_xml_str(
1016
                vt_id,
1017
                vt.get('solution'),
1018
                vt.get('solution_type'),
1019
                vt.get('solution_method'),
1020
            )
1021
            vt_xml.append(secET.fromstring(solution_xml_str))
1022
1023
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1024
            detection_xml_str = self.get_detection_vt_as_xml_str(
1025
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1026
            )
1027
            vt_xml.append(secET.fromstring(detection_xml_str))
1028
1029
        if vt.get('severities'):
1030
            severities_xml_str = self.get_severities_vt_as_xml_str(
1031
                vt_id, vt.get('severities')
1032
            )
1033
            vt_xml.append(secET.fromstring(severities_xml_str))
1034
1035
        if vt.get('custom'):
1036
            custom_xml_str = self.get_custom_vt_as_xml_str(
1037
                vt_id, vt.get('custom')
1038
            )
1039
            vt_xml.append(secET.fromstring(custom_xml_str))
1040
1041
        return vt_xml
1042
1043
    def get_vts_selection_list(
1044
        self, vt_id: str = None, filtered_vts: Dict = None
1045
    ) -> Iterable[str]:
1046
        """
1047
        Get list of VT's OID.
1048
        If vt_id is specified, the collection will contain only this vt, if
1049
        found.
1050
        If no vt_id is specified or filtered_vts is None (default), the
1051
        collection will contain all vts. Otherwise those vts passed
1052
        in filtered_vts or vt_id are returned. In case of both vt_id and
1053
        filtered_vts are given, filtered_vts has priority.
1054
1055
        Arguments:
1056
            vt_id (vt_id, optional): ID of the vt to get.
1057
            filtered_vts (list, optional): Filtered VTs collection.
1058
1059
        Return:
1060
            List of selected VT's OID.
1061
        """
1062
        vts_xml = []
1063
        if not self.vts:
1064
            return vts_xml
1065
1066
        # No match for the filter
1067
        if filtered_vts is not None and len(filtered_vts) == 0:
1068
            return vts_xml
1069
1070
        if filtered_vts:
1071
            vts_list = filtered_vts
1072
        elif vt_id:
1073
            vts_list = [vt_id]
1074
        else:
1075
            vts_list = iter(self.vts)
1076
1077
        return vts_list
1078
1079
    def handle_command(self, command: str, stream) -> str:
1080
        """ Handles an osp command in a string.
1081
1082
        @return: OSP Response to command.
1083
        """
1084
        try:
1085
            tree = secET.fromstring(command)
1086
        except secET.ParseError:
1087
            logger.debug("Erroneous client input: %s", command)
1088
            raise OspdCommandError('Invalid data')
1089
1090
        command = self.commands.get(tree.tag, None)
1091
        if not command and tree.tag != "authenticate":
1092
            raise OspdCommandError('Bogus command name')
1093
1094
        response = command.handle_xml(tree)
1095
1096
        if isinstance(response, bytes):
1097
            stream.write(response)
1098
        else:
1099
            for data in response:
1100
                stream.write(data)
1101
1102
    def check(self):
1103
        """ Asserts to False. Should be implemented by subclass. """
1104
        raise NotImplementedError
1105
1106
    def run(self) -> None:
1107
        """ Starts the Daemon, handling commands until interrupted.
1108
        """
1109
1110
        try:
1111
            while True:
1112
                time.sleep(10)
1113
                self.scheduler()
1114
                self.clean_forgotten_scans()
1115
                self.wait_for_children()
1116
        except KeyboardInterrupt:
1117
            logger.info("Received Ctrl-C shutting-down ...")
1118
1119
    def scheduler(self):
1120
        """ Should be implemented by subclass in case of need
1121
        to run tasks periodically. """
1122
1123
    def wait_for_children(self):
1124
        """ Join the zombie process to releases resources."""
1125
        for scan_id in self.scan_processes:
1126
            self.scan_processes[scan_id].join(0)
1127
1128
    def create_scan(
1129
        self,
1130
        scan_id: str,
1131
        targets: Dict,
1132
        options: Optional[Dict],
1133
        vt_selection: Dict,
1134
    ) -> Optional[str]:
1135
        """ Creates a new scan.
1136
1137
        @target: Target to scan.
1138
        @options: Miscellaneous scan options.
1139
1140
        @return: New scan's ID. None if the scan_id already exists and the
1141
                 scan status is RUNNING or FINISHED.
1142
        """
1143
        status = None
1144
        scan_exists = self.scan_exists(scan_id)
1145
        if scan_id and scan_exists:
1146
            status = self.get_scan_status(scan_id)
1147
1148
        if scan_exists and status == ScanStatus.STOPPED:
1149
            logger.info("Scan %s exists. Resuming scan.", scan_id)
1150
        elif scan_exists and (
1151
            status == ScanStatus.RUNNING or status == ScanStatus.FINISHED
1152
        ):
1153
            logger.info(
1154
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1155
            )
1156
            return
1157
        return self.scan_collection.create_scan(
1158
            scan_id, targets, options, vt_selection
1159
        )
1160
1161
    def get_scan_options(self, scan_id: str) -> str:
1162
        """ Gives a scan's list of options. """
1163
        return self.scan_collection.get_options(scan_id)
1164
1165
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1166
        """ Sets a scan's option to a provided value. """
1167
        return self.scan_collection.set_option(scan_id, name, value)
1168
1169
    def clean_forgotten_scans(self) -> None:
1170
        """ Check for old stopped or finished scans which have not been
1171
        deleted and delete them if the are older than the set value."""
1172
1173
        if not self.scaninfo_store_time:
1174
            return
1175
1176
        for scan_id in list(self.scan_collection.ids_iterator()):
1177
            end_time = int(self.get_scan_end_time(scan_id))
1178
            scan_status = self.get_scan_status(scan_id)
1179
1180
            if (
1181
                scan_status == ScanStatus.STOPPED
1182
                or scan_status == ScanStatus.FINISHED
1183
            ) and end_time:
1184
                stored_time = int(time.time()) - end_time
1185
                if stored_time > self.scaninfo_store_time * 3600:
1186
                    logger.debug(
1187
                        'Scan %s is older than %d hours and seems have been '
1188
                        'forgotten. Scan info will be deleted from the '
1189
                        'scan table',
1190
                        scan_id,
1191
                        self.scaninfo_store_time,
1192
                    )
1193
                    self.delete_scan(scan_id)
1194
1195
    def check_scan_process(self, scan_id: str) -> None:
1196
        """ Check the scan's process, and terminate the scan if not alive. """
1197
        scan_process = self.scan_processes[scan_id]
1198
        progress = self.get_scan_progress(scan_id)
1199
1200
        if progress < 100 and not scan_process.is_alive():
1201
            if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
1202
                self.set_scan_status(scan_id, ScanStatus.STOPPED)
1203
                self.add_scan_error(
1204
                    scan_id, name="", host="", value="Scan process failure."
1205
                )
1206
1207
                logger.info("%s: Scan stopped with errors.", scan_id)
1208
1209
        elif progress == 100:
1210
            scan_process.join(0)
1211
1212
    def get_scan_progress(self, scan_id: str):
1213
        """ Gives a scan's current progress value. """
1214
        return self.scan_collection.get_progress(scan_id)
1215
1216
    def get_scan_host(self, scan_id: str) -> str:
1217
        """ Gives a scan's target. """
1218
        return self.scan_collection.get_host_list(scan_id)
1219
1220
    def get_scan_ports(self, scan_id: str) -> str:
1221
        """ Gives a scan's ports list. """
1222
        return self.scan_collection.get_ports(scan_id)
1223
1224
    def get_scan_exclude_hosts(self, scan_id: str):
1225
        """ Gives a scan's exclude host list. If a target is passed gives
1226
        the exclude host list for the given target. """
1227
        return self.scan_collection.get_exclude_hosts(scan_id)
1228
1229
    def get_scan_credentials(self, scan_id: str) -> Dict:
1230
        """ Gives a scan's credential list. If a target is passed gives
1231
        the credential list for the given target. """
1232
        return self.scan_collection.get_credentials(scan_id)
1233
1234
    def get_scan_target_options(self, scan_id: str) -> Dict:
1235
        """ Gives a scan's target option dict. If a target is passed gives
1236
        the credential list for the given target. """
1237
        return self.scan_collection.get_target_options(scan_id)
1238
1239
    def get_scan_vts(self, scan_id: str) -> Dict:
1240
        """ Gives a scan's vts. """
1241
        return self.scan_collection.get_vts(scan_id)
1242
1243
    def get_scan_unfinished_hosts(self, scan_id: str) -> List:
1244
        """ Get a list of unfinished hosts."""
1245
        return self.scan_collection.get_hosts_unfinished(scan_id)
1246
1247
    def get_scan_finished_hosts(self, scan_id: str) -> List:
1248
        """ Get a list of unfinished hosts."""
1249
        return self.scan_collection.get_hosts_finished(scan_id)
1250
1251
    def get_scan_start_time(self, scan_id: str) -> str:
1252
        """ Gives a scan's start time. """
1253
        return self.scan_collection.get_start_time(scan_id)
1254
1255
    def get_scan_end_time(self, scan_id: str) -> str:
1256
        """ Gives a scan's end time. """
1257
        return self.scan_collection.get_end_time(scan_id)
1258
1259
    def add_scan_log(
1260
        self,
1261
        scan_id: str,
1262
        host: str = '',
1263
        hostname: str = '',
1264
        name: str = '',
1265
        value: str = '',
1266
        port: str = '',
1267
        test_id: str = '',
1268
        qod: str = '',
1269
    ):
1270
        """ Adds a log result to scan_id scan. """
1271
        self.scan_collection.add_result(
1272
            scan_id,
1273
            ResultType.LOG,
1274
            host,
1275
            hostname,
1276
            name,
1277
            value,
1278
            port,
1279
            test_id,
1280
            '0.0',
1281
            qod,
1282
        )
1283
1284
    def add_scan_error(
1285
        self,
1286
        scan_id: str,
1287
        host: str = '',
1288
        hostname: str = '',
1289
        name: str = '',
1290
        value: str = '',
1291
        port: str = '',
1292
    ) -> None:
1293
        """ Adds an error result to scan_id scan. """
1294
        self.scan_collection.add_result(
1295
            scan_id, ResultType.ERROR, host, hostname, name, value, port
1296
        )
1297
1298
    def add_scan_host_detail(
1299
        self,
1300
        scan_id: str,
1301
        host: str = '',
1302
        hostname: str = '',
1303
        name: str = '',
1304
        value: str = '',
1305
    ) -> None:
1306
        """ Adds a host detail result to scan_id scan. """
1307
        self.scan_collection.add_result(
1308
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value
1309
        )
1310
1311
    def add_scan_alarm(
1312
        self,
1313
        scan_id: str,
1314
        host: str = '',
1315
        hostname: str = '',
1316
        name: str = '',
1317
        value: str = '',
1318
        port: str = '',
1319
        test_id: str = '',
1320
        severity: str = '',
1321
        qod: str = '',
1322
    ):
1323
        """ Adds an alarm result to scan_id scan. """
1324
        self.scan_collection.add_result(
1325
            scan_id,
1326
            ResultType.ALARM,
1327
            host,
1328
            hostname,
1329
            name,
1330
            value,
1331
            port,
1332
            test_id,
1333
            severity,
1334
            qod,
1335
        )
1336