Passed
Push — master ( 5898b3...6f2eae )
by Juan José
01:49
created

ospd.misc.ScanCollection.set_host_finished()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 5
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 4
dl 5
loc 5
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2014-2018 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
""" Miscellaneous classes and functions related to OSPD.
20
"""
21
22
23
# Needed to say that when we import ospd, we mean the package and not the
24
# module in that directory.
25
from __future__ import absolute_import
26
from __future__ import print_function
27
28
import argparse
29
import binascii
30
import collections
31
import logging
32
import logging.handlers
33
import os
34
import re
35
import socket
36
import struct
37
import sys
38
import time
39
import ssl
40
import uuid
41
import multiprocessing
42
import itertools
43
44
LOGGER = logging.getLogger(__name__)
45
46
# Default file locations as used by a OpenVAS default installation
47
KEY_FILE = "/usr/var/lib/gvm/private/CA/serverkey.pem"
48
CERT_FILE = "/usr/var/lib/gvm/CA/servercert.pem"
49
CA_FILE = "/usr/var/lib/gvm/CA/cacert.pem"
50
51
PORT = 1234
52
ADDRESS = "0.0.0.0"
53
54
55 View Code Duplication
class ScanCollection(object):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
56
57
    """ Scans collection, managing scans and results read and write, exposing
58
    only needed information.
59
60
    Each scan has meta-information such as scan ID, current progress (from 0 to
61
    100), start time, end time, scan target and options and a list of results.
62
63
    There are 4 types of results: Alarms, Logs, Errors and Host Details.
64
65
    Todo:
66
    - Better checking for Scan ID existence and handling otherwise.
67
    - More data validation.
68
    - Mutex access per table/scan_info.
69
70
    """
71
72
    def __init__(self):
73
        """ Initialize the Scan Collection. """
74
75
        self.data_manager = None
76
        self.scans_table = dict()
77
78
    def add_result(self, scan_id, result_type, host='', name='', value='',
79
                   port='', test_id='', severity='', qod=''):
80
        """ Add a result to a scan in the table. """
81
82
        assert scan_id
83
        assert len(name) or len(value)
84
        result = dict()
85
        result['type'] = result_type
86
        result['name'] = name
87
        result['severity'] = severity
88
        result['test_id'] = test_id
89
        result['value'] = value
90
        result['host'] = host
91
        result['port'] = port
92
        result['qod'] = qod
93
        results = self.scans_table[scan_id]['results']
94
        results.append(result)
95
        # Set scan_info's results to propagate results to parent process.
96
        self.scans_table[scan_id]['results'] = results
97
98
    def set_progress(self, scan_id, progress):
99
        """ Sets scan_id scan's progress. """
100
101
        if progress > 0 and progress <= 100:
102
            self.scans_table[scan_id]['progress'] = progress
103
        if progress == 100:
104
            self.scans_table[scan_id]['end_time'] = int(time.time())
105
106
    def set_target_progress(self, scan_id, target, progress):
107
        """ Sets scan_id scan's progress. """
108
        if progress > 0 and progress <= 100:
109
            target_process = dict()
110
            target_process = self.scans_table[scan_id]['target_progress']
111
            target_process[target] = progress
112
            # Set scan_info's target_progress to propagate progresses
113
            # to parent process.
114
            self.scans_table[scan_id]['target_progress'] = target_process
115
116
    def set_host_finished(self, scan_id, target, host):
117
        """ Add the host in a list of finished hosts """
118
        finished_hosts = self.scans_table[scan_id]['finished_hosts']
119
        finished_hosts[target].extend(host)
120
        self.scans_table[scan_id]['finished_hosts'] = finished_hosts
121
122
    def results_iterator(self, scan_id, pop_res):
123
        """ Returns an iterator over scan_id scan's results. If pop_res is True,
124
        it removed the fetched results from the list.
125
        """
126
        if pop_res:
127
            result_aux = self.scans_table[scan_id]['results']
128
            self.scans_table[scan_id]['results'] = list()
