Completed
Push — master ( 49ba58...c1f282 )
by Juan José
13s queued 11s
created

ospd.scan.ScanCollection.update_count_total()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
import logging
19
import multiprocessing
20
import time
21
import uuid
22
23
from collections import OrderedDict
24
from enum import Enum, IntEnum
25
from typing import List, Any, Dict, Iterator, Optional, Iterable, Union
26
27
from ospd.network import target_str_to_list
28
from ospd.datapickler import DataPickler
29
from ospd.errors import OspdCommandError
30
31
LOGGER = logging.getLogger(__name__)
32
33
34
class ScanStatus(Enum):
35
    """Scan status. """
36
37
    QUEUED = 0
38
    INIT = 1
39
    RUNNING = 2
40
    STOPPED = 3
41
    FINISHED = 4
42
    INTERRUPTED = 5
43
44
45
class ScanProgress(IntEnum):
46
    """Scan or host progress. """
47
48
    FINISHED = 100
49
    INIT = 0
50
    DEAD_HOST = -1
51
    INTERRUPTED = -2
52
53
54
class ScanCollection:
55
56
    """Scans collection, managing scans and results read and write, exposing
57
    only needed information.
58
59
    Each scan has meta-information such as scan ID, current progress (from 0 to
60
    100), start time, end time, scan target and options and a list of results.
61
62
    There are 4 types of results: Alarms, Logs, Errors and Host Details.
63
64
    Todo:
65
    - Better checking for Scan ID existence and handling otherwise.
66
    - More data validation.
67
    - Mutex access per table/scan_info.
68
69
    """
70
71
    def __init__(self, file_storage_dir: str) -> None:
72
        """ Initialize the Scan Collection. """
73
74
        self.data_manager = (
75
            None
76
        )  # type: Optional[multiprocessing.managers.SyncManager]
77
        self.scans_table = dict()  # type: Dict
78
        self.file_storage_dir = file_storage_dir
79
80
    def init(self):
81
        self.data_manager = multiprocessing.Manager()
82
83
    def add_result(
84
        self,
85
        scan_id: str,
86
        result_type: int,
87
        host: str = '',
88
        hostname: str = '',
89
        name: str = '',
90
        value: str = '',
91
        port: str = '',
92
        test_id: str = '',
93
        severity: str = '',
94
        qod: str = '',
95
        uri: str = '',
96
    ) -> None:
97
        """ Add a result to a scan in the table. """
98
99
        assert scan_id
100
        assert len(name) or len(value)
101
102
        result = OrderedDict()  # type: Dict
103
        result['type'] = result_type
104
        result['name'] = name
105
        result['severity'] = severity
106
        result['test_id'] = test_id
107
        result['value'] = value
108
        result['host'] = host
109
        result['hostname'] = hostname
110
        result['port'] = port
111
        result['qod'] = qod
112
        result['uri'] = uri
113
        results = self.scans_table[scan_id]['results']
114
        results.append(result)
115
116
        # Set scan_info's results to propagate results to parent process.
117
        self.scans_table[scan_id]['results'] = results
118
119
    def add_result_list(
120
        self, scan_id: str, result_list: Iterable[Dict[str, str]]
121
    ) -> None:
122
        """
123
        Add a batch of results to the result's table for the corresponding
124
        scan_id
125
        """
126
        results = self.scans_table[scan_id]['results']
127
        results.extend(result_list)
128
129
        # Set scan_info's results to propagate results to parent process.
130
        self.scans_table[scan_id]['results'] = results
131
132
    def remove_hosts_from_target_progress(
133
        self, scan_id: str, hosts: List
134
    ) -> None:
135
        """Remove a list of hosts from the main scan progress table to avoid
136
        the hosts to be included in the calculation of the scan progress"""
137
        if not hosts:
138
            return
139
140
        target = self.scans_table[scan_id].get('target_progress')
141
        for host in hosts:
142
            if host in target:
143
                del target[host]
144
145
        # Set scan_info's target_progress to propagate progresses
146
        # to parent process.
147
        self.scans_table[scan_id]['target_progress'] = target
148
149
    def set_progress(self, scan_id: str, progress: int) -> None:
150
        """ Sets scan_id scan's progress. """
151
152
        if progress > ScanProgress.INIT and progress <= ScanProgress.FINISHED:
153
            self.scans_table[scan_id]['progress'] = progress
154
155
        if progress == ScanProgress.FINISHED:
156
            self.scans_table[scan_id]['end_time'] = int(time.time())
