Completed
Pull Request — master (#3058)
by Lakshmi
04:51
created

ParamikoSSHClient._get_pkey_object()   D

Complexity

Conditions 8

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
dl 0
loc 27
rs 4
c 0
b 0
f 0
1
# Licensed to the Apache Software Foundation (ASF) under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import os
17
import posixpath
18
from StringIO import StringIO
19
import time
20
21
import eventlet
22
from oslo_config import cfg
23
24
import paramiko
25
26
# Depending on your version of Paramiko, it may cause a deprecation
27
# warning on Python 2.6.
28
# Ref: https://bugs.launchpad.net/paramiko/+bug/392973
29
30
from st2common.log import logging
31
from st2common.util.misc import strip_shell_chars
32
from st2common.util.shell import quote_unix
33
from st2common.constants.runners import DEFAULT_SSH_PORT, REMOTE_RUNNER_PRIVATE_KEY_HEADER
34
35
__all__ = [
36
    'ParamikoSSHClient',
37
38
    'SSHCommandTimeoutError'
39
]
40
41
42
class SSHCommandTimeoutError(Exception):
43
    """
44
    Exception which is raised when an SSH command times out.
45
    """
46
47
    def __init__(self, cmd, timeout, stdout=None, stderr=None):
48
        """
49
        :param stdout: Stdout which was consumed until the timeout occured.
50
        :type stdout: ``str``
51
52
        :param stdout: Stderr which was consumed until the timeout occured.
53
        :type stderr: ``str``
54
        """
55
        self.cmd = cmd
56
        self.timeout = timeout
57
        self.stdout = stdout
58
        self.stderr = stderr
59
        message = 'Command didn\'t finish in %s seconds' % (timeout)
60
        super(SSHCommandTimeoutError, self).__init__(message)
61
62
    def __repr__(self):
63
        return ('<SSHCommandTimeoutError: cmd="%s",timeout=%s)>' %
64
                (self.cmd, self.timeout))
65
66
    def __str__(self):
67
        return self.message
68
69
70
class ParamikoSSHClient(object):
71
    """
72
    A SSH Client powered by Paramiko.
73
    """
74
75
    # Maximum number of bytes to read at once from a socket
76
    CHUNK_SIZE = 1024
77
78
    # How long to sleep while waiting for command to finish
79
    SLEEP_DELAY = 1.5
80
81
    # Connect socket timeout
82
    CONNECT_TIMEOUT = 60
83
84
    def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None,
85
                 bastion_host=None,
86
                 key_files=None, key_material=None, timeout=None, passphrase=None):
87
        """
88
        Authentication is always attempted in the following order:
89
90
        - The key passed in (if key is provided)
91
        - Any key we can find through an SSH agent (only if no password and
92
          key is provided)
93
        - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (only if no
94
          password and key is provided)
95
        - Plain username/password auth, if a password was given (if password is
96
          provided)
97
        """
98
        self.hostname = hostname
99
        self.port = port
100
        self.username = username
101
        self.password = password
102
        self.key_files = key_files
103
        self.timeout = timeout or ParamikoSSHClient.CONNECT_TIMEOUT
104
        self.key_material = key_material
105
        self.bastion_host = bastion_host
106
        self.passphrase = passphrase
107
        self.ssh_config_file = os.path.expanduser(
108
            cfg.CONF.ssh_runner.ssh_config_path or
109
            '~/.ssh/config'
110
        )
111
        self.logger = logging.getLogger(__name__)
112
113
        self.client = None
114
        self.sftp_client = None
115
116
        self.bastion_client = None
117
        self.bastion_socket = None
118
119
    def connect(self):
120
        """
121
        Connect to the remote node over SSH.
122
123
        :return: True if the connection has been successfully established,
124
                 False otherwise.
125
        :rtype: ``bool``
126
        """
127
        if self.bastion_host:
128
            self.logger.debug('Bastion host specified, connecting')
129
            self.bastion_client = self._connect(host=self.bastion_host)
130
            transport = self.bastion_client.get_transport()
131
            real_addr = (self.hostname, self.port)
132
            # fabric uses ('', 0) for direct-tcpip, this duplicates that behaviour
133
            # see https://github.com/fabric/fabric/commit/c2a9bbfd50f560df6c6f9675603fb405c4071cad
134
            local_addr = ('', 0)
135
            self.bastion_socket = transport.open_channel('direct-tcpip', real_addr, local_addr)
136
137
        self.client = self._connect(host=self.hostname, socket=self.bastion_socket)
138
        return True
139
140
    def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