129
            return iter(result_aux)
130
131
        return iter(self.scans_table[scan_id]['results'])
132
133
    def ids_iterator(self):
134
        """ Returns an iterator over the collection's scan IDS. """
135
136
        return iter(self.scans_table.keys())
137
138
    def create_scan(self, scan_id='', targets='', target_str=None,
139
                    options=dict(), vts=''):
140
        """ Creates a new scan with provided scan information. """
141
142
        if self.data_manager is None:
143
            self.data_manager = multiprocessing.Manager()
144
        scan_info = self.data_manager.dict()
145
        scan_info['results'] = list()
146
        scan_info['finished_hosts'] = dict(
147
            [[target, []] for target, _, _ in targets])
148
        scan_info['progress'] = 0
149
        scan_info['target_progress'] = dict(
150
            [[target, 0] for target, _, _ in targets])
151
        scan_info['targets'] = targets
152
        scan_info['legacy_target'] = target_str
153
        scan_info['vts'] = vts
154
        scan_info['options'] = options
155
        scan_info['start_time'] = int(time.time())
156
        scan_info['end_time'] = "0"
157
        scan_info['status'] = ""
158
        if scan_id is None or scan_id == '':
159
            scan_id = str(uuid.uuid4())
160
        scan_info['scan_id'] = scan_id
161
        self.scans_table[scan_id] = scan_info
162
        return scan_id
163
164
    def set_status(self, scan_id, status):
165
        """ Sets scan_id scan's status. """
166
        self.scans_table[scan_id]['status'] = status
167
168
    def get_status(self, scan_id):
169
        """ Get scan_id scans's status."""
170
171
        return self.scans_table[scan_id]['status']
172
173
    def get_options(self, scan_id):
174
        """ Get scan_id scan's options list. """
175
176
        return self.scans_table[scan_id]['options']
177
178
    def set_option(self, scan_id, name, value):
179
        """ Set a scan_id scan's name option to value. """
180
181
        self.scans_table[scan_id]['options'][name] = value
182
183
    def get_progress(self, scan_id):
184
        """ Get a scan's current progress value. """
185
186
        return self.scans_table[scan_id]['progress']
187
188
    def get_target_progress(self, scan_id):
189
        """ Get a scan's current progress value. """
190
191
        return self.scans_table[scan_id]['target_progress']
192
193
    def get_start_time(self, scan_id):
194
        """ Get a scan's start time. """
195
196
        return self.scans_table[scan_id]['start_time']
197
198
    def get_end_time(self, scan_id):
199
        """ Get a scan's end time. """
200
201
        return self.scans_table[scan_id]['end_time']
202
203
    def get_target(self, scan_id):
204
        """ Get a scan's target list. """
205
        if self.scans_table[scan_id]['legacy_target']:
206
            return self.scans_table[scan_id]['legacy_target']
207
208
        target_list = []
209
        for item in self.scans_table[scan_id]['targets']:
210
            target_list.append(item[0])
211
        separ = ','
212
        return separ.join(target_list)
213
214
    def get_ports(self, scan_id, target):
215
        """ Get a scan's ports list. If a target is specified
216
        it will return the corresponding port for it. If not,
217
        it returns the port item of the first nested list in
218
        the target's list.
219
        """
220
        if target:
221
            for item in self.scans_table[scan_id]['targets']:
222
                if target == item[0]:
223
                    return item[1]
224
225
        return self.scans_table[scan_id]['targets'][0][1]
226
227
    def get_credentials(self, scan_id, target):
228
        """ Get a scan's credential list. It return dictionary with
229
        the corresponding credential for a given target.
230
        """
231
        if target:
232
            for item in self.scans_table[scan_id]['targets']:
233
                if target == item[0]:
234
                    return item[2]
235
236
    def get_vts(self, scan_id):
237
        """ Get a scan's vts list. """
238
239
        return self.scans_table[scan_id]['vts']
240
241
    def id_exists(self, scan_id):
