Completed
Pull Request — master (#2596)
by Edward
07:10 queued 01:24
created

delete_dir()   B

Complexity

Conditions 2

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 29
rs 8.8571
cc 2
1
# Licensed to the Apache Software Foundation (ASF) under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import os
17
import posixpath
18
from StringIO import StringIO
19
import time
20
21
import eventlet
22
from oslo_config import cfg
23
24
import paramiko
25
26
# Depending on your version of Paramiko, it may cause a deprecation
27
# warning on Python 2.6.
28
# Ref: https://bugs.launchpad.net/paramiko/+bug/392973
29
30
from st2common.log import logging
31
from st2common.util.misc import strip_shell_chars
32
from st2common.util.shell import quote_unix
33
34
__all__ = [
35
    'ParamikoSSHClient',
36
37
    'SSHCommandTimeoutError'
38
]
39
40
PRIVATE_KEY_HEADER = 'PRIVATE KEY-----'.lower()
41
42
43
class SSHCommandTimeoutError(Exception):
44
    """
45
    Exception which is raised when an SSH command times out.
46
    """
47
48
    def __init__(self, cmd, timeout, stdout=None, stderr=None):
49
        """
50
        :param stdout: Stdout which was consumed until the timeout occured.
51
        :type stdout: ``str``
52
53
        :param stdout: Stderr which was consumed until the timeout occured.
54
        :type stderr: ``str``
55
        """
56
        self.cmd = cmd
57
        self.timeout = timeout
58
        self.stdout = stdout
59
        self.stderr = stderr
60
        message = 'Command didn\'t finish in %s seconds' % (timeout)
61
        super(SSHCommandTimeoutError, self).__init__(message)
62
63
    def __repr__(self):
64
        return ('<SSHCommandTimeoutError: cmd="%s",timeout=%s)>' %
65
                (self.cmd, self.timeout))
66
67
    def __str__(self):
68
        return self.message
69
70
71
class ParamikoSSHClient(object):
72
    """
73
    A SSH Client powered by Paramiko.
74
    """
75
76
    # Maximum number of bytes to read at once from a socket
77
    CHUNK_SIZE = 1024
78
    # How long to sleep while waiting for command to finish
79
    SLEEP_DELAY = 1.5
80
    # Connect socket timeout
81
    CONNECT_TIMEOUT = 60
82
83
    def __init__(self, hostname, port=22, username=None, password=None, bastion_host=None,
84
                 key=None, key_files=None, key_material=None, timeout=None, passphrase=None):
85
        """
86
        Authentication is always attempted in the following order:
87
88
        - The key passed in (if key is provided)
89
        - Any key we can find through an SSH agent (only if no password and
90
          key is provided)
91
        - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (only if no
92
          password and key is provided)
93
        - Plain username/password auth, if a password was given (if password is
94
          provided)
95
        """
96
        if key_files and key_material:
97
            raise ValueError(('key_files and key_material arguments are '
98
                              'mutually exclusive'))
99
100
        if passphrase and not key_material:
101
            raise ValueError('passphrase should accompany private key material')
102
103
        self.hostname = hostname
104
        self.port = port
105
        self.username = username if username else cfg.CONF.system_user
106
        self.password = password
107
        self.key = key if key else cfg.CONF.system_user.ssh_key_file
108
        self.key_files = key_files
109
        if not self.key_files and self.key:
110
            self.key_files = key  # `key` arg is deprecated.
111
        self.timeout = timeout or ParamikoSSHClient.CONNECT_TIMEOUT
112
        self.key_material = key_material
113
        self.client = None
114
        self.logger = logging.getLogger(__name__)
115
        self.sftp = None
116
        self.bastion_host = bastion_host
117
        self.bastion_client = None
118
        self.bastion_socket = None
119
        self.passphrase = passphrase
120
121
    def connect(self):
122
        """
123
        Connect to the remote node over SSH.
124
125
        :return: True if the connection has been successfully established,
126
                 False otherwise.
127
        :rtype: ``bool``
128
        """
129
        if self.bastion_host:
130
            self.logger.debug('Bastion host specified, connecting')
131
            self.bastion_client = self._connect(host=self.bastion_host)
132
            transport = self.bastion_client.get_transport()
133
            real_addr = (self.hostname, self.port)
134
            # fabric uses ('', 0) for direct-tcpip, this duplicates that behaviour
135
            # see https://github.com/fabric/fabric/commit/c2a9bbfd50f560df6c6f9675603fb405c4071cad
136
            local_addr = ('', 0)
137
            self.bastion_socket = transport.open_channel('direct-tcpip', real_addr, local_addr)
138
139
        self.client = self._connect(host=self.hostname, socket=self.bastion_socket)
140
        self.sftp = self.client.open_sftp()
141
        return True
142
143
    def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
144
        """
145
        Upload a file to the remote node.
146
147
        :type local_path: ``st``
148
        :param local_path: File path on the local node.
149
150
        :type remote_path: ``str``
151
        :param remote_path: File path on the remote node.
152
153
        :type mode: ``int``
154
        :param mode: Permissions mode for the file. E.g. 0744.
155
156
        :type mirror_local_mode: ``int``
157
        :param mirror_local_mode: Should remote file mirror local mode.
158
159
        :return: Attributes of the remote file.
160
        :rtype: :class:`posix.stat_result` or ``None``
161
        """
162
163
        if not local_path or not remote_path:
164
            raise Exception('Need both local_path and remote_path. local: %s, remote: %s' %
165
                            local_path, remote_path)
166
        local_path = quote_unix(local_path)
167
        remote_path = quote_unix(remote_path)
168
169
        extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
170
                 '_mirror_local_mode': mirror_local_mode}