141
        """
142
        Upload a file to the remote node.
143
144
        :type local_path: ``st``
145
        :param local_path: File path on the local node.
146
147
        :type remote_path: ``str``
148
        :param remote_path: File path on the remote node.
149
150
        :type mode: ``int``
151
        :param mode: Permissions mode for the file. E.g. 0744.
152
153
        :type mirror_local_mode: ``int``
154
        :param mirror_local_mode: Should remote file mirror local mode.
155
156
        :return: Attributes of the remote file.
157
        :rtype: :class:`posix.stat_result` or ``None``
158
        """
159
160
        if not local_path or not remote_path:
161
            raise Exception('Need both local_path and remote_path. local: %s, remote: %s' %
162
                            local_path, remote_path)
163
        local_path = quote_unix(local_path)
164
        remote_path = quote_unix(remote_path)
165
166
        extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
167
                 '_mirror_local_mode': mirror_local_mode}
168
        self.logger.debug('Uploading file', extra=extra)
169
170
        if not os.path.exists(local_path):
171
            raise Exception('Path %s does not exist locally.' % local_path)
172
173
        rattrs = self.sftp.put(local_path, remote_path)
174
175
        if mode or mirror_local_mode:
176
            local_mode = mode
177
            if not mode or mirror_local_mode:
178
                local_mode = os.stat(local_path).st_mode
179
180
            # Cast to octal integer in case of string
181
            if isinstance(local_mode, basestring):
182
                local_mode = int(local_mode, 8)
183
            local_mode = local_mode & 07777
184
            remote_mode = rattrs.st_mode
185
            # Only bitshift if we actually got an remote_mode
186
            if remote_mode is not None:
187
                remote_mode = (remote_mode & 07777)
188
            if local_mode != remote_mode:
189
                self.sftp.chmod(remote_path, local_mode)
190
191
        return rattrs
192
193
    def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False):
194
        """
195
        Upload a dir to the remote node.
196
197
        :type local_path: ``str``
198
        :param local_path: Dir path on the local node.
199
200
        :type remote_path: ``str``
201
        :param remote_path: Base dir path on the remote node.
202
203
        :type mode: ``int``
204
        :param mode: Permissions mode for the file. E.g. 0744.
205
206
        :type mirror_local_mode: ``int``
207
        :param mirror_local_mode: Should remote file mirror local mode.
208
209
        :return: List of files created on remote node.
210
        :rtype: ``list`` of ``str``
211
        """
212
213
        extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
214
                 '_mirror_local_mode': mirror_local_mode}
215
        self.logger.debug('Uploading dir', extra=extra)
216
217
        if os.path.basename(local_path):
218
            strip = os.path.dirname(local_path)
219
        else:
220
            strip = os.path.dirname(os.path.dirname(local_path))
221
222
        remote_paths = []
223
224
        for context, dirs, files in os.walk(local_path):
225
            rcontext = context.replace(strip, '', 1)
226
            # normalize pathname separators with POSIX separator
227
            rcontext = rcontext.replace(os.sep, '/')
228
            rcontext = rcontext.lstrip('/')
229
            rcontext = posixpath.join(remote_path, rcontext)
230
231
            if not self.exists(rcontext):
232
                self.sftp.mkdir(rcontext)
233
234
            for d in dirs:
235
                n = posixpath.join(rcontext, d)
236
                if not self.exists(n):
237
                    self.sftp.mkdir(n)
238
239
            for f in files:
240
                local_path = os.path.join(context, f)
241
                n = posixpath.join(rcontext, f)
242
                # Note that quote_unix is done by put anyways.
243
                p = self.put(local_path=local_path, remote_path=n,
244
                             mirror_local_mode=mirror_local_mode, mode=mode)
245
                remote_paths.append(p)
246
247
        return remote_paths
248
249
    def exists(self, remote_path):
250
        """
251
        Validate whether a remote file or directory exists.
252
253
        :param remote_path: Path to remote file.
254
        :type remote_path: ``str``
255
256
        :rtype: ``bool``
257
        """
258
        try:
259
            self.sftp.lstat(remote_path).st_mode
260
        except IOError:
261
            return False
262
263
        return True
264
265
    def mkdir(self, dir_path):
266
        """
267
        Create a directory on remote box.
268
269
        :param dir_path: Path to remote directory to be created.
270
        :type dir_path: ``str``
271
272
        :return: Returns nothing if successful else raises IOError exception.
273
274
        :rtype: ``None``
275
        """
276
277
        dir_path = quote_unix(dir_path)
278
        extra = {'_dir_path': dir_path}