242
        """ Check whether a scan exists in the table. """
243
244
        return self.scans_table.get(scan_id) is not None
245
246
    def delete_scan(self, scan_id):
247
        """ Delete a scan if fully finished. """
248
249
        if self.get_status(scan_id) == "running":
250
            return False
251
        self.scans_table.pop(scan_id)
252
        if len(self.scans_table) == 0:
253
            del self.data_manager
254
            self.data_manager = None
255
        return True
256
257
258 View Code Duplication
class ResultType(object):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
259
260
    """ Various scan results types values. """
261
262
    ALARM = 0
263
    LOG = 1
264
    ERROR = 2
265
    HOST_DETAIL = 3
266
267
    @classmethod
268
    def get_str(cls, result_type):
269
        """ Return string name of a result type. """
270
        if result_type == cls.ALARM:
271
            return "Alarm"
272
        elif result_type == cls.LOG:
273
            return "Log Message"
274
        elif result_type == cls.ERROR:
275
            return "Error Message"
276
        elif result_type == cls.HOST_DETAIL:
277
            return "Host Detail"
278
        else:
279
            assert False, "Erroneous result type {0}.".format(result_type)
280
281
    @classmethod
282
    def get_type(cls, result_name):
283
        """ Return string name of a result type. """
284
        if result_name == "Alarm":
285
            return cls.ALARM
286
        elif result_name == "Log Message":
287
            return cls.LOG
288
        elif result_name == "Error Message":
289
            return cls.ERROR
290
        elif result_name == "Host Detail":
291
            return cls.HOST_DETAIL
292
        else:
293
            assert False, "Erroneous result name {0}.".format(result_name)
294
295
296
__inet_pton = None
297
298
299 View Code Duplication
def inet_pton(address_family, ip_string):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
300
    """ A platform independent version of inet_pton """
301
    global __inet_pton
302
    if __inet_pton is None:
303
        if hasattr(socket, 'inet_pton'):
304
            __inet_pton = socket.inet_pton
305
        else:
306
            from ospd import win_socket
307
            __inet_pton = win_socket.inet_pton
308
309
    return __inet_pton(address_family, ip_string)
310
311
312
__inet_ntop = None
313
314
315 View Code Duplication
def inet_ntop(address_family, packed_ip):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
316
    """ A platform independent version of inet_ntop """
317
    global __inet_ntop
318
    if __inet_ntop is None:
319
        if hasattr(socket, 'inet_ntop'):
320
            __inet_ntop = socket.inet_ntop
321
        else:
322
            from ospd import win_socket
323
            __inet_ntop = win_socket.inet_ntop
324
325
    return __inet_ntop(address_family, packed_ip)
326
327
328
def target_to_ipv4(target):
329
    """ Attempt to return a single IPv4 host list from a target string. """
330
331
    try:
332
        inet_pton(socket.AF_INET, target)
333
        return [target]
334
    except socket.error:
335
        return None
336
337
338
def target_to_ipv6(target):
339
    """ Attempt to return a single IPv6 host list from a target string. """
340
341
    try:
342
        inet_pton(socket.AF_INET6, target)
343
        return [target]
344
    except socket.error:
345
        return None
346
347
348 View Code Duplication
def ipv4_range_to_list(start_packed, end_packed):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
349
    """ Return a list of IPv4 entries from start_packed to end_packed. """
350
351
    new_list = list()
352
    start = struct.unpack('!L', start_packed)[0]
353
    end = struct.unpack('!L', end_packed)[0]
354
    for value in range(start, end + 1):
355
        new_ip = socket.inet_ntoa(struct.pack('!L', value))
356
        new_list.append(new_ip)
357
    return new_list
