Passed
Pull Request — master (#102)
by Juan José
01:15
created

ospd.misc.ScanCollection.get_hosts_unfinished()   A

Complexity

Conditions 4

Size

Total Lines 11
Code Lines 8

Duplication

Lines 11
Ratio 100 %

Importance

Changes 0
Metric Value
cc 4
eloc 8
nop 2
dl 11
loc 11
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2018 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
""" Miscellaneous classes and functions related to OSPD.
20
"""
21
22
23
# Needed to say that when we import ospd, we mean the package and not the
24
# module in that directory.
25
from __future__ import absolute_import
26
from __future__ import print_function
27
28
import argparse
29
import binascii
30
import collections
31
import logging
32
import logging.handlers
33
import os
34
import re
35
import socket
36
import struct
37
import sys
38
import time
39
import ssl
40
import uuid
41
import multiprocessing
42
import itertools
43
from enum import Enum
44
45
LOGGER = logging.getLogger(__name__)
46
47
# Default file locations as used by a OpenVAS default installation
48
KEY_FILE = "/usr/var/lib/gvm/private/CA/serverkey.pem"
49
CERT_FILE = "/usr/var/lib/gvm/CA/servercert.pem"
50
CA_FILE = "/usr/var/lib/gvm/CA/cacert.pem"
51
52
PORT = 1234
53
ADDRESS = "0.0.0.0"
54
55
class ScanStatus(Enum):
56
    """Scan status. """
57
    INIT = 0
58
    RUNNING = 1
59
    STOPPED = 2
60
    FINISHED = 3
61
62
63
64 View Code Duplication
class ScanCollection(object):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
65
66
    """ Scans collection, managing scans and results read and write, exposing
67
    only needed information.
68
69
    Each scan has meta-information such as scan ID, current progress (from 0 to
70
    100), start time, end time, scan target and options and a list of results.
71
72
    There are 4 types of results: Alarms, Logs, Errors and Host Details.
73
74
    Todo:
75
    - Better checking for Scan ID existence and handling otherwise.
76
    - More data validation.
77
    - Mutex access per table/scan_info.
78
79
    """
80
81
    def __init__(self):
82
        """ Initialize the Scan Collection. """
83
84
        self.data_manager = None
85
        self.scans_table = dict()
86
87
    def add_result(self, scan_id, result_type, host='', name='', value='',
88
                   port='', test_id='', severity='', qod=''):
89
        """ Add a result to a scan in the table. """
90
91
        assert scan_id
92
        assert len(name) or len(value)
93
        result = dict()
94
        result['type'] = result_type
95
        result['name'] = name
96
        result['severity'] = severity
97
        result['test_id'] = test_id
98
        result['value'] = value
99
        result['host'] = host
100
        result['port'] = port
101
        result['qod'] = qod
102
        results = self.scans_table[scan_id]['results']
103
        results.append(result)
104
        # Set scan_info's results to propagate results to parent process.
105
        self.scans_table[scan_id]['results'] = results
106
107
    def set_progress(self, scan_id, progress):
108
        """ Sets scan_id scan's progress. """
109
110
        if progress > 0 and progress <= 100:
111
            self.scans_table[scan_id]['progress'] = progress
112
        if progress == 100:
113
            self.scans_table[scan_id]['end_time'] = int(time.time())
114
115
    def set_target_progress(self, scan_id, target, host, progress):
116
        """ Sets scan_id scan's progress. """
117
        if progress > 0 and progress <= 100:
118
            targets =  self.scans_table[scan_id]['target_progress']
119
            targets[target][host] = progress
120
            # Set scan_info's target_progress to propagate progresses
121
            # to parent process.
122
            self.scans_table[scan_id]['target_progress'] = targets
123
124
    def set_host_finished(self, scan_id, target, host):
125
        """ Add the host in a list of finished hosts """
126
        finished_hosts = self.scans_table[scan_id]['finished_hosts']
127
        finished_hosts[target].extend(host)
128
        self.scans_table[scan_id]['finished_hosts'] = finished_hosts
129
130
    def get_hosts_unfinished(self, scan_id):
131
        """ Get a list of finished hosts."""
132
133
        unfinished_hosts = list()
134
        for target in self.scans_table[scan_id]['finished_hosts']:
135
            unfinished_hosts.extend(target_str_to_list(target))
136
        for target in self.scans_table[scan_id]['finished_hosts']:
137
           for host in self.scans_table[scan_id]['finished_hosts'][target]:
138
               unfinished_hosts.remove(host)
139
140
        return unfinished_hosts
141
142
    def results_iterator(self, scan_id, pop_res):
143
        """ Returns an iterator over scan_id scan's results. If pop_res is True,
144
        it removed the fetched results from the list.
145
        """
146
        if pop_res:
147
            result_aux = self.scans_table[scan_id]['results']
148
            self.scans_table[scan_id]['results'] = list()
149
            return iter(result_aux)
150
151
        return iter(self.scans_table[scan_id]['results'])
152
153
    def ids_iterator(self):
154
        """ Returns an iterator over the collection's scan IDS. """
155
156
        return iter(self.scans_table.keys())
157
158
    def remove_single_result(self, scan_id, result):
159
        """Removes a single result from the result list in scan_table"""
160
        del self.scans_table[scan_id]['results'][result]
161
162
    def del_results_for_stopped_hosts(self, scan_id):
163
        """ Remove results from the result table for those host
164
        """
165
        unfinished_hosts = self.get_hosts_unfinished(scan_id)
166
        for result in self.results_iterator(scan_id, False):
167
            if result['host'] in unfinished_hosts:
168
                self.remove_single_result(scan_id, result)
169
170
    def resume_scan(self, scan_id, options):
171
        """ Reset the scan status in the scan_table to INIT.
172
        Also, overwrite the options, because a resume task cmd
173
        can add some new option. E.g. exclude hosts list.
174
        Parameters:
175
            scan_id (uuid): Scan ID to identify the scan process to be resumed.
176
            options (dict): Options for the scan to be resumed. This options
177
                            are not added to the already existent ones.
178
                            The old ones are removed
179
180
        Return:
181
            Scan ID which identifies the current scan.
182
        """
183
        self.scans_table[scan_id]['status'] = ScanStatus.INIT
184
        if options:
185
            self.scans_table[scan_id]['options'] = options
186
187
        self.del_results_for_stopped_hosts(scan_id)
188
189
        return scan_id
190
191
    def create_scan(self, scan_id='', targets='', options=None, vts=''):
192
        """ Creates a new scan with provided scan information. """
193
194
        if self.data_manager is None:
195
            self.data_manager = multiprocessing.Manager()
196
197
        # Check if it is possible to resume task. To avoid to resume, the
198
        # scan must be deleted from the scans_table.
199
        if scan_id and self.id_exists(scan_id) and (
200
                self.get_status(scan_id) == ScanStatus.STOPPED):
201
            return self.resume_scan(scan_id, options)
202
203
        if not options:
204
            options = dict()
205
        scan_info = self.data_manager.dict()
206
        scan_info['results'] = list()
207
        scan_info['finished_hosts'] = dict(
208
            [[target, []] for target, _, _ in targets])
209
        scan_info['progress'] = 0
210
        scan_info['target_progress'] = dict(
211
            [[target, {}] for target, _, _ in targets])
212
        scan_info['targets'] = targets
213
        scan_info['vts'] = vts
214
        scan_info['options'] = options
215
        scan_info['start_time'] = int(time.time())
216
        scan_info['end_time'] = "0"
217
        scan_info['status'] = ScanStatus.INIT
218
        if scan_id is None or scan_id == '':
219
            scan_id = str(uuid.uuid4())
220
        scan_info['scan_id'] = scan_id
221
        self.scans_table[scan_id] = scan_info
222
        return scan_id
223
224
    def set_status(self, scan_id, status):
225
        """ Sets scan_id scan's status. """
226
        self.scans_table[scan_id]['status'] = status
227
228
    def get_status(self, scan_id):
229
        """ Get scan_id scans's status."""
230
231
        return self.scans_table[scan_id]['status']
232
233
    def get_options(self, scan_id):
234
        """ Get scan_id scan's options list. """
235
236
        return self.scans_table[scan_id]['options']
237
238
    def set_option(self, scan_id, name, value):
239
        """ Set a scan_id scan's name option to value. """
240
241
        self.scans_table[scan_id]['options'][name] = value
242
243
    def get_progress(self, scan_id):
244
        """ Get a scan's current progress value. """
245
246
        return self.scans_table[scan_id]['progress']
247
248
    def get_target_progress(self, scan_id, target):
249
        """ Get a target's current progress value.
250
        The value is calculated with the progress of each single host
251
        in the target."""
252
253
        total_hosts = len(target_str_to_list(target))
254
        host_progresses = self.scans_table[scan_id]['target_progress'].get(target)
255
        try:
256
            t_prog = sum(host_progresses.values()) / total_hosts
257
        except ZeroDivisionError:
258
            LOGGER.error("Zero division error in ", get_target_progress.__name__)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_target_progress does not seem to be defined.
Loading history...
259
            raise
260
        return t_prog
261
262
    def get_start_time(self, scan_id):
263
        """ Get a scan's start time. """
264
265
        return self.scans_table[scan_id]['start_time']
266
267
    def get_end_time(self, scan_id):
268
        """ Get a scan's end time. """
269
270
        return self.scans_table[scan_id]['end_time']
271
272
    def get_target_list(self, scan_id):
273
        """ Get a scan's target list. """
274
275
        target_list = []
276
        for target, _, _ in self.scans_table[scan_id]['targets']:
277
            target_list.append(target)
278
        return target_list
279
280
    def get_ports(self, scan_id, target):
281
        """ Get a scan's ports list. If a target is specified
282
        it will return the corresponding port for it. If not,
283
        it returns the port item of the first nested list in
284
        the target's list.
285
        """
286
        if target:
287
            for item in self.scans_table[scan_id]['targets']:
288
                if target == item[0]:
289
                    return item[1]
290
291
        return self.scans_table[scan_id]['targets'][0][1]
292
293
    def get_credentials(self, scan_id, target):
294
        """ Get a scan's credential list. It return dictionary with
295
        the corresponding credential for a given target.
296
        """
297
        if target:
298
            for item in self.scans_table[scan_id]['targets']:
299
                if target == item[0]:
300
                    return item[2]
301
302
    def get_vts(self, scan_id):
303
        """ Get a scan's vts list. """
304
305
        return self.scans_table[scan_id]['vts']
306
307
    def id_exists(self, scan_id):
308
        """ Check whether a scan exists in the table. """
309
310
        return self.scans_table.get(scan_id) is not None
311
312
    def delete_scan(self, scan_id):
313
        """ Delete a scan if fully finished. """
314
315
        if self.get_status(scan_id) == ScanStatus.RUNNING:
316
            return False
317
        self.scans_table.pop(scan_id)
318
        if len(self.scans_table) == 0:
319
            del self.data_manager
320
            self.data_manager = None
321
        return True
322
323
324 View Code Duplication
class ResultType(object):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
325
326
    """ Various scan results types values. """
327
328
    ALARM = 0
329
    LOG = 1
330
    ERROR = 2
331
    HOST_DETAIL = 3
332
333
    @classmethod
334
    def get_str(cls, result_type):
335
        """ Return string name of a result type. """
336
        if result_type == cls.ALARM:
337
            return "Alarm"
338
        elif result_type == cls.LOG:
339
            return "Log Message"
340
        elif result_type == cls.ERROR:
341
            return "Error Message"
342
        elif result_type == cls.HOST_DETAIL:
343
            return "Host Detail"
344
        else:
345
            assert False, "Erroneous result type {0}.".format(result_type)
346
347
    @classmethod
348
    def get_type(cls, result_name):
349
        """ Return string name of a result type. """
350
        if result_name == "Alarm":
351
            return cls.ALARM
352
        elif result_name == "Log Message":
353
            return cls.LOG
354
        elif result_name == "Error Message":
355
            return cls.ERROR
356
        elif result_name == "Host Detail":
357
            return cls.HOST_DETAIL
358
        else:
359
            assert False, "Erroneous result name {0}.".format(result_name)
360
361
362
__inet_pton = None
363
364
365 View Code Duplication
def inet_pton(address_family, ip_string):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
366
    """ A platform independent version of inet_pton """
367
    global __inet_pton
368
    if __inet_pton is None:
369
        if hasattr(socket, 'inet_pton'):
370
            __inet_pton = socket.inet_pton
371
        else:
372
            from ospd import win_socket
373
            __inet_pton = win_socket.inet_pton
374
375
    return __inet_pton(address_family, ip_string)
376
377
378
__inet_ntop = None
379
380
381 View Code Duplication
def inet_ntop(address_family, packed_ip):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
382
    """ A platform independent version of inet_ntop """
383
    global __inet_ntop
384
    if __inet_ntop is None:
385
        if hasattr(socket, 'inet_ntop'):
386
            __inet_ntop = socket.inet_ntop
387
        else:
388
            from ospd import win_socket
389
            __inet_ntop = win_socket.inet_ntop
390
391
    return __inet_ntop(address_family, packed_ip)
392
393
394
def target_to_ipv4(target):
395
    """ Attempt to return a single IPv4 host list from a target string. """
396
397
    try:
398
        inet_pton(socket.AF_INET, target)
399
        return [target]
400
    except socket.error:
401
        return None
402
403
404
def target_to_ipv6(target):
405
    """ Attempt to return a single IPv6 host list from a target string. """
406
407
    try:
408
        inet_pton(socket.AF_INET6, target)
409
        return [target]
410
    except socket.error:
411
        return None
412
413
414 View Code Duplication
def ipv4_range_to_list(start_packed, end_packed):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
415
    """ Return a list of IPv4 entries from start_packed to end_packed. """
416
417
    new_list = list()
418
    start = struct.unpack('!L', start_packed)[0]
419
    end = struct.unpack('!L', end_packed)[0]
420
    for value in range(start, end + 1):
421
        new_ip = socket.inet_ntoa(struct.pack('!L', value))
422
        new_list.append(new_ip)
423
    return new_list
424
425
426 View Code Duplication
def target_to_ipv4_short(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
427
    """ Attempt to return a IPv4 short range list from a target string. """
428
429
    splitted = target.split('-')
430
    if len(splitted) != 2:
431
        return None
432
    try:
433
        start_packed = inet_pton(socket.AF_INET, splitted[0])
434
        end_value = int(splitted[1])
435
    except (socket.error, ValueError):
436
        return None
437
    start_value = int(binascii.hexlify(bytes(start_packed[3])), 16)
438
    if end_value < 0 or end_value > 255 or end_value < start_value:
439
        return None
440
    end_packed = start_packed[0:3] + struct.pack('B', end_value)
441
    return ipv4_range_to_list(start_packed, end_packed)
442
443
444 View Code Duplication
def target_to_ipv4_cidr(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
445
    """ Attempt to return a IPv4 CIDR list from a target string. """
446
447
    splitted = target.split('/')
448
    if len(splitted) != 2:
449
        return None
450
    try:
451
        start_packed = inet_pton(socket.AF_INET, splitted[0])
452
        block = int(splitted[1])
453
    except (socket.error, ValueError):
454
        return None
455
    if block <= 0 or block > 30:
456
        return None
457
    start_value = int(binascii.hexlify(start_packed), 16) >> (32 - block)
458
    start_value = (start_value << (32 - block)) + 1
459
    end_value = (start_value | (0xffffffff >> block)) - 1
460
    start_packed = struct.pack('!I', start_value)
461
    end_packed = struct.pack('!I', end_value)
462
    return ipv4_range_to_list(start_packed, end_packed)
463
464
465 View Code Duplication
def target_to_ipv6_cidr(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
466
    """ Attempt to return a IPv6 CIDR list from a target string. """
467
468
    splitted = target.split('/')
469
    if len(splitted) != 2:
470
        return None
471
    try:
472
        start_packed = inet_pton(socket.AF_INET6, splitted[0])
473
        block = int(splitted[1])
474
    except (socket.error, ValueError):
475
        return None
476
    if block <= 0 or block > 126:
477
        return None
478
    start_value = int(binascii.hexlify(start_packed), 16) >> (128 - block)
479
    start_value = (start_value << (128 - block)) + 1
480
    end_value = (start_value | (int('ff' * 16, 16) >> block)) - 1
481
    high = start_value >> 64
482
    low = start_value & ((1 << 64) - 1)
483
    start_packed = struct.pack('!QQ', high, low)
484
    high = end_value >> 64
485
    low = end_value & ((1 << 64) - 1)
486
    end_packed = struct.pack('!QQ', high, low)
487
    return ipv6_range_to_list(start_packed, end_packed)
488
489
490 View Code Duplication
def target_to_ipv4_long(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
491
    """ Attempt to return a IPv4 long-range list from a target string. """
492
493
    splitted = target.split('-')
494
    if len(splitted) != 2:
495
        return None
496
    try:
497
        start_packed = inet_pton(socket.AF_INET, splitted[0])
498
        end_packed = inet_pton(socket.AF_INET, splitted[1])
499
    except socket.error:
500
        return None
501
    if end_packed < start_packed:
502
        return None
503
    return ipv4_range_to_list(start_packed, end_packed)
504
505
506 View Code Duplication
def ipv6_range_to_list(start_packed, end_packed):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
507
    """ Return a list of IPv6 entries from start_packed to end_packed. """
508
509
    new_list = list()
510
    start = int(binascii.hexlify(start_packed), 16)
511
    end = int(binascii.hexlify(end_packed), 16)
512
    for value in range(start, end + 1):
513
        high = value >> 64
514
        low = value & ((1 << 64) - 1)
515
        new_ip = inet_ntop(socket.AF_INET6,
516
                           struct.pack('!2Q', high, low))
517
        new_list.append(new_ip)
518
    return new_list
519
520
521 View Code Duplication
def target_to_ipv6_short(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
522
    """ Attempt to return a IPv6 short-range list from a target string. """
523
524
    splitted = target.split('-')
525
    if len(splitted) != 2:
526
        return None
527
    try:
528
        start_packed = inet_pton(socket.AF_INET6, splitted[0])
529
        end_value = int(splitted[1], 16)
530
    except (socket.error, ValueError):
531
        return None
532
    start_value = int(binascii.hexlify(start_packed[14:]), 16)
533
    if end_value < 0 or end_value > 0xffff or end_value < start_value:
534
        return None
535
    end_packed = start_packed[:14] + struct.pack('!H', end_value)
536
    return ipv6_range_to_list(start_packed, end_packed)
537
538
539 View Code Duplication
def target_to_ipv6_long(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
540
    """ Attempt to return a IPv6 long-range list from a target string. """
541
542
    splitted = target.split('-')
543
    if len(splitted) != 2:
544
        return None
545
    try:
546
        start_packed = inet_pton(socket.AF_INET6, splitted[0])
547
        end_packed = inet_pton(socket.AF_INET6, splitted[1])
548
    except socket.error:
549
        return None
550
    if end_packed < start_packed:
551
        return None
552
    return ipv6_range_to_list(start_packed, end_packed)
553
554
555
def target_to_hostname(target):
556
    """ Attempt to return a single hostname list from a target string. """
557
558
    if len(target) == 0 or len(target) > 255:
559
        return None
560
    if not re.match(r'^[\w.-]+$', target):
561
        return None
562
    return [target]
563
564
565 View Code Duplication
def target_to_list(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
566
    """ Attempt to return a list of single hosts from a target string. """
567
568
    # Is it an IPv4 address ?
569
    new_list = target_to_ipv4(target)
570
    # Is it an IPv6 address ?
571
    if not new_list:
572
        new_list = target_to_ipv6(target)
573
    # Is it an IPv4 CIDR ?
574
    if not new_list:
575
        new_list = target_to_ipv4_cidr(target)
576
    # Is it an IPv6 CIDR ?
577
    if not new_list:
578
        new_list = target_to_ipv6_cidr(target)
579
    # Is it an IPv4 short-range ?
580
    if not new_list:
581
        new_list = target_to_ipv4_short(target)
582
    # Is it an IPv4 long-range ?
583
    if not new_list:
584
        new_list = target_to_ipv4_long(target)
585
    # Is it an IPv6 short-range ?
586
    if not new_list:
587
        new_list = target_to_ipv6_short(target)
588
    # Is it an IPv6 long-range ?
589
    if not new_list:
590
        new_list = target_to_ipv6_long(target)
591
    # Is it a hostname ?
592
    if not new_list:
593
        new_list = target_to_hostname(target)
594
    return new_list
595
596
597 View Code Duplication
def target_str_to_list(target_str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
598
    """ Parses a targets string into a list of individual targets. """
599
    new_list = list()
600
    for target in target_str.split(','):
601
        target = target.strip()
602
        target_list = target_to_list(target)
603
        if target_list:
604
            new_list.extend(target_list)
605
        else:
606
            LOGGER.info("{0}: Invalid target value".format(target))
607
            return None
608
    return list(collections.OrderedDict.fromkeys(new_list))
609
610
611
def resolve_hostname(hostname):
612
    """ Returns IP of a hostname. """
613
614
    assert hostname
615
    try:
616
        return socket.gethostbyname(hostname)
617
    except socket.gaierror:
618
        return None
619
620
621 View Code Duplication
def port_range_expand(portrange):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
622
    """
623
    Receive a port range and expands it in individual ports.
624
625
    @input Port range.
626
    e.g. "4-8"
627
628
    @return List of integers.
629
    e.g. [4, 5, 6, 7, 8]
630
    """
631
    if not portrange or '-' not in portrange:
632
        LOGGER.info("Invalid port range format")
633
        return None
634
    port_list = list()
635
    for single_port in range(int(portrange[:portrange.index('-')]),
636
                             int(portrange[portrange.index('-') + 1:]) + 1):
637
        port_list.append(single_port)
638
    return port_list
639
640
641 View Code Duplication
def port_str_arrange(ports):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
642
    """ Gives a str in the format (always tcp listed first).
643
    T:<tcp ports/portrange comma separated>U:<udp ports comma separated>
644
    """
645
    b_tcp = ports.find("T")
646
    b_udp = ports.find("U")
647
    if (b_udp != -1 and b_tcp != -1) and b_udp < b_tcp:
648
        return ports[b_tcp:] + ports[b_udp:b_tcp]
649
650
    return ports
651
652
653 View Code Duplication
def ports_str_check_failed(port_str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
654
    """
655
    Check if the port string is well formed.
656
    Return True if fail, False other case.
657
    """
658
659
    pattern = r'[^TU:0-9, \-]'
660
    if (
661
        re.search(pattern, port_str)
662
        or port_str.count('T') > 1
663
        or port_str.count('U') > 1
664
        or port_str.count(':') < (port_str.count('T') + port_str.count('U'))
665
    ):
666
        return True
667
    return False
668
669
670 View Code Duplication
def ports_as_list(port_str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
671
    """
672
    Parses a ports string into two list of individual tcp and udp ports.
673
674
    @input string containing a port list
675
    e.g. T:1,2,3,5-8 U:22,80,600-1024
676
677
    @return two list of sorted integers, for tcp and udp ports respectively.
678
    """
679
    if not port_str:
680
        LOGGER.info("Invalid port value")
681
        return [None, None]
682
683
    if ports_str_check_failed(port_str):
684
        LOGGER.info("{0}: Port list malformed.")
685
        return [None, None]
686
687
    tcp_list = list()
688
    udp_list = list()
689
    ports = port_str.replace(' ', '')
690
    b_tcp = ports.find("T")
691
    b_udp = ports.find("U")
692
693
    if ports[b_tcp - 1] == ',':
694
        ports = ports[:b_tcp - 1] + ports[b_tcp:]
695
    if ports[b_udp - 1] == ',':
696
        ports = ports[:b_udp - 1] + ports[b_udp:]
697
    ports = port_str_arrange(ports)
698
699
    tports = ''
700
    uports = ''
701
    # TCP ports listed first, then UDP ports
702
    if b_udp != -1 and b_tcp != -1:
703
        tports = ports[ports.index('T:') + 2:ports.index('U:')]
704
        uports = ports[ports.index('U:') + 2:]
705
    # Only UDP ports
706
    elif b_tcp == -1 and b_udp != -1:
707
        uports = ports[ports.index('U:') + 2:]
708
    # Only TCP ports
709
    elif b_udp == -1 and b_tcp != -1:
710
        tports = ports[ports.index('T:') + 2:]
711
    else:
712
        tports = ports
713
714
    if tports:
715
        for port in tports.split(','):
716
            if '-' in port:
717
                tcp_list.extend(port_range_expand(port))
718
            else:
719
                tcp_list.append(int(port))
720
        tcp_list.sort()
721
    if uports:
722
        for port in uports.split(','):
723
            if '-' in port:
724
                udp_list.extend(port_range_expand(port))
725
            else:
726
                udp_list.append(int(port))
727
        udp_list.sort()
728
729
    return (tcp_list, udp_list)
730
731
732
def get_tcp_port_list(port_str):
733
    """ Return a list with tcp ports from a given port list in string format """
734
    return ports_as_list(port_str)[0]
735
736
737
def get_udp_port_list(port_str):
738
    """ Return a list with udp ports from a given port list in string format """
739
    return ports_as_list(port_str)[1]
740
741
742 View Code Duplication
def port_list_compress(port_list):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
743
    """ Compress a port list and return a string. """
744
745
    if not port_list or len(port_list) == 0:
746
        LOGGER.info("Invalid or empty port list.")
747
        return ''
748
749
    port_list = sorted(set(port_list))
750
    compressed_list = []
751
    for key, group in itertools.groupby(enumerate(port_list),
752
                                        lambda t: t[1] - t[0]):
753
        group = list(group)
754
        if group[0][1] == group[-1][1]:
755
            compressed_list.append(str(group[0][1]))
756
        else:
757
            compressed_list.append(str(group[0][1]) + '-' + str(group[-1][1]))
758
759
    return ','.join(compressed_list)
760
761
762
def valid_uuid(value):
763
    """ Check if value is a valid UUID. """
764
765
    try:
766
        uuid.UUID(value, version=4)
767
        return True
768
    except (TypeError, ValueError, AttributeError):
769
        return False
770
771
772 View Code Duplication
def create_args_parser(description):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
773
    """ Create a command-line arguments parser for OSPD. """
774
775
    parser = argparse.ArgumentParser(description=description)
776
777
    def network_port(string):
778
        """ Check if provided string is a valid network port. """
779
780
        value = int(string)
781
        if not 0 < value <= 65535:
782
            raise argparse.ArgumentTypeError(
783
                'port must be in ]0,65535] interval')
784
        return value
785
786
    def cacert_file(cacert):
787
        """ Check if provided file is a valid CA Certificate """
788
        try:
789
            context = ssl.create_default_context(cafile=cacert)
790
        except AttributeError:
791
            # Python version < 2.7.9
792
            return cacert
793
        except IOError:
794
            raise argparse.ArgumentTypeError('CA Certificate not found')
795
        try:
796
            not_after = context.get_ca_certs()[0]['notAfter']
797
            not_after = ssl.cert_time_to_seconds(not_after)
798
            not_before = context.get_ca_certs()[0]['notBefore']
799
            not_before = ssl.cert_time_to_seconds(not_before)
800
        except (KeyError, IndexError):
801
            raise argparse.ArgumentTypeError('CA Certificate is erroneous')
802
        if not_after < int(time.time()):
803
            raise argparse.ArgumentTypeError('CA Certificate expired')
804
        if not_before > int(time.time()):
805
            raise argparse.ArgumentTypeError('CA Certificate not active yet')
806
        return cacert
807
808
    def log_level(string):
809
        """ Check if provided string is a valid log level. """
810
811
        value = getattr(logging, string.upper(), None)
812
        if not isinstance(value, int):
813
            raise argparse.ArgumentTypeError(
814
                'log level must be one of {debug,info,warning,error,critical}')
815
        return value
816
817
    def filename(string):
818
        """ Check if provided string is a valid file path. """
819
820
        if not os.path.isfile(string):
821
            raise argparse.ArgumentTypeError(
822
                '%s is not a valid file path' % string)
823
        return string
824
825
    parser.add_argument('-p', '--port', default=PORT, type=network_port,
826
                        help='TCP Port to listen on. Default: {0}'.format(PORT))
827
    parser.add_argument('-b', '--bind-address', default=ADDRESS,
828
                        help='Address to listen on. Default: {0}'
829
                        .format(ADDRESS))
830
    parser.add_argument('-u', '--unix-socket',
831
                        help='Unix file socket to listen on.')
832
    parser.add_argument('-k', '--key-file', type=filename,
833
                        help='Server key file. Default: {0}'.format(KEY_FILE))
834
    parser.add_argument('-c', '--cert-file', type=filename,
835
                        help='Server cert file. Default: {0}'.format(CERT_FILE))
836
    parser.add_argument('--ca-file', type=cacert_file,
837
                        help='CA cert file. Default: {0}'.format(CA_FILE))
838
    parser.add_argument('-L', '--log-level', default='warning', type=log_level,
839
                        help='Wished level of logging. Default: WARNING')
840
    parser.add_argument('--foreground', action='store_true',
841
                        help='Run in foreground and logs all messages to console.')
842
    parser.add_argument('-l', '--log-file', type=filename,
843
                        help='Path to the logging file.')
844
    parser.add_argument('--version', action='store_true',
845
                        help='Print version then exit.')
846
    return parser
847
848
849
def go_to_background():
850
    """ Daemonize the running process. """
851
    try:
852
        if os.fork():
853
            sys.exit()
854
    except OSError as errmsg:
855
        LOGGER.error('Fork failed: {0}'.format(errmsg))
856
        sys.exit('Fork failed')
857
858
859 View Code Duplication
def get_common_args(parser, args=None):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
860
    """ Return list of OSPD common command-line arguments from parser, after
861
    validating provided values or setting default ones.
862
863
    """
864
865
    options = parser.parse_args(args)
866
    # TCP Port to listen on.
867
    port = options.port
868
869
    # Network address to bind listener to
870
    address = options.bind_address
871
872
    # Unix file socket to listen on
873
    unix_socket = options.unix_socket
874
875
    # Debug level.
876
    log_level = options.log_level
877
878
    # Server key path.
879
    keyfile = options.key_file or KEY_FILE
880
881
    # Server cert path.
882
    certfile = options.cert_file or CERT_FILE
883
884
    # CA cert path.
885
    cafile = options.ca_file or CA_FILE
886
887
    common_args = dict()
888
    common_args['port'] = port
889
    common_args['address'] = address
890
    common_args['unix_socket'] = unix_socket
891
    common_args['keyfile'] = keyfile
892
    common_args['certfile'] = certfile
893
    common_args['cafile'] = cafile
894
    common_args['log_level'] = log_level
895
    common_args['foreground'] = options.foreground
896
    common_args['log_file'] = options.log_file
897
    common_args['version'] = options.version
898
899
    return common_args
900
901
902 View Code Duplication
def print_version(wrapper):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
903
    """ Prints the server version and license information."""
904
905
    scanner_name = wrapper.get_scanner_name()
906
    server_version = wrapper.get_server_version()
907
    print("OSP Server for {0} version {1}".format(scanner_name, server_version))
908
    protocol_version = wrapper.get_protocol_version()
909
    print("OSP Version: {0}".format(protocol_version))
910
    daemon_name = wrapper.get_daemon_name()
911
    daemon_version = wrapper.get_daemon_version()
912
    print("Using: {0} {1}".format(daemon_name, daemon_version))
913
    print("Copyright (C) 2014, 2015 Greenbone Networks GmbH\n"
914
          "License GPLv2+: GNU GPL version 2 or later\n"
915
          "This is free software: you are free to change"
916
          " and redistribute it.\n"
917
          "There is NO WARRANTY, to the extent permitted by law.")
918
919
920 View Code Duplication
def main(name, klass):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
921
    """ OSPD Main function. """
922
923
    # Common args parser.
924
    parser = create_args_parser(name)
925
926
    # Common args
927
    cargs = get_common_args(parser)
928
    logging.getLogger().setLevel(cargs['log_level'])
929
    wrapper = klass(certfile=cargs['certfile'], keyfile=cargs['keyfile'],
930
                    cafile=cargs['cafile'])
931
932
    if cargs['version']:
933
        print_version(wrapper)
934
        sys.exit()
935
936
    if cargs['foreground']:
937
        console = logging.StreamHandler()
938
        console.setFormatter(
939
            logging.Formatter(
940
                '%(asctime)s %(name)s: %(levelname)s: %(message)s'))
941
        logging.getLogger().addHandler(console)
942
    elif cargs['log_file']:
943
        logfile = logging.handlers.WatchedFileHandler(cargs['log_file'])
944
        logfile.setFormatter(
945
            logging.Formatter(
946
                '%(asctime)s %(name)s: %(levelname)s: %(message)s'))
947
        logging.getLogger().addHandler(logfile)
948
        go_to_background()
949
    else:
950
        syslog = logging.handlers.SysLogHandler('/dev/log')
951
        syslog.setFormatter(
952
            logging.Formatter('%(name)s: %(levelname)s: %(message)s'))
953
        logging.getLogger().addHandler(syslog)
954
        # Duplicate syslog's file descriptor to stout/stderr.
955
        syslog_fd = syslog.socket.fileno()
956
        os.dup2(syslog_fd, 1)
957
        os.dup2(syslog_fd, 2)
958
        go_to_background()
959
960
    if not wrapper.check():
961
        return 1
962
    return wrapper.run(cargs['address'], cargs['port'], cargs['unix_socket'])
963