279
        self.logger.debug('mkdir', extra=extra)
280
        return self.sftp.mkdir(dir_path)
281
282
    def delete_file(self, path):
283
        """
284
        Delete a file on remote box.
285
286
        :param path: Path to remote file to be deleted.
287
        :type path: ``str``
288
289
        :return: True if the file has been successfully deleted, False
290
                 otherwise.
291
        :rtype: ``bool``
292
        """
293
294
        path = quote_unix(path)
295
        extra = {'_path': path}
296
        self.logger.debug('Deleting file', extra=extra)
297
        self.sftp.unlink(path)
298
        return True
299
300
    def delete_dir(self, path, force=False, timeout=None):
301
        """
302
        Delete a dir on remote box.
303
304
        :param path: Path to remote dir to be deleted.
305
        :type path: ``str``
306
307
        :param force: Optional Forcefully remove dir.
308
        :type force: ``bool``
309
310
        :param timeout: Optional Time to wait for dir to be deleted. Only relevant for force.
311
        :type timeout: ``int``
312
313
        :return: True if the file has been successfully deleted, False
314
                 otherwise.
315
        :rtype: ``bool``
316
        """
317
318
        path = quote_unix(path)
319
        extra = {'_path': path}
320
        if force:
321
            command = 'rm -rf %s' % path
322
            extra['_command'] = command
323
            extra['_force'] = force
324
            self.logger.debug('Deleting dir', extra=extra)
325
            return self.run(command, timeout=timeout)
326
327
        self.logger.debug('Deleting dir', extra=extra)
328
        return self.sftp.rmdir(path)
329
330
    def run(self, cmd, timeout=None, quote=False):
331
        """
332
        Note: This function is based on paramiko's exec_command()
333
        method.
334
335
        :param timeout: How long to wait (in seconds) for the command to
336
                        finish (optional).
337
        :type timeout: ``float``
338
        """
339
340
        if quote:
341
            cmd = quote_unix(cmd)
342
343
        extra = {'_cmd': cmd}
344
        self.logger.info('Executing command', extra=extra)
345
346
        # Use the system default buffer size
347
        bufsize = -1
348
349
        transport = self.client.get_transport()
350
        chan = transport.open_session()
351
352
        start_time = time.time()
353
        if cmd.startswith('sudo'):
354
            # Note that fabric does this as well. If you set pty, stdout and stderr
355
            # streams will be combined into one.
356
            chan.get_pty()
357
        chan.exec_command(cmd)
358
359
        stdout = StringIO()
360
        stderr = StringIO()
361
362
        # Create a stdin file and immediately close it to prevent any
363
        # interactive script from hanging the process.
364
        stdin = chan.makefile('wb', bufsize)
365
        stdin.close()
366
367
        # Receive all the output
368
        # Note #1: This is used instead of chan.makefile approach to prevent
369
        # buffering issues and hanging if the executed command produces a lot
370
        # of output.
371
        #
372
        # Note #2: If you are going to remove "ready" checks inside the loop
373
        # you are going to have a bad time. Trying to consume from a channel
374
        # which is not ready will block for indefinitely.
375
        exit_status_ready = chan.exit_status_ready()
376
377
        if exit_status_ready:
378
            stdout.write(self._consume_stdout(chan).getvalue())
379
            stderr.write(self._consume_stderr(chan).getvalue())
380
381
        while not exit_status_ready:
382
            current_time = time.time()
383
            elapsed_time = (current_time - start_time)
384
385
            if timeout and (elapsed_time > timeout):
386
                # TODO: Is this the right way to clean up?
387
                chan.close()
388
389
                stdout = strip_shell_chars(stdout.getvalue())
390
                stderr = strip_shell_chars(stderr.getvalue())
391
                raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, stdout=stdout,
392
                                             stderr=stderr)
393
394
            stdout.write(self._consume_stdout(chan).getvalue())
395
            stderr.write(self._consume_stderr(chan).getvalue())
396
397
            # We need to check the exist status here, because the command could
398
            # print some output and exit during this sleep bellow.
399
            exit_status_ready = chan.exit_status_ready()
400
401
            if exit_status_ready:
402
                break
403
404
            # Short sleep to prevent busy waiting
405
            eventlet.sleep(self.SLEEP_DELAY)
406
        # print('Wait over. Channel must be ready for host: %s' % self.hostname)
407
408
        # Receive the exit status code of the command we ran.
409
        status = chan.recv_exit_status()
410
411
        stdout = strip_shell_chars(stdout.getvalue())
412
        stderr = strip_shell_chars(stderr.getvalue())
