Passed
Push — master ( 7ae5f8...8189ff )
by Plexxi
03:56
created

ParamikoSSHClient.delete_file()   A

Complexity

Conditions 1

Size

Total Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 17
rs 9.4285
c 0
b 0
f 0
1
# Licensed to the Apache Software Foundation (ASF) under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import os
17
import posixpath
18
from StringIO import StringIO
19
import time
20
21
import eventlet
22
from oslo_config import cfg
23
24
import paramiko
25
26
# Depending on your version of Paramiko, it may cause a deprecation
27
# warning on Python 2.6.
28
# Ref: https://bugs.launchpad.net/paramiko/+bug/392973
29
30
from st2common.log import logging
31
from st2common.util.misc import strip_shell_chars
32
from st2common.util.shell import quote_unix
33
from st2common.constants.runners import REMOTE_RUNNER_PRIVATE_KEY_HEADER
34
35
__all__ = [
36
    'ParamikoSSHClient',
37
38
    'SSHCommandTimeoutError'
39
]
40
41
42
class SSHCommandTimeoutError(Exception):
43
    """
44
    Exception which is raised when an SSH command times out.
45
    """
46
47
    def __init__(self, cmd, timeout, stdout=None, stderr=None):
48
        """
49
        :param stdout: Stdout which was consumed until the timeout occured.
50
        :type stdout: ``str``
51
52
        :param stdout: Stderr which was consumed until the timeout occured.
53
        :type stderr: ``str``
54
        """
55
        self.cmd = cmd
56
        self.timeout = timeout
57
        self.stdout = stdout
58
        self.stderr = stderr
59
        message = 'Command didn\'t finish in %s seconds' % (timeout)
60
        super(SSHCommandTimeoutError, self).__init__(message)
61
62
    def __repr__(self):
63
        return ('<SSHCommandTimeoutError: cmd="%s",timeout=%s)>' %
64
                (self.cmd, self.timeout))
65
66
    def __str__(self):
67
        return self.message
68
69
70
class ParamikoSSHClient(object):
71
    """
72
    A SSH Client powered by Paramiko.
73
    """
74
75
    # Maximum number of bytes to read at once from a socket
76
    CHUNK_SIZE = 1024
77
78
    # How long to sleep while waiting for command to finish
79
    SLEEP_DELAY = 1.5
80
81
    # Connect socket timeout
82
    CONNECT_TIMEOUT = 60
83
84
    def __init__(self, hostname, port=22, username=None, password=None, bastion_host=None,
85
                 key_files=None, key_material=None, timeout=None, passphrase=None):
86
        """
87
        Authentication is always attempted in the following order:
88
89
        - The key passed in (if key is provided)
90
        - Any key we can find through an SSH agent (only if no password and
91
          key is provided)
92
        - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (only if no
93
          password and key is provided)
94
        - Plain username/password auth, if a password was given (if password is
95
          provided)
96
        """
97
        if key_files and key_material:
98
            raise ValueError(('key_files and key_material arguments are '
99
                              'mutually exclusive'))
100
101
        if passphrase and not (key_files or key_material):
102
            raise ValueError('passphrase should accompany private key material')
103
104
        credentials_provided = password or key_files or key_material
105
        if not credentials_provided and cfg.CONF.system_user.ssh_key_file:
106
            key_files = cfg.CONF.system_user.ssh_key_file
107
108
        self.hostname = hostname
109
        self.port = port
110
        self.username = username if username else cfg.CONF.system_user
111
        self.password = password
112
        self.key_files = key_files
113
        self.timeout = timeout or ParamikoSSHClient.CONNECT_TIMEOUT
114
        self.key_material = key_material
115
        self.bastion_host = bastion_host
116
        self.passphrase = passphrase
117
118
        self.logger = logging.getLogger(__name__)
119
120
        self.client = None
121
        self.sftp_client = None
122
123
        self.bastion_client = None
124
        self.bastion_socket = None
125
126
    def connect(self):
127
        """
128
        Connect to the remote node over SSH.
129
130
        :return: True if the connection has been successfully established,
131
                 False otherwise.
132
        :rtype: ``bool``
133
        """
134
        if self.bastion_host:
135
            self.logger.debug('Bastion host specified, connecting')
136
            self.bastion_client = self._connect(host=self.bastion_host)
137
            transport = self.bastion_client.get_transport()
138
            real_addr = (self.hostname, self.port)
139
            # fabric uses ('', 0) for direct-tcpip, this duplicates that behaviour
140
            # see https://github.com/fabric/fabric/commit/c2a9bbfd50f560df6c6f9675603fb405c4071cad
141
            local_addr = ('', 0)
142
            self.bastion_socket = transport.open_channel('direct-tcpip', real_addr, local_addr)
143
144
        self.client = self._connect(host=self.hostname, socket=self.bastion_socket)