157
158
    def set_host_progress(
159
        self, scan_id: str, host_progress_batch: Dict[str, int]
160
    ) -> None:
161
        """ Sets scan_id scan's progress. """
162
163
        host_progresses = self.scans_table[scan_id].get('target_progress')
164
        host_progresses.update(host_progress_batch)
165
166
        # Set scan_info's target_progress to propagate progresses
167
        # to parent process.
168
        self.scans_table[scan_id]['target_progress'] = host_progresses
169
170
    def set_host_finished(self, scan_id: str, hosts: List[str]) -> None:
171
        """ Increase the amount of finished hosts which were alive."""
172
173
        total_finished = len(hosts)
174
        count_alive = (
175
            self.scans_table[scan_id].get('count_alive') + total_finished
176
        )
177
        self.scans_table[scan_id]['count_alive'] = count_alive
178
179
    def set_host_dead(self, scan_id: str, hosts: List[str]) -> None:
180
        """ Increase the amount of dead hosts. """
181
182
        total_dead = len(hosts)
183
        count_dead = self.scans_table[scan_id].get('count_dead') + total_dead
184
        self.scans_table[scan_id]['count_dead'] = count_dead
185
186
    def set_amount_dead_hosts(self, scan_id: str, total_dead: int) -> None:
187
        """ Increase the amount of dead hosts. """
188
189
        count_dead = self.scans_table[scan_id].get('count_dead') + total_dead
190
        self.scans_table[scan_id]['count_dead'] = count_dead
191
192
    def clean_temp_result_list(self, scan_id):
193
        """ Clean the results stored in the temporary list. """
194
        self.scans_table[scan_id]['temp_results'] = list()
195
196
    def restore_temp_result_list(self, scan_id):
197
        """Add the results stored in the temporary list into the results
198
        list again."""
199
        result_aux = self.scans_table[scan_id].get('results', list())
200
        result_aux.extend(self.scans_table[scan_id].get('temp_results', list()))
201
202
        # Propagate results
203
        self.scans_table[scan_id]['results'] = result_aux
204
        self.clean_temp_result_list(scan_id)
205
206
    def results_iterator(
207
        self, scan_id: str, pop_res: bool = False, max_res: int = None
208
    ) -> Iterator[Any]:
209
        """Returns an iterator over scan_id scan's results. If pop_res is True,
210
        it removed the fetched results from the list.
211
212
        If max_res is None, return all the results.
213
        Otherwise, if max_res = N > 0 return N as maximum number of results.
214
215
        max_res works only together with pop_results.
216
        """
217
        if pop_res and max_res:
218
            result_aux = self.scans_table[scan_id].get('results', list())
219
            self.scans_table[scan_id]['results'] = result_aux[max_res:]
220
            self.scans_table[scan_id]['temp_results'] = result_aux[:max_res]
221
            return iter(self.scans_table[scan_id]['temp_results'])
222
        elif pop_res:
223
            self.scans_table[scan_id]['temp_results'] = self.scans_table[
224
                scan_id
225
            ].get('results', list())
226
            self.scans_table[scan_id]['results'] = list()
227
            return iter(self.scans_table[scan_id]['temp_results'])
228
229
        return iter(self.scans_table[scan_id]['results'])
230
231
    def ids_iterator(self) -> Iterator[str]:
232
        """ Returns an iterator over the collection's scan IDS. """
233
234
        # Do not iterate over the scans_table because it can change
235
        # during iteration, since it is accessed by multiple processes.
236
        scan_id_list = list(self.scans_table)
237
        return iter(scan_id_list)
238
239
    def clean_up_pickled_scan_info(self) -> None:
240
        """ Remove files of pickled scan info """
241
        for scan_id in self.ids_iterator():
242
            if self.get_status(scan_id) == ScanStatus.QUEUED:
243
                self.remove_file_pickled_scan_info(scan_id)
244
245
    def remove_file_pickled_scan_info(self, scan_id: str) -> None:
246
        pickler = DataPickler(self.file_storage_dir)
247
        pickler.remove_file(scan_id)
248
249
    def unpickle_scan_info(self, scan_id: str) -> None:
250
        """Unpickle a stored scan_inf corresponding to the scan_id
251
        and store it in the scan_table"""
252
253
        scan_info = self.scans_table.get(scan_id)
254
        scan_info_hash = scan_info.pop('scan_info_hash')
255
256
        pickler = DataPickler(self.file_storage_dir)
257
        unpickled_scan_info = pickler.load_data(scan_id, scan_info_hash)
258
259
        if not unpickled_scan_info:
260
            pickler.remove_file(scan_id)