413
414
        extra = {'_status': status, '_stdout': stdout, '_stderr': stderr}
415
        self.logger.debug('Command finished', extra=extra)
416
417
        return [stdout, stderr, status]
418
419
    def close(self):
420
        self.logger.debug('Closing server connection')
421
422
        self.client.close()
423
424
        if self.sftp_client:
425
            self.sftp_client.close()
426
427
        if self.bastion_client:
428
            self.bastion_client.close()
429
430
        return True
431
432
    @property
433
    def sftp(self):
434
        """
435
        Method which lazily establishes SFTP connection if one is not established yet when this
436
        variable is accessed.
437
        """
438
        if not self.sftp_client:
439
            self.sftp_client = self.client.open_sftp()
440
441
        return self.sftp_client
442
443
    def _consume_stdout(self, chan):
444
        """
445
        Try to consume stdout data from chan if it's receive ready.
446
        """
447
448
        out = bytearray()
449
        stdout = StringIO()
450
        if chan.recv_ready():
451
            data = chan.recv(self.CHUNK_SIZE)
452
            out += data
453
454
            while data:
455
                ready = chan.recv_ready()
456
457
                if not ready:
458
                    break
459
460
                data = chan.recv(self.CHUNK_SIZE)
461
                out += data
462
463
        stdout.write(self._get_decoded_data(out))
464
        return stdout
465
466
    def _consume_stderr(self, chan):
467
        """
468
        Try to consume stderr data from chan if it's receive ready.
469
        """
470
471
        out = bytearray()
472
        stderr = StringIO()
473
        if chan.recv_stderr_ready():
474
            data = chan.recv_stderr(self.CHUNK_SIZE)
475
            out += data
476
477
            while data:
478
                ready = chan.recv_stderr_ready()
479
480
                if not ready:
481
                    break
482
483
                data = chan.recv_stderr(self.CHUNK_SIZE)
484
                out += data
485
486
        stderr.write(self._get_decoded_data(out))
487
        return stderr
488
489
    def _get_decoded_data(self, data):
490
        try:
491
            return data.decode('utf-8')
492
        except:
493
            self.logger.exception('Non UTF-8 character found in data: %s', data)
494
            raise
495
496
    def _get_pkey_object(self, key_material, passphrase):
497
        """
498
        Try to detect private key type and return paramiko.PKey object.
499
        """
500
501
        for cls in [paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey]:
502
            try:
503
                key = cls.from_private_key(StringIO(key_material), password=passphrase)
504
            except paramiko.ssh_exception.SSHException:
505
                # Invalid key, try other key type
506
                pass
507
            else:
508
                return key
509
510
        # If a user passes in something which looks like file path we throw a more friendly
511
        # exception letting the user know we expect the contents a not a path.
512
        # Note: We do it here and not up the stack to avoid false positives.
513
        contains_header = REMOTE_RUNNER_PRIVATE_KEY_HEADER in key_material.lower()
514
        if not contains_header and (key_material.count('/') >= 1 or key_material.count('\\') >= 1):
515
            msg = ('"private_key" parameter needs to contain private key data / content and not '
516
                   'a path')
517
        elif passphrase:
518
            msg = 'Invalid passphrase or invalid/unsupported key type'
519
        else:
520
            msg = 'Invalid or unsupported key type'
521
522
        raise paramiko.ssh_exception.SSHException(msg)
523
524
    def _connect(self, host, socket=None):
525
        """
526
        Order of precedence for SSH connection parameters:
527
528
        1. If user supplies parameters via action parameters, we use them to connect.
529
        2. For parameters not supplied via action parameters, if there is an entry
530
           for host in SSH config file, we use those. Note that this is a merge operation.
531
        3. If user does not supply certain action parameters (username and key file location)
532
           and there is no entry for host in SSH config file, we use values supplied in
533
           st2 config file for those parameters.
534
535
        :type host: ``str``
536
        :param host: Host to connect to
537
538
        :type socket: :class:`paramiko.Channel` or an opened :class:`socket.socket`
539
        :param socket: If specified, won't open a socket for communication to the specified host
540
                       and will use this instead
541
542
        :return: A connected SSHClient
543
        :rtype: :class:`paramiko.SSHClient`
544
        """
545
546
        conninfo = {'hostname': host,
547
                    'allow_agent': False,
548
                    'look_for_keys': False,
549
                    'timeout': self.timeout}
550
551
        ssh_config_file_info = {}
552
        if cfg.CONF.ssh_runner.use_ssh_config:
