Test Failed
Push — master ( e380d0...f5671d )
by W
02:58
created

st2common/st2common/runners/parallel_ssh.py (2 issues)

1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import absolute_import
17
import json
18
import re
19
import os
20
import traceback
21
22
import eventlet
23
from paramiko.ssh_exception import SSHException
24
25
from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE
26
from st2common.runners.paramiko_ssh import ParamikoSSHClient
27
from st2common.runners.paramiko_ssh import SSHCommandTimeoutError
28
from st2common import log as logging
29
from st2common.exceptions.ssh import NoHostsConnectedToException
30
import st2common.util.jsonify as jsonify
31
from st2common.util import ip_utils
32
33
LOG = logging.getLogger(__name__)
34
35
36
class ParallelSSHClient(object):
37
    KEYS_TO_TRANSFORM = ['stdout', 'stderr']
38
    CONNECT_ERROR = 'Cannot connect to host.'
39
40
    def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_material=None, port=22,
41
                 bastion_host=None, concurrency=10, raise_on_any_error=False, connect=True,
42
                 passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None,
43
                 sudo_password=False):
44
        """
45
        :param handle_stdout_line_func: Callback function which is called dynamically each time a
46
                                        new stdout line is received.
47
        :type handle_stdout_line_func: ``func``
48
49
        :param handle_stderr_line_func: Callback function which is called dynamically each time a
50
                                        new stderr line is received.
51
        :type handle_stderr_line_func: ``func``
52
        """
53
        self._ssh_user = user
54
55
        self._ssh_user = user
56
        self._ssh_key_file = pkey_file
57
        self._ssh_key_material = pkey_material
58
        self._ssh_password = password
59
        self._hosts = hosts
60
        self._successful_connects = 0
61
        self._ssh_port = port
62
        self._bastion_host = bastion_host
63
        self._passphrase = passphrase
64
        self._handle_stdout_line_func = handle_stdout_line_func
65
        self._handle_stderr_line_func = handle_stderr_line_func
66
        self._sudo_password = sudo_password
67
68
        if not hosts:
69
            raise Exception('Need an non-empty list of hosts to talk to.')
70
71
        self._pool = eventlet.GreenPool(concurrency)
72
        self._hosts_client = {}
73
        self._bad_hosts = {}
74
        self._scan_interval = 0.1
75
76
        if connect:
77
            connect_results = self.connect(raise_on_any_error=raise_on_any_error)
78
            extra = {'_connect_results': connect_results}
79
            LOG.debug('Connect to hosts complete.', extra=extra)
80
81
    def connect(self, raise_on_any_error=False):
82
        """
83
        Connect to hosts in hosts list. Returns status of connect as a dict.
84
85
        :param raise_on_any_error: Optional Raise an exception even if connecting to one
86
                                   of the hosts fails.
87
        :type raise_on_any_error: ``boolean``
88
89
        :rtype: ``dict`` of ``str`` to ``dict``
90
        """
91
        results = {}
92
93
        for host in self._hosts:
94
            while not self._pool.free():
95
                eventlet.sleep(self._scan_interval)
96
            self._pool.spawn(self._connect, host=host, results=results,
97
                             raise_on_any_error=raise_on_any_error)
98
99
        self._pool.waitall()
100
101
        if self._successful_connects < 1:
102
            # We definitely have to raise an exception in this case.
103
            LOG.error('Unable to connect to any of the hosts.',
104
                      extra={'connect_results': results})
105
            msg = ('Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s' %
106
                   (self._hosts, json.dumps(results, indent=2)))
107
            raise NoHostsConnectedToException(msg)
108
109
        return results
110
111
    def run(self, cmd, timeout=None):
112
        """
113
        Run a command on remote hosts. Returns a dict containing results
114
        of execution from all hosts.
115
116
        :param cmd: Command to run. Must be shlex quoted.
117
        :type cmd: ``str``
118
119
        :param timeout: Optional Timeout for the command.
120
        :type timeout: ``int``
121
122
        :param cwd: Optional Current working directory. Must be shlex quoted.
123
        :type cwd: ``str``
124
125
        :rtype: ``dict`` of ``str`` to ``dict``
126
        """
127
128
        options = {
129
            'cmd': cmd,
130
            'timeout': timeout
131
        }
132
        results = self._execute_in_pool(self._run_command, **options)
133
        return results
134
135
    def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
136
        """
137
        Copy a file or folder to remote host.
138
139
        :param local_path: Path to local file or dir. Must be shlex quoted.
140
        :type local_path: ``str``
141
142
        :param remote_path: Path to remote file or dir. Must be shlex quoted.
143
        :type remote_path: ``str``
144
145
        :param mode: Optional mode to use for the file or dir.
146
        :type mode: ``int``
147
148
        :param mirror_local_mode: Optional Flag to mirror the mode
149
                                           on local file/dir on remote host.
150
        :type mirror_local_mode: ``boolean``
151
152
        :rtype: ``dict`` of ``str`` to ``dict``
153
        """
154
155
        if not os.path.exists(local_path):
