Passed
Push — master ( 6bd570...b7fa82 )
by Plexxi
02:49
created

BaseParallelSSHRunner.__init__()   A

Complexity

Conditions 1

Size

Total Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 21
rs 9.3142
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from oslo_config import cfg
17
import six
18
19
from st2common.runners.base import ShellRunnerMixin
20
from st2common.runners.base import ActionRunner
21
from st2common.constants.runners import REMOTE_RUNNER_PRIVATE_KEY_HEADER
22
from st2common.runners.parallel_ssh import ParallelSSHClient
23
from st2common import log as logging
24
from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED
25
from st2common.constants.action import LIVEACTION_STATUS_TIMED_OUT
26
from st2common.constants.action import LIVEACTION_STATUS_FAILED
27
from st2common.constants.runners import REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT
28
from st2common.exceptions.actionrunner import ActionRunnerPreRunError
29
30
__all__ = [
31
    'BaseParallelSSHRunner'
32
]
33
34
LOG = logging.getLogger(__name__)
35
36
# constants to lookup in runner_parameters.
37
RUNNER_HOSTS = 'hosts'
38
RUNNER_USERNAME = 'username'
39
RUNNER_PASSWORD = 'password'
40
RUNNER_PRIVATE_KEY = 'private_key'
41
RUNNER_PARALLEL = 'parallel'
42
RUNNER_SUDO = 'sudo'
43
RUNNER_ON_BEHALF_USER = 'user'
44
RUNNER_REMOTE_DIR = 'dir'
45
RUNNER_COMMAND = 'cmd'
46
RUNNER_CWD = 'cwd'
47
RUNNER_ENV = 'env'
48
RUNNER_KWARG_OP = 'kwarg_op'
49
RUNNER_TIMEOUT = 'timeout'
50
RUNNER_SSH_PORT = 'port'
51
RUNNER_BASTION_HOST = 'bastion_host'
52
RUNNER_PASSPHRASE = 'passphrase'
53
54
55
class BaseParallelSSHRunner(ActionRunner, ShellRunnerMixin):
56
57
    def __init__(self, runner_id):
58
        super(BaseParallelSSHRunner, self).__init__(runner_id=runner_id)
59
        self._hosts = None
60
        self._parallel = True
61
        self._sudo = False
62
        self._on_behalf_user = None
63
        self._username = None
64
        self._password = None
65
        self._private_key = None
66
        self._passphrase = None
67
        self._kwarg_op = '--'
68
        self._cwd = None
69
        self._env = None
70
        self._ssh_port = None
71
        self._timeout = None
72
        self._bastion_host = None
73
        self._on_behalf_user = cfg.CONF.system_user.user
74
75
        self._ssh_key_file = None
76
        self._parallel_ssh_client = None
77
        self._max_concurrency = cfg.CONF.ssh_runner.max_parallel_actions
78
79
    def pre_run(self):
80
        super(BaseParallelSSHRunner, self).pre_run()
81
82
        LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
83
                  self.liveaction_id)
84
        hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
85
        self._hosts = [h.strip() for h in hosts if len(h) > 0]
86
        if len(self._hosts) < 1:
87
            raise ActionRunnerPreRunError('No hosts specified to run action for action %s.',
88
                                          self.liveaction_id)
89
        self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
90
        self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
91
        self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None)
92
        self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None)
93
94
        self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, None)
95
        self._ssh_key_file = self._private_key
96
        self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True)
97
        self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
98
        self._sudo = self._sudo if self._sudo else False
99
        if self.context:
100
            self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user)
101
        self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
102
        self._env = self.runner_parameters.get(RUNNER_ENV, {})
103
        self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
104
        self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT,
105
                                                   REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
106
        self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None)
107
108
        LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
109
                 self.runner_id, self.liveaction_id)
110
111
        concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
112
        if concurrency > self._max_concurrency:
113
            LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
114
            concurrency = self._max_concurrency
115
116
        client_kwargs = {
117
            'hosts': self._hosts,
118
            'user': self._username,
119
            'port': self._ssh_port,
120
            'concurrency': concurrency,
121
            'bastion_host': self._bastion_host,
122
            'raise_on_any_error': False,
123
            'connect': True
124
        }
125
126
        if self._password:
127
            client_kwargs['password'] = self._password
128
        elif self._private_key:
129
            # Determine if the private_key is a path to the key file or the raw key material
130
            is_key_material = self._is_private_key_material(private_key=self._private_key)
131
132
            if is_key_material:
133
                # Raw key material
134
                client_kwargs['pkey_material'] = self._private_key
135
            else:
136
                # Assume it's a path to the key file, verify the file exists
137
                client_kwargs['pkey_file'] = self._private_key
138
139
            if self._passphrase:
140
                client_kwargs['passphrase'] = self._passphrase
141
        else:
142
            # Default to stanley key file specified in the config
143
            client_kwargs['pkey_file'] = self._ssh_key_file
144
145
        self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)
146
147
    def _is_private_key_material(self, private_key):
148
        return private_key and REMOTE_RUNNER_PRIVATE_KEY_HEADER in private_key.lower()
149
150
    def _get_env_vars(self):
151
        """
152
        :rtype: ``dict``
153
        """
154
        env_vars = {}
155
156
        if self._env:
157
            env_vars.update(self._env)
158
159
        # Include common st2 env vars
160
        st2_env_vars = self._get_common_action_env_variables()
161
        env_vars.update(st2_env_vars)
162
163
        return env_vars
164
165
    @staticmethod
166
    def _get_result_status(result, allow_partial_failure):
167
168
        if 'error' in result and 'traceback' in result:
169
            # Assume this is a global failure where the result dictionary doesn't contain entry
170
            # per host
171
            timeout = False
172
            success = result.get('succeeded', False)
173
            status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success,
174
                                                                               timeout=timeout)
175
            return status
176
177
        success = not allow_partial_failure
178
        timeout = True
179
180
        for r in six.itervalues(result):
181
            r_succeess = r.get('succeeded', False) if r else False
182
            r_timeout = r.get('timeout', False) if r else False
183
184
            timeout &= r_timeout
185
186
            if allow_partial_failure:
187
                success |= r_succeess
188
                if success:
189
                    break
190
            else:
191
                success &= r_succeess
192
                if not success:
193
                    break
194
195
        status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success,
196
                                                                           timeout=timeout)
197
198
        return status
199
200
    @staticmethod
201
    def _get_status_for_success_and_timeout(success, timeout):
202
        if success:
203
            status = LIVEACTION_STATUS_SUCCEEDED
204
        elif timeout:
205
            # Note: Right now we only set status to timeout if all the hosts have timed out
206
            status = LIVEACTION_STATUS_TIMED_OUT
207
        else:
208
            status = LIVEACTION_STATUS_FAILED
209
        return status
210