Passed
Pull Request — master (#310)
by
unknown
01:59
created

ospd.ospd   F

Complexity

Total Complexity 236

Size/Duplication

Total Lines 1581
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 913
dl 0
loc 1581
rs 1.687
c 0
b 0
f 0
wmc 236

95 Methods

Rating   Name   Duplication   Size   Complexity  
A OSPDaemon.get_xml_str() 0 12 1
A OSPDaemon.elements_as_text() 0 4 1
A OSPDaemon.scan_exists() 0 6 1
A OSPDaemon.get_scan_status() 0 3 1
A OSPDaemon.delete_scan() 0 20 5
A OSPDaemon.get_help_text() 0 26 5
A OSPDaemon.get_scan_results_xml() 0 15 2
A OSPDaemon.set_scan_status() 0 3 1
A OSPDaemon.process_vts_params() 0 7 1
A OSPDaemon.process_targets_element() 0 7 1
A OSPDaemon.process_credentials_elements() 0 7 1
A OSPDaemon.process_scan_params() 0 4 1
A OSPDaemon.finish_scan() 0 5 1
A OSPDaemon.target_is_finished() 0 3 1
A OSPDaemon.exec_scan() 0 3 1
A OSPDaemon.stop_scan_cleanup() 0 3 1
A OSPDaemon.get_scanner_name() 0 3 1
A OSPDaemon.get_vts_version() 0 4 1
A OSPDaemon.get_protocol_version() 0 3 1
A OSPDaemon.init() 0 8 1
A OSPDaemon.get_scanner_description() 0 3 1
A OSPDaemon.command_exists() 0 3 1
A OSPDaemon.get_scanner_params() 0 2 1
A OSPDaemon.set_vts_version() 0 12 2
B OSPDaemon.__init__() 0 55 4
A OSPDaemon.add_scanner_param() 0 4 1
A OSPDaemon.set_command_attributes() 0 5 2
A OSPDaemon.get_server_version() 0 4 1
B OSPDaemon.add_vt() 0 51 1
A OSPDaemon.set_scanner_param() 0 7 1
A OSPDaemon.get_scanner_version() 0 3 1
B OSPDaemon.clean_forgotten_scans() 0 26 8
A OSPDaemon.get_scan_options() 0 3 1
A OSPDaemon.create_scan() 0 26 3
A OSPDaemon.get_scan_host() 0 3 1
A OSPDaemon.get_modification_time_vt_as_xml_str() 0 15 1
A OSPDaemon.get_custom_vt_as_xml_str() 0 15 1
A OSPDaemon.add_scan_error() 0 22 1
A OSPDaemon.get_scan_start_time() 0 3 1
A OSPDaemon.get_params_vt_as_xml_str() 0 15 1
A OSPDaemon.add_scan_alarm() 0 26 1
A OSPDaemon.get_count_queued_scans() 0 7 3
A OSPDaemon.dry_run_scan() 0 16 2
A OSPDaemon.get_scan_vts() 0 3 1
A OSPDaemon.get_scan_progress() 0 3 1
A OSPDaemon.process_finished_hosts() 0 9 2
A OSPDaemon.get_severities_vt_as_xml_str() 0 15 1
A OSPDaemon.get_scan_target_options() 0 4 1
A OSPDaemon.get_affected_vt_as_xml_str() 0 15 1
A OSPDaemon.get_scanner_param_mandatory() 0 7 2
B OSPDaemon.get_scan_xml() 0 50 6
A OSPDaemon.is_new_scan_allowed() 0 17 3
A OSPDaemon.get_detection_vt_as_xml_str() 0 15 1
A OSPDaemon.check() 0 3 1
A OSPDaemon._get_scan_progress_xml() 0 25 1
A OSPDaemon.set_scan_option() 0 3 1
A OSPDaemon.get_insight_vt_as_xml_str() 0 15 1
A OSPDaemon.get_scanner_param_default() 0 7 2
A OSPDaemon.scheduler() 0 2 1
A OSPDaemon.get_summary_vt_as_xml_str() 0 15 1
A OSPDaemon.get_dependencies_vt_as_xml_str() 0 15 1
A OSPDaemon.wait_for_children() 0 4 2
A OSPDaemon.get_daemon_name() 0 3 1
A OSPDaemon.get_impact_vt_as_xml_str() 0 15 1
A OSPDaemon.get_scanner_param_type() 0 7 2
A OSPDaemon.is_enough_free_memory() 0 24 3
A OSPDaemon.get_vt_iterator() 0 6 1
A OSPDaemon.add_scan_log() 0 26 1
A OSPDaemon.get_refs_vt_as_xml_str() 0 15 1
B OSPDaemon.start_scan() 0 28 6
D OSPDaemon.preprocess_scan_params() 0 38 12
A OSPDaemon.get_scan_ports() 0 3 1
B OSPDaemon.check_scan_process() 0 24 7
F OSPDaemon.get_vt_xml() 0 98 19
A OSPDaemon.set_scan_host_progress() 0 18 5
A OSPDaemon.get_vts_selection_list() 0 33 5
A OSPDaemon.get_solution_vt_as_xml_str() 0 15 1
A OSPDaemon.get_scan_exclude_hosts() 0 4 1
C OSPDaemon.daemon_exit_cleanup() 0 30 11
A OSPDaemon.get_scan_end_time() 0 3 1
A OSPDaemon.set_scan_progress_batch() 0 5 1
A OSPDaemon.add_scan_host_detail() 0 12 1
A OSPDaemon.get_scan_credentials() 0 4 1
A OSPDaemon.set_scan_progress() 0 5 1
A OSPDaemon.get_creation_time_vt_as_xml_str() 0 15 1
B OSPDaemon.stop_scan() 0 47 8
A OSPDaemon.get_scanner_params_xml() 0 7 1
A OSPDaemon.interrupt_scan() 0 4 1
C OSPDaemon.handle_client_stream() 0 45 11
A OSPDaemon.handle_timeout() 0 7 1
A OSPDaemon.sort_host_finished() 0 27 5
C OSPDaemon.start_queued_scans() 0 40 9
D OSPDaemon.handle_command() 0 42 12
A OSPDaemon.get_daemon_version() 0 3 1
A OSPDaemon.run() 0 13 3

1 Function

Rating   Name   Duplication   Size   Complexity  
A _terminate_process_group() 0 2 1

How to fix   Complexity   

Complexity

Complex classes like ospd.ospd often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" OSP Daemon core class.
21
"""
22
23
import logging
24
import socket
25
import ssl
26
import multiprocessing
27
import time
28
import os
29
30
from typing import (
31
    List,
32
    Any,
33
    Iterator,
34
    Dict,
35
    Optional,
36
    Iterable,
37
    Tuple,
38
    Union,
39
)
40
from xml.etree.ElementTree import Element, SubElement
41
42
import defusedxml.ElementTree as secET
43
44
import psutil
45
46
from deprecated import deprecated
47
48
from ospd import __version__
49
from ospd.command import get_commands
50
from ospd.errors import OspdCommandError
51
from ospd.misc import ResultType, create_process
52
from ospd.network import resolve_hostname, target_str_to_list
53
from ospd.protocol import OspRequest, OspResponse, RequestParser
54
from ospd.scan import ScanCollection, ScanStatus, ScanProgress
55
from ospd.server import BaseServer, Stream
56
from ospd.vtfilter import VtsFilter
57
from ospd.vts import Vts
58
from ospd.xml import (
59
    elements_as_text,
60
    get_result_xml,
61
    get_progress_xml,
62
    get_elements_from_dict,
63
)
64
65
logger = logging.getLogger(__name__)
66
67
PROTOCOL_VERSION = "1.2"
68
69
SCHEDULER_CHECK_PERIOD = 10  # in seconds
70
71
BASE_SCANNER_PARAMS = {
72
    'debug_mode': {
73
        'type': 'boolean',
74
        'name': 'Debug Mode',
75
        'default': 0,
76
        'mandatory': 0,
77
        'description': 'Whether to get extra scan debug information.',
78
    },
79
    'dry_run': {
80
        'type': 'boolean',
81
        'name': 'Dry Run',
82
        'default': 0,
83
        'mandatory': 0,
84
        'description': 'Whether to dry run scan.',
85
    },
86
}  # type: Dict
87
88
89
def _terminate_process_group(process: multiprocessing.Process) -> None:
90
    os.killpg(os.getpgid(process.pid), 15)
91
92
93
class OSPDaemon:
94
95
    """ Daemon class for OSP traffic handling.
96
97
    Every scanner wrapper should subclass it and make necessary additions and
98
    changes.
99
100
    * Add any needed parameters in __init__.
101
    * Implement check() method which verifies scanner availability and other
102
      environment related conditions.
103
    * Implement process_scan_params and exec_scan methods which are
104
      specific to handling the <start_scan> command, executing the wrapped
105
      scanner and storing the results.
106
    * Implement other methods that assert to False such as get_scanner_name,
107
      get_scanner_version.
108
    * Use Call set_command_attributes at init time to add scanner command
109
      specific options eg. the w3af profile for w3af wrapper.
110
    """
111
112
    def __init__(
113
        self,
114
        *,
115
        customvtfilter=None,
116
        storage=None,
117
        max_scans=0,
118
        min_free_mem_scan_queue=0,
119
        file_storage_dir='/var/run/ospd',
120
        max_queued_scans=0,
121
        **kwargs
122
    ):  # pylint: disable=unused-argument
123
        """ Initializes the daemon's internal data. """
124
        self.scan_collection = ScanCollection(file_storage_dir)
125
        self.scan_processes = dict()
126
127
        self.daemon_info = dict()
128
        self.daemon_info['name'] = "OSPd"
129
        self.daemon_info['version'] = __version__
130
        self.daemon_info['description'] = "No description"
131
132
        self.scanner_info = dict()
133
        self.scanner_info['name'] = 'No name'
134
        self.scanner_info['version'] = 'No version'
135
        self.scanner_info['description'] = 'No description'
136
137
        self.server_version = None  # Set by the subclass.
138
139
        self.initialized = None  # Set after initialization finished
140
141
        self.max_scans = max_scans
142
        self.min_free_mem_scan_queue = min_free_mem_scan_queue
143
        self.max_queued_scans = max_queued_scans
144
145
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
146
147
        self.protocol_version = PROTOCOL_VERSION
148
149
        self.commands = {}
150
151
        for command_class in get_commands():
152
            command = command_class(self)
153
            self.commands[command.get_name()] = command
154
155
        self.scanner_params = dict()
156
157
        for name, params in BASE_SCANNER_PARAMS.items():
158
            self.set_scanner_param(name, params)
159
160
        self.vts = Vts(storage)
161
        self.vts_version = None
162
163
        if customvtfilter:
164
            self.vts_filter = customvtfilter
165
        else:
166
            self.vts_filter = VtsFilter()
167
168
    def init(self, server: BaseServer) -> None:
169
        """ Should be overridden by a subclass if the initialization is costly.
170
171
            Will be called after check.
172
        """
173
        self.scan_collection.init()
174
        server.start(self.handle_client_stream)
175
        self.initialized = True
176
177
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
178
        """ Sets the xml attributes of a specified command. """
179
        if self.command_exists(name):
180
            command = self.commands.get(name)
181
            command.attributes = attributes
182
183
    @deprecated(version="20.8", reason="Use set_scanner_param instead")
184
    def add_scanner_param(self, name: str, scanner_params: Dict) -> None:
185
        """ Set a scanner parameter. """
186
        self.set_scanner_param(name, scanner_params)
187
188
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
189
        """ Set a scanner parameter. """
190
191
        assert name
192
        assert scanner_params
193
194
        self.scanner_params[name] = scanner_params
195
196
    def get_scanner_params(self) -> Dict:
197
        return self.scanner_params
198
199
    def add_vt(
200
        self,
201
        vt_id: str,
202
        name: str = None,
203
        vt_params: str = None,
204
        vt_refs: str = None,
205
        custom: str = None,
206
        vt_creation_time: str = None,
207
        vt_modification_time: str = None,
208
        vt_dependencies: str = None,
209
        summary: str = None,
210
        impact: str = None,
211
        affected: str = None,
212
        insight: str = None,
213
        solution: str = None,
214
        solution_t: str = None,
215
        solution_m: str = None,
216
        detection: str = None,
217
        qod_t: str = None,
218
        qod_v: str = None,
219
        severities: str = None,
220
    ) -> None:
221
        """ Add a vulnerability test information.
222
223
        IMPORTANT: The VT's Data Manager will store the vts collection.
224
        If the collection is considerably big and it will be consultated
225
        intensible during a routine, consider to do a deepcopy(), since
226
        accessing the shared memory in the data manager is very expensive.
227
        At the end of the routine, the temporal copy must be set to None
228
        and deleted.
229
        """
230
        self.vts.add(
231
            vt_id,
232
            name=name,
233
            vt_params=vt_params,
234
            vt_refs=vt_refs,
235
            custom=custom,
236
            vt_creation_time=vt_creation_time,
237
            vt_modification_time=vt_modification_time,
238
            vt_dependencies=vt_dependencies,
239
            summary=summary,
240
            impact=impact,
241
            affected=affected,
242
            insight=insight,
243
            solution=solution,
244
            solution_t=solution_t,
245
            solution_m=solution_m,
246
            detection=detection,
247
            qod_t=qod_t,
248
            qod_v=qod_v,
249
            severities=severities,
250
        )
251
252
    def set_vts_version(self, vts_version: str) -> None:
253
        """ Add into the vts dictionary an entry to identify the
254
        vts version.
255
256
        Parameters:
257
            vts_version (str): Identifies a unique vts version.
258
        """
259
        if not vts_version:
260
            raise OspdCommandError(
261
                'A vts_version parameter is required', 'set_vts_version'
262
            )
263
        self.vts_version = vts_version
264
265
    def get_vts_version(self) -> Optional[str]:
266
        """Return the vts version.
267
        """
268
        return self.vts_version
269
270
    def command_exists(self, name: str) -> bool:
271
        """ Checks if a commands exists. """
272
        return name in self.commands
273
274
    def get_scanner_name(self) -> str:
275
        """ Gives the wrapped scanner's name. """
276
        return self.scanner_info['name']
277
278
    def get_scanner_version(self) -> str:
279
        """ Gives the wrapped scanner's version. """
280
        return self.scanner_info['version']
281
282
    def get_scanner_description(self) -> str:
283
        """ Gives the wrapped scanner's description. """
284
        return self.scanner_info['description']
285
286
    def get_server_version(self) -> str:
287
        """ Gives the specific OSP server's version. """
288
        assert self.server_version
289
        return self.server_version
290
291
    def get_protocol_version(self) -> str:
292
        """ Gives the OSP's version. """
293
        return self.protocol_version
294
295
    def preprocess_scan_params(self, xml_params):
296
        """ Processes the scan parameters. """
297
        params = {}
298
299
        for param in xml_params:
300
            params[param.tag] = param.text or ''
301
302
        # Validate values.
303
        for key in params:
304
            param_type = self.get_scanner_param_type(key)
305
            if not param_type:
306
                continue
307
308
            if param_type in ['integer', 'boolean']:
309
                try:
310
                    params[key] = int(params[key])
311
                except ValueError:
312
                    raise OspdCommandError(
313
                        'Invalid %s value' % key, 'start_scan'
314
                    )
315
316
            if param_type == 'boolean':
317
                if params[key] not in [0, 1]:
318
                    raise OspdCommandError(
319
                        'Invalid %s value' % key, 'start_scan'
320
                    )
321
            elif param_type == 'selection':
322
                selection = self.get_scanner_param_default(key).split('|')
323
                if params[key] not in selection:
324
                    raise OspdCommandError(
325
                        'Invalid %s value' % key, 'start_scan'
326
                    )
327
            if self.get_scanner_param_mandatory(key) and params[key] == '':
328
                raise OspdCommandError(
329
                    'Mandatory %s value is missing' % key, 'start_scan'
330
                )
331
332
        return params
333
334
    def process_scan_params(self, params: Dict) -> Dict:
335
        """ This method is to be overridden by the child classes if necessary
336
        """
337
        return params
338
339
    @staticmethod
340
    @deprecated(
341
        version="20.8",
342
        reason="Please use OspRequest.process_vt_params instead.",
343
    )
344
    def process_vts_params(scanner_vts) -> Dict:
345
        return OspRequest.process_vts_params(scanner_vts)
346
347
    @staticmethod
348
    @deprecated(
349
        version="20.8",
350
        reason="Please use OspRequest.process_credential_elements instead.",
351
    )
352
    def process_credentials_elements(cred_tree) -> Dict:
353
        return OspRequest.process_credentials_elements(cred_tree)
354
355
    @staticmethod
356
    @deprecated(
357
        version="20.8",
358
        reason="Please use OspRequest.process_targets_elements instead.",
359
    )
360
    def process_targets_element(scanner_target) -> List:
361
        return OspRequest.process_target_element(scanner_target)
362
363
    def stop_scan(self, scan_id: str) -> None:
364
        if (
365
            scan_id in self.scan_collection.ids_iterator()
366
            and self.get_scan_status(scan_id) == ScanStatus.QUEUED
367
        ):
368
            logger.info('Scan %s has been removed from the queue.', scan_id)
369
            self.scan_collection.remove_file_pickled_scan_info(scan_id)
370
            self.set_scan_status(scan_id, ScanStatus.STOPPED)
371
372
            return
373
374
        scan_process = self.scan_processes.get(scan_id)
375
        if not scan_process:
376
            raise OspdCommandError(
377
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
378
            )
379
        if not scan_process.is_alive():
380
            raise OspdCommandError(
381
                'Scan already stopped or finished.', 'stop_scan'
382
            )
383
384
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
385
386
        logger.info(
387
            '%s: Stopping Scan with the PID %s.', scan_id, scan_process.ident
388
        )
389
390
        self.stop_scan_cleanup(scan_id)
391
392
        try:
393
            scan_process.terminate()
394
        except AttributeError:
395
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
396
397
        try:
398
            _terminate_process_group(scan_process)
399
        except ProcessLookupError:
400
            logger.info(
401
                '%s: Scan with the PID %s is already stopped.',
402
                scan_id,
403
                scan_process.pid,
404
            )
405
406
        if scan_process.ident != os.getpid():
407
            scan_process.join(0)
408
409
        logger.info('%s: Scan stopped.', scan_id)
410
411
    @staticmethod
412
    def stop_scan_cleanup(scan_id: str):
413
        """ Should be implemented by subclass in case of a clean up before
414
        terminating is needed. """
415
416
    @staticmethod
417
    def target_is_finished(scan_id: str):
418
        """ Should be implemented by subclass in case of a check before
419
        stopping is needed. """
420
421
    def exec_scan(self, scan_id: str):
422
        """ Asserts to False. Should be implemented by subclass. """
423
        raise NotImplementedError
424
425
    def finish_scan(self, scan_id: str) -> None:
426
        """ Sets a scan as finished. """
427
        self.scan_collection.set_progress(scan_id, ScanProgress.FINISHED.value)
428
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
429
        logger.info("%s: Scan finished.", scan_id)
430
431
    def interrupt_scan(self, scan_id: str) -> None:
432
        """ Set scan status as interrupted. """
433
        self.set_scan_status(scan_id, ScanStatus.INTERRUPTED)
434
        logger.info("%s: Scan interrupted.", scan_id)
435
436
    def daemon_exit_cleanup(self) -> None:
437
        """ Perform a cleanup before exiting """
438
        self.scan_collection.clean_up_pickled_scan_info()
439
440
        # Stop scans which are not already stopped.
441
        for scan_id in self.scan_collection.ids_iterator():
442
            status = self.get_scan_status(scan_id)
443
            if (
444
                status != ScanStatus.STOPPED
445
                and status != ScanStatus.FINISHED
446
                and status != ScanStatus.INTERRUPTED
447
            ):
448
                self.stop_scan(scan_id)
449
450
        # Wait for scans to be in some stopped state.
451
        while True:
452
            all_stopped = True
453
            for scan_id in self.scan_collection.ids_iterator():
454
                status = self.get_scan_status(scan_id)
455
                if (
456
                    status != ScanStatus.STOPPED
457
                    and status != ScanStatus.FINISHED
458
                    and status != ScanStatus.INTERRUPTED
459
                ):
460
                    all_stopped = False
461
462
            if all_stopped:
463
                return
464
465
            time.sleep(1)
466
467
    def get_daemon_name(self) -> str:
468
        """ Gives osp daemon's name. """
469
        return self.daemon_info['name']
470
471
    def get_daemon_version(self) -> str:
472
        """ Gives osp daemon's version. """
473
        return self.daemon_info['version']
474
475
    def get_scanner_param_type(self, param: str):
476
        """ Returns type of a scanner parameter. """
477
        assert isinstance(param, str)
478
        entry = self.scanner_params.get(param)
479
        if not entry:
480
            return None
481
        return entry.get('type')
482
483
    def get_scanner_param_mandatory(self, param: str):
484
        """ Returns if a scanner parameter is mandatory. """
485
        assert isinstance(param, str)
486
        entry = self.scanner_params.get(param)
487
        if not entry:
488
            return False
489
        return entry.get('mandatory')
490
491
    def get_scanner_param_default(self, param: str):
492
        """ Returns default value of a scanner parameter. """
493
        assert isinstance(param, str)
494
        entry = self.scanner_params.get(param)
495
        if not entry:
496
            return None
497
        return entry.get('default')
498
499
    @deprecated(
500
        version="20.8",
501
        reason="Please use OspResponse.create_scanner_params_xml instead.",
502
    )
503
    def get_scanner_params_xml(self):
504
        """ Returns the OSP Daemon's scanner params in xml format. """
505
        return OspResponse.create_scanner_params_xml(self.scanner_params)
506
507
    def handle_client_stream(self, stream: Stream) -> None:
508
        """ Handles stream of data received from client. """
509
        data = b''
510
511
        request_parser = RequestParser()
512
513
        while True:
514
            try:
515
                buf = stream.read()
516
                if not buf:
517
                    break
518
519
                data += buf
520
521
                if request_parser.has_ended(buf):
522
                    break
523
            except (AttributeError, ValueError) as message:
524
                logger.error(message)
525
                return
526
            except (ssl.SSLError) as exception:
527
                logger.debug('Error: %s', exception)
528
                break
529
            except (socket.timeout) as exception:
530
                logger.debug('Request timeout: %s', exception)
531
                break
532
533
        if len(data) <= 0:
534
            logger.debug("Empty client stream")
535
            return
536
537
        response = None
538
        try:
539
            self.handle_command(data, stream)
540
        except OspdCommandError as exception:
541
            response = exception.as_xml()
542
            logger.debug('Command error: %s', exception.message)
543
        except Exception:  # pylint: disable=broad-except
544
            logger.exception('While handling client command:')
545
            exception = OspdCommandError('Fatal error', 'error')
546
            response = exception.as_xml()
547
548
        if response:
549
            stream.write(response)
550
551
        stream.close()
552
553
    def process_finished_hosts(self, scan_id: str) -> None:
554
        """ Process the finished hosts before launching the scans."""
555
556
        finished_hosts = self.scan_collection.get_finished_hosts(scan_id)
557
        if not finished_hosts:
558
            return
559
560
        exc_finished_hosts_list = target_str_to_list(finished_hosts)
561
        self.scan_collection.set_host_finished(scan_id, exc_finished_hosts_list)
562
563
    def start_scan(self, scan_id: str) -> None:
564
        """ Starts the scan with scan_id. """
565
        os.setsid()
566
567
        self.process_finished_hosts(scan_id)
568
569
        try:
570
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
571
            self.exec_scan(scan_id)
572
        except Exception as e:  # pylint: disable=broad-except
573
            self.add_scan_error(
574
                scan_id,
575
                name='',
576
                host=self.get_scan_host(scan_id),
577
                value='Host process failure (%s).' % e,
578
            )
579
            logger.exception('While scanning: %s', scan_id)
580
        else:
581
            logger.info("%s: Host scan finished.", scan_id)
582
583
        is_stopped = self.get_scan_status(scan_id) == ScanStatus.STOPPED
584
        self.set_scan_progress(scan_id)
585
        progress = self.get_scan_progress(scan_id)
586
587
        if not is_stopped and progress == ScanProgress.FINISHED:
588
            self.finish_scan(scan_id)
589
        elif not is_stopped:
590
            self.interrupt_scan(scan_id)
591
592
    def dry_run_scan(self, scan_id: str, target: Dict) -> None:
593
        """ Dry runs a scan. """
594
595
        os.setsid()
596
597
        host = resolve_hostname(target.get('hosts'))
598
        if host is None:
599
            logger.info("Couldn't resolve %s.", self.get_scan_host(scan_id))
600
601
        port = self.get_scan_ports(scan_id)
602
603
        logger.info("%s:%s: Dry run mode.", host, port)
604
605
        self.add_scan_log(scan_id, name='', host=host, value='Dry run result')
606
607
        self.finish_scan(scan_id)
608
609
    def handle_timeout(self, scan_id: str, host: str) -> None:
610
        """ Handles scanner reaching timeout error. """
611
        self.add_scan_error(
612
            scan_id,
613
            host=host,
614
            name="Timeout",
615
            value="{0} exec timeout.".format(self.get_scanner_name()),
616
        )
617
618
    def sort_host_finished(
619
        self, scan_id: str, finished_hosts: Union[List[str], str],
620
    ) -> None:
621
        """ Check if the finished host in the list was alive or dead
622
        and update the corresponding alive_count or dead_count. """
623
        if isinstance(finished_hosts, str):
624
            finished_hosts = [finished_hosts]
625
626
        alive_hosts = []
627
        dead_hosts = []
628
629
        current_hosts = self.scan_collection.get_current_target_progress(
630
            scan_id
631
        )
632
        for finished_host in finished_hosts:
633
            progress = current_hosts.get(finished_host)
634
            if progress == ScanProgress.FINISHED:
635
                alive_hosts.append(finished_host)
636
            if progress == ScanProgress.DEAD_HOST:
637
                dead_hosts.append(finished_host)
638
639
        self.scan_collection.set_host_dead(scan_id, dead_hosts)
640
641
        self.scan_collection.set_host_finished(scan_id, alive_hosts)
642
643
        self.scan_collection.remove_hosts_from_target_progress(
644
            scan_id, finished_hosts
645
        )
646
647
    def set_scan_progress(self, scan_id: str):
648
        """ Calculate the target progress with the current host states
649
        and stores in the scan table. """
650
        scan_progress = self.scan_collection.calculate_target_progress(scan_id)
651
        self.scan_collection.set_progress(scan_id, scan_progress)
652
653
    def set_scan_progress_batch(
654
        self, scan_id: str, host_progress: Dict[str, int]
655
    ):
656
        self.scan_collection.set_host_progress(scan_id, host_progress)
657
        self.set_scan_progress(scan_id)
658
659
    def set_scan_host_progress(
660
        self, scan_id: str, host: str = None, progress: int = None,
661
    ) -> None:
662
        """ Sets host's progress which is part of target.
663
        Each time a host progress is updated, the scan progress
664
        is updated too.
665
        """
666
        if host is None or progress is None:
667
            return
668
669
        if not isinstance(progress, int):
670
            try:
671
                progress = int(progress)
672
            except (TypeError, ValueError):
673
                return
674
675
        host_progress = {host: progress}
676
        self.set_scan_progress_batch(scan_id, host_progress)
677
678
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
679
        """ Set the scan's status."""
680
        self.scan_collection.set_status(scan_id, status)
681
682
    def get_scan_status(self, scan_id: str) -> ScanStatus:
683
        """ Get scan_id scans's status."""
684
        return self.scan_collection.get_status(scan_id)
685
686
    def scan_exists(self, scan_id: str) -> bool:
687
        """ Checks if a scan with ID scan_id is in collection.
688
689
        @return: 1 if scan exists, 0 otherwise.
690
        """
691
        return self.scan_collection.id_exists(scan_id)
692
693
    def get_help_text(self) -> str:
694
        """ Returns the help output in plain text format."""
695
696
        txt = ''
697
        for name, info in self.commands.items():
698
            description = info.get_description()
699
            attributes = info.get_attributes()
700
            elements = info.get_elements()
701
702
            command_txt = "\t{0: <22} {1}\n".format(name, description)
703
704
            if attributes:
705
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
706
707
                for attrname, attrdesc in attributes.items():
708
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
709
                    command_txt = ''.join([command_txt, attr_txt])
710
711
            if elements:
712
                command_txt = ''.join(
713
                    [command_txt, "\t Elements:\n", elements_as_text(elements),]
714
                )
715
716
            txt += command_txt
717
718
        return txt
719
720
    @deprecated(version="20.8", reason="Use ospd.xml.elements_as_text instead.")
721
    def elements_as_text(self, elems: Dict, indent: int = 2) -> str:
722
        """ Returns the elems dictionary as formatted plain text. """
723
        return elements_as_text(elems, indent)
724
725
    def delete_scan(self, scan_id: str) -> int:
726
        """ Deletes scan_id scan from collection.
727
728
        @return: 1 if scan deleted, 0 otherwise.
729
        """
730
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
731
            return 0
732
733
        # Don't delete the scan until the process stops
734
        exitcode = None
735
        try:
736
            self.scan_processes[scan_id].join()
737
            exitcode = self.scan_processes[scan_id].exitcode
738
        except KeyError:
739
            logger.debug('Scan process for %s never started,', scan_id)
740
741
        if exitcode or exitcode == 0:
742
            del self.scan_processes[scan_id]
743
744
        return self.scan_collection.delete_scan(scan_id)
745
746
    def get_scan_results_xml(
747
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
748
    ):
749
        """ Gets scan_id scan's results in XML format.
750
751
        @return: String of scan results in xml.
752
        """
753
        results = Element('results')
754
        for result in self.scan_collection.results_iterator(
755
            scan_id, pop_res, max_res
756
        ):
757
            results.append(get_result_xml(result))
758
759
        logger.debug('Returning %d results', len(results))
760
        return results
761
762
    def _get_scan_progress_xml(self, scan_id: str):
763
        """ Gets scan_id scan's progress in XML format.
764
765
        @return: String of scan progress in xml.
766
        """
767
        current_progress = dict()
768
769
        current_progress[
770
            'current_hosts'
771
        ] = self.scan_collection.get_current_target_progress(scan_id)
772
        current_progress['overall'] = self.get_scan_progress(scan_id)
773
        current_progress['count_alive'] = self.scan_collection.get_count_alive(
774
            scan_id
775
        )
776
        current_progress['count_dead'] = self.scan_collection.get_count_dead(
777
            scan_id
778
        )
779
        current_progress[
780
            'count_excluded'
781
        ] = self.scan_collection.simplify_exclude_host_count(scan_id)
782
        current_progress['count_total'] = self.scan_collection.get_host_count(
783
            scan_id
784
        )
785
786
        return get_progress_xml(current_progress)
787
788
    @deprecated(
789
        version="20.8",
790
        reason="Please use ospd.xml.get_elements_from_dict instead.",
791
    )
792
    def get_xml_str(self, data: Dict) -> List:
793
        """ Creates a string in XML Format using the provided data structure.
794
795
        @param: Dictionary of xml tags and their elements.
796
797
        @return: String of data in xml format.
798
        """
799
        return get_elements_from_dict(data)
800
801
    def get_scan_xml(
802
        self,
803
        scan_id: str,
804
        detailed: bool = True,
805
        pop_res: bool = False,
806
        max_res: int = 0,
807
        progress: bool = False,
808
    ):
809
        """ Gets scan in XML format.
810
811
        @return: String of scan in XML format.
812
        """
813
        if not scan_id:
814
            return Element('scan')
815
816
        if self.get_scan_status(scan_id) == ScanStatus.QUEUED:
817
            target = ''
818
            scan_progress = 0
819
            status = self.get_scan_status(scan_id)
820
            start_time = 0
821
            end_time = 0
822
            response = Element('scan')
823
            detailed = False
824
            progress = False
825
            response.append(Element('results'))
826
        else:
827
            target = self.get_scan_host(scan_id)
828
            scan_progress = self.get_scan_progress(scan_id)
829
            status = self.get_scan_status(scan_id)
830
            start_time = self.get_scan_start_time(scan_id)
831
            end_time = self.get_scan_end_time(scan_id)
832
            response = Element('scan')
833
834
        for name, value in [
835
            ('id', scan_id),
836
            ('target', target),
837
            ('progress', scan_progress),
838
            ('status', status.name.lower()),
839
            ('start_time', start_time),
840
            ('end_time', end_time),
841
        ]:
842
            response.set(name, str(value))
843
        if detailed:
844
            response.append(
845
                self.get_scan_results_xml(scan_id, pop_res, max_res)
846
            )
847
        if progress:
848
            response.append(self._get_scan_progress_xml(scan_id))
849
850
        return response
851
852
    @staticmethod
853
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
854
        vt_id: str, custom: Dict
855
    ) -> str:
