ospd.ospd   F
last analyzed

Complexity

Total Complexity 234

Size/Duplication

Total Lines 1629
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 944
dl 0
loc 1629
rs 1.656
c 0
b 0
f 0
wmc 234

89 Methods

Rating   Name   Duplication   Size   Complexity  
A OSPDaemon.get_scanner_param_mandatory() 0 7 2
A OSPDaemon.get_scanner_param_default() 0 7 2
A OSPDaemon.get_daemon_name() 0 3 1
A OSPDaemon.get_scanner_param_type() 0 7 2
A OSPDaemon.get_daemon_version() 0 3 1
A OSPDaemon.init() 0 8 1
A OSPDaemon.set_command_attributes() 0 5 2
A OSPDaemon.get_scanner_name() 0 3 1
A OSPDaemon.get_protocol_version() 0 3 1
A OSPDaemon.get_scanner_description() 0 3 1
A OSPDaemon.command_exists() 0 3 1
A OSPDaemon.get_server_version() 0 4 1
A OSPDaemon.get_scanner_version() 0 3 1
A OSPDaemon.process_finished_hosts() 0 9 2
C OSPDaemon.handle_client_stream() 0 45 11
A OSPDaemon._get_scan_progress_raw() 0 27 1
A OSPDaemon.get_vts_version() 0 3 1
A OSPDaemon.get_scanner_params() 0 2 1
A OSPDaemon.set_vts_version() 0 12 2
D OSPDaemon.preprocess_scan_params() 0 38 12
B OSPDaemon.__init__() 0 56 4
B OSPDaemon.add_vt() 0 51 1
A OSPDaemon.set_scanner_param() 0 7 1
A OSPDaemon.process_scan_params() 0 3 1
B OSPDaemon.clean_forgotten_scans() 0 26 8
A OSPDaemon.get_scan_options() 0 3 1
A OSPDaemon.create_scan() 0 28 3
A OSPDaemon.get_scan_host() 0 3 1
A OSPDaemon.get_count_running_scans() 0 8 4
A OSPDaemon.finish_scan() 0 5 1
A OSPDaemon.get_modification_time_vt_as_xml_str() 0 16 1
A OSPDaemon.scan_exists() 0 7 1
A OSPDaemon.get_custom_vt_as_xml_str() 0 16 1
A OSPDaemon.add_scan_error() 0 22 1
A OSPDaemon.get_scan_start_time() 0 3 1
A OSPDaemon.get_params_vt_as_xml_str() 0 16 1
A OSPDaemon.add_scan_alarm() 0 26 1
A OSPDaemon.get_scan_vts() 0 3 1
A OSPDaemon.get_count_queued_scans() 0 7 3
A OSPDaemon.get_scan_progress() 0 5 1
A OSPDaemon.get_severities_vt_as_xml_str() 0 16 1
A OSPDaemon.set_scan_total_hosts() 0 4 1
A OSPDaemon.get_scan_target_options() 0 4 1
A OSPDaemon.get_affected_vt_as_xml_str() 0 16 1
A OSPDaemon.is_new_scan_allowed() 0 17 3
B OSPDaemon.get_scan_xml() 0 51 6
A OSPDaemon.get_detection_vt_as_xml_str() 0 16 1
A OSPDaemon.check() 0 3 1
A OSPDaemon._get_scan_progress_xml() 0 8 1
A OSPDaemon.get_insight_vt_as_xml_str() 0 16 1
A OSPDaemon.set_scan_option() 0 3 1
A OSPDaemon.get_scan_status() 0 5 1
A OSPDaemon.scheduler() 0 2 1
A OSPDaemon.get_summary_vt_as_xml_str() 0 16 1
A OSPDaemon.wait_for_children() 0 4 2
A OSPDaemon.get_dependencies_vt_as_xml_str() 0 16 1
A OSPDaemon.get_scan_host_progress() 0 10 1
A OSPDaemon.get_impact_vt_as_xml_str() 0 16 1
A OSPDaemon.is_enough_free_memory() 0 38 5
A OSPDaemon.get_vt_iterator() 0 6 1
A OSPDaemon.add_scan_log() 0 26 1
A OSPDaemon.get_refs_vt_as_xml_str() 0 16 1
B OSPDaemon.start_scan() 0 37 6
A OSPDaemon.get_scan_ports() 0 3 1
A OSPDaemon.exec_scan() 0 3 1
B OSPDaemon.check_scan_process() 0 34 7
F OSPDaemon.get_vt_xml() 0 99 19
A OSPDaemon.set_scan_host_progress() 0 21 5
A OSPDaemon.get_solution_vt_as_xml_str() 0 16 1
A OSPDaemon.get_vts_selection_list() 0 33 5
A OSPDaemon.get_scan_exclude_hosts() 0 4 1
A OSPDaemon.delete_scan() 0 21 5
A OSPDaemon.set_scan_progress_batch() 0 5 1
A OSPDaemon.get_scan_end_time() 0 3 1
C OSPDaemon.daemon_exit_cleanup() 0 35 11
A OSPDaemon.get_scan_credentials() 0 4 1
A OSPDaemon.add_scan_host_detail() 0 12 1
B OSPDaemon.get_help_text() 0 30 5
A OSPDaemon.set_scan_progress() 0 9 1
A OSPDaemon.get_creation_time_vt_as_xml_str() 0 16 1
B OSPDaemon.stop_scan() 0 48 8
A OSPDaemon.interrupt_scan() 0 4 1
A OSPDaemon.handle_timeout() 0 7 1
A OSPDaemon.get_scan_results_xml() 0 16 2
C OSPDaemon.start_queued_scans() 0 41 9
B OSPDaemon.sort_host_finished() 0 37 5
D OSPDaemon.handle_command() 0 51 12
A OSPDaemon.set_scan_status() 0 4 1
A OSPDaemon.run() 0 12 3

1 Function

Rating   Name   Duplication   Size   Complexity  
A _terminate_process_group() 0 2 1

How to fix   Complexity   

Complexity

Complex classes like ospd.ospd often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (C) 2014-2021 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
# pylint: disable=too-many-lines
19
20
""" OSP Daemon core class.
21
"""
22
23
import logging
24
import socket
25
import ssl
26
import multiprocessing
27
import time
28
import os
29
30
from pprint import pformat
31
from typing import (
32
    List,
33
    Any,
34
    Iterator,
35
    Dict,
36
    Optional,
37
    Iterable,
38
    Tuple,
39
    Union,
40
)
41
from xml.etree.ElementTree import Element, SubElement
42
43
import defusedxml.ElementTree as secET
44
45
import psutil
46
47
from ospd import __version__
48
from ospd.command import get_commands
49
from ospd.errors import OspdCommandError
50
from ospd.misc import ResultType, create_process
51
from ospd.network import target_str_to_list
52
from ospd.protocol import RequestParser
53
from ospd.scan import ScanCollection, ScanStatus, ScanProgress
54
from ospd.server import BaseServer, Stream
55
from ospd.vtfilter import VtsFilter
56
from ospd.vts import Vts
57
from ospd.xml import (
58
    elements_as_text,
59
    get_result_xml,
60
    get_progress_xml,
61
)
62
63
logger = logging.getLogger(__name__)
64
65
PROTOCOL_VERSION = __version__
66
67
SCHEDULER_CHECK_PERIOD = 10  # in seconds
68
69
MIN_TIME_BETWEEN_START_SCAN = 60  # in seconds
70
71
BASE_SCANNER_PARAMS = {
72
    'debug_mode': {
73
        'type': 'boolean',
74
        'name': 'Debug Mode',
75
        'default': 0,
76
        'mandatory': 0,
77
        'description': 'Whether to get extra scan debug information.',
78
    },
79
    'dry_run': {
80
        'type': 'boolean',
81
        'name': 'Dry Run',
82
        'default': 0,
83
        'mandatory': 0,
84
        'description': 'Whether to dry run scan.',
85
    },
86
}  # type: Dict
87
88
89
def _terminate_process_group(process: multiprocessing.Process) -> None:
90
    os.killpg(os.getpgid(process.pid), 15)