145
        return True
146
147
    def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
148
        """
149
        Upload a file to the remote node.
150
151
        :type local_path: ``st``
152
        :param local_path: File path on the local node.
153
154
        :type remote_path: ``str``
155
        :param remote_path: File path on the remote node.
156
157
        :type mode: ``int``
158
        :param mode: Permissions mode for the file. E.g. 0744.
159
160
        :type mirror_local_mode: ``int``
161
        :param mirror_local_mode: Should remote file mirror local mode.
162
163
        :return: Attributes of the remote file.
164
        :rtype: :class:`posix.stat_result` or ``None``
165
        """
166
167
        if not local_path or not remote_path:
168
            raise Exception('Need both local_path and remote_path. local: %s, remote: %s' %
169
                            local_path, remote_path)
170
        local_path = quote_unix(local_path)
171
        remote_path = quote_unix(remote_path)
172
173
        extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
174
                 '_mirror_local_mode': mirror_local_mode}
175
        self.logger.debug('Uploading file', extra=extra)
176
177
        if not os.path.exists(local_path):
178
            raise Exception('Path %s does not exist locally.' % local_path)
179
180
        rattrs = self.sftp.put(local_path, remote_path)
181
182
        if mode or mirror_local_mode:
183
            local_mode = mode
184
            if not mode or mirror_local_mode:
185
                local_mode = os.stat(local_path).st_mode
186
187
            # Cast to octal integer in case of string
188
            if isinstance(local_mode, basestring):
189
                local_mode = int(local_mode, 8)
190
            local_mode = local_mode & 07777
191
            remote_mode = rattrs.st_mode
192
            # Only bitshift if we actually got an remote_mode
193
            if remote_mode is not None:
194
                remote_mode = (remote_mode & 07777)
195
            if local_mode != remote_mode:
196
                self.sftp.chmod(remote_path, local_mode)
197
198
        return rattrs
199
200
    def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False):
201
        """
202
        Upload a dir to the remote node.
203
204
        :type local_path: ``str``
205
        :param local_path: Dir path on the local node.
206
207
        :type remote_path: ``str``
208
        :param remote_path: Base dir path on the remote node.
209
210
        :type mode: ``int``
211
        :param mode: Permissions mode for the file. E.g. 0744.
212
213
        :type mirror_local_mode: ``int``
214
        :param mirror_local_mode: Should remote file mirror local mode.
215
216
        :return: List of files created on remote node.
217
        :rtype: ``list`` of ``str``
218
        """
219
220
        extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
221
                 '_mirror_local_mode': mirror_local_mode}
222
        self.logger.debug('Uploading dir', extra=extra)
223
224
        if os.path.basename(local_path):
225
            strip = os.path.dirname(local_path)
226
        else:
227
            strip = os.path.dirname(os.path.dirname(local_path))
228
229
        remote_paths = []
230
231
        for context, dirs, files in os.walk(local_path):
232
            rcontext = context.replace(strip, '', 1)
233
            # normalize pathname separators with POSIX separator
234
            rcontext = rcontext.replace(os.sep, '/')
235
            rcontext = rcontext.lstrip('/')
236
            rcontext = posixpath.join(remote_path, rcontext)
237
238
            if not self.exists(rcontext):
239
                self.sftp.mkdir(rcontext)
240
241
            for d in dirs:
242
                n = posixpath.join(rcontext, d)
243
                if not self.exists(n):
244
                    self.sftp.mkdir(n)
245
246
            for f in files:
247
                local_path = os.path.join(context, f)
248
                n = posixpath.join(rcontext, f)
249
                # Note that quote_unix is done by put anyways.
250
                p = self.put(local_path=local_path, remote_path=n,
251
                             mirror_local_mode=mirror_local_mode, mode=mode)
252
                remote_paths.append(p)
253
254
        return remote_paths
255
256
    def exists(self, remote_path):
257
        """
258
        Validate whether a remote file or directory exists.
259
260
        :param remote_path: Path to remote file.
261
        :type remote_path: ``str``
262
263
        :rtype: ``bool``
264
        """
265
        try:
266
            self.sftp.lstat(remote_path).st_mode
267
        except IOError:
268
            return False
269
270
        return True
271
272
    def mkdir(self, dir_path):
273
        """
274
        Create a directory on remote box.
275
276
        :param dir_path: Path to remote directory to be created.
277
        :type dir_path: ``str``
278
279
        :return: Returns nothing if successful else raises IOError exception.
280
281
        :rtype: ``None``
282
        """
283
284
        dir_path = quote_unix(dir_path)
285
        extra = {'_dir_path': dir_path}
286
        self.logger.debug('mkdir', extra=extra)
287
        return self.sftp.mkdir(dir_path)
288
289
    def delete_file(self, path):