856
        """ Create a string representation of the XML object from the
857
        custom data object.
858
        This needs to be implemented by each ospd wrapper, in case
859
        custom elements for VTs are used.
860
861
        The custom XML object which is returned will be embedded
862
        into a <custom></custom> element.
863
864
        @return: XML object as string for custom data.
865
        """
866
        return ''
867
868
    @staticmethod
869
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
870
        vt_id: str, vt_params
871
    ) -> str:
872
        """ Create a string representation of the XML object from the
873
        vt_params data object.
874
        This needs to be implemented by each ospd wrapper, in case
875
        vt_params elements for VTs are used.
876
877
        The params XML object which is returned will be embedded
878
        into a <params></params> element.
879
880
        @return: XML object as string for vt parameters data.
881
        """
882
        return ''
883
884
    @staticmethod
885
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
886
        vt_id: str, vt_refs
887
    ) -> str:
888
        """ Create a string representation of the XML object from the
889
        refs data object.
890
        This needs to be implemented by each ospd wrapper, in case
891
        refs elements for VTs are used.
892
893
        The refs XML object which is returned will be embedded
894
        into a <refs></refs> element.
895
896
        @return: XML object as string for vt references data.
897
        """
898
        return ''
899
900
    @staticmethod
901
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
902
        vt_id: str, vt_dependencies