91
92
93
class OSPDaemon:
94
95
    """Daemon class for OSP traffic handling.
96
97
    Every scanner wrapper should subclass it and make necessary additions and
98
    changes.
99
100
    * Add any needed parameters in __init__.
101
    * Implement check() method which verifies scanner availability and other
102
      environment related conditions.
103
    * Implement process_scan_params and exec_scan methods which are
104
      specific to handling the <start_scan> command, executing the wrapped
105
      scanner and storing the results.
106
    * Implement other methods that assert to False such as get_scanner_name,
107
      get_scanner_version.
108
    * Use Call set_command_attributes at init time to add scanner command
109
      specific options eg. the w3af profile for w3af wrapper.
110
    """
111
112
    def __init__(
113
        self,
114
        *,
115
        customvtfilter=None,
116
        storage=None,
117
        max_scans=0,
118
        min_free_mem_scan_queue=0,
119
        file_storage_dir='/var/run/ospd',
120
        max_queued_scans=0,
121
        **kwargs,
122
    ):  # pylint: disable=unused-argument
123
        """Initializes the daemon's internal data."""
124
        self.scan_collection = ScanCollection(file_storage_dir)
125
        self.scan_processes = dict()
126
127
        self.daemon_info = dict()
128
        self.daemon_info['name'] = "OSPd"
129
        self.daemon_info['version'] = __version__
130
        self.daemon_info['description'] = "No description"
131
132
        self.scanner_info = dict()
133
        self.scanner_info['name'] = 'No name'
134
        self.scanner_info['version'] = 'No version'
135
        self.scanner_info['description'] = 'No description'
136
137
        self.server_version = None  # Set by the subclass.
138
139
        self.initialized = None  # Set after initialization finished
140
141
        self.max_scans = max_scans
142
        self.min_free_mem_scan_queue = min_free_mem_scan_queue
143
        self.max_queued_scans = max_queued_scans
144
        self.last_scan_start_time = 0
145
146
        self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
147
148
        self.protocol_version = PROTOCOL_VERSION
149
150
        self.commands = {}
151
152
        for command_class in get_commands():
153
            command = command_class(self)
154
            self.commands[command.get_name()] = command
155
156
        self.scanner_params = dict()
157
158
        for name, params in BASE_SCANNER_PARAMS.items():
159
            self.set_scanner_param(name, params)
160
161
        self.vts = Vts(storage)
162
        self.vts_version = None
163
164
        if customvtfilter:
165
            self.vts_filter = customvtfilter
166
        else:
167
            self.vts_filter = VtsFilter()
168
169
    def init(self, server: BaseServer) -> None:
170
        """Should be overridden by a subclass if the initialization is costly.
171
172
        Will be called after check.
173
        """
174
        self.scan_collection.init()
175
        server.start(self.handle_client_stream)
176
        self.initialized = True
177
178
    def set_command_attributes(self, name: str, attributes: Dict) -> None:
179
        """Sets the xml attributes of a specified command."""
180
        if self.command_exists(name):
181
            command = self.commands.get(name)
182
            command.attributes = attributes
183
184
    def set_scanner_param(self, name: str, scanner_params: Dict) -> None:
185
        """Set a scanner parameter."""
186
187
        assert name
188
        assert scanner_params
189
190
        self.scanner_params[name] = scanner_params
191
192
    def get_scanner_params(self) -> Dict:
193
        return self.scanner_params
194
195
    def add_vt(
196
        self,
197
        vt_id: str,
198
        name: str = None,
199
        vt_params: str = None,
200
        vt_refs: str = None,
201
        custom: str = None,
202
        vt_creation_time: str = None,
203
        vt_modification_time: str = None,
204
        vt_dependencies: str = None,
205
        summary: str = None,
206
        impact: str = None,
207
        affected: str = None,
208
        insight: str = None,
209
        solution: str = None,
210
        solution_t: str = None,
211
        solution_m: str = None,
212
        detection: str = None,
213
        qod_t: str = None,
214
        qod_v: str = None,
215
        severities: str = None,
216
    ) -> None:
217
        """Add a vulnerability test information.
218
219
        IMPORTANT: The VT's Data Manager will store the vts collection.
220
        If the collection is considerably big and it will be consultated
221
        intensible during a routine, consider to do a deepcopy(), since
222
        accessing the shared memory in the data manager is very expensive.
223
        At the end of the routine, the temporal copy must be set to None
224
        and deleted.
225
        """
226
        self.vts.add(
227
            vt_id,
228
            name=name,
229
            vt_params=vt_params,
230
            vt_refs=vt_refs,
231
            custom=custom,
232
            vt_creation_time=vt_creation_time,
233
            vt_modification_time=vt_modification_time,
234
            vt_dependencies=vt_dependencies,
235
            summary=summary,
236
            impact=impact,
237
            affected=affected,
238
            insight=insight,
239
            solution=solution,
240
            solution_t=solution_t,
241
            solution_m=solution_m,
242
            detection=detection,
243
            qod_t=qod_t,
244
            qod_v=qod_v,
245
            severities=severities,
246
        )
247
248
    def set_vts_version(self, vts_version: str) -> None:
249
        """Add into the vts dictionary an entry to identify the
250
        vts version.
251
252
        Parameters:
253
            vts_version (str): Identifies a unique vts version.
254
        """
255
        if not vts_version:
256
            raise OspdCommandError(
257
                'A vts_version parameter is required', 'set_vts_version'
258
            )
259
        self.vts_version = vts_version
260
261
    def get_vts_version(self) -> Optional[str]:
262
        """Return the vts version."""
263
        return self.vts_version
264
265
    def command_exists(self, name: str) -> bool:
266
        """Checks if a commands exists."""
267
        return name in self.commands
268
269
    def get_scanner_name(self) -> str:
270
        """Gives the wrapped scanner's name."""
271
        return self.scanner_info['name']
272
273
    def get_scanner_version(self) -> str:
274
        """Gives the wrapped scanner's version."""
275
        return self.scanner_info['version']
276
277
    def get_scanner_description(self) -> str:
278
        """Gives the wrapped scanner's description."""
279
        return self.scanner_info['description']
280
281
    def get_server_version(self) -> str:
282
        """Gives the specific OSP server's version."""
283
        assert self.server_version
284
        return self.server_version
285
286
    def get_protocol_version(self) -> str:
287
        """Gives the OSP's version."""
288
        return self.protocol_version
289
290
    def preprocess_scan_params(self, xml_params):
291
        """Processes the scan parameters."""
292
        params = {}
293
294
        for param in xml_params:
295
            params[param.tag] = param.text or ''
296
297
        # Validate values.
298
        for key in params:
299
            param_type = self.get_scanner_param_type(key)
300
            if not param_type:
301
                continue
302
303
            if param_type in ['integer', 'boolean']:
304
                try:
305
                    params[key] = int(params[key])
306
                except ValueError:
307
                    raise OspdCommandError(
308
                        'Invalid %s value' % key, 'start_scan'
309
                    ) from None
310
311
            if param_type == 'boolean':
312
                if params[key] not in [0, 1]:
313
                    raise OspdCommandError(
314
                        'Invalid %s value' % key, 'start_scan'
315
                    )