290
        """
291
        Delete a file on remote box.
292
293
        :param path: Path to remote file to be deleted.
294
        :type path: ``str``
295
296
        :return: True if the file has been successfully deleted, False
297
                 otherwise.
298
        :rtype: ``bool``
299
        """
300
301
        path = quote_unix(path)
302
        extra = {'_path': path}
303
        self.logger.debug('Deleting file', extra=extra)
304
        self.sftp.unlink(path)
305
        return True
306
307
    def delete_dir(self, path, force=False, timeout=None):
308
        """
309
        Delete a dir on remote box.
310
311
        :param path: Path to remote dir to be deleted.
312
        :type path: ``str``
313
314
        :param force: Optional Forcefully remove dir.
315
        :type force: ``bool``
316
317
        :param timeout: Optional Time to wait for dir to be deleted. Only relevant for force.
318
        :type timeout: ``int``
319
320
        :return: True if the file has been successfully deleted, False
321
                 otherwise.
322
        :rtype: ``bool``
323
        """
324
325
        path = quote_unix(path)
326
        extra = {'_path': path}
327
        if force:
328
            command = 'rm -rf %s' % path
329
            extra['_command'] = command
330
            extra['_force'] = force
331
            self.logger.debug('Deleting dir', extra=extra)
332
            return self.run(command, timeout=timeout)
333
334
        self.logger.debug('Deleting dir', extra=extra)
335
        return self.sftp.rmdir(path)
336
337
    def run(self, cmd, timeout=None, quote=False):
338
        """
339
        Note: This function is based on paramiko's exec_command()
340
        method.
341
342
        :param timeout: How long to wait (in seconds) for the command to
343
                        finish (optional).
344
        :type timeout: ``float``
345
        """
346
347
        if quote:
348
            cmd = quote_unix(cmd)
349
350
        extra = {'_cmd': cmd}
351
        self.logger.info('Executing command', extra=extra)
352
353
        # Use the system default buffer size
354
        bufsize = -1
355
356
        transport = self.client.get_transport()
357
        chan = transport.open_session()
358
359
        start_time = time.time()
360
        if cmd.startswith('sudo'):
361
            # Note that fabric does this as well. If you set pty, stdout and stderr
362
            # streams will be combined into one.
363
            chan.get_pty()
364
        chan.exec_command(cmd)
365
366
        stdout = StringIO()
367
        stderr = StringIO()
368
369
        # Create a stdin file and immediately close it to prevent any
370
        # interactive script from hanging the process.
371
        stdin = chan.makefile('wb', bufsize)
372
        stdin.close()
373
374
        # Receive all the output
375
        # Note #1: This is used instead of chan.makefile approach to prevent
376
        # buffering issues and hanging if the executed command produces a lot
377
        # of output.
378
        #
379
        # Note #2: If you are going to remove "ready" checks inside the loop
380
        # you are going to have a bad time. Trying to consume from a channel
381
        # which is not ready will block for indefinitely.
382
        exit_status_ready = chan.exit_status_ready()
383
384
        if exit_status_ready:
385
            stdout.write(self._consume_stdout(chan).getvalue())
386
            stderr.write(self._consume_stderr(chan).getvalue())
387
388
        while not exit_status_ready:
389
            current_time = time.time()
390
            elapsed_time = (current_time - start_time)
391
392
            if timeout and (elapsed_time > timeout):
393
                # TODO: Is this the right way to clean up?
394
                chan.close()
395
396
                stdout = strip_shell_chars(stdout.getvalue())
397
                stderr = strip_shell_chars(stderr.getvalue())
398
                raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, stdout=stdout,
399
                                             stderr=stderr)
400
401
            stdout.write(self._consume_stdout(chan).getvalue())
402
            stderr.write(self._consume_stderr(chan).getvalue())
403
404
            # We need to check the exist status here, because the command could
405
            # print some output and exit during this sleep bellow.
406
            exit_status_ready = chan.exit_status_ready()
407
408
            if exit_status_ready:
409
                break
410
411
            # Short sleep to prevent busy waiting
412
            eventlet.sleep(self.SLEEP_DELAY)
413
        # print('Wait over. Channel must be ready for host: %s' % self.hostname)
414
415
        # Receive the exit status code of the command we ran.
416
        status = chan.recv_exit_status()
417
418
        stdout = strip_shell_chars(stdout.getvalue())
419
        stderr = strip_shell_chars(stderr.getvalue())
420
421
        extra = {'_status': status, '_stdout': stdout, '_stderr': stderr}
422
        self.logger.debug('Command finished', extra=extra)
423
424
        return [stdout, stderr, status]
425
426
    def close(self):
427
        self.logger.debug('Closing server connection')
428
429
        self.client.close()
430
431
        if self.sftp_client:
432
            self.sftp_client.close()
433
434
        if self.bastion_client:
