Completed
Push — master ( f36c2e...ca95a9 )
by Ionel Cristian
41s
created

main()   F

Complexity

Conditions 11

Size

Total Lines 50

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 1 Features 0
Metric Value
cc 11
c 1
b 1
f 0
dl 0
loc 50
rs 3.375

How to fix   Complexity   

Complexity

Complex classes like main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
from __future__ import print_function
3
4
import argparse
5
import errno
6
import os
7
import re
8
import readline
9
import select
10
import signal
11
import socket
12
import sys
13
import time
14
import threading
15
16
try:
17
    input = raw_input
18
except NameError:
19
    pass
20
21
SIG_NAMES = {}
22
SIG_NUMBERS = set()
23
for sig, num in vars(signal).items():
24
    if sig.startswith('SIG') and '_' not in sig:
25
        SIG_NAMES[sig] = num
26
        SIG_NAMES[sig[3:]] = num
27
        SIG_NUMBERS.add(num)
28
29
30
def parse_pid(value, regex=re.compile(r'^(/tmp/manhole-)?(?P<pid>\d+)$')):
31
    match = regex.match(value)
32
    if not match:
33
        raise argparse.ArgumentTypeError("PID must be in one of these forms: 1234 or /tmp/manhole-1234")
34
35
    return int(match.group('pid'))
36
37
38
def parse_signal(value):
39
    try:
40
        value = int(value)
41
    except ValueError:
42
        pass
43
    else:
44
        if value in SIG_NUMBERS:
45
            return value
46
        else:
47
            raise argparse.ArgumentTypeError("Invalid signal number %s. Expected one of: %s" % (
48
                value, ', '.join(str(i) for i in SIG_NUMBERS)
49
            ))
50
    value = value.upper()
51
    if value in SIG_NAMES:
52
        return SIG_NAMES[value]
53
    else:
54
        raise argparse.ArgumentTypeError("Invalid signal name %r." % value)
55
56
57
parser = argparse.ArgumentParser(description='Connect to a manhole.')
58
parser.add_argument('pid', metavar='PID', type=parse_pid,  # nargs='?',
59
                    help='A numerical process id, or a path in the form: /tmp/manhole-1234')
60
parser.add_argument('-t', '--timeout', dest='timeout', default=1, type=float,
61
                    help='Timeout to use. Default: %(default)s seconds.')
62
group = parser.add_mutually_exclusive_group()
63
group.add_argument('-1', '-USR1', dest='signal', action='store_const', const=int(signal.SIGUSR1),
64
                   help='Send USR1 (%(const)s) to the process before connecting.')
65
group.add_argument('-2', '-USR2', dest='signal', action='store_const', const=int(signal.SIGUSR2),
66
                   help='Send USR2 (%(const)s) to the process before connecting.')
67
group.add_argument('-s', '--signal', dest='signal', type=parse_signal, metavar="SIGNAL",
68
                   help='Send the given SIGNAL to the process before connecting.')
69
70
71
class ConnectionHandler(threading.Thread):
72
    def __init__(self, timeout, sock, read_fd=None, wait_the_end=True):
73
        super(ConnectionHandler, self).__init__()
74
        self.sock = sock
75
        self.read_fd = read_fd
76
        self.conn_fd = sock.fileno()
77
        self.timeout = timeout
78
        self.should_run = True
79
        self._poller = select.poll()
80
        self.wait_the_end = wait_the_end
81
82
    def run(self):
83
        if self.read_fd is not None:
84
            self._poller.register(self.read_fd, select.POLLIN | select.POLLPRI | select.POLLERR | select.POLLHUP)
85
        self._poller.register(self.conn_fd, select.POLLIN | select.POLLPRI | select.POLLERR | select.POLLHUP)
86
87
        while self.should_run:
88
            self.poll()
89
        if self.wait_the_end:
90
            t = time.time()
91
            while time.time() - t < self.timeout:
92
                self.poll()
93
94
    def poll(self):
95
        for fd, _ in self._poller.poll(self.timeout):
96
            if fd == self.conn_fd:
97
                data = self.sock.recv(1024*1024)
98
                sys.stdout.write(data.decode('utf8'))
99
                sys.stdout.flush()
100
                readline.redisplay()
101
            elif fd == self.read_fd:
102
                data = os.read(self.read_fd, 1024)
103
                self.sock.sendall(data)
104
            else:
105
                raise RuntimeError("Unknown FD %s" % fd)
106
107
108
def main():
109
    args = parser.parse_args()
110
111
    histfile = os.path.join(os.path.expanduser("~"), ".manhole_history")
112
    try:
113
        readline.read_history_file(histfile)
114
    except IOError:
115
        pass
116
    import atexit
117
118
    atexit.register(readline.write_history_file, histfile)
119
    del histfile
120
121
    if args.signal:
122
        os.kill(args.pid, args.signal)
123
124
    start = time.time()
125
    uds_path = '/tmp/manhole-%s' % args.pid
126
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
127
    sock.settimeout(args.timeout)
128
    while time.time() - start < args.timeout:
129
        try:
130
            sock.connect(uds_path)
131
        except Exception as exc:
132
            if exc.errno not in (errno.ENOENT, errno.ECONNREFUSED):
133
                print("Failed to connect to %r: %r" % (uds_path, exc), file=sys.stderr)
134
        else:
135
            break
136
    else:
137
        print("Failed to connect to %r: Timeout" % uds_path, file=sys.stderr)
138
        sys.exit(5)
139
140
    read_fd, write_fd = os.pipe()
141
142
    thread = ConnectionHandler(args.timeout, sock, read_fd, not sys.stdin.isatty())
143
    thread.start()
144
145
    try:
146
        while thread.is_alive():
147
            try:
148
                data = input()
149
            except EOFError:
150
                break
151
            os.write(write_fd, data.encode('utf8'))
152
            os.write(write_fd, b'\n')
153
    except KeyboardInterrupt:
154
        pass
155
    finally:
156
        thread.should_run = False
157
        thread.join()
158