316
            elif param_type == 'selection':
317
                selection = self.get_scanner_param_default(key).split('|')
318
                if params[key] not in selection:
319
                    raise OspdCommandError(
320
                        'Invalid %s value' % key, 'start_scan'
321
                    )
322
            if self.get_scanner_param_mandatory(key) and params[key] == '':
323
                raise OspdCommandError(
324
                    'Mandatory %s value is missing' % key, 'start_scan'
325
                )
326
327
        return params
328
329
    def process_scan_params(self, params: Dict) -> Dict:
330
        """This method is to be overridden by the child classes if necessary"""
331
        return params
332
333
    def stop_scan(self, scan_id: str) -> None:
334
        if (
335
            scan_id in self.scan_collection.ids_iterator()
336
            and self.get_scan_status(scan_id) == ScanStatus.QUEUED
337
        ):
338
            logger.info('Scan %s has been removed from the queue.', scan_id)
339
            self.scan_collection.remove_file_pickled_scan_info(scan_id)
340
            self.set_scan_status(scan_id, ScanStatus.STOPPED)
341
342
            return
343
344
        scan_process = self.scan_processes.get(scan_id)
345
        if not scan_process:
346
            raise OspdCommandError(
347
                'Scan not found {0}.'.format(scan_id), 'stop_scan'
348
            )
349
        if not scan_process.is_alive():
350
            raise OspdCommandError(
351
                'Scan already stopped or finished.', 'stop_scan'
352
            )
353
354
        self.set_scan_status(scan_id, ScanStatus.STOPPED)
355
356
        logger.info(
357
            '%s: Stopping Scan with the PID %s.', scan_id, scan_process.ident
358
        )
359
360
        try:
361
            scan_process.terminate()
362
        except AttributeError:
363
            logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
364
365
        try:
366
            logger.debug(
367
                '%s: Terminating process group after stopping.', scan_id
368
            )
369
            _terminate_process_group(scan_process)
370
        except ProcessLookupError:
371
            logger.info(
372
                '%s: Scan with the PID %s is already stopped.',
373
                scan_id,
374
                scan_process.pid,
375
            )
376
377
        if scan_process.ident != os.getpid():
378
            scan_process.join(0)
379
380
        logger.info('%s: Scan stopped.', scan_id)
381
382
    def exec_scan(self, scan_id: str):
383
        """Asserts to False. Should be implemented by subclass."""
384
        raise NotImplementedError
385
386
    def finish_scan(self, scan_id: str) -> None:
387
        """Sets a scan as finished."""
388
        self.scan_collection.set_progress(scan_id, ScanProgress.FINISHED.value)
389
        self.set_scan_status(scan_id, ScanStatus.FINISHED)
390
        logger.info("%s: Scan finished.", scan_id)
391
392
    def interrupt_scan(self, scan_id: str) -> None:
393
        """Set scan status as interrupted."""
394
        self.set_scan_status(scan_id, ScanStatus.INTERRUPTED)
395
        logger.info("%s: Scan interrupted.", scan_id)
396
397
    def daemon_exit_cleanup(self) -> None:
398
        """Perform a cleanup before exiting"""
399
        self.scan_collection.clean_up_pickled_scan_info()
400
401
        # Stop scans which are not already stopped.
402
        for scan_id in self.scan_collection.ids_iterator():
403
            status = self.get_scan_status(scan_id)
404
            if (
405
                status != ScanStatus.STOPPED
406
                and status != ScanStatus.FINISHED
407
                and status != ScanStatus.INTERRUPTED
408
            ):
409
                logger.debug("%s: Stopping scan before daemon exit.", scan_id)
410
                self.stop_scan(scan_id)
411
412
        # Wait for scans to be in some stopped state.
413
        while True:
414
            all_stopped = True
415
            for scan_id in self.scan_collection.ids_iterator():
416
                status = self.get_scan_status(scan_id)
417
                if (
418
                    status != ScanStatus.STOPPED
419
                    and status != ScanStatus.FINISHED
420
                    and status != ScanStatus.INTERRUPTED
421
                ):
422
                    all_stopped = False
423
424
            if all_stopped:
425
                logger.debug(
426
                    "All scans stopped and daemon clean and ready to exit"
427
                )
428
                return
429
430
            logger.debug("Waiting for running scans before daemon exit. ")
431
            time.sleep(1)
432
433
    def get_daemon_name(self) -> str:
434
        """Gives osp daemon's name."""
435
        return self.daemon_info['name']
436
437
    def get_daemon_version(self) -> str:
438
        """Gives osp daemon's version."""
439
        return self.daemon_info['version']
440
441
    def get_scanner_param_type(self, param: str):
442
        """Returns type of a scanner parameter."""
443
        assert isinstance(param, str)
444
        entry = self.scanner_params.get(param)
445
        if not entry:
446
            return None
447
        return entry.get('type')
448
449
    def get_scanner_param_mandatory(self, param: str):
450
        """Returns if a scanner parameter is mandatory."""
451
        assert isinstance(param, str)
452
        entry = self.scanner_params.get(param)
453
        if not entry:
454
            return False
455
        return entry.get('mandatory')
456
457
    def get_scanner_param_default(self, param: str):
458
        """Returns default value of a scanner parameter."""
459
        assert isinstance(param, str)
460
        entry = self.scanner_params.get(param)
461
        if not entry:
462
            return None
463
        return entry.get('default')
464
465
    def handle_client_stream(self, stream: Stream) -> None:
466
        """Handles stream of data received from client."""
467
        data = b''
468
469
        request_parser = RequestParser()
470
471
        while True:
472
            try:
473
                buf = stream.read()
474
                if not buf:
475
                    break
476
477
                data += buf
478
479
                if request_parser.has_ended(buf):
480
                    break
481
            except (AttributeError, ValueError) as message:
482
                logger.error(message)
483
                return
484
            except (ssl.SSLError) as exception:
485
                logger.debug('Error: %s', exception)
486
                break
487
            except (socket.timeout) as exception:
488
                logger.debug('Request timeout: %s', exception)
489
                break
490
491
        if len(data) <= 0:
492
            logger.debug("Empty client stream")
493
            return
494
495
        response = None
496
        try:
497
            self.handle_command(data, stream)
498
        except OspdCommandError as exception:
499
            response = exception.as_xml()
500
            logger.debug('Command error: %s', exception.message)
501
        except Exception:  # pylint: disable=broad-except
502
            logger.exception('While handling client command:')
503
            exception = OspdCommandError('Fatal error', 'error')
504
            response = exception.as_xml()
505
506
        if response:
507
            stream.write(response)
508
509
        stream.close()
510
511
    def process_finished_hosts(self, scan_id: str) -> None:
512
        """Process the finished hosts before launching the scans."""
513
514
        finished_hosts = self.scan_collection.get_finished_hosts(scan_id)
515
        if not finished_hosts:
516
            return
517
518
        exc_finished_hosts_list = target_str_to_list(finished_hosts)
519
        self.scan_collection.set_host_finished(scan_id, exc_finished_hosts_list)
520
521
    def start_scan(self, scan_id: str) -> None:
522
        """Starts the scan with scan_id."""
523
        os.setsid()
524
525
        self.process_finished_hosts(scan_id)
526
527
        try:
528
            self.set_scan_status(scan_id, ScanStatus.RUNNING)
529
            self.exec_scan(scan_id)
530
        except Exception as e:  # pylint: disable=broad-except
531
            self.add_scan_error(
532
                scan_id,
533
                name='',
534
                host=self.get_scan_host(scan_id),
535
                value='Host process failure (%s).' % e,
536
            )