261
            raise OspdCommandError(
262
                'Not possible to unpickle stored scan info for %s' % scan_id,
263
                'start_scan',
264
            )
265
266
        scan_info['results'] = list()
267
        scan_info['temp_results'] = list()
268
        scan_info['progress'] = ScanProgress.INIT.value
269
        scan_info['target_progress'] = dict()
270
        scan_info['count_alive'] = 0
271
        scan_info['count_dead'] = 0
272
        scan_info['count_total'] = 0
273
        scan_info['target'] = unpickled_scan_info.pop('target')
274
        scan_info['vts'] = unpickled_scan_info.pop('vts')
275
        scan_info['options'] = unpickled_scan_info.pop('options')
276
        scan_info['start_time'] = int(time.time())
277
        scan_info['end_time'] = 0
278
279
        self.scans_table[scan_id] = scan_info
280
281
        pickler.remove_file(scan_id)
282
283
    def create_scan(
284
        self,
285
        scan_id: str = '',
286
        target: Dict = None,
287
        options: Optional[Dict] = None,
288
        vts: Dict = None,
289
    ) -> str:
290
        """Creates a new scan with provided scan information.
291
292
        @target: Target to scan.
293
        @options: Miscellaneous scan options supplied via <scanner_params>
294
                  XML element.
295
296
        @return: Scan's ID. None if error occurs.
297
        """
298
299
        if not options:
300
            options = dict()
301
302
        credentials = target.pop('credentials')
303
304
        scan_info = self.data_manager.dict()  # type: Dict
305
        scan_info['status'] = ScanStatus.QUEUED
306
        scan_info['credentials'] = credentials
307
        scan_info['start_time'] = int(time.time())
308
309
        scan_info_to_pickle = {
310
            'target': target,
311
            'options': options,
312
            'vts': vts,
313
        }
314
315
        if scan_id is None or scan_id == '':
316
            scan_id = str(uuid.uuid4())
317
318
        pickler = DataPickler(self.file_storage_dir)
319
        scan_info_hash = None
320
        try:
321
            scan_info_hash = pickler.store_data(scan_id, scan_info_to_pickle)
322
        except OspdCommandError as e:
323
            LOGGER.error(e)
324
            return
325
326
        scan_info['scan_id'] = scan_id
327
        scan_info['scan_info_hash'] = scan_info_hash
328
329
        self.scans_table[scan_id] = scan_info
330
        return scan_id
331
332
    def set_status(self, scan_id: str, status: ScanStatus) -> None:
333
        """ Sets scan_id scan's status. """
334
        self.scans_table[scan_id]['status'] = status
335
        if status == ScanStatus.STOPPED:
336
            self.scans_table[scan_id]['end_time'] = int(time.time())
337
338
    def get_status(self, scan_id: str) -> ScanStatus:
339
        """ Get scan_id scans's status."""
340
341
        return self.scans_table[scan_id].get('status')
342
343
    def get_options(self, scan_id: str) -> Dict:
344
        """ Get scan_id scan's options list. """
345
346
        return self.scans_table[scan_id].get('options')
347
348
    def set_option(self, scan_id, name: str, value: Any) -> None:
349
        """ Set a scan_id scan's name option to value. """
350
351
        self.scans_table[scan_id]['options'][name] = value
352
353
    def get_progress(self, scan_id: str) -> int:
354
        """ Get a scan's current progress value. """
355
356
        return self.scans_table[scan_id].get('progress', ScanProgress.INIT)
357
358
    def get_count_dead(self, scan_id: str) -> int:
359
        """ Get a scan's current dead host count. """
360
361
        return self.scans_table[scan_id]['count_dead']
362
363
    def get_count_alive(self, scan_id: str) -> int:
364
        """ Get a scan's current alive host count. """
365
366
        return self.scans_table[scan_id]['count_alive']
367
368
    def update_count_total(self, scan_id: str, count_total: int) -> int:
369
        """ Sets a scan's total hosts."""
370
371
        self.scans_table[scan_id]['count_total'] = count_total
372
373
    def get_count_total(self, scan_id: str) -> int:
374
        """ Get a scan's total host count. """
375
376
        count_total = self.scans_table[scan_id]['count_total']
377
        if not count_total:
378
            count_total = self.get_host_count(scan_id)
379
            self.update_count_total(scan_id, count_total)
380
381
        return count_total
382
383
    def get_current_target_progress(self, scan_id: str) -> Dict[str, int]:
384
        """ Get a scan's current hosts progress """
385
        return self.scans_table[scan_id]['target_progress']