553
            ssh_config_file_info = self._get_ssh_config_for_host(host)
554
555
        self.username = (self.username or ssh_config_file_info.get('user', None) or
556
                         cfg.CONF.system_user)
557
        self.port = self.port or ssh_config_file_info.get('port' or None) or DEFAULT_SSH_PORT
558
559
        # If both key file and key material are provided as action parameters,
560
        # throw an error informing user only one is required.
561
        if self.key_files and self.key_material:
562
            msg = ('key_files (%s) and key_material arguments are '
563
                   'mutually exclusive. Supply only one.' % self.key_files)
564
            raise ValueError(msg)
565
566
        # If key material is not provided, only then we look at key file and decide
567
        # if we want to use the user supplied one or the one in SSH config.
568
        if not self.key_material:
569
            self.key_files = (self.key_files or ssh_config_file_info.get('identityfile', None) or
570
                              cfg.CONF.system_user.ssh_key_file)
571
572
        if self.passphrase and not (self.key_files or self.key_material):
573
            raise ValueError('passphrase should accompany private key material')
574
575
        credentials_provided = self.password or self.key_files or self.key_material
576
577
        if not credentials_provided:
578
            msg = ('Either password or key file location or key material should be supplied ' +
579
                   'for action. You can also add an entry for host %s in SSH config file %s.' %
580
                   (self.ssh_config_file, host))
581
            raise ValueError(msg)
582
583
        conninfo['username'] = self.username
584
        conninfo['port'] = self.port
585
586
        if self.password:
587
            conninfo['password'] = self.password
588
589
        if self.key_files:
590
            conninfo['key_filename'] = self.key_files
591
592
            passphrase_reqd = self._is_key_file_needs_passphrase(self.key_files)
593
            if passphrase_reqd and not self.passphrase:
594
                msg = ('Private key file %s is passphrase protected. Supply a passphrase.' %
595
                       self.key_files)
596
                raise paramiko.ssh_exception.PasswordRequiredException(msg)
597
598
            if self.passphrase:
599
                # Optional passphrase for unlocking the private key
600
                conninfo['password'] = self.passphrase
601
602
        if self.key_material:
603
            conninfo['pkey'] = self._get_pkey_object(key_material=self.key_material,
604
                                                     passphrase=self.passphrase)
605
606
        if not self.password and not (self.key_files or self.key_material):
607
            conninfo['allow_agent'] = True
608
            conninfo['look_for_keys'] = True
609
610
        extra = {'_hostname': host, '_port': self.port,
611
                 '_username': self.username, '_timeout': self.timeout}
612
        self.logger.debug('Connecting to server', extra=extra)
613
614
        socket = socket or ssh_config_file_info.get('sock', None)
615
        if socket:
616
            conninfo['sock'] = socket
617
618
        client = paramiko.SSHClient()
619
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
620
621
        extra = {'_conninfo': conninfo}
622
        self.logger.debug('Connection info', extra=extra)
623
        client.connect(**conninfo)
624
625
        return client
626
627
    def _get_ssh_config_for_host(self, host):
628
        ssh_config_info = {}
629
        ssh_config_parser = paramiko.SSHConfig()
630
631
        try:
632
            with open(self.ssh_config_file) as f:
633
                ssh_config_parser.parse(f)
634
        except IOError as e:
635
            raise Exception('Error accessing ssh config file %s. Code: %s Reason %s' %
636
                            (self.ssh_config_file, e.errno, e.strerror))
637
638
        ssh_config = ssh_config_parser.lookup(host)
639
        if ssh_config:
640
            for k in ('hostname', 'user', 'port'):
641
                if k in ssh_config:
642
                    ssh_config_info[k] = ssh_config[k]
643
644
            if 'proxycommand' in ssh_config:
645
                ssh_config_info['sock'] = paramiko.ProxyCommand(ssh_config['proxycommand'])
646
647
            if 'identityfile' in ssh_config:
648
                ssh_config_info['key_filename'] = ssh_config['identityfile']
649
650
        return ssh_config_info
651
652
    @staticmethod
653
    def _is_key_file_needs_passphrase(file):
654
        for cls in [paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey]:
655
            try:
656
                cls.from_private_key_file(file, password=None)
657
            except paramiko.ssh_exception.PasswordRequiredException:
658
                return True
659
            except paramiko.ssh_exception.SSHException:
660
                continue
661
662
        return False
663
664
    def __repr__(self):
665
        return ('<ParamikoSSHClient hostname=%s,port=%s,username=%s,id=%s>' %
666
                (self.hostname, self.port, self.username, id(self)))
667