903
    ) -> str:
904
        """ Create a string representation of the XML object from the
905
        vt_dependencies data object.
906
        This needs to be implemented by each ospd wrapper, in case
907
        vt_dependencies elements for VTs are used.
908
909
        The vt_dependencies XML object which is returned will be embedded
910
        into a <dependencies></dependencies> element.
911
912
        @return: XML object as string for vt dependencies data.
913
        """
914
        return ''
915
916
    @staticmethod
917
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
918
        vt_id: str, vt_creation_time
919
    ) -> str:
920
        """ Create a string representation of the XML object from the
921
        vt_creation_time data object.
922
        This needs to be implemented by each ospd wrapper, in case
923
        vt_creation_time elements for VTs are used.
924
925
        The vt_creation_time XML object which is returned will be embedded
926
        into a <vt_creation_time></vt_creation_time> element.
927
928
        @return: XML object as string for vt creation time data.
929
        """
930
        return ''
931
932
    @staticmethod
933
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
934
        vt_id: str, vt_modification_time
935
    ) -> str:
936
        """ Create a string representation of the XML object from the
937
        vt_modification_time data object.
938
        This needs to be implemented by each ospd wrapper, in case
939
        vt_modification_time elements for VTs are used.
940
941
        The vt_modification_time XML object which is returned will be embedded
942
        into a <vt_modification_time></vt_modification_time> element.
943
944
        @return: XML object as string for vt references data.
945
        """