537
            logger.exception('%s: Exception %s while scanning', scan_id, e)
538
        else:
539
            logger.info("%s: Host scan finished.", scan_id)
540
541
        status = self.get_scan_status(scan_id)
542
        is_stopped = status == ScanStatus.STOPPED
543
        self.set_scan_progress(scan_id)
544
        progress = self.get_scan_progress(scan_id)
545
        if not is_stopped and progress == ScanProgress.FINISHED:
546
            self.finish_scan(scan_id)
547
        elif not is_stopped:
548
            logger.info(
549
                "%s: Host scan got interrupted. Progress: %d, Status: %s",
550
                scan_id,
551
                progress,
552
                status.name,
553
            )
554
            self.interrupt_scan(scan_id)
555
556
        # For debug purposes
557
        self._get_scan_progress_raw(scan_id)
558
559
    def handle_timeout(self, scan_id: str, host: str) -> None:
560
        """Handles scanner reaching timeout error."""
561
        self.add_scan_error(
562
            scan_id,
563
            host=host,
564
            name="Timeout",
565
            value="{0} exec timeout.".format(self.get_scanner_name()),
566
        )
567
568
    def sort_host_finished(
569
        self,
570
        scan_id: str,
571
        finished_hosts: Union[List[str], str],
572
    ) -> None:
573
        """Check if the finished host in the list was alive or dead
574
        and update the corresponding alive_count or dead_count."""
575
        if isinstance(finished_hosts, str):
576
            finished_hosts = [finished_hosts]
577
578
        alive_hosts = []
579
        dead_hosts = []
580
581
        current_hosts = self.scan_collection.get_current_target_progress(
582
            scan_id
583
        )
584
        for finished_host in finished_hosts:
585
            progress = current_hosts.get(finished_host)
586
            if progress == ScanProgress.FINISHED:
587
                alive_hosts.append(finished_host)
588
            elif progress == ScanProgress.DEAD_HOST:
589
                dead_hosts.append(finished_host)
590
            else:
591
                logger.debug(
592
                    'The host %s is considered dead or finished, but '
593
                    'its progress is still %d. This can lead to '
594
                    'interrupted scan.',
595
                    finished_host,
596
                    progress,
597
                )
598
599
        self.scan_collection.set_host_dead(scan_id, dead_hosts)
600
601
        self.scan_collection.set_host_finished(scan_id, alive_hosts)
602
603
        self.scan_collection.remove_hosts_from_target_progress(
604
            scan_id, finished_hosts
605
        )
606
607
    def set_scan_progress(self, scan_id: str):
608
        """Calculate the target progress with the current host states
609
        and stores in the scan table."""
610
        # Get current scan progress for debugging purposes
611
        logger.debug("Calculating scan progress with the following data:")
612
        self._get_scan_progress_raw(scan_id)
613
614
        scan_progress = self.scan_collection.calculate_target_progress(scan_id)
615
        self.scan_collection.set_progress(scan_id, scan_progress)
616
617
    def set_scan_progress_batch(
618
        self, scan_id: str, host_progress: Dict[str, int]
619
    ):
620
        self.scan_collection.set_host_progress(scan_id, host_progress)
621
        self.set_scan_progress(scan_id)
622
623
    def set_scan_host_progress(
624
        self,
625
        scan_id: str,
626
        host: str = None,
627
        progress: int = None,
628
    ) -> None:
629
        """Sets host's progress which is part of target.
630
        Each time a host progress is updated, the scan progress
631
        is updated too.
632
        """
633
        if host is None or progress is None:
634
            return
635
636
        if not isinstance(progress, int):
637
            try:
638
                progress = int(progress)
639
            except (TypeError, ValueError):
640
                return
641
642
        host_progress = {host: progress}
643
        self.set_scan_progress_batch(scan_id, host_progress)
644
645
    def get_scan_host_progress(
646
        self,
647
        scan_id: str,
648
        host: str = None,
649
    ) -> int:
650
        """Get host's progress which is part of target."""
651
        current_progress = self.scan_collection.get_current_target_progress(
652
            scan_id
653
        )
654
        return current_progress.get(host)
655
656
    def set_scan_status(self, scan_id: str, status: ScanStatus) -> None:
657
        """Set the scan's status."""
658
        logger.debug('%s: Set scan status %s,', scan_id, status.name)
659
        self.scan_collection.set_status(scan_id, status)
660
661
    def get_scan_status(self, scan_id: str) -> ScanStatus:
662
        """Get scan_id scans's status."""
663
        status = self.scan_collection.get_status(scan_id)
664
        logger.debug('%s: Current scan status: %s,', scan_id, status.name)
665
        return status
666
667
    def scan_exists(self, scan_id: str) -> bool:
668
        """Checks if a scan with ID scan_id is in collection.
669
670
        Returns:
671
            1 if scan exists, 0 otherwise.
672
        """
673
        return self.scan_collection.id_exists(scan_id)
674
675
    def get_help_text(self) -> str:
676
        """Returns the help output in plain text format."""
677
678
        txt = ''
679
        for name, info in self.commands.items():
680
            description = info.get_description()
681
            attributes = info.get_attributes()
682
            elements = info.get_elements()
683
684
            command_txt = "\t{0: <22} {1}\n".format(name, description)
685
686
            if attributes:
687
                command_txt = ''.join([command_txt, "\t Attributes:\n"])
688
689
                for attrname, attrdesc in attributes.items():
690
                    attr_txt = "\t  {0: <22} {1}\n".format(attrname, attrdesc)
691
                    command_txt = ''.join([command_txt, attr_txt])
692
693
            if elements:
694
                command_txt = ''.join(
695
                    [
696
                        command_txt,
697
                        "\t Elements:\n",
698
                        elements_as_text(elements),
699
                    ]
700
                )
701
702
            txt += command_txt
703
704
        return txt
705
706
    def delete_scan(self, scan_id: str) -> int:
707
        """Deletes scan_id scan from collection.
708
709
        Returns:
710
            1 if scan deleted, 0 otherwise.
711
        """
712
        if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
713
            return 0
714
715
        # Don't delete the scan until the process stops
716
        exitcode = None
717
        try:
718
            self.scan_processes[scan_id].join()
719
            exitcode = self.scan_processes[scan_id].exitcode
720
        except KeyError:
721
            logger.debug('Scan process for %s never started,', scan_id)
722
723
        if exitcode or exitcode == 0:
724
            del self.scan_processes[scan_id]
725
726
        return self.scan_collection.delete_scan(scan_id)
727
728
    def get_scan_results_xml(
729
        self, scan_id: str, pop_res: bool, max_res: Optional[int]
730
    ):
731
        """Gets scan_id scan's results in XML format.
732
733
        Returns:
734
            String of scan results in xml.
735
        """
736
        results = Element('results')
737
        for result in self.scan_collection.results_iterator(
738
            scan_id, pop_res, max_res
739
        ):
740
            results.append(get_result_xml(result))
741
742
        logger.debug('Returning %d results', len(results))
743
        return results
744
745
    def _get_scan_progress_raw(self, scan_id: str) -> Dict:
746
        """Returns a dictionary with scan_id scan's progress information."""
747
        current_progress = dict()
748
749
        current_progress[
750
            'current_hosts'
751
        ] = self.scan_collection.get_current_target_progress(scan_id)
752
        current_progress['overall'] = self.get_scan_progress(scan_id)
753
        current_progress['count_alive'] = self.scan_collection.get_count_alive(
754
            scan_id
755
        )
756
        current_progress['count_dead'] = self.scan_collection.get_count_dead(
757
            scan_id
758
        )
