Passed
Pull Request — master (#209)
by Juan José
03:26
created

ospd.main.exit_cleanup()   A

Complexity

Conditions 4

Size

Total Lines 16
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 12
nop 4
dl 0
loc 16
rs 9.8
c 0
b 0
f 0
1
# Copyright (C) 2019 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
19
import logging
20
21
from logging.handlers import SysLogHandler, WatchedFileHandler
22
23
import os
24
import sys
25
import atexit
26
import signal
27
28
from functools import partial
29
30
from typing import Type, Optional
31
from pathlib import Path
32
33
from ospd.misc import go_to_background, create_pid
34
from ospd.ospd import OSPDaemon
35
from ospd.parser import create_parser, ParserType
36
from ospd.server import TlsServer, UnixSocketServer, BaseServer
37
38
COPYRIGHT = """Copyright (C) 2014, 2015, 2018, 2019 Greenbone Networks GmbH
39
License GPLv2+: GNU GPL version 2 or later
40
This is free software: you are free to change and redistribute it.
41
There is NO WARRANTY, to the extent permitted by law."""
42
43
LOGGER = logging.getLogger(__name__)
44
45
46
def print_version(daemon: OSPDaemon, file=sys.stdout):
47
    """ Prints the server version and license information."""
48
49
    scanner_name = daemon.get_scanner_name()
50
    server_version = daemon.get_server_version()
51
    protocol_version = daemon.get_protocol_version()
52
    daemon_name = daemon.get_daemon_name()
53
    daemon_version = daemon.get_daemon_version()
54
55
    print(
56
        "OSP Server for {0}: {1}".format(scanner_name, server_version),
57
        file=file,
58
    )
59
    print("OSP: {0}".format(protocol_version), file=file)
60
    print("{0}: {1}".format(daemon_name, daemon_version), file=file)
61
    print(file=file)
62
    print(COPYRIGHT, file=file)
63
64
65
def init_logging(
66
    name: str,
67
    log_level: int,
68
    *,
69
    log_file: Optional[str] = None,
70
    foreground: Optional[bool] = False
71
):
72
73
    rootlogger = logging.getLogger()
74
    rootlogger.setLevel(log_level)
75
76
    if foreground:
77
        console = logging.StreamHandler()
78
        console.setFormatter(
79
            logging.Formatter(
80
                '%(asctime)s {}: %(levelname)s: (%(name)s) %(message)s'.format(
81
                    name
82
                )
83
            )
84
        )
85
        rootlogger.addHandler(console)
86
    elif log_file:
87
        logfile = WatchedFileHandler(log_file)
88
        logfile.setFormatter(
89
            logging.Formatter(
90
                '%(asctime)s {}: %(levelname)s: (%(name)s) %(message)s'.format(
91
                    name
92
                )
93
            )
94
        )
95
        rootlogger.addHandler(logfile)
96
    else:
97
        syslog = SysLogHandler('/dev/log')
98
        syslog.setFormatter(
99
            logging.Formatter(
100
                '{}: %(levelname)s: (%(name)s) %(message)s'.format(name)
101
            )
102
        )
103
        rootlogger.addHandler(syslog)
104
        # Duplicate syslog's file descriptor to stout/stderr.
105
        syslog_fd = syslog.socket.fileno()
106
        os.dup2(syslog_fd, 1)
107
        os.dup2(syslog_fd, 2)
108
109
110
def exit_cleanup(
111
    pidfile: str, server: BaseServer, _signum=None, _frame=None
112
) -> None:
113
    """ Removes the pidfile before ending the daemon. """
114
    pidpath = Path(pidfile)
115
116
    if not pidpath.is_file():
117
        return
118
119
    with pidpath.open() as f:
120
        if int(f.read()) == os.getpid():
121
            LOGGER.info("Shutting-down server ...")
122
            server.close()
123
            LOGGER.debug("Finishing daemon process")
124
            pidpath.unlink()
125
            sys.exit()
126
127
128
def main(
129
    name: str,
130
    daemon_class: Type[OSPDaemon],
131
    parser: Optional[ParserType] = None,
132
):
133
    """ OSPD Main function. """
134
135
    if not parser:
136
        parser = create_parser(name)
137
    args = parser.parse_arguments()
138
139
    if args.version:
140
        args.foreground = True
141
142
    init_logging(
143
        name, args.log_level, log_file=args.log_file, foreground=args.foreground
144
    )
145
146
    if args.port == 0:
147
        server = UnixSocketServer(
148
            args.unix_socket, args.socket_mode, args.stream_timeout,
149
        )
150
    else:
151
        server = TlsServer(
152
            args.address,
153
            args.port,
154
            args.cert_file,
155
            args.key_file,
156
            args.ca_file,
157
            args.stream_timeout,
158
        )
159
160
    daemon = daemon_class(**vars(args))
161
162
    if args.version:
163
        print_version(daemon)
164
        sys.exit()
165
166
    if args.list_commands:
167
        print(daemon.get_help_text())
168
        sys.exit()
169
170
    if not args.foreground:
171
        go_to_background()
172
173
    if not create_pid(args.pid_file):
174
        sys.exit()
175
176
    # Set signal handler and cleanup
177
    atexit.register(exit_cleanup, pidfile=args.pid_file, server=server)
178
    signal.signal(signal.SIGTERM, partial(exit_cleanup, args.pid_file, server))
179
180
    if not daemon.check():
181
        return 1
182
183
    daemon.init(server)
184
    daemon.run()
185
186
    return 0
187