358
359
360 View Code Duplication
def target_to_ipv4_short(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
361
    """ Attempt to return a IPv4 short range list from a target string. """
362
363
    splitted = target.split('-')
364
    if len(splitted) != 2:
365
        return None
366
    try:
367
        start_packed = inet_pton(socket.AF_INET, splitted[0])
368
        end_value = int(splitted[1])
369
    except (socket.error, ValueError):
370
        return None
371
    start_value = int(binascii.hexlify(bytes(start_packed[3])), 16)
372
    if end_value < 0 or end_value > 255 or end_value < start_value:
373
        return None
374
    end_packed = start_packed[0:3] + struct.pack('B', end_value)
375
    return ipv4_range_to_list(start_packed, end_packed)
376
377
378 View Code Duplication
def target_to_ipv4_cidr(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
379
    """ Attempt to return a IPv4 CIDR list from a target string. """
380
381
    splitted = target.split('/')
382
    if len(splitted) != 2:
383
        return None
384
    try:
385
        start_packed = inet_pton(socket.AF_INET, splitted[0])
386
        block = int(splitted[1])
387
    except (socket.error, ValueError):
388
        return None
389
    if block <= 0 or block > 30:
390
        return None
391
    start_value = int(binascii.hexlify(start_packed), 16) >> (32 - block)
392
    start_value = (start_value << (32 - block)) + 1
393
    end_value = (start_value | (0xffffffff >> block)) - 1
394
    start_packed = struct.pack('!I', start_value)
395
    end_packed = struct.pack('!I', end_value)
396
    return ipv4_range_to_list(start_packed, end_packed)
397
398
399 View Code Duplication
def target_to_ipv6_cidr(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
400
    """ Attempt to return a IPv6 CIDR list from a target string. """
401
402
    splitted = target.split('/')
403
    if len(splitted) != 2:
404
        return None
405
    try:
406
        start_packed = inet_pton(socket.AF_INET6, splitted[0])
407
        block = int(splitted[1])
408
    except (socket.error, ValueError):
409
        return None
410
    if block <= 0 or block > 126:
411
        return None
412
    start_value = int(binascii.hexlify(start_packed), 16) >> (128 - block)
413
    start_value = (start_value << (128 - block)) + 1
414
    end_value = (start_value | (int('ff' * 16, 16) >> block)) - 1
415
    high = start_value >> 64
416
    low = start_value & ((1 << 64) - 1)
417
    start_packed = struct.pack('!QQ', high, low)
418
    high = end_value >> 64
419
    low = end_value & ((1 << 64) - 1)
420
    end_packed = struct.pack('!QQ', high, low)
421
    return ipv6_range_to_list(start_packed, end_packed)
422
423
424 View Code Duplication
def target_to_ipv4_long(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
425
    """ Attempt to return a IPv4 long-range list from a target string. """
426
427
    splitted = target.split('-')
428
    if len(splitted) != 2:
429
        return None
430
    try:
431
        start_packed = inet_pton(socket.AF_INET, splitted[0])
432
        end_packed = inet_pton(socket.AF_INET, splitted[1])
433
    except socket.error:
434
        return None
435
    if end_packed < start_packed:
436
        return None
437
    return ipv4_range_to_list(start_packed, end_packed)
438
439
440 View Code Duplication
def ipv6_range_to_list(start_packed, end_packed):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
441
    """ Return a list of IPv6 entries from start_packed to end_packed. """
442
443
    new_list = list()
444
    start = int(binascii.hexlify(start_packed), 16)
445
    end = int(binascii.hexlify(end_packed), 16)
446
    for value in range(start, end + 1):
447
        high = value >> 64
448
        low = value & ((1 << 64) - 1)
449
        new_ip = inet_ntop(socket.AF_INET6,
450
                           struct.pack('!2Q', high, low))
451
        new_list.append(new_ip)
452
    return new_list
453
454
455 View Code Duplication
def target_to_ipv6_short(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
456
    """ Attempt to return a IPv6 short-range list from a target string. """
457
458
    splitted = target.split('-')
459
    if len(splitted) != 2:
460
        return None
461
    try:
462
        start_packed = inet_pton(socket.AF_INET6, splitted[0])
463
        end_value = int(splitted[1], 16)
464
    except (socket.error, ValueError):
465
        return None
466
    start_value = int(binascii.hexlify(start_packed[14:]), 16)
467
    if end_value < 0 or end_value > 0xffff or end_value < start_value:
468
        return None
469
    end_packed = start_packed[:14] + struct.pack('!H', end_value)
470
    return ipv6_range_to_list(start_packed, end_packed)
471
472
473 View Code Duplication
def target_to_ipv6_long(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
474
    """ Attempt to return a IPv6 long-range list from a target string. """
475
476
    splitted = target.split('-')
477
    if len(splitted) != 2:
478
        return None
479
    try:
480
        start_packed = inet_pton(socket.AF_INET6, splitted[0])
481
        end_packed = inet_pton(socket.AF_INET6, splitted[1])
482
    except socket.error:
483
        return None
484
    if end_packed < start_packed:
485
        return None
486
    return ipv6_range_to_list(start_packed, end_packed)
487
488
489
def target_to_hostname(target):
490
    """ Attempt to return a single hostname list from a target string. """
491
492
    if len(target) == 0 or len(target) > 255:
493
        return None
494
    if not re.match(r'^[\w.-]+$', target):
495
        return None
496
    return [target]
497
498
499 View Code Duplication
def target_to_list(target):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
500
    """ Attempt to return a list of single hosts from a target string. """
501
502
    # Is it an IPv4 address ?
503
    new_list = target_to_ipv4(target)
504
    # Is it an IPv6 address ?
505
    if not new_list:
506
        new_list = target_to_ipv6(target)
507
    # Is it an IPv4 CIDR ?
508
    if not new_list:
509
        new_list = target_to_ipv4_cidr(target)
510
    # Is it an IPv6 CIDR ?
511
    if not new_list:
512
        new_list = target_to_ipv6_cidr(target)
513
    # Is it an IPv4 short-range ?
514
    if not new_list:
515
        new_list = target_to_ipv4_short(target)
516
    # Is it an IPv4 long-range ?
517
    if not new_list:
518
        new_list = target_to_ipv4_long(target)
519
    # Is it an IPv6 short-range ?
520
    if not new_list:
521
        new_list = target_to_ipv6_short(target)
522
    # Is it an IPv6 long-range ?
523
    if not new_list:
524
        new_list = target_to_ipv6_long(target)
525
    # Is it a hostname ?
526
    if not new_list:
527
        new_list = target_to_hostname(target)
528
    return new_list
529
530
531 View Code Duplication
def target_str_to_list(target_str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
532
    """ Parses a targets string into a list of individual targets. """
533
    new_list = list()
534
    for target in target_str.split(','):
535
        target = target.strip()
536
        target_list = target_to_list(target)
537
        if target_list:
538
            new_list.extend(target_list)
539
        else:
540
            LOGGER.info("{0}: Invalid target value".format(target))
541
            return None
542
    return list(collections.OrderedDict.fromkeys(new_list))
543
544
545
def resolve_hostname(hostname):
546
    """ Returns IP of a hostname. """
547
548
    assert hostname
549
    try:
550
        return socket.gethostbyname(hostname)
551
    except socket.gaierror:
552
        return None
553
554
555 View Code Duplication
def port_range_expand(portrange):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
556
    """
557
    Receive a port range and expands it in individual ports.
558
559
    @input Port range.
560
    e.g. "4-8"
561
562
    @return List of integers.
563
    e.g. [4, 5, 6, 7, 8]
564
    """
565
    if not portrange or '-' not in portrange:
566
        LOGGER.info("Invalid port range format")
567
        return None
568
    port_list = list()
569
    for single_port in range(int(portrange[:portrange.index('-')]),
570
                             int(portrange[portrange.index('-') + 1:]) + 1):
571
        port_list.append(single_port)
572
    return port_list
573
574
575 View Code Duplication
def port_str_arrange(ports):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
576
    """ Gives a str in the format (always tcp listed first).
577
    T:<tcp ports/portrange comma separated>U:<udp ports comma separated>
578
    """
579
    b_tcp = ports.find("T")
580
    b_udp = ports.find("U")
581
    if (b_udp != -1 and b_tcp != -1) and b_udp < b_tcp:
582
        return ports[b_tcp:] + ports[b_udp:b_tcp]
583
584
    return ports
585
586
587 View Code Duplication
def ports_str_check_failed(port_str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
588
    """
589
    Check if the port string is well formed.
590
    Return True if fail, False other case.
591
    """
592
593
    pattern = r'[^TU:0-9, \-]'
594
    if (
595
        re.search(pattern, port_str)
596
        or port_str.count('T') > 1
597
        or port_str.count('U') > 1
598
        or port_str.count(':') < (port_str.count('T') + port_str.count('U'))
599
    ):
600
        return True
601
    return False
602
603
604 View Code Duplication
def ports_as_list(port_str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
605
    """
606
    Parses a ports string into two list of individual tcp and udp ports.
607
608
    @input string containing a port list
609
    e.g. T:1,2,3,5-8 U:22,80,600-1024
610
611
    @return two list of sorted integers, for tcp and udp ports respectively.
612
    """
613
    if not port_str:
614
        LOGGER.info("Invalid port value")
615
        return [None, None]
616
617
    if ports_str_check_failed(port_str):
618
        LOGGER.info("{0}: Port list malformed.")
619
        return [None, None]
620
621
    tcp_list = list()
622
    udp_list = list()
623
    ports = port_str.replace(' ', '')
624
    b_tcp = ports.find("T")
625
    b_udp = ports.find("U")
626
627
    if ports[b_tcp - 1] == ',':
628
        ports = ports[:b_tcp - 1] + ports[b_tcp:]
629
    if ports[b_udp - 1] == ',':
630
        ports = ports[:b_udp - 1] + ports[b_udp:]
631
    ports = port_str_arrange(ports)
632
633
    tports = ''
634
    uports = ''
635
    # TCP ports listed first, then UDP ports
636
    if b_udp != -1 and b_tcp != -1:
637
        tports = ports[ports.index('T:') + 2:ports.index('U:')]
638
        uports = ports[ports.index('U:') + 2:]
639
    # Only UDP ports
640
    elif b_tcp == -1 and b_udp != -1:
641
        uports = ports[ports.index('U:') + 2:]
642
    # Only TCP ports
643
    elif b_udp == -1 and b_tcp != -1:
644
        tports = ports[ports.index('T:') + 2:]
645
    else:
646
        tports = ports
647
648
    if tports:
649
        for port in tports.split(','):
650
            if '-' in port:
651
                tcp_list.extend(port_range_expand(port))
652
            else:
653
                tcp_list.append(int(port))
654
        tcp_list.sort()
655
    if uports:
656
        for port in uports.split(','):
657
            if '-' in port:
658
                udp_list.extend(port_range_expand(port))
659
            else:
660
                udp_list.append(int(port))
661
        udp_list.sort()
662
663
    return (tcp_list, udp_list)
664
665
666
def get_tcp_port_list(port_str):
667
    """ Return a list with tcp ports from a given port list in string format """
668
    return ports_as_list(port_str)[0]
669
670
671
def get_udp_port_list(port_str):
672
    """ Return a list with udp ports from a given port list in string format """
673
    return ports_as_list(port_str)[1]
674
675
676 View Code Duplication
def port_list_compress(port_list):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
677
    """ Compress a port list and return a string. """
678
679
    if not port_list or len(port_list) == 0:
680
        LOGGER.info("Invalid or empty port list.")
681
        return ''
682
683
    port_list = sorted(set(port_list))
684
    compressed_list = []
685
    for key, group in itertools.groupby(enumerate(port_list),
686
                                        lambda t: t[1] - t[0]):
687
        group = list(group)
688
        if group[0][1] == group[-1][1]:
689
            compressed_list.append(str(group[0][1]))
690
        else:
691
            compressed_list.append(str(group[0][1]) + '-' + str(group[-1][1]))
692
693
    return ','.join(compressed_list)
694
695
696
def valid_uuid(value):
697
    """ Check if value is a valid UUID. """
698
699
    try:
700
        uuid.UUID(value, version=4)
701
        return True
702
    except (TypeError, ValueError, AttributeError):
703
        return False
704
705
706 View Code Duplication
def create_args_parser(description):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
707
    """ Create a command-line arguments parser for OSPD. """
708
709
    parser = argparse.ArgumentParser(description=description)
710
711
    def network_port(string):
712
        """ Check if provided string is a valid network port. """
713
714
        value = int(string)
715
        if not 0 < value <= 65535:
716
            raise argparse.ArgumentTypeError(
717
                'port must be in ]0,65535] interval')
718
        return value
719
720
    def cacert_file(cacert):
721
        """ Check if provided file is a valid CA Certificate """
722
        try:
723
            context = ssl.create_default_context(cafile=cacert)
724
        except AttributeError:
725
            # Python version < 2.7.9
726
            return cacert
727
        except IOError:
728
            raise argparse.ArgumentTypeError('CA Certificate not found')
729
        try:
730
            not_after = context.get_ca_certs()[0]['notAfter']
731
            not_after = ssl.cert_time_to_seconds(not_after)
732
            not_before = context.get_ca_certs()[0]['notBefore']
733
            not_before = ssl.cert_time_to_seconds(not_before)
734
        except (KeyError, IndexError):
735
            raise argparse.ArgumentTypeError('CA Certificate is erroneous')
736
        if not_after < int(time.time()):
737
            raise argparse.ArgumentTypeError('CA Certificate expired')
738
        if not_before > int(time.time()):
739
            raise argparse.ArgumentTypeError('CA Certificate not active yet')
740
        return cacert
741
742
    def log_level(string):
743
        """ Check if provided string is a valid log level. """
744
745
        value = getattr(logging, string.upper(), None)
746
        if not isinstance(value, int):
747
            raise argparse.ArgumentTypeError(
748
                'log level must be one of {debug,info,warning,error,critical}')
749
        return value
750
751
    def filename(string):
752
        """ Check if provided string is a valid file path. """
753
754
        if not os.path.isfile(string):
755
            raise argparse.ArgumentTypeError(
756
                '%s is not a valid file path' % string)
757
        return string
758
759
    parser.add_argument('-p', '--port', default=PORT, type=network_port,
760
                        help='TCP Port to listen on. Default: {0}'.format(PORT))
761
    parser.add_argument('-b', '--bind-address', default=ADDRESS,
762
                        help='Address to listen on. Default: {0}'
763
                        .format(ADDRESS))
764
    parser.add_argument('-u', '--unix-socket',
765
                        help='Unix file socket to listen on.')
766
    parser.add_argument('-k', '--key-file', type=filename,
767
                        help='Server key file. Default: {0}'.format(KEY_FILE))
768
    parser.add_argument('-c', '--cert-file', type=filename,
769
                        help='Server cert file. Default: {0}'.format(CERT_FILE))
770
    parser.add_argument('--ca-file', type=cacert_file,
771
                        help='CA cert file. Default: {0}'.format(CA_FILE))
772
    parser.add_argument('-L', '--log-level', default='warning', type=log_level,
773
                        help='Wished level of logging. Default: WARNING')
774
    parser.add_argument('--foreground', action='store_true',
775
                        help='Run in foreground and logs all messages to console.')
776
    parser.add_argument('-l', '--log-file', type=filename,
777
                        help='Path to the logging file.')
778
    parser.add_argument('--version', action='store_true',
779
                        help='Print version then exit.')
780
    return parser
781
782
783
def go_to_background():
784
    """ Daemonize the running process. """
785
    try:
786
        if os.fork():
787
            sys.exit()
788
    except OSError as errmsg:
789
        LOGGER.error('Fork failed: {0}'.format(errmsg))
790
        sys.exit('Fork failed')
791
792
793 View Code Duplication
def get_common_args(parser, args=None):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
794
    """ Return list of OSPD common command-line arguments from parser, after
795
    validating provided values or setting default ones.
796
797
    """
798
799
    options = parser.parse_args(args)
800
    # TCP Port to listen on.
801
    port = options.port
802
803
    # Network address to bind listener to
804
    address = options.bind_address
805
806
    # Unix file socket to listen on
807
    unix_socket = options.unix_socket
808
809
    # Debug level.
810
    log_level = options.log_level
811
812
    # Server key path.
813
    keyfile = options.key_file or KEY_FILE
814
815
    # Server cert path.
816
    certfile = options.cert_file or CERT_FILE
817
818
    # CA cert path.
819
    cafile = options.ca_file or CA_FILE
820
821
    common_args = dict()
822
    common_args['port'] = port
823
    common_args['address'] = address
824
    common_args['unix_socket'] = unix_socket
825
    common_args['keyfile'] = keyfile
826
    common_args['certfile'] = certfile
827
    common_args['cafile'] = cafile
828
    common_args['log_level'] = log_level
829
    common_args['foreground'] = options.foreground
830
    common_args['log_file'] = options.log_file
831
    common_args['version'] = options.version
832
833
    return common_args
834
835
836 View Code Duplication
def print_version(wrapper):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
837
    """ Prints the server version and license information."""
838
839
    scanner_name = wrapper.get_scanner_name()
840
    server_version = wrapper.get_server_version()
841
    print("OSP Server for {0} version {1}".format(scanner_name, server_version))
842
    protocol_version = wrapper.get_protocol_version()
843
    print("OSP Version: {0}".format(protocol_version))
844
    daemon_name = wrapper.get_daemon_name()
845
    daemon_version = wrapper.get_daemon_version()
846
    print("Using: {0} {1}".format(daemon_name, daemon_version))
847
    print("Copyright (C) 2014, 2015 Greenbone Networks GmbH\n"
848
          "License GPLv2+: GNU GPL version 2 or later\n"
849
          "This is free software: you are free to change"
850
          " and redistribute it.\n"
851
          "There is NO WARRANTY, to the extent permitted by law.")
852
853
854 View Code Duplication
def main(name, klass):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
855
    """ OSPD Main function. """
856
857
    # Common args parser.
858
    parser = create_args_parser(name)
859
860
    # Common args
861
    cargs = get_common_args(parser)
862
    logging.getLogger().setLevel(cargs['log_level'])
863
    wrapper = klass(certfile=cargs['certfile'], keyfile=cargs['keyfile'],
864
                    cafile=cargs['cafile'])
865
866
    if cargs['version']:
867
        print_version(wrapper)
868
        sys.exit()
869
870
    if cargs['foreground']:
871
        console = logging.StreamHandler()
872
        console.setFormatter(
873
            logging.Formatter(
874
                '%(asctime)s %(name)s: %(levelname)s: %(message)s'))
875
        logging.getLogger().addHandler(console)
876
    elif cargs['log_file']:
877
        logfile = logging.handlers.WatchedFileHandler(cargs['log_file'])
878
        logfile.setFormatter(
879
            logging.Formatter(
880
                '%(asctime)s %(name)s: %(levelname)s: %(message)s'))
881
        logging.getLogger().addHandler(logfile)
882
        go_to_background()
883
    else:
884
        syslog = logging.handlers.SysLogHandler('/dev/log')
885
        syslog.setFormatter(
886
            logging.Formatter('%(name)s: %(levelname)s: %(message)s'))
887
        logging.getLogger().addHandler(syslog)
888
        # Duplicate syslog's file descriptor to stout/stderr.
889
        syslog_fd = syslog.socket.fileno()
890
        os.dup2(syslog_fd, 1)
891
        os.dup2(syslog_fd, 2)
892
        go_to_background()
893
894
    if not wrapper.check():
895
        return 1
896
    return wrapper.run(cargs['address'], cargs['port'], cargs['unix_socket'])
897