759
        current_progress[
760
            'count_excluded'
761
        ] = self.scan_collection.get_simplified_exclude_host_count(scan_id)
762
        current_progress['count_total'] = self.scan_collection.get_count_total(
763
            scan_id
764
        )
765
766
        logging.debug(
767
            "%s: Current progress: \n%s",
768
            scan_id,
769
            pformat(current_progress),
770
        )
771
        return current_progress
772
773
    def _get_scan_progress_xml(self, scan_id: str):
774
        """Gets scan_id scan's progress in XML format.
775
776
        Returns:
777
            String of scan progress in xml.
778
        """
779
        current_progress = self._get_scan_progress_raw(scan_id)
780
        return get_progress_xml(current_progress)
781
782
    def get_scan_xml(
783
        self,
784
        scan_id: str,
785
        detailed: bool = True,
786
        pop_res: bool = False,
787
        max_res: int = 0,
788
        progress: bool = False,
789
    ):
790
        """Gets scan in XML format.
791
792
        Returns:
793
            String of scan in XML format.
794
        """
795
        if not scan_id:
796
            return Element('scan')
797
798
        if self.get_scan_status(scan_id) == ScanStatus.QUEUED:
799
            target = ''
800
            scan_progress = 0
801
            status = self.get_scan_status(scan_id)
802
            start_time = 0
803
            end_time = 0
804
            response = Element('scan')
805
            detailed = False
806
            progress = False
807
            response.append(Element('results'))
808
        else:
809
            target = self.get_scan_host(scan_id)
810
            scan_progress = self.get_scan_progress(scan_id)
811
            status = self.get_scan_status(scan_id)
812
            start_time = self.get_scan_start_time(scan_id)
813
            end_time = self.get_scan_end_time(scan_id)
814
            response = Element('scan')
815
816
        for name, value in [
817
            ('id', scan_id),
818
            ('target', target),
819
            ('progress', scan_progress),
820
            ('status', status.name.lower()),
821
            ('start_time', start_time),
822
            ('end_time', end_time),
823
        ]:
824
            response.set(name, str(value))
825
        if detailed:
826
            response.append(
827
                self.get_scan_results_xml(scan_id, pop_res, max_res)
828
            )
829
        if progress:
830
            response.append(self._get_scan_progress_xml(scan_id))
831
832
        return response
833
834
    @staticmethod
835
    def get_custom_vt_as_xml_str(  # pylint: disable=unused-argument
836
        vt_id: str, custom: Dict
837
    ) -> str:
838
        """Create a string representation of the XML object from the
839
        custom data object.
840
        This needs to be implemented by each ospd wrapper, in case
841
        custom elements for VTs are used.
842
843
        The custom XML object which is returned will be embedded
844
        into a <custom></custom> element.
845
846
        Returns:
847
            XML object as string for custom data.
848
        """
849
        return ''
850
851
    @staticmethod
852
    def get_params_vt_as_xml_str(  # pylint: disable=unused-argument
853
        vt_id: str, vt_params
854
    ) -> str:
855
        """Create a string representation of the XML object from the
856
        vt_params data object.
857
        This needs to be implemented by each ospd wrapper, in case
858
        vt_params elements for VTs are used.
859
860
        The params XML object which is returned will be embedded
861
        into a <params></params> element.
862
863
        Returns:
864
            XML object as string for vt parameters data.
865
        """
866
        return ''
867
868
    @staticmethod
869
    def get_refs_vt_as_xml_str(  # pylint: disable=unused-argument
870
        vt_id: str, vt_refs
871
    ) -> str:
872
        """Create a string representation of the XML object from the
873
        refs data object.
874
        This needs to be implemented by each ospd wrapper, in case
875
        refs elements for VTs are used.
876
877
        The refs XML object which is returned will be embedded
878
        into a <refs></refs> element.
879
880
        Returns:
881
            XML object as string for vt references data.
882
        """
883
        return ''
884
885
    @staticmethod
886
    def get_dependencies_vt_as_xml_str(  # pylint: disable=unused-argument
887
        vt_id: str, vt_dependencies
888
    ) -> str:
889
        """Create a string representation of the XML object from the
890
        vt_dependencies data object.
891
        This needs to be implemented by each ospd wrapper, in case
892
        vt_dependencies elements for VTs are used.
893
894
        The vt_dependencies XML object which is returned will be embedded
895
        into a <dependencies></dependencies> element.
896
897
        Returns:
898
            XML object as string for vt dependencies data.
899
        """
900
        return ''
901
902
    @staticmethod
903
    def get_creation_time_vt_as_xml_str(  # pylint: disable=unused-argument
904
        vt_id: str, vt_creation_time
905
    ) -> str:
906
        """Create a string representation of the XML object from the
907
        vt_creation_time data object.
908
        This needs to be implemented by each ospd wrapper, in case
909
        vt_creation_time elements for VTs are used.
910
911
        The vt_creation_time XML object which is returned will be embedded
912
        into a <vt_creation_time></vt_creation_time> element.
913
914
        Returns:
915
            XML object as string for vt creation time data.
916
        """
917
        return ''
918
919
    @staticmethod
920
    def get_modification_time_vt_as_xml_str(  # pylint: disable=unused-argument
921
        vt_id: str, vt_modification_time
922
    ) -> str:
923
        """Create a string representation of the XML object from the
924
        vt_modification_time data object.
925
        This needs to be implemented by each ospd wrapper, in case
926
        vt_modification_time elements for VTs are used.
927
928
        The vt_modification_time XML object which is returned will be embedded
929
        into a <vt_modification_time></vt_modification_time> element.
930
931
        Returns:
932
            XML object as string for vt references data.
933
        """
934
        return ''
935
936
    @staticmethod
937
    def get_summary_vt_as_xml_str(  # pylint: disable=unused-argument
938
        vt_id: str, summary
939
    ) -> str:
940
        """Create a string representation of the XML object from the
941
        summary data object.
942
        This needs to be implemented by each ospd wrapper, in case
943
        summary elements for VTs are used.
944
945
        The summary XML object which is returned will be embedded
946
        into a <summary></summary> element.
947
948
        Returns:
949
            XML object as string for summary data.
950
        """
951
        return ''
952
953
    @staticmethod
954
    def get_impact_vt_as_xml_str(  # pylint: disable=unused-argument
955
        vt_id: str, impact
956
    ) -> str:
957
        """Create a string representation of the XML object from the
958
        impact data object.
959
        This needs to be implemented by each ospd wrapper, in case
960
        impact elements for VTs are used.
961
962
        The impact XML object which is returned will be embedded
963
        into a <impact></impact> element.
964
965
        Returns:
966
            XML object as string for impact data.
967
        """
968
        return ''
969
970
    @staticmethod
971
    def get_affected_vt_as_xml_str(  # pylint: disable=unused-argument
972
        vt_id: str, affected
973
    ) -> str:
974
        """Create a string representation of the XML object from the
975
        affected data object.
976
        This needs to be implemented by each ospd wrapper, in case
977
        affected elements for VTs are used.
978
979
        The affected XML object which is returned will be embedded
980
        into a <affected></affected> element.
981
982
        Returns:
983
            XML object as string for affected data.
984
        """
985
        return ''
986
987
    @staticmethod
988
    def get_insight_vt_as_xml_str(  # pylint: disable=unused-argument
989
        vt_id: str, insight
990
    ) -> str:
991
        """Create a string representation of the XML object from the
992
        insight data object.
993
        This needs to be implemented by each ospd wrapper, in case
994
        insight elements for VTs are used.
995
996
        The insight XML object which is returned will be embedded
997
        into a <insight></insight> element.
998
999
        Returns:
1000
            XML object as string for insight data.
1001
        """