156
            raise Exception('Local path %s does not exist.' % local_path)
157
158
        options = {
159
            'local_path': local_path,
160
            'remote_path': remote_path,
161
            'mode': mode,
162
            'mirror_local_mode': mirror_local_mode
163
        }
164
165
        return self._execute_in_pool(self._put_files, **options)
166
167
    def mkdir(self, path):
168
        """
169
        Create a directory on remote hosts.
170
171
        :param path: Path to remote dir that must be created. Must be shlex quoted.
172
        :type path: ``str``
173
174
        :rtype path: ``dict`` of ``str`` to ``dict``
175
        """
176
177
        options = {
178
            'path': path
179
        }
180
        return self._execute_in_pool(self._mkdir, **options)
181
182
    def delete_file(self, path):
183
        """
184
        Delete a file on remote hosts.
185
186
        :param path: Path to remote file that must be deleted. Must be shlex quoted.
187
        :type path: ``str``
188
189
        :rtype path: ``dict`` of ``str`` to ``dict``
190
        """
191
192
        options = {
193
            'path': path
194
        }
195
        return self._execute_in_pool(self._delete_file, **options)
196
197
    def delete_dir(self, path, force=False, timeout=None):
198
        """
199
        Delete a dir on remote hosts.
200
201
        :param path: Path to remote dir that must be deleted. Must be shlex quoted.
202
        :type path: ``str``
203
204
        :rtype path: ``dict`` of ``str`` to ``dict``
205
        """
206
207
        options = {
208
            'path': path,
209
            'force': force
210
        }
211
        return self._execute_in_pool(self._delete_dir, **options)
212
213
    def close(self):
214
        """
215
        Close all open SSH connections to hosts.
216
        """
217
218
        for host in self._hosts_client.keys():
219
            try:
220
                self._hosts_client[host].close()
221
            except:
222
                LOG.exception('Failed shutting down SSH connection to host: %s', host)
223
224
    def _execute_in_pool(self, execute_method, **kwargs):
225
        results = {}
226
227
        for host in self._bad_hosts.keys():
228
            results[host] = self._bad_hosts[host]
229
230
        for host in self._hosts_client.keys():
231
            while not self._pool.free():
232
                eventlet.sleep(self._scan_interval)
233
            self._pool.spawn(execute_method, host=host, results=results, **kwargs)
234
235
        self._pool.waitall()
236
        return results
237
238
    def _connect(self, host, results, raise_on_any_error=False):
239
        (hostname, port) = self._get_host_port_info(host)
240
241
        extra = {'host': host, 'port': port, 'user': self._ssh_user}
242
        if self._ssh_password:
243
            extra['password'] = '<redacted>'
244
        elif self._ssh_key_file:
245
            extra['key_file_path'] = self._ssh_key_file
246
        else:
247
            extra['private_key'] = '<redacted>'
248
249
        LOG.debug('Connecting to host.', extra=extra)
250
251
        client = ParamikoSSHClient(hostname=hostname, port=port,
252
                                   username=self._ssh_user,
253
                                   password=self._ssh_password,
254
                                   bastion_host=self._bastion_host,
255
                                   key_files=self._ssh_key_file,
256
                                   key_material=self._ssh_key_material,
257
                                   passphrase=self._passphrase,
258
                                   handle_stdout_line_func=self._handle_stdout_line_func,
259
                                   handle_stderr_line_func=self._handle_stderr_line_func)
260
        try:
261
            client.connect()
262
        except SSHException as ex:
263
            LOG.exception(ex)
264
            if raise_on_any_error:
265
                raise
266
            error_dict = self._generate_error_result(exc=ex, message='Connection error.')
267
            self._bad_hosts[hostname] = error_dict
268
            results[hostname] = error_dict
269
        except Exception as ex:
270
            error = 'Failed connecting to host %s.' % hostname
271
            LOG.exception(error)
272
            if raise_on_any_error:
273
                raise
274
            error_dict = self._generate_error_result(exc=ex, message=error)
275
            self._bad_hosts[hostname] = error_dict
276
            results[hostname] = error_dict
277
        else:
278
            self._successful_connects += 1
279
            self._hosts_client[hostname] = client
280
            results[hostname] = {'message': 'Connected to host.'}
281
282
    def _run_command(self, host, cmd, results, timeout=None):
283
        try:
284
            LOG.debug('Running command: %s on host: %s.', cmd, host)
285
            client = self._hosts_client[host]
286
            (stdout, stderr, exit_code) = client.run(cmd, timeout=timeout,
287
                                                     call_line_handler_func=True)
288
289
            result = self._handle_command_result(stdout=stdout, stderr=stderr, exit_code=exit_code)
290
            results[host] = result
291
        except Exception as ex:
292
            cmd = self._sanitize_command_string(cmd=cmd)
293
            error = 'Failed executing command "%s" on host "%s"' % (cmd, host)
294
            LOG.exception(error)
295
            results[host] = self._generate_error_result(exc=ex, message=error)
296
297
    def _put_files(self, local_path, remote_path, host, results, mode=None,
298
                   mirror_local_mode=False):