171
        self.logger.debug('Uploading file', extra=extra)
172
173
        if not os.path.exists(local_path):
174
            raise Exception('Path %s does not exist locally.' % local_path)
175
176
        rattrs = self.sftp.put(local_path, remote_path)
177
178
        if mode or mirror_local_mode:
179
            local_mode = mode
180
            if not mode or mirror_local_mode:
181
                local_mode = os.stat(local_path).st_mode
182
183
            # Cast to octal integer in case of string
184
            if isinstance(local_mode, basestring):
185
                local_mode = int(local_mode, 8)
186
            local_mode = local_mode & 07777
187
            remote_mode = rattrs.st_mode
188
            # Only bitshift if we actually got an remote_mode
189
            if remote_mode is not None:
190
                remote_mode = (remote_mode & 07777)
191
            if local_mode != remote_mode:
192
                self.sftp.chmod(remote_path, local_mode)
193
194
        return rattrs
195
196
    def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False):
197
        """
198
        Upload a dir to the remote node.
199
200
        :type local_path: ``str``
201
        :param local_path: Dir path on the local node.
202
203
        :type remote_path: ``str``
204
        :param remote_path: Base dir path on the remote node.
205
206
        :type mode: ``int``
207
        :param mode: Permissions mode for the file. E.g. 0744.
208
209
        :type mirror_local_mode: ``int``
210
        :param mirror_local_mode: Should remote file mirror local mode.
211
212
        :return: List of files created on remote node.
213
        :rtype: ``list`` of ``str``
214
        """
215
216
        extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
217
                 '_mirror_local_mode': mirror_local_mode}
218
        self.logger.debug('Uploading dir', extra=extra)
219
220
        if os.path.basename(local_path):
221
            strip = os.path.dirname(local_path)
222
        else:
223
            strip = os.path.dirname(os.path.dirname(local_path))
224
225
        remote_paths = []
226
227
        for context, dirs, files in os.walk(local_path):
228
            rcontext = context.replace(strip, '', 1)
229
            # normalize pathname separators with POSIX separator
230
            rcontext = rcontext.replace(os.sep, '/')
231
            rcontext = rcontext.lstrip('/')
232
            rcontext = posixpath.join(remote_path, rcontext)
233
234
            if not self.exists(rcontext):
235
                self.sftp.mkdir(rcontext)
236
237
            for d in dirs:
238
                n = posixpath.join(rcontext, d)
239
                if not self.exists(n):
240
                    self.sftp.mkdir(n)
241
242
            for f in files:
243
                local_path = os.path.join(context, f)
244
                n = posixpath.join(rcontext, f)
245
                # Note that quote_unix is done by put anyways.
246
                p = self.put(local_path=local_path, remote_path=n,
247
                             mirror_local_mode=mirror_local_mode, mode=mode)
248
                remote_paths.append(p)
249
250
        return remote_paths
251
252
    def exists(self, remote_path):
253
        """
254
        Validate whether a remote file or directory exists.
255
256
        :param remote_path: Path to remote file.
257
        :type remote_path: ``str``
258
259
        :rtype: ``bool``
260
        """
261
        try:
262
            self.sftp.lstat(remote_path).st_mode
263
        except IOError:
264
            return False
265
266
        return True
267
268
    def mkdir(self, dir_path):
269
        """
270
        Create a directory on remote box.
271
272
        :param dir_path: Path to remote directory to be created.
273
        :type dir_path: ``str``
274
275
        :return: Returns nothing if successful else raises IOError exception.
276
277
        :rtype: ``None``
278
        """
279
280
        dir_path = quote_unix(dir_path)