1002
        return ''
1003
1004
    @staticmethod
1005
    def get_solution_vt_as_xml_str(  # pylint: disable=unused-argument
1006
        vt_id: str, solution, solution_type=None, solution_method=None
1007
    ) -> str:
1008
        """Create a string representation of the XML object from the
1009
        solution data object.
1010
        This needs to be implemented by each ospd wrapper, in case
1011
        solution elements for VTs are used.
1012
1013
        The solution XML object which is returned will be embedded
1014
        into a <solution></solution> element.
1015
1016
        Returns:
1017
            XML object as string for solution data.
1018
        """
1019
        return ''
1020
1021
    @staticmethod
1022
    def get_detection_vt_as_xml_str(  # pylint: disable=unused-argument
1023
        vt_id: str, detection=None, qod_type=None, qod=None
1024
    ) -> str:
1025
        """Create a string representation of the XML object from the
1026
        detection data object.
1027
        This needs to be implemented by each ospd wrapper, in case
1028
        detection elements for VTs are used.
1029
1030
        The detection XML object which is returned is an element with
1031
        tag <detection></detection> element
1032
1033
        Returns:
1034
            XML object as string for detection data.
1035
        """
1036
        return ''
1037
1038
    @staticmethod
1039
    def get_severities_vt_as_xml_str(  # pylint: disable=unused-argument
1040
        vt_id: str, severities
1041
    ) -> str:
1042
        """Create a string representation of the XML object from the
1043
        severities data object.
1044
        This needs to be implemented by each ospd wrapper, in case
1045
        severities elements for VTs are used.
1046
1047
        The severities XML objects which are returned will be embedded
1048
        into a <severities></severities> element.
1049
1050
        Returns:
1051
            XML object as string for severities data.
1052
        """
1053
        return ''
1054
1055
    def get_vt_iterator(  # pylint: disable=unused-argument
1056
        self, vt_selection: List[str] = None, details: bool = True
1057
    ) -> Iterator[Tuple[str, Dict]]:
1058
        """Return iterator object for getting elements
1059
        from the VTs dictionary."""
1060
        return self.vts.items()
1061
1062
    def get_vt_xml(self, single_vt: Tuple[str, Dict]) -> Element:
1063
        """Gets a single vulnerability test information in XML format.
1064
1065
        Returns:
1066
            String of single vulnerability test information in XML format.
1067
        """
1068
        if not single_vt or single_vt[1] is None:
1069
            return Element('vt')
1070
1071
        vt_id, vt = single_vt
1072
1073
        name = vt.get('name')
1074
        vt_xml = Element('vt')
1075
        vt_xml.set('id', vt_id)
1076
1077
        for name, value in [('name', name)]:
1078
            elem = SubElement(vt_xml, name)
1079
            elem.text = str(value)
1080
1081
        if vt.get('vt_params'):
1082
            params_xml_str = self.get_params_vt_as_xml_str(
1083
                vt_id, vt.get('vt_params')
1084
            )
1085
            vt_xml.append(secET.fromstring(params_xml_str))
1086
1087
        if vt.get('vt_refs'):
1088
            refs_xml_str = self.get_refs_vt_as_xml_str(vt_id, vt.get('vt_refs'))
1089
            vt_xml.append(secET.fromstring(refs_xml_str))
1090
1091
        if vt.get('vt_dependencies'):
1092
            dependencies = self.get_dependencies_vt_as_xml_str(
1093
                vt_id, vt.get('vt_dependencies')
1094
            )
1095
            vt_xml.append(secET.fromstring(dependencies))
1096
1097
        if vt.get('creation_time'):
1098
            vt_ctime = self.get_creation_time_vt_as_xml_str(
1099
                vt_id, vt.get('creation_time')
1100
            )
1101
            vt_xml.append(secET.fromstring(vt_ctime))
1102
1103
        if vt.get('modification_time'):
1104
            vt_mtime = self.get_modification_time_vt_as_xml_str(
1105
                vt_id, vt.get('modification_time')
1106
            )
1107
            vt_xml.append(secET.fromstring(vt_mtime))
1108
1109
        if vt.get('summary'):
1110
            summary_xml_str = self.get_summary_vt_as_xml_str(
1111
                vt_id, vt.get('summary')
1112
            )
1113
            vt_xml.append(secET.fromstring(summary_xml_str))
1114
1115
        if vt.get('impact'):
1116
            impact_xml_str = self.get_impact_vt_as_xml_str(
1117
                vt_id, vt.get('impact')
1118
            )
1119
            vt_xml.append(secET.fromstring(impact_xml_str))
1120
1121
        if vt.get('affected'):
1122
            affected_xml_str = self.get_affected_vt_as_xml_str(
1123
                vt_id, vt.get('affected')
1124
            )
1125
            vt_xml.append(secET.fromstring(affected_xml_str))
1126
1127
        if vt.get('insight'):
1128
            insight_xml_str = self.get_insight_vt_as_xml_str(
1129
                vt_id, vt.get('insight')
1130
            )
1131
            vt_xml.append(secET.fromstring(insight_xml_str))
1132
1133
        if vt.get('solution'):
1134
            solution_xml_str = self.get_solution_vt_as_xml_str(
1135
                vt_id,
1136
                vt.get('solution'),
1137
                vt.get('solution_type'),
1138
                vt.get('solution_method'),
1139
            )
1140
            vt_xml.append(secET.fromstring(solution_xml_str))
1141
1142
        if vt.get('detection') or vt.get('qod_type') or vt.get('qod'):
1143
            detection_xml_str = self.get_detection_vt_as_xml_str(
1144
                vt_id, vt.get('detection'), vt.get('qod_type'), vt.get('qod')
1145
            )
1146
            vt_xml.append(secET.fromstring(detection_xml_str))
1147
1148
        if vt.get('severities'):
1149
            severities_xml_str = self.get_severities_vt_as_xml_str(
1150
                vt_id, vt.get('severities')
1151
            )
1152
            vt_xml.append(secET.fromstring(severities_xml_str))
1153
1154
        if vt.get('custom'):
1155
            custom_xml_str = self.get_custom_vt_as_xml_str(
1156
                vt_id, vt.get('custom')
1157
            )
1158
            vt_xml.append(secET.fromstring(custom_xml_str))
1159
1160
        return vt_xml
1161
1162
    def get_vts_selection_list(
1163
        self, vt_id: str = None, filtered_vts: Dict = None
1164
    ) -> Iterable[str]:
1165
        """
1166
        Get list of VT's OID.
1167
        If vt_id is specified, the collection will contain only this vt, if
1168
        found.
1169
        If no vt_id is specified or filtered_vts is None (default), the
1170
        collection will contain all vts. Otherwise those vts passed
1171
        in filtered_vts or vt_id are returned. In case of both vt_id and
1172
        filtered_vts are given, filtered_vts has priority.
1173
1174
        Arguments:
1175
            vt_id (vt_id, optional): ID of the vt to get.
1176
            filtered_vts (list, optional): Filtered VTs collection.
1177
1178
        Returns:
1179
            List of selected VT's OID.
1180
        """
1181
        vts_xml = []
1182
1183
        # No match for the filter
1184
        if filtered_vts is not None and len(filtered_vts) == 0:
1185
            return vts_xml
1186
1187
        if filtered_vts:
1188
            vts_list = filtered_vts
1189
        elif vt_id:
1190
            vts_list = [vt_id]
1191
        else:
1192
            vts_list = self.vts.keys()
1193
1194
        return vts_list
1195
1196
    def handle_command(self, data: bytes, stream: Stream) -> None:
