main()   F
last analyzed

Complexity

Conditions 11

Size

Total Lines 50

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 1 Features 0
Metric Value
cc 11
c 1
b 1
f 0
dl 0
loc 50
rs 3.375

How to fix   Complexity   

Complexity

Complex classes like main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
from __future__ import print_function
3
4
import argparse
5
import errno
6
import os
7
import re
8
import readline
9
import select
10
import signal
11
import socket
12
import sys
13
import threading
14
import time
15
16
try:
17
    input = raw_input
18
except NameError:
19
    pass
20
21
SIG_NAMES = {}
22
SIG_NUMBERS = set()
23
for sig, num in vars(signal).items():
24
    if sig.startswith('SIG') and '_' not in sig:
25
        SIG_NAMES[sig] = num
26
        SIG_NAMES[sig[3:]] = num
27
        SIG_NUMBERS.add(num)
28
29
30
def parse_pid(value, regex=re.compile(r'^(/tmp/manhole-)?(?P<pid>\d+)$')):
31
    match = regex.match(value)
32
    if not match:
33
        raise argparse.ArgumentTypeError("PID must be in one of these forms: 1234 or /tmp/manhole-1234")
34
35
    return int(match.group('pid'))
36
37
38
def parse_signal(value):
39
    try:
40
        value = int(value)
41
    except ValueError:
42
        pass
43
    else:
44
        if value in SIG_NUMBERS:
45
            return value
46
        else:
47
            raise argparse.ArgumentTypeError("Invalid signal number %s. Expected one of: %s" % (
48
                value, ', '.join(str(i) for i in SIG_NUMBERS)
49
            ))
50
    value = value.upper()
51
    if value in SIG_NAMES:
52
        return SIG_NAMES[value]
53
    else:
54
        raise argparse.ArgumentTypeError("Invalid signal name %r." % value)
55
56
57
parser = argparse.ArgumentParser(description='Connect to a manhole.')
58
parser.add_argument('pid', metavar='PID', type=parse_pid,  # nargs='?',
59
                    help='A numerical process id, or a path in the form: /tmp/manhole-1234')
60
parser.add_argument('-t', '--timeout', dest='timeout', default=1, type=float,
61
                    help='Timeout to use. Default: %(default)s seconds.')
62
group = parser.add_mutually_exclusive_group()
63
group.add_argument('-1', '-USR1', dest='signal', action='store_const', const=int(signal.SIGUSR1),
64
                   help='Send USR1 (%(const)s) to the process before connecting.')
65
group.add_argument('-2', '-USR2', dest='signal', action='store_const', const=int(signal.SIGUSR2),
66
                   help='Send USR2 (%(const)s) to the process before connecting.')
67
group.add_argument('-s', '--signal', dest='signal', type=parse_signal, metavar="SIGNAL",
68
                   help='Send the given SIGNAL to the process before connecting.')
69
70
71
class ConnectionHandler(threading.Thread):
72
    def __init__(self, timeout, sock, read_fd=None, wait_the_end=True):
73
        super(ConnectionHandler, self).__init__()
74
        self.sock = sock
75
        self.read_fd = read_fd
76
        self.conn_fd = sock.fileno()
77
        self.timeout = timeout
78
        self.should_run = True
79
        self._poller = select.poll()
80
        self.wait_the_end = wait_the_end
81
82
    def run(self):
83
        if self.read_fd is not None:
84
            self._poller.register(self.read_fd, select.POLLIN)
85
        self._poller.register(self.conn_fd, select.POLLIN)
86
87
        while self.should_run:
88
            self.poll()
89
        if self.wait_the_end:
90
            t = time.time()
91
            while time.time() - t < self.timeout:
92
                self.poll()
93
94
    def poll(self):
95
        milliseconds = self.timeout * 1000
96
        for fd, _ in self._poller.poll(milliseconds):
97
            if fd == self.conn_fd:
98
                data = self.sock.recv(1024*1024)
99
                sys.stdout.write(data.decode('utf8'))
100
                sys.stdout.flush()
101
                readline.redisplay()
102
            elif fd == self.read_fd:
103
                data = os.read(self.read_fd, 1024)
104
                self.sock.sendall(data)
105
            else:
106
                raise RuntimeError("Unknown FD %s" % fd)
107
108
109
def main():
110
    args = parser.parse_args()
111
112
    histfile = os.path.join(os.path.expanduser("~"), ".manhole_history")
113
    try:
114
        readline.read_history_file(histfile)
115
    except IOError:
116
        pass
117
    import atexit
118
119
    atexit.register(readline.write_history_file, histfile)
120
    del histfile
121
122
    if args.signal:
123
        os.kill(args.pid, args.signal)
124
125
    start = time.time()
126
    uds_path = '/tmp/manhole-%s' % args.pid
127
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
128
    sock.settimeout(args.timeout)
129
    while time.time() - start < args.timeout:
130
        try:
131
            sock.connect(uds_path)
132
        except Exception as exc:
133
            if exc.errno not in (errno.ENOENT, errno.ECONNREFUSED):
134
                print("Failed to connect to %r: %r" % (uds_path, exc), file=sys.stderr)
135
        else:
136
            break
137
    else:
138
        print("Failed to connect to %r: Timeout" % uds_path, file=sys.stderr)
139
        sys.exit(5)
140
141
    read_fd, write_fd = os.pipe()
142
143
    thread = ConnectionHandler(args.timeout, sock, read_fd, not sys.stdin.isatty())
144
    thread.start()
145
146
    try:
147
        while thread.is_alive():
148
            try:
149
                data = input()
150
            except EOFError:
151
                break
152
            os.write(write_fd, data.encode('utf8'))
153
            os.write(write_fd, b'\n')
154
    except KeyboardInterrupt:
155
        pass
156
    finally:
157
        thread.should_run = False
158
        thread.join()
159