281
        extra = {'_dir_path': dir_path}
282
        self.logger.debug('mkdir', extra=extra)
283
        return self.sftp.mkdir(dir_path)
284
285
    def delete_file(self, path):
286
        """
287
        Delete a file on remote box.
288
289
        :param path: Path to remote file to be deleted.
290
        :type path: ``str``
291
292
        :return: True if the file has been successfully deleted, False
293
                 otherwise.
294
        :rtype: ``bool``
295
        """
296
297
        path = quote_unix(path)
298
        extra = {'_path': path}
299
        self.logger.debug('Deleting file', extra=extra)
300
        self.sftp.unlink(path)
301
        return True
302
303
    def delete_dir(self, path, force=False, timeout=None):
304
        """
305
        Delete a dir on remote box.
306
307
        :param path: Path to remote dir to be deleted.
308
        :type path: ``str``
309
310
        :param force: Optional Forcefully remove dir.
311
        :type force: ``bool``
312
313
        :param timeout: Optional Time to wait for dir to be deleted. Only relevant for force.
314
        :type timeout: ``int``
315
316
        :return: True if the file has been successfully deleted, False
317
                 otherwise.
318
        :rtype: ``bool``
319
        """
320
321
        path = quote_unix(path)
322
        extra = {'_path': path}
323
        if force:
324
            command = 'rm -rf %s' % path
325
            extra['_command'] = command
326
            extra['_force'] = force
327
            self.logger.debug('Deleting dir', extra=extra)
328
            return self.run(command, timeout=timeout)
329
330
        self.logger.debug('Deleting dir', extra=extra)
331
        return self.sftp.rmdir(path)
332
333
    def run(self, cmd, timeout=None, quote=False):
334
        """
335
        Note: This function is based on paramiko's exec_command()
336
        method.
337
338
        :param timeout: How long to wait (in seconds) for the command to
339
                        finish (optional).
340
        :type timeout: ``float``
341
        """
342
343
        if quote:
344
            cmd = quote_unix(cmd)
345
346
        extra = {'_cmd': cmd}
347
        self.logger.info('Executing command', extra=extra)
348
349
        # Use the system default buffer size
350
        bufsize = -1
351
352
        transport = self.client.get_transport()
353
        chan = transport.open_session()
354
355
        start_time = time.time()
356
        if cmd.startswith('sudo'):
357
            # Note that fabric does this as well. If you set pty, stdout and stderr
358
            # streams will be combined into one.
359
            chan.get_pty()
360
        chan.exec_command(cmd)
361
362
        stdout = StringIO()
363
        stderr = StringIO()
364
365
        # Create a stdin file and immediately close it to prevent any
366
        # interactive script from hanging the process.
367
        stdin = chan.makefile('wb', bufsize)
368
        stdin.close()
369
370
        # Receive all the output
371
        # Note #1: This is used instead of chan.makefile approach to prevent
372
        # buffering issues and hanging if the executed command produces a lot
373
        # of output.
374
        #
375
        # Note #2: If you are going to remove "ready" checks inside the loop
376
        # you are going to have a bad time. Trying to consume from a channel
377
        # which is not ready will block for indefinitely.
378
        exit_status_ready = chan.exit_status_ready()
379
380
        if exit_status_ready:
381
            stdout.write(self._consume_stdout(chan).getvalue())
382
            stderr.write(self._consume_stderr(chan).getvalue())
383
384
        while not exit_status_ready:
385
            current_time = time.time()
386
            elapsed_time = (current_time - start_time)
387
388
            if timeout and (elapsed_time > timeout):
389
                # TODO: Is this the right way to clean up?
390
                chan.close()
391
392
                stdout = strip_shell_chars(stdout.getvalue())
393
                stderr = strip_shell_chars(stderr.getvalue())
394
                raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, stdout=stdout,
395
                                             stderr=stderr)
396
397
            stdout.write(self._consume_stdout(chan).getvalue())
398
            stderr.write(self._consume_stderr(chan).getvalue())
399
400
            # We need to check the exist status here, because the command could
401
            # print some output and exit during this sleep bellow.
402
            exit_status_ready = chan.exit_status_ready()
403
404
            if exit_status_ready:
405
                break
406
407
            # Short sleep to prevent busy waiting
408
            eventlet.sleep(self.SLEEP_DELAY)
409
        # print('Wait over. Channel must be ready for host: %s' % self.hostname)
410
411
        # Receive the exit status code of the command we ran.
412
        status = chan.recv_exit_status()
413
414
        stdout = strip_shell_chars(stdout.getvalue())
415
        stderr = strip_shell_chars(stderr.getvalue())
416
417
        extra = {'_status': status, '_stdout': stdout, '_stderr': stderr}