946
        return ''
947
948
    @staticmethod
949
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
950
        vt_id: str, summary
951
    ) -> str:
952
        """ Create a string representation of the XML object from the
953
        summary data object.
954
        This needs to be implemented by each ospd wrapper, in case
955
        summary elements for VTs are used.
956
957
        The summary XML object which is returned will be embedded
958
        into a <summary></summary> element.
959
960
        @return: XML object as string for summary data.
961
        """
962
        return ''
963
964
    @staticmethod
965
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
966
        vt_id: str, impact
967
    ) -> str:
968
        """ Create a string representation of the XML object from the
969
        impact data object.
970
        This needs to be implemented by each ospd wrapper, in case
971
        impact elements for VTs are used.
972
973
        The impact XML object which is returned will be embedded
974
        into a <impact></impact> element.
975
976
        @return: XML object as string for impact data.
977
        """
978
        return ''
979
980
    @staticmethod
981
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
982
        vt_id: str, affected
983
    ) -> str:
984
        """ Create a string representation of the XML object from the
985
        affected data object.
986
        This needs to be implemented by each ospd wrapper, in case
987
        affected elements for VTs are used.
988
989
        The affected XML object which is returned will be embedded
990
        into a <affected></affected> element.
991
992
        @return: XML object as string for affected data.
993
        """