1197
        """Handles an osp command in a string."""
1198
        try:
1199
            tree = secET.fromstring(data)
1200
        except secET.ParseError as e:
1201
            logger.debug("Erroneous client input: %s", data)
1202
            raise OspdCommandError('Invalid data') from e
1203
1204
        command_name = tree.tag
1205
1206
        logger.debug('Handling %s command request.', command_name)
1207
1208
        command = self.commands.get(command_name, None)
1209
        if not command and command_name != "authenticate":
1210
            raise OspdCommandError('Bogus command name')
1211
1212
        if not self.initialized and command.must_be_initialized:
1213
            exception = OspdCommandError(
1214
                '%s is still starting' % self.daemon_info['name'], 'error'
1215
            )
1216
            response = exception.as_xml()
1217
            stream.write(response)
1218
            return
1219
1220
        response = command.handle_xml(tree)
1221
1222
        write_success = True
1223
        if isinstance(response, bytes):
1224
            write_success = stream.write(response)
1225
        else:
1226
            for data in response:
1227
                write_success = stream.write(data)
1228
                if not write_success:
1229
                    break
1230
1231
        scan_id = tree.get('scan_id')
1232
        if self.scan_exists(scan_id) and command_name == "get_scans":
1233
            if write_success:
1234
                logger.debug(
1235
                    '%s: Results sent successfully to the client. Cleaning '
1236
                    'temporary result list.',
1237
                    scan_id,
1238
                )
1239
                self.scan_collection.clean_temp_result_list(scan_id)
1240
            else:
1241
                logger.debug(
1242
                    '%s: Failed sending results to the client. Restoring '
1243
                    'result list into the cache.',
1244
                    scan_id,
1245
                )
1246
                self.scan_collection.restore_temp_result_list(scan_id)
1247
1248
    def check(self):
1249
        """Asserts to False. Should be implemented by subclass."""
1250
        raise NotImplementedError
1251
1252
    def run(self) -> None:
1253
        """Starts the Daemon, handling commands until interrupted."""
1254
1255
        try:
1256
            while True:
1257
                time.sleep(SCHEDULER_CHECK_PERIOD)
1258
                self.scheduler()
1259
                self.clean_forgotten_scans()
1260
                self.start_queued_scans()
1261
                self.wait_for_children()
1262
        except KeyboardInterrupt:
1263
            logger.info("Received Ctrl-C shutting-down ...")
1264
1265
    def start_queued_scans(self) -> None:
1266
        """Starts a queued scan if it is allowed"""
1267
1268
        current_queued_scans = self.get_count_queued_scans()
1269
        if not current_queued_scans:
1270
            return
1271
1272
        if not self.initialized:
1273
            logger.info(
1274
                "Queued task can not be started because a feed "
1275
                "update is being performed."
1276
            )
1277
            return
1278
1279
        logger.info('Currently %d queued scans.', current_queued_scans)
1280
1281
        for scan_id in self.scan_collection.ids_iterator():
1282
            scan_allowed = (
1283
                self.is_new_scan_allowed() and self.is_enough_free_memory()
1284
            )
1285
            scan_is_queued = self.get_scan_status(scan_id) == ScanStatus.QUEUED
1286
1287
            if scan_is_queued and scan_allowed:
1288
                try:
1289
                    self.scan_collection.unpickle_scan_info(scan_id)
1290
                except OspdCommandError as e:
1291
                    logger.error("Start scan error %s", e)
1292
                    self.stop_scan(scan_id)
1293
                    continue
1294
1295
                scan_func = self.start_scan
1296
                scan_process = create_process(func=scan_func, args=(scan_id,))
1297
                self.scan_processes[scan_id] = scan_process
1298
                scan_process.start()
1299
                self.set_scan_status(scan_id, ScanStatus.INIT)
1300
1301
                current_queued_scans = current_queued_scans - 1
1302
                self.last_scan_start_time = time.time()
1303
                logger.info('Starting scan %s.', scan_id)
1304
            elif scan_is_queued and not scan_allowed:
1305
                return
1306
1307
    def is_new_scan_allowed(self) -> bool:
1308
        """Check if max_scans has been reached.
1309
1310
        Returns:
1311
            True if a new scan can be launch.
1312
        """
1313
        if (self.max_scans != 0) and (
1314
            len(self.scan_processes) >= self.max_scans
1315
        ):
1316
            logger.info(
1317
                'Not possible to run a new scan. Max scan limit set '
1318
                'to %d reached.',
1319
                self.max_scans,
1320
            )
1321
            return False
1322
1323
        return True
1324
1325
    def is_enough_free_memory(self) -> bool:
1326
        """Check if there is enough free memory in the system to run
1327
        a new scan. The necessary memory is a rough calculation and very
1328
        conservative.
1329
1330
        Returns:
1331
            True if there is enough memory for a new scan.
1332
        """
1333
        if not self.min_free_mem_scan_queue:
1334
            return True
1335
1336
        # If min_free_mem_scan_queue option is set, also wait some time
1337
        # between scans. Consider the case in which the last scan
1338
        # finished in a few seconds and there is no need to wait.
1339
        time_between_start_scan = time.time() - self.last_scan_start_time
1340
        if (
1341
            time_between_start_scan < MIN_TIME_BETWEEN_START_SCAN
1342
            and self.get_count_running_scans()
1343
        ):
1344
            logger.debug(
1345
                'Not possible to run a new scan right now, a scan have been '
1346
                'just started.'
1347
            )
1348
            return False
1349
1350
        free_mem = psutil.virtual_memory().available / (1024 * 1024)
1351
1352
        if free_mem > self.min_free_mem_scan_queue:
1353
            return True
1354
1355
        logger.info(
1356
            'Not possible to run a new scan. Not enough free memory. '
1357
            'Only %d MB available but at least %d are required',
1358
            free_mem,
1359
            self.min_free_mem_scan_queue,
1360
        )
1361
1362
        return False
1363
1364
    def scheduler(self):
1365
        """Should be implemented by subclass in case of need
1366
        to run tasks periodically."""
1367
1368
    def wait_for_children(self):
1369
        """Join the zombie process to releases resources."""
1370
        for _, process in self.scan_processes.items():
1371
            process.join(0)
1372
1373
    def create_scan(
1374
        self,
1375
        scan_id: str,
1376
        targets: Dict,
1377
        options: Optional[Dict],
1378
        vt_selection: Dict,
1379
    ) -> Optional[str]:
1380
        """Creates a new scan.
1381
1382
        Arguments:
1383
            target: Target to scan.
1384
            options: Miscellaneous scan options supplied via <scanner_params>
1385
                  XML element.
1386
1387
        Returns:
1388
            New scan's ID. None if the scan_id already exists.
1389
        """
1390
        status = None
1391
        scan_exists = self.scan_exists(scan_id)
1392
        if scan_id and scan_exists:
1393
            status = self.get_scan_status(scan_id)
1394
            logger.info(
1395
                "Scan %s exists with status %s.", scan_id, status.name.lower()
1396
            )
1397
            return
1398
1399
        return self.scan_collection.create_scan(
1400
            scan_id, targets, options, vt_selection
1401
        )
1402
1403
    def get_scan_options(self, scan_id: str) -> str:
1404
        """Gives a scan's list of options."""
1405
        return self.scan_collection.get_options(scan_id)
1406
1407
    def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
1408
        """Sets a scan's option to a provided value."""
1409
        return self.scan_collection.set_option(scan_id, name, value)
1410
1411
    def set_scan_total_hosts(self, scan_id: str, count_total: int) -> None:
1412
        """Sets a scan's total hosts. Allow the scanner to update
1413
        the total count of host to be scanned."""
