Completed
Pull Request — master (#848)
by
unknown
01:18
created

prompt_choice()   A

Complexity

Conditions 3

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 22
rs 9.2
1
# -*- coding: utf-8 -*-
2
3
import codecs
4
import collections
5
import json
6
import pprint
7
8
import click
9
from jinja2 import Environment
10
11
DEFAULT_PROMPT = 'Please enter a value for "{variable.name}"'
12
13
VALID_TYPES = [
14
    'boolean',
15
    'yes_no',
16
    'int',
17
    'json',
18
    'string',
19
]
20
21
22
def prompt_string(variable, default):
23
    return click.prompt(
24
        variable.prompt,
25
        default=default,
26
        hide_input=variable.hide_input,
27
        type=click.STRING,
28
    )
29
30
31
def prompt_boolean(variable, default):
32
    return click.prompt(
33
        variable.prompt,
34
        default=default,
35
        hide_input=variable.hide_input,
36
        type=click.BOOL,
37
    )
38
39
40
def prompt_int(variable, default):
41
    return click.prompt(
42
        variable.prompt,
43
        default=default,
44
        hide_input=variable.hide_input,
45
        type=click.INT,
46
    )
47
48
49
def prompt_json(variable, default):
50
    # The JSON object from cookiecutter.json might be very large
51
    # We only show 'default'
52
    DEFAULT_JSON = 'default'
53
54
    def process_json(user_value):
55
        try:
56
            return json.loads(
57
                user_value,
58
                object_pairs_hook=collections.OrderedDict,
59
            )
60
        except json.decoder.JSONDecodeError:
61
            # Leave it up to click to ask the user again
62
            raise click.UsageError('Unable to decode to JSON.')
63
64
    dict_value = click.prompt(
65
        variable.prompt,
66
        default=DEFAULT_JSON,
67
        hide_input=variable.hide_input,
68
        type=click.STRING,
69
        value_proc=process_json,
70
    )
71
72
    if dict_value == DEFAULT_JSON:
73
        # Return the given default w/o any processing
74
        return default
75
    return dict_value
76
77
78
def prompt_yes_no(variable, default):
79
    if default is True:
80
        default_display = 'y'
81
    else:
82
        default_display = 'n'
83
84
    return click.prompt(
85
        variable.prompt,
86
        default=default_display,
87
        hide_input=variable.hide_input,
88
        type=click.BOOL,
89
    )
90
91
92
def prompt_choice(variable, default):
93
    """Returns prompt, default and callback for a choice variable"""
94
    choice_map = collections.OrderedDict(
95
        (u'{}'.format(i), value)
96
        for i, value in enumerate(variable.choices, 1)
97
    )
98
    choices = choice_map.keys()
99
100
    prompt = u'\n'.join((
101
        variable.prompt,
102
        u'\n'.join([u'{} - {}'.format(*c) for c in choice_map.items()]),
103
        u'Choose from {}'.format(u', '.join(choices)),
104
    ))
105
    default = str(variable.choices.index(default) + 1)
106
107
    user_choice = click.prompt(
108
        prompt,
109
        default=default,
110
        hide_input=variable.hide_input,
111
        type=click.Choice(choices),
112
    )
113
    return choice_map[user_choice]
114
115
116
PROMPTS = {
117
    'string': prompt_string,
118
    'boolean': prompt_boolean,
119
    'int': prompt_int,
120
    'json': prompt_json,
121
    'yes_no': prompt_yes_no,
122
}
123
124
125
def deserialize_string(value):
126
    return str(value)
127
128
129
def deserialize_boolean(value):
130
    return bool(value)
131
132
133
def deserialize_yes_no(value):
134
    return bool(value)
135
136
137
def deserialize_int(value):
138
    return int(value)
139
140
141
def deserialize_json(value):
142
    return value
143
144
145
DESERIALIZERS = {
146
    'string': deserialize_string,
147
    'boolean': deserialize_boolean,
148
    'int': deserialize_int,
149
    'json': deserialize_json,
150
    'yes_no': deserialize_yes_no,
151
}
152
153
154
class Variable(object):
155
    def __init__(self, name, default, **info):
156
157
        # mandatory fields
158
        self.name = name
159
        self.default = default
160
161
        # optional fields
162
        self.description = info.get('description', None)
163
        self.prompt = info.get('prompt', DEFAULT_PROMPT.format(variable=self))
164
        self.hide_input = info.get('hide_input', False)
165
166
        self.var_type = info.get('type', 'string')
167
        if self.var_type not in VALID_TYPES:
168
            msg = 'Invalid type {var_type} for variable'
169
            raise ValueError(msg.format(var_type=self.var_type))
170
171
        self.skip_if = info.get('skip_if', '')
172
        if not isinstance(self.skip_if, str):
173
            # skip_if was specified in cookiecutter.json
174
            msg = 'Field skip_if is required to be a str, got {value}'
175
            raise ValueError(msg.format(value=self.skip_if))
176
177
        self.prompt_user = info.get('prompt_user', True)
178
        if not isinstance(self.prompt_user, bool):
179
            # prompt_user was specified in cookiecutter.json
180
            msg = 'Field prompt_user is required to be a bool, got {value}'
181
            raise ValueError(msg.format(value=self.prompt_user))
182
183
        # choices are somewhat special as they can of every type
184
        self.choices = info.get('choices', [])
185
        if self.choices and default not in self.choices:
186
            msg = 'Invalid default value {default} for choice variable'
187
            raise ValueError(msg.format(default=self.default))
188
189
    def __repr__(self):
190
        return "<{class_name} {variable_name}>".format(
191
            class_name=self.__class__.__name__,
192
            variable_name=self.name,
193
        )
194
195
196
class CookiecutterTemplate(object):
197
    def __init__(self, name, cookiecutter_version, variables, **info):
198
        # mandatory fields
199
        self.name = name
200
        self.cookiecutter_version = cookiecutter_version
201
        self.variables = [Variable(**v) for v in variables]
202
203
        # optional fields
204
        self.authors = info.get('authors', [])
205
        self.description = info.get('description', None)
206
        self.keywords = info.get('keywords', [])
207
        self.license = info.get('license', None)
208
        self.url = info.get('url', None)
209
        self.version = info.get('version', None)
210
211
    def __repr__(self):
212
        return "<{class_name} {template_name}>".format(
213
            class_name=self.__class__.__name__,
214
            template_name=self.name,
215
        )
216
217
    def __iter__(self):
218
        for v in self.variables:
219
            yield v
220
221
222
def load_context(json_object, verbose):
223
    env = Environment(extensions=['jinja2_time.TimeExtension'])
224
    context = collections.OrderedDict({})
225
226
    for variable in CookiecutterTemplate(**json_object):
227
        if variable.skip_if:
228
            skip_template = env.from_string(variable.skip_if)
229
            if skip_template.render(cookiecutter=context) == 'True':
230
                continue
231
232
        default = variable.default
233
234
        if isinstance(default, str):
235
            template = env.from_string(default)
236
            default = template.render(cookiecutter=context)
237
238
        deserialize = DESERIALIZERS[variable.var_type]
239
240
        if not variable.prompt_user:
241
            context[variable.name] = deserialize(default)
242
            continue
243
244
        if variable.choices:
245
            prompt = prompt_choice
246
        else:
247
            prompt = PROMPTS[variable.var_type]
248
249
        if verbose and variable.description:
250
            click.echo(variable.description)
251
252
        value = prompt(variable, default)
253
254
        if verbose:
255
            width, _ = click.get_terminal_size()
256
            click.echo('-' * width)
257
258
        context[variable.name] = deserialize(value)
259
260
    return context
261
262
263
def main(file_path):
264
    """Load the json object and prompt the user for input"""
265
266
    with codecs.open(file_path, 'r', encoding='utf8') as f:
267
        json_object = json.load(f, object_pairs_hook=collections.OrderedDict)
268
269
    pprint.pprint(load_context(json_object, True))
270
271
if __name__ == '__main__':
272
    main('tests/new-context/cookiecutter.json')
273