418
        self.logger.debug('Command finished', extra=extra)
419
420
        return [stdout, stderr, status]
421
422
    def close(self):
423
        self.logger.debug('Closing server connection')
424
425
        self.client.close()
426
        if self.sftp:
427
            self.sftp.close()
428
        if self.bastion_client:
429
            self.bastion_client.close()
430
        return True
431
432
    def _consume_stdout(self, chan):
433
        """
434
        Try to consume stdout data from chan if it's receive ready.
435
        """
436
437
        out = bytearray()
438
        stdout = StringIO()
439
        if chan.recv_ready():
440
            data = chan.recv(self.CHUNK_SIZE)
441
            out += data
442
443
            while data:
444
                ready = chan.recv_ready()
445
446
                if not ready:
447
                    break
448
449
                data = chan.recv(self.CHUNK_SIZE)
450
                out += data
451
452
        stdout.write(self._get_decoded_data(out))
453
        return stdout
454
455
    def _consume_stderr(self, chan):
456
        """
457
        Try to consume stderr data from chan if it's receive ready.
458
        """
459
460
        out = bytearray()
461
        stderr = StringIO()
462
        if chan.recv_stderr_ready():
463
            data = chan.recv_stderr(self.CHUNK_SIZE)
464
            out += data
465
466
            while data:
467
                ready = chan.recv_stderr_ready()
468
469
                if not ready:
470
                    break
471
472
                data = chan.recv_stderr(self.CHUNK_SIZE)
473
                out += data
474
475
        stderr.write(self._get_decoded_data(out))
476
        return stderr
477
478
    def _get_decoded_data(self, data):
479
        try:
480
            return data.decode('utf-8')
481
        except:
482
            self.logger.exception('Non UTF-8 character found in data: %s', data)
483
            raise
484
485
    def _get_pkey_object(self, key_material, passphrase):
486
        """
487
        Try to detect private key type and return paramiko.PKey object.
488
        """
489
490
        for cls in [paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey]:
491
            try:
492
                key = cls.from_private_key(StringIO(key_material), password=passphrase)
493
            except paramiko.ssh_exception.SSHException:
494
                # Invalid key, try other key type
495
                pass
496
            else:
497
                return key
498
499
        # If a user passes in something which looks like file path we throw a more friendly
500
        # exception letting the user know we expect the contents a not a path.
501
        # Note: We do it here and not up the stack to avoid false positives.
502
        contains_header = PRIVATE_KEY_HEADER in key_material.lower()
503
        if not contains_header and (key_material.count('/') >= 1 or key_material.count('\\') >= 1):
504
            msg = ('"private_key" parameter needs to contain private key data / content and not '
505
                   'a path')
506
        elif passphrase:
507
            msg = 'Invalid passphrase or invalid/unsupported key type'
508
        else:
509
            msg = 'Invalid or unsupported key type'
510
511
        raise paramiko.ssh_exception.SSHException(msg)
512
513
    def _connect(self, host, socket=None):
514
        """
515
516
        :type host: ``str``
517
        :param host: Host to connect to
518
519
        :type socket: :class:`paramiko.Channel` or an opened :class:`socket.socket`
520
        :param socket: If specified, won't open a socket for communication to the specified host
521
                       and will use this instead
522
523
        :return: A connected SSHClient
524
        :rtype: :class:`paramiko.SSHClient`
525
        """
526
        conninfo = {'hostname': host,
527
                    'port': self.port,
528
                    'username': self.username,
529
                    'allow_agent': False,
530
                    'look_for_keys': False,
531
                    'timeout': self.timeout}
532
533
        if self.password:
534
            conninfo['password'] = self.password
535
536
        if self.key_files:
537
            conninfo['key_filename'] = self.key_files
538
539
        if self.key_material:
540
            conninfo['pkey'] = self._get_pkey_object(key_material=self.key_material,
541
                                                     passphrase=self.passphrase)
542
543
        if not self.password and not (self.key_files or self.key_material):
544
            conninfo['allow_agent'] = True
545
            conninfo['look_for_keys'] = True
546
547
        extra = {'_hostname': host, '_port': self.port,
548
                 '_username': self.username, '_timeout': self.timeout}
549
        self.logger.debug('Connecting to server', extra=extra)
550
551
        if socket:
552
            conninfo['sock'] = socket
553
554
        client = paramiko.SSHClient()
555
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
556
        client.connect(**conninfo)
557
558
        return client
559
560
    def __repr__(self):
561
        return ('<ParamikoSSHClient hostname=%s,port=%s,username=%s,id=%s>' %
562
                (self.hostname, self.port, self.username, id(self)))
563