435
            self.bastion_client.close()
436
437
        return True
438
439
    @property
440
    def sftp(self):
441
        """
442
        Method which lazily establishes SFTP connection if one is not established yet when this
443
        variable is accessed.
444
        """
445
        if not self.sftp_client:
446
            self.sftp_client = self.client.open_sftp()
447
448
        return self.sftp_client
449
450
    def _consume_stdout(self, chan):
451
        """
452
        Try to consume stdout data from chan if it's receive ready.
453
        """
454
455
        out = bytearray()
456
        stdout = StringIO()
457
        if chan.recv_ready():
458
            data = chan.recv(self.CHUNK_SIZE)
459
            out += data
460
461
            while data:
462
                ready = chan.recv_ready()
463
464
                if not ready:
465
                    break
466
467
                data = chan.recv(self.CHUNK_SIZE)
468
                out += data
469
470
        stdout.write(self._get_decoded_data(out))
471
        return stdout
472
473
    def _consume_stderr(self, chan):
474
        """
475
        Try to consume stderr data from chan if it's receive ready.
476
        """
477
478
        out = bytearray()
479
        stderr = StringIO()
480
        if chan.recv_stderr_ready():
481
            data = chan.recv_stderr(self.CHUNK_SIZE)
482
            out += data
483
484
            while data:
485
                ready = chan.recv_stderr_ready()
486
487
                if not ready:
488
                    break
489
490
                data = chan.recv_stderr(self.CHUNK_SIZE)
491
                out += data
492
493
        stderr.write(self._get_decoded_data(out))
494
        return stderr
495
496
    def _get_decoded_data(self, data):
497
        try:
498
            return data.decode('utf-8')
499
        except:
500
            self.logger.exception('Non UTF-8 character found in data: %s', data)
501
            raise
502
503
    def _get_pkey_object(self, key_material, passphrase):
504
        """
505
        Try to detect private key type and return paramiko.PKey object.
506
        """
507
508
        for cls in [paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey]:
509
            try:
510
                key = cls.from_private_key(StringIO(key_material), password=passphrase)
511
            except paramiko.ssh_exception.SSHException:
512
                # Invalid key, try other key type
513
                pass
514
            else:
515
                return key
516
517
        # If a user passes in something which looks like file path we throw a more friendly
518
        # exception letting the user know we expect the contents a not a path.
519
        # Note: We do it here and not up the stack to avoid false positives.
520
        contains_header = REMOTE_RUNNER_PRIVATE_KEY_HEADER in key_material.lower()
521
        if not contains_header and (key_material.count('/') >= 1 or key_material.count('\\') >= 1):
522
            msg = ('"private_key" parameter needs to contain private key data / content and not '
523
                   'a path')
524
        elif passphrase:
525
            msg = 'Invalid passphrase or invalid/unsupported key type'
526
        else:
527
            msg = 'Invalid or unsupported key type'
528
529
        raise paramiko.ssh_exception.SSHException(msg)
530
531
    def _connect(self, host, socket=None):
532
        """
533
534
        :type host: ``str``
535
        :param host: Host to connect to
536
537
        :type socket: :class:`paramiko.Channel` or an opened :class:`socket.socket`
538
        :param socket: If specified, won't open a socket for communication to the specified host
539
                       and will use this instead
540
541
        :return: A connected SSHClient
542
        :rtype: :class:`paramiko.SSHClient`
543
        """
544
        conninfo = {'hostname': host,
545
                    'port': self.port,
546
                    'username': self.username,
547
                    'allow_agent': False,
548
                    'look_for_keys': False,
549
                    'timeout': self.timeout}
550
551
        if self.password:
552
            conninfo['password'] = self.password
553
554
        if self.key_files:
555
            conninfo['key_filename'] = self.key_files
556
557
            if self.passphrase:
558
                # Optional passphrase for unlocking the private key
559
                conninfo['password'] = self.passphrase
560
561
        if self.key_material:
562
            conninfo['pkey'] = self._get_pkey_object(key_material=self.key_material,
563
                                                     passphrase=self.passphrase)
564
565
        if not self.password and not (self.key_files or self.key_material):
566
            conninfo['allow_agent'] = True
567
            conninfo['look_for_keys'] = True
568
569
        extra = {'_hostname': host, '_port': self.port,
570
                 '_username': self.username, '_timeout': self.timeout}
571
        self.logger.debug('Connecting to server', extra=extra)
572
573
        if socket:
574
            conninfo['sock'] = socket
575
576
        client = paramiko.SSHClient()
577
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
578
        client.connect(**conninfo)
579
580
        return client
581
582
    def __repr__(self):
583
        return ('<ParamikoSSHClient hostname=%s,port=%s,username=%s,id=%s>' %
584
                (self.hostname, self.port, self.username, id(self)))
585