299
        try:
300
            LOG.debug('Copying file to host: %s' % host)
0 ignored issues
show
Coding Style Best Practice introduced by
Specify string format arguments as logging function parameters
Loading history...
301
            if os.path.isdir(local_path):
302
                result = self._hosts_client[host].put_dir(local_path, remote_path)
303
            else:
304
                result = self._hosts_client[host].put(local_path, remote_path,
305
                                                      mirror_local_mode=mirror_local_mode,
306
                                                      mode=mode)
307
            LOG.debug('Result of copy: %s' % result)
0 ignored issues
show
Coding Style Best Practice introduced by
Specify string format arguments as logging function parameters
Loading history...
308
            results[host] = result
309
        except Exception as ex:
310
            error = 'Failed sending file(s) in path %s to host %s' % (local_path, host)
311
            LOG.exception(error)
312
            results[host] = self._generate_error_result(exc=ex, message=error)
313
314
    def _mkdir(self, host, path, results):
315
        try:
316
            result = self._hosts_client[host].mkdir(path)
317
            results[host] = result
318
        except Exception as ex:
319
            error = 'Failed "mkdir %s" on host %s.' % (path, host)
320
            LOG.exception(error)
321
            results[host] = self._generate_error_result(exc=ex, message=error)
322
323
    def _delete_file(self, host, path, results):
324
        try:
325
            result = self._hosts_client[host].delete_file(path)
326
            results[host] = result
327
        except Exception as ex:
328
            error = 'Failed deleting file %s on host %s.' % (path, host)
329
            LOG.exception(error)
330
            results[host] = self._generate_error_result(exc=ex, message=error)
331
332
    def _delete_dir(self, host, path, results, force=False, timeout=None):
333
        try:
334
            result = self._hosts_client[host].delete_dir(path, force=force, timeout=timeout)
335
            results[host] = result
336
        except Exception as ex:
337
            error = 'Failed deleting dir %s on host %s.' % (path, host)
338
            LOG.exception(error)
339
            results[host] = self._generate_error_result(exc=ex, message=error)
340
341
    def _get_host_port_info(self, host_str):
342
        (hostname, port) = ip_utils.split_host_port(host_str)
343
        if not port:
344
            port = self._ssh_port
345
346
        return (hostname, port)
347
348
    def _handle_command_result(self, stdout, stderr, exit_code):
349
        # Detect if user provided an invalid sudo password or sudo is not configured for that user
350
        if self._sudo_password:
351
            if re.search('sudo: \d+ incorrect password attempts', stderr):
352
                match = re.search('\[sudo\] password for (.+?)\:', stderr)
353
354
                if match:
355
                    username = match.groups()[0]
356
                else:
357
                    username = 'unknown'
358
359
                error = ('Invalid sudo password provided or sudo is not configured for this user '
360
                        '(%s)' % (username))
361
                raise ValueError(error)
362
        is_succeeded = (exit_code == 0)
363
        result_dict = {'stdout': stdout, 'stderr': stderr, 'return_code': exit_code,
364
                       'succeeded': is_succeeded, 'failed': not is_succeeded}
365
366
        result = jsonify.json_loads(result_dict, ParallelSSHClient.KEYS_TO_TRANSFORM)
367
        return result
368
369
    @staticmethod
370
    def _sanitize_command_string(cmd):
371
        """
372
        Remove any potentially sensitive information from the command string.
373
374
        For now we only mask the values of the sensitive environment variables.
375
        """
376
        if not cmd:
377
            return cmd
378
379
        result = re.sub('ST2_ACTION_AUTH_TOKEN=(.+?)\s+?', 'ST2_ACTION_AUTH_TOKEN=%s ' %
380
                        (MASKED_ATTRIBUTE_VALUE), cmd)
381
        return result
382
383
    @staticmethod
384
    def _generate_error_result(exc, message):
385
        """
386
        :param exc: Raised exception.
387
        :type exc: Exception.
388
389
        :param message: Error message which will be prefixed to the exception exception message.
390
        :type message: ``str``
391
        """
392
        exc_message = getattr(exc, 'message', str(exc))
393
        error_message = '%s %s' % (message, exc_message)
394
        traceback_message = traceback.format_exc()
395
396
        if isinstance(exc, SSHCommandTimeoutError):
397
            return_code = -9
398
            timeout = True
399
        else:
400
            timeout = False
401
            return_code = 255
402
403
        stdout = getattr(exc, 'stdout', None) or ''
404
        stderr = getattr(exc, 'stderr', None) or ''
405
406
        error_dict = {
407
            'failed': True,
408
            'succeeded': False,
409
            'timeout': timeout,
410
            'return_code': return_code,
411
            'stdout': stdout,
412
            'stderr': stderr,
413
            'error': error_message,
414
            'traceback': traceback_message,
415
        }
416
        return error_dict
417
418
    def __repr__(self):
419
        return ('<ParallelSSHClient hosts=%s,user=%s,id=%s>' %
420
                (repr(self._hosts), self._ssh_user, id(self)))
421