1414
        self.scan_collection.update_count_total(scan_id, count_total)
1415
1416
    def clean_forgotten_scans(self) -> None:
1417
        """Check for old stopped or finished scans which have not been
1418
        deleted and delete them if the are older than the set value."""
1419
1420
        if not self.scaninfo_store_time:
1421
            return
1422
1423
        for scan_id in list(self.scan_collection.ids_iterator()):
1424
            end_time = int(self.get_scan_end_time(scan_id))
1425
            scan_status = self.get_scan_status(scan_id)
1426
1427
            if (
1428
                scan_status == ScanStatus.STOPPED
1429
                or scan_status == ScanStatus.FINISHED
1430
                or scan_status == ScanStatus.INTERRUPTED
1431
            ) and end_time:
1432
                stored_time = int(time.time()) - end_time
1433
                if stored_time > self.scaninfo_store_time * 3600:
1434
                    logger.debug(
1435
                        'Scan %s is older than %d hours and seems have been '
1436
                        'forgotten. Scan info will be deleted from the '
1437
                        'scan table',
1438
                        scan_id,
1439
                        self.scaninfo_store_time,
1440
                    )
1441
                    self.delete_scan(scan_id)
1442
1443
    def check_scan_process(self, scan_id: str) -> None:
1444
        """Check the scan's process, and terminate the scan if not alive."""
1445
        status = self.get_scan_status(scan_id)
1446
        if status == ScanStatus.QUEUED:
1447
            return
1448
1449
        scan_process = self.scan_processes.get(scan_id)
1450
        progress = self.get_scan_progress(scan_id)
1451
1452
        if (
1453
            progress < ScanProgress.FINISHED
1454
            and scan_process
1455
            and not scan_process.is_alive()
1456
        ):
1457
            if not status == ScanStatus.STOPPED:
1458
                self.add_scan_error(
1459
                    scan_id, name="", host="", value="Scan process Failure"
1460
                )
1461
1462
                logger.info(
1463
                    "%s: Scan process is dead and its progress is %d",
1464
                    scan_id,
1465
                    progress,
1466
                )
1467
                self.interrupt_scan(scan_id)
1468
1469
        elif progress == ScanProgress.FINISHED:
1470
            scan_process.join(0)
1471
1472
        logger.debug(
1473
            "%s: Check scan process: \n\tProgress %d\n\t Status: %s",
1474
            scan_id,
1475
            progress,
1476
            status.name,
1477
        )
1478
1479
    def get_count_queued_scans(self) -> int:
1480
        """Get the amount of scans with queued status"""
1481
        count = 0
1482
        for scan_id in self.scan_collection.ids_iterator():
1483
            if self.get_scan_status(scan_id) == ScanStatus.QUEUED:
1484
                count += 1
1485
        return count
1486
1487
    def get_count_running_scans(self) -> int:
1488
        """Get the amount of scans with INIT/RUNNING status"""
1489
        count = 0
1490
        for scan_id in self.scan_collection.ids_iterator():
1491
            status = self.get_scan_status(scan_id)
1492
            if status == ScanStatus.RUNNING or status == ScanStatus.INIT:
1493
                count += 1
1494
        return count
1495
1496
    def get_scan_progress(self, scan_id: str) -> int:
1497
        """Gives a scan's current progress value."""
1498
        progress = self.scan_collection.get_progress(scan_id)
1499
        logger.debug('%s: Current scan progress: %s,', scan_id, progress)
1500
        return progress
1501
1502
    def get_scan_host(self, scan_id: str) -> str:
1503
        """Gives a scan's target."""
1504
        return self.scan_collection.get_host_list(scan_id)
1505
1506
    def get_scan_ports(self, scan_id: str) -> str:
1507
        """Gives a scan's ports list."""
1508
        return self.scan_collection.get_ports(scan_id)
1509
1510
    def get_scan_exclude_hosts(self, scan_id: str):
1511
        """Gives a scan's exclude host list. If a target is passed gives
1512
        the exclude host list for the given target."""
1513
        return self.scan_collection.get_exclude_hosts(scan_id)
1514
1515
    def get_scan_credentials(self, scan_id: str) -> Dict:
1516
        """Gives a scan's credential list. If a target is passed gives
1517
        the credential list for the given target."""
1518
        return self.scan_collection.get_credentials(scan_id)
1519
1520
    def get_scan_target_options(self, scan_id: str) -> Dict:
1521
        """Gives a scan's target option dict. If a target is passed gives
1522
        the credential list for the given target."""
1523
        return self.scan_collection.get_target_options(scan_id)
1524
1525
    def get_scan_vts(self, scan_id: str) -> Dict:
1526
        """Gives a scan's vts."""
1527
        return self.scan_collection.get_vts(scan_id)
1528
1529
    def get_scan_start_time(self, scan_id: str) -> str:
1530
        """Gives a scan's start time."""
1531
        return self.scan_collection.get_start_time(scan_id)
1532
1533
    def get_scan_end_time(self, scan_id: str) -> str:
1534
        """Gives a scan's end time."""
1535
        return self.scan_collection.get_end_time(scan_id)
1536
1537
    def add_scan_log(
1538
        self,
1539
        scan_id: str,
1540
        host: str = '',
1541
        hostname: str = '',
1542
        name: str = '',
1543
        value: str = '',
1544
        port: str = '',
1545
        test_id: str = '',
1546
        qod: str = '',
1547
        uri: str = '',
1548
    ) -> None:
1549
        """Adds a log result to scan_id scan."""
1550
1551
        self.scan_collection.add_result(
1552
            scan_id,
1553
            ResultType.LOG,
1554
            host,
1555
            hostname,
1556
            name,
1557
            value,
1558
            port,
1559
            test_id,
1560
            '0.0',
1561
            qod,
1562
            uri,
1563
        )
1564
1565
    def add_scan_error(
1566
        self,
1567
        scan_id: str,
1568
        host: str = '',
1569
        hostname: str = '',
1570
        name: str = '',
1571
        value: str = '',
1572
        port: str = '',
1573
        test_id='',
1574
        uri: str = '',
1575
    ) -> None:
1576
        """Adds an error result to scan_id scan."""
1577
        self.scan_collection.add_result(
1578
            scan_id,
1579
            ResultType.ERROR,
1580
            host,
1581
            hostname,
1582
            name,
1583
            value,
1584
            port,
1585
            test_id,
1586
            uri,
1587
        )
1588
1589
    def add_scan_host_detail(
1590
        self,
1591
        scan_id: str,
1592
        host: str = '',
1593
        hostname: str = '',
1594
        name: str = '',
1595
        value: str = '',
1596
        uri: str = '',
1597
    ) -> None:
1598
        """Adds a host detail result to scan_id scan."""
1599
        self.scan_collection.add_result(
1600
            scan_id, ResultType.HOST_DETAIL, host, hostname, name, value, uri
1601
        )
1602
1603
    def add_scan_alarm(
1604
        self,
1605
        scan_id: str,
1606
        host: str = '',
1607
        hostname: str = '',
1608
        name: str = '',
1609
        value: str = '',
1610
        port: str = '',
1611
        test_id: str = '',
1612
        severity: str = '',
1613
        qod: str = '',
1614
        uri: str = '',
1615
    ) -> None:
1616
        """Adds an alarm result to scan_id scan."""
1617
        self.scan_collection.add_result(
1618
            scan_id,
1619
            ResultType.ALARM,
1620
            host,
1621
            hostname,
1622
            name,
1623
            value,
1624
            port,
1625
            test_id,
1626
            severity,
1627
            qod,
1628
            uri,
1629
        )
1630