994
        return ''
995
996
    @staticmethod
997
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
998
        vt_id: str, insight
999
    ) -> str:
1000
        """ Create a string representation of the XML object from the
1001
        insight data object.
1002
        This needs to be implemented by each ospd wrapper, in case
1003
        insight elements for VTs are used.
1004
1005
        The insight XML object which is returned will be embedded
1006
        into a <insight></insight> element.
1007
1008
        @return: XML object as string for insight data.
1009
        """
1010
        return ''
1011
1012
    @staticmethod
1013
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
1014
        vt_id: str, solution, solution_type=None, solution_method=None
1015
    ) -> str:
1016
        """ Create a string representation of the XML object from the
1017
        solution data object.
1018
        This needs to be implemented by each ospd wrapper, in case
1019
        solution elements for VTs are used.
1020
1021
        The solution XML object which is returned will be embedded
1022
        into a <solution></solution> element.
1023
1024
        @return: XML object as string for solution data.
1025
        """
1026
        return ''
1027
1028
    @staticmethod
1029
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
1030
        vt_id: str, detection=None, qod_type=None, qod=None
1031
    ) -> str:
1032
        """ Create a string representation of the XML object from the
1033
        detection data object.
1034
        This needs to be implemented by each ospd wrapper, in case
1035
        detection elements for VTs are used.
1036
1037
        The detection XML object which is returned is an element with
1038
        tag <detection></detection> element
1039
1040
        @return: XML object as string for detection data.
1041
        """