386
387
    def simplify_exclude_host_count(self, scan_id: str) -> int:
388
        """Remove from exclude_hosts the received hosts in the finished_hosts
389
        list sent by the client.
390
        The finished hosts are sent also as exclude hosts for backward
391
        compatibility purposses.
392
393
        Return:
394
            Count of excluded host.
395
        """
396
397
        exc_hosts_list = target_str_to_list(self.get_exclude_hosts(scan_id))
398
399
        finished_hosts_list = target_str_to_list(
400
            self.get_finished_hosts(scan_id)
401
        )
402
403
        if finished_hosts_list and exc_hosts_list:
404
            for finished in finished_hosts_list:
405
                if finished in exc_hosts_list:
406
                    exc_hosts_list.remove(finished)
407
408
        return len(exc_hosts_list) if exc_hosts_list else 0
409
410
    def calculate_target_progress(self, scan_id: str) -> int:
411
        """Get a target's current progress value.
412
        The value is calculated with the progress of each single host
413
        in the target."""
414
415
        total_hosts = self.get_count_total(scan_id)
416
        exc_hosts = self.simplify_exclude_host_count(scan_id)
417
        count_alive = self.get_count_alive(scan_id)
418
        count_dead = self.get_count_dead(scan_id)
419
        host_progresses = self.get_current_target_progress(scan_id)
420
421
        try:
422
            t_prog = int(
423
                (sum(host_progresses.values()) + 100 * count_alive)
424
                / (total_hosts - exc_hosts - count_dead)
425
            )
426
        except ZeroDivisionError:
427
            # Consider the case in which all hosts are dead or excluded
428
            t_prog = ScanProgress.FINISHED.value
429
430
        return t_prog
431
432
    def get_start_time(self, scan_id: str) -> str:
433
        """ Get a scan's start time. """
434
435
        return self.scans_table[scan_id]['start_time']
436
437
    def get_end_time(self, scan_id: str) -> str:
438
        """ Get a scan's end time. """
439
440
        return self.scans_table[scan_id]['end_time']
441
442
    def get_host_list(self, scan_id: str) -> Dict:
443
        """ Get a scan's host list. """
444
445
        return self.scans_table[scan_id]['target'].get('hosts')
446
447
    def get_host_count(self, scan_id: str) -> int:
448
        """ Get total host count in the target. """
449
        host = self.get_host_list(scan_id)
450
        total_hosts = 0
451
452
        if host:
453
            total_hosts = len(target_str_to_list(host))
454
455
        return total_hosts
456
457
    def get_ports(self, scan_id: str) -> str:
458
        """Get a scan's ports list."""
459
        target = self.scans_table[scan_id].get('target')
460
        ports = target.pop('ports')
461
        self.scans_table[scan_id]['target'] = target
462
        return ports
463
464
    def get_exclude_hosts(self, scan_id: str) -> str:
465
        """Get an exclude host list for a given target."""
466
        return self.scans_table[scan_id]['target'].get('exclude_hosts')
467
468
    def get_finished_hosts(self, scan_id: str) -> str:
469
        """Get the finished host list sent by the client for a given target."""
470
        return self.scans_table[scan_id]['target'].get('finished_hosts')
471
472
    def get_credentials(self, scan_id: str) -> Dict[str, Dict[str, str]]:
473
        """Get a scan's credential list. It return dictionary with
474
        the corresponding credential for a given target.
475
        """
476
        return self.scans_table[scan_id].get('credentials')
477
478
    def get_target_options(self, scan_id: str) -> Dict[str, str]:
479
        """Get a scan's target option dictionary.
480
        It return dictionary with the corresponding options for
481
        a given target.
482
        """
483
        return self.scans_table[scan_id]['target'].get('options')
484
485
    def get_vts(self, scan_id: str) -> Dict[str, Union[Dict[str, str], List]]:
486
        """ Get a scan's vts. """
487
        scan_info = self.scans_table[scan_id]
488
        vts = scan_info.pop('vts')
489
        self.scans_table[scan_id] = scan_info
490
491
        return vts
492
493
    def id_exists(self, scan_id: str) -> bool:
494
        """ Check whether a scan exists in the table. """
495
496
        return self.scans_table.get(scan_id) is not None
497
498
    def delete_scan(self, scan_id: str) -> bool:
499
        """ Delete a scan if fully finished. """
500
501
        if self.get_status(scan_id) == ScanStatus.RUNNING:
502
            return False
503
504
        scans_table = self.scans_table
505
        del scans_table[scan_id]
506
        self.scans_table = scans_table
507
508
        return True
509