1042
        return ''
1043
1044
    @staticmethod
1045
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
1046
        vt_id: str, severities
1047
    ) -> str:
1048
        """ Create a string representation of the XML object from the
1049
        severities data object.
1050
        This needs to be implemented by each ospd wrapper, in case
1051
        severities elements for VTs are used.
1052
1053
        The severities XML objects which are returned will be embedded
1054
        into a <severities></severities> element.
1055
1056
        @return: XML object as string for severities data.
1057
        """
1058
        return ''
1059
1060
    def get_vt_iterator(  # pylint: disable=unused-argument
1061
        self, vt_selection: List[str] = None, details: bool = True
1062
    ) -> Iterator[Tuple[str, Dict]]:
1063
        """ Return iterator object for getting elements
1064
        from the VTs dictionary. """
1065
        return self.vts.items()
1066
1067
    def get_vt_xml(self, single_vt: Tuple[str, Dict]) -> Element:
1068
        """ Gets a single vulnerability test information in XML format.
1069
1070
        @return: String of single vulnerability test information in XML format.
1071
        """
1072
        if not single_vt or single_vt[1] is None:
1073
            return Element('vt')
1074
1075
        vt_id, vt = single_vt
1076
1077
        name = vt.get('name')
1078
        vt_xml = Element('vt')
1079
        vt_xml.set('id', vt_id)
1080
1081
        for name, value in [('name', name)]:
1082
            elem = SubElement(vt_xml, name)
1083
            elem.text = str(value)
1084
1085
        if vt.get('vt_params'):
1086
            params_xml_str = self.get_params_vt_as_xml_str(
1087
                vt_id, vt.get('vt_params')
1088
            )
1089
            vt_xml.append(secET.fromstring(params_xml_str))
1090
1091
        if vt.get('vt_refs'):
1092
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
1093
            vt_xml.append(secET.fromstring(refs_xml_str))
1094
1095
        if vt.get('vt_dependencies'):
1096
            dependencies = self.get_dependencies_vt_as_xml_str(
1097
                vt_id, vt.get('vt_dependencies')
1098
            )
1099
            vt_xml.append(secET.fromstring(dependencies))
1100
1101
        if vt.get('creation_time'):
1102
            vt_ctime = self.get_creation_time_vt_as_xml_str(
1103
                vt_id, vt.get('creation_time')
1104
            )
1105
            vt_xml.append(secET.fromstring(vt_ctime))
1106
1107
        if vt.get('modification_time'):
1108
            vt_mtime = self.get_modification_time_vt_as_xml_str(
1109
                vt_id, vt.get('modification_time')
1110
            )
1111
            vt_xml.append(secET.fromstring(vt_mtime))
1112
1113
        if vt.get('summary'):
1114
            summary_xml_str = self.get_summary_vt_as_xml_str(
1115
                vt_id, vt.get('summary')
1116
            )
1117
            vt_xml.append(secET.fromstring(summary_xml_str))
1118
1119
        if vt.get('impact'):
1120
            impact_xml_str = self.get_impact_vt_as_xml_str(
1121
                vt_id, vt.get('impact')
1122
            )
1123
            vt_xml.append(secET.fromstring(impact_xml_str))
1124
1125
        if vt.get('affected'):
1126
            affected_xml_str = self.get_affected_vt_as_xml_str(
1127
                vt_id, vt.get('affected')
1128
            )
1129
            vt_xml.append(secET.fromstring(affected_xml_str))
1130
1131
        if vt.get('insight'):
1132
            insight_xml_str = self.get_insight_vt_as_xml_str(
1133
                vt_id, vt.get('insight')
1134
            )
1135
            vt_xml.append(secET.fromstring(insight_xml_str))
1136
1137
        if vt.get('solution'):
1138
            solution_xml_str = self.get_solution_vt_as_xml_str(
1139
                vt_id,
1140
                vt.get('solution'),
1141
                vt.get('solution_type'),
1142
                vt.get('solution_method'),
1143
            )
1144
            vt_xml.append(secET.fromstring(solution_xml_str))
1145
1146
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1147
            detection_xml_str = self.get_detection_vt_as_xml_str(
1148
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1149
            )
1150
            vt_xml.append(secET.fromstring(detection_xml_str))
1151
1152
        if vt.get('severities'):
1153
            severities_xml_str = self.get_severities_vt_as_xml_str(
1154
                vt_id, vt.get('severities')
1155
            )
1156
            vt_xml.append(secET.fromstring(severities_xml_str))
1157
1158
        if vt.get('custom'):
1159
            custom_xml_str = self.get_custom_vt_as_xml_str(
1160
                vt_id, vt.get('custom')
1161
            )
1162
            vt_xml.append(secET.fromstring(custom_xml_str))
1163
1164
        return vt_xml
1165
1166
    def get_vts_selection_list(
1167
        self, vt_id: str = None, filtered_vts: Dict = None
1168
    ) -> Iterable[str]:
1169
        """
1170
        Get list of VT's OID.
1171
        If vt_id is specified, the collection will contain only this vt, if
1172
        found.
1173
        If no vt_id is specified or filtered_vts is None (default), the
1174
        collection will contain all vts. Otherwise those vts passed
1175
        in filtered_vts or vt_id are returned. In case of both vt_id and
1176
        filtered_vts are given, filtered_vts has priority.
1177
1178
        Arguments:
1179
            vt_id (vt_id, optional): ID of the vt to get.
1180
            filtered_vts (list, optional): Filtered VTs collection.
1181
1182
        Return:
1183
            List of selected VT's OID.
1184
        """
1185
        vts_xml = []
1186
1187
        # No match for the filter
1188
        if filtered_vts is not None and len(filtered_vts) == 0:
1189
            return vts_xml
1190
1191
        if filtered_vts:
1192
            vts_list = filtered_vts
1193
        elif vt_id:
1194
            vts_list = [vt_id]
1195
        else:
1196
            vts_list = self.vts.keys()
1197
1198
        return vts_list
1199
1200
    def handle_command(self, data: bytes, stream: Stream) -> None:
1201
        """ Handles an osp command in a string.
1202
        """
1203
        try:
1204
            tree = secET.fromstring(data)
1205
        except secET.ParseError:
1206
            logger.debug("Erroneous client input: %s", data)
1207
            raise OspdCommandError('Invalid data')
1208
1209
        command_name = tree.tag
1210
1211
        logger.debug('Handling %s command request.', command_name)
1212
1213
        command = self.commands.get(command_name, None)
1214
        if not command and command_name != "authenticate":
1215
            raise OspdCommandError('Bogus command name')
1216
1217
        if not self.initialized and command.must_be_initialized:
1218
            exception = OspdCommandError(
1219
                '%s is still starting' % self.daemon_info['name'], 'error'
1220
            )
1221
            response = exception.as_xml()
1222
            stream.write(response)
1223
            return
1224
1225
        response = command.handle_xml(tree)
1226
1227
        write_success = True
1228
        if isinstance(response, bytes):
1229
            write_success = stream.write(response)
1230
        else:
1231
            for data in response:
1232
                write_success = stream.write(data)
1233
                if not write_success:
1234
                    break
1235
1236
        scan_id = tree.get('scan_id')
1237
        if self.scan_exists(scan_id) and command_name == "get_scans":
1238
            if write_success:
1239
                self.scan_collection.clean_temp_result_list(scan_id)
1240
            else:
1241
                self.scan_collection.restore_temp_result_list(scan_id)
1242
1243
    def check(self):
1244
        """ Asserts to False. Should be implemented by subclass. """
1245
        raise NotImplementedError
1246
1247
    def run(self) -> None:
1248
        """ Starts the Daemon, handling commands until interrupted.
1249
        """
1250
1251
        try:
1252
            while True:
1253
                time.sleep(SCHEDULER_CHECK_PERIOD)
1254
                self.scheduler()
1255
                self.clean_forgotten_scans()
1256
                self.start_queued_scans()
1257
                self.wait_for_children()
1258
        except KeyboardInterrupt:
1259
            logger.info("Received Ctrl-C shutting-down ...")
1260
1261
    def start_queued_scans(self) -> None:
1262
        """ Starts a queued scan if it is allowed """
1263
1264
        current_queued_scans = self.get_count_queued_scans()
1265
        if not current_queued_scans:
1266
            return
1267
1268
        if not self.initialized:
1269
            logger.info(
1270
                "Queued task can not be started because a feed "
1271
                "update is being performed."
1272
            )
1273
            return
1274
1275
        logger.info('Currently %d queued scans.', current_queued_scans)
1276
1277
        for scan_id in self.scan_collection.ids_iterator():
1278
            scan_allowed = (
1279
                self.is_new_scan_allowed() and self.is_enough_free_memory()
1280
            )
1281
            scan_is_queued = self.get_scan_status(scan_id) == ScanStatus.QUEUED
1282
1283
            if scan_is_queued and scan_allowed:
1284
                try:
1285
                    self.scan_collection.unpickle_scan_info(scan_id)
1286
                except OspdCommandError as e:
1287
                    logger.error("Start scan error %s", e)
1288
                    self.stop_scan(scan_id)
1289
                    continue
1290
1291
                scan_func = self.start_scan
1292
                scan_process = create_process(func=scan_func, args=(scan_id,))
1293
                self.scan_processes[scan_id] = scan_process
1294
                scan_process.start()
1295
                self.set_scan_status(scan_id, ScanStatus.INIT)
1296
1297
                current_queued_scans = current_queued_scans - 1
1298
                logger.info('Starting scan %s.', scan_id)
1299
            elif scan_is_queued and not scan_allowed:
1300
                return
1301
1302
    def is_new_scan_allowed(self) -> bool:
1303
        """ Check if max_scans has been reached.
1304
1305
        Return:
1306
            True if a new scan can be launch.
1307
        """
1308
        if (self.max_scans != 0) and (
1309
            len(self.scan_processes) >= self.max_scans
1310
        ):
1311
            logger.info(
1312
                'Not possible to run a new scan. Max scan limit set '
1313
                'to %d reached.',
1314
                self.max_scans,
1315
            )
1316
            return False
1317
1318
        return True
1319
1320
    def is_enough_free_memory(self) -> bool:
1321
        """ Check if there is enough free memory in the system to run
1322
        a new scan. The necessary memory is a rough calculation and very
1323
        conservative.
1324
1325
        Return:
1326
            True if there is enough memory for a new scan.
1327
        """
1328
        if not self.min_free_mem_scan_queue:
1329
            return True
1330
1331
        free_mem = psutil.virtual_memory().available / (1024 * 1024)
1332
1333
        if free_mem > self.min_free_mem_scan_queue:
1334
            return True
1335
1336
        logger.info(
1337
            'Not possible to run a new scan. Not enough free memory. '
1338
            'Only %d MB available but at least %d are required',
1339
            free_mem,
1340
            self.min_free_mem_scan_queue,
1341
        )
1342
1343
        return False
1344
1345
    def scheduler(self):
1346
        """ Should be implemented by subclass in case of need
1347
        to run tasks periodically. """
1348
1349
    def wait_for_children(self):
1350
        """ Join the zombie process to releases resources."""
1351
        for scan_id in self.scan_processes:
1352
            self.scan_processes[scan_id].join(0)
1353
1354
    def create_scan(
1355
        self,
1356
        scan_id: str,
1357
        targets: Dict,
1358
        options: Optional[Dict],
1359
        vt_selection: Dict,
1360
    ) -> Optional[str]:
1361
        """ Creates a new scan.
1362
1363
        @target: Target to scan.
1364
        @options: Miscellaneous scan options supplied via <scanner_params>
1365
                  XML element.
1366
1367
        @return: New scan's ID. None if the scan_id already exists.
1368
        """
1369
        status = None
1370
        scan_exists = self.scan_exists(scan_id)
1371
        if scan_id and scan_exists:
1372
            status = self.get_scan_status(scan_id)
1373
            logger.info(
1374
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1375
            )
1376
            return
1377
1378
        return self.scan_collection.create_scan(
1379
            scan_id, targets, options, vt_selection
1380
        )
1381
1382
    def get_scan_options(self, scan_id: str) -> str:
1383
        """ Gives a scan's list of options. """
1384
        return self.scan_collection.get_options(scan_id)
1385
1386
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1387
        """ Sets a scan's option to a provided value. """
1388
        return self.scan_collection.set_option(scan_id, name, value)
1389
1390
    def clean_forgotten_scans(self) -> None:
1391
        """ Check for old stopped or finished scans which have not been
1392
        deleted and delete them if the are older than the set value."""
1393
1394
        if not self.scaninfo_store_time:
1395
            return
1396
1397
        for scan_id in list(self.scan_collection.ids_iterator()):
1398
            end_time = int(self.get_scan_end_time(scan_id))
1399
            scan_status = self.get_scan_status(scan_id)
1400
1401
            if (
1402
                scan_status == ScanStatus.STOPPED
1403
                or scan_status == ScanStatus.FINISHED
1404
                or scan_status == ScanStatus.INTERRUPTED
1405
            ) and end_time:
1406
                stored_time = int(time.time()) - end_time
1407
                if stored_time > self.scaninfo_store_time * 3600:
1408
                    logger.debug(
1409
                        'Scan %s is older than %d hours and seems have been '
1410
                        'forgotten. Scan info will be deleted from the '
1411
                        'scan table',
1412
                        scan_id,
1413
                        self.scaninfo_store_time,
1414
                    )
1415
                    self.delete_scan(scan_id)
1416
1417
    def check_scan_process(self, scan_id: str) -> None:
1418
        """ Check the scan's process, and terminate the scan if not alive. """
1419
        if self.get_scan_status(scan_id) == ScanStatus.QUEUED:
1420
            return
1421
1422
        scan_process = self.scan_processes.get(scan_id)
1423
        progress = self.get_scan_progress(scan_id)
1424
1425
        if (
1426
            progress < ScanProgress.FINISHED
1427
            and scan_process
1428
            and not scan_process.is_alive()
1429
        ):
1430
            if not self.get_scan_status(scan_id) == ScanStatus.STOPPED:
1431
                self.set_scan_status(scan_id, ScanStatus.STOPPED)
1432
                self.add_scan_error(
1433
                    scan_id, name="", host="", value="Scan process failure."
1434
                )
1435
1436
                logger.info("%s: Scan stopped with errors.", scan_id)
1437
                self.interrupt_scan(scan_id)
1438
1439
        elif progress == ScanProgress.FINISHED:
1440
            scan_process.join(0)
1441
1442
    def get_count_queued_scans(self) -> int:
1443
        """ Get the amount of scans with queued status """
1444
        count = 0
1445
        for scan_id in self.scan_collection.ids_iterator():
1446
            if self.get_scan_status(scan_id) == ScanStatus.QUEUED:
1447
                count += 1
1448
        return count
1449
1450
    def get_scan_progress(self, scan_id: str) -> int:
1451
        """ Gives a scan's current progress value. """
1452
        return self.scan_collection.get_progress(scan_id)
1453
1454
    def get_scan_host(self, scan_id: str) -> str:
1455
        """ Gives a scan's target. """
1456
        return self.scan_collection.get_host_list(scan_id)
1457
1458
    def get_scan_ports(self, scan_id: str) -> str:
1459
        """ Gives a scan's ports list. """
1460
        return self.scan_collection.get_ports(scan_id)
1461
1462
    def get_scan_exclude_hosts(self, scan_id: str):
1463
        """ Gives a scan's exclude host list. If a target is passed gives
1464
        the exclude host list for the given target. """
1465
        return self.scan_collection.get_exclude_hosts(scan_id)
1466
1467
    def get_scan_credentials(self, scan_id: str) -> Dict:
1468
        """ Gives a scan's credential list. If a target is passed gives
1469
        the credential list for the given target. """
1470
        return self.scan_collection.get_credentials(scan_id)
1471
1472
    def get_scan_target_options(self, scan_id: str) -> Dict:
1473
        """ Gives a scan's target option dict. If a target is passed gives
1474
        the credential list for the given target. """
1475
        return self.scan_collection.get_target_options(scan_id)
1476
1477
    def get_scan_vts(self, scan_id: str) -> Dict:
1478
        """ Gives a scan's vts. """
1479
        return self.scan_collection.get_vts(scan_id)
1480
1481
    def get_scan_start_time(self, scan_id: str) -> str:
1482
        """ Gives a scan's start time. """
1483
        return self.scan_collection.get_start_time(scan_id)
1484
1485
    def get_scan_end_time(self, scan_id: str) -> str:
1486
        """ Gives a scan's end time. """
1487
        return self.scan_collection.get_end_time(scan_id)
1488
1489
    def add_scan_log(
1490
        self,
1491
        scan_id: str,
1492
        host: str = '',
1493
        hostname: str = '',
1494
        name: str = '',
1495
        value: str = '',
1496
        port: str = '',
1497
        test_id: str = '',
1498
        qod: str = '',
1499
        uri: str = '',
1500
    ) -> None:
1501
        """ Adds a log result to scan_id scan. """
1502
1503
        self.scan_collection.add_result(
1504
            scan_id,
1505
            ResultType.LOG,
1506
            host,
1507
            hostname,
1508
            name,
1509
            value,
1510
            port,
1511
            test_id,
1512
            '0.0',
1513
            qod,
1514
            uri,
1515
        )
1516
1517
    def add_scan_error(
1518
        self,
1519
        scan_id: str,
1520
        host: str = '',
1521
        hostname: str = '',
1522
        name: str = '',
1523
        value: str = '',
1524
        port: str = '',
1525
        test_id='',
1526
        uri: str = '',
1527
    ) -> None:
1528
        """ Adds an error result to scan_id scan. """
1529
        self.scan_collection.add_result(
1530
            scan_id,
1531
            ResultType.ERROR,
1532
            host,
1533
            hostname,
1534
            name,
1535
            value,
1536
            port,
1537
            test_id,
1538
            uri,
1539
        )
1540
1541
    def add_scan_host_detail(
1542
        self,
1543
        scan_id: str,
1544
        host: str = '',
1545
        hostname: str = '',
1546
        name: str = '',
1547
        value: str = '',
1548
        uri: str = '',
1549
    ) -> None:
1550
        """ Adds a host detail result to scan_id scan. """
1551
        self.scan_collection.add_result(
1552
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value, uri
1553
        )
1554
1555
    def add_scan_alarm(
1556
        self,
1557
        scan_id: str,
1558
        host: str = '',
1559
        hostname: str = '',
1560
        name: str = '',
1561
        value: str = '',
1562
        port: str = '',
1563
        test_id: str = '',
1564
        severity: str = '',
1565
        qod: str = '',
1566
        uri: str = '',
1567
    ) -> None:
1568
        """ Adds an alarm result to scan_id scan. """
1569
        self.scan_collection.add_result(
1570
            scan_id,
1571
            ResultType.ALARM,
1572
            host,
1573
            hostname,
1574
            name,
1575
            value,
1576
            port,
1577
            test_id,
1578
            severity,
1579
            qod,
1580
            uri,
1581
        )
1582