Conditions | 28 |
Total Lines | 248 |
Code Lines | 110 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like amd.compare.compare() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | """Functions for comparing AMDs and PDDs of crystals. |
||
40 | def compare( |
||
41 | crystals: CompareInput, |
||
42 | crystals_: Optional[CompareInput] = None, |
||
43 | by: str = 'AMD', |
||
44 | k: int = 100, |
||
45 | n_neighbors: Optional[int] = None, |
||
46 | csd_refcodes: bool = False, |
||
47 | verbose: bool = True, |
||
48 | **kwargs |
||
49 | ) -> pd.DataFrame: |
||
50 | r"""Given one or two sets of crystals, compare by AMD or PDD and |
||
51 | return a pandas DataFrame of the distance matrix. |
||
52 | |||
53 | Given one or two paths to CIFs, periodic sets, CSD refcodes or lists |
||
54 | thereof, compare by AMD or PDD and return a pandas DataFrame of the |
||
55 | distance matrix. Default is to comapre by AMD with k = 100. Accepts |
||
56 | any keyword arguments accepted by |
||
57 | :class:`CifReader <.io.CifReader>`, |
||
58 | :class:`CSDReader <.io.CSDReader>` and functions from |
||
59 | :mod:`.compare`. |
||
60 | |||
61 | Parameters |
||
62 | ---------- |
||
63 | crystals : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>` |
||
64 | A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple |
||
65 | or a list of those. |
||
66 | crystals\_ : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`, optional |
||
67 | A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple |
||
68 | or a list of those. |
||
69 | by : str, default 'AMD' |
||
70 | Use AMD or PDD to compare crystals. |
||
71 | k : int, default 100 |
||
72 | Parameter for AMD/PDD, the number of neighbor atoms to consider |
||
73 | for each atom in a unit cell. |
||
74 | n_neighbors : int, deafult None |
||
75 | Find a number of nearest neighbors instead of a full distance |
||
76 | matrix between crystals. |
||
77 | csd_refcodes : bool, optional, csd-python-api only |
||
78 | Interpret ``crystals`` and ``crystals_`` as CSD refcodes or |
||
79 | lists thereof, rather than paths. |
||
80 | verbose: bool, optional |
||
81 | If True, prints a progress bar during reading, calculating and |
||
82 | comparing items. |
||
83 | **kwargs : |
||
84 | Any keyword arguments accepted by the ``amd.CifReader``, |
||
85 | ``amd.CSDReader``, ``amd.PDD`` and functions used to compare: |
||
86 | ``reader``, ``remove_hydrogens``, ``disorder``, |
||
87 | ``heaviest_component``, ``molecular_centres``, |
||
88 | ``show_warnings``, (from class:`CifReader <.io.CifReader>`), |
||
89 | ``refcode_families`` (from :class:`CSDReader <.io.CSDReader>`), |
||
90 | ``collapse_tol`` (from :func:`PDD <.calculate.PDD>`), |
||
91 | ``metric``, ``low_memory`` |
||
92 | (from :func:`AMD_pdist <.compare.AMD_pdist>`), ``metric``, |
||
93 | ``backend``, ``n_jobs``, ``verbose``, |
||
94 | (from :func:`PDD_pdist <.compare.PDD_pdist>`), ``algorithm``, |
||
95 | ``leaf_size``, ``metric``, ``p``, ``metric_params``, ``n_jobs`` |
||
96 | (from :func:`_nearest_items <.compare._nearest_items>`). |
||
97 | |||
98 | Returns |
||
99 | ------- |
||
100 | df : :class:`pandas.DataFrame` |
||
101 | DataFrame of the distance matrix for the given crystals compared |
||
102 | by the chosen invariant. |
||
103 | |||
104 | Raises |
||
105 | ------ |
||
106 | ValueError |
||
107 | If by is not 'AMD' or 'PDD', if either set given have no valid |
||
108 | crystals to compare, or if crystals or crystals\_ are an invalid |
||
109 | type. |
||
110 | |||
111 | Examples |
||
112 | -------- |
||
113 | Compare everything in a .cif (deafult, AMD with k=100):: |
||
114 | |||
115 | df = amd.compare('data.cif') |
||
116 | |||
117 | Compare everything in one cif with all crystals in all cifs in a |
||
118 | directory (PDD, k=50):: |
||
119 | |||
120 | df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50) |
||
121 | |||
122 | **Examples (csd-python-api only)** |
||
123 | |||
124 | Compare two crystals by CSD refcode (PDD, k=50):: |
||
125 | |||
126 | df = amd.compare('DEBXIT01', 'DEBXIT02', csd_refcodes=True, by='PDD', k=50) |
||
127 | |||
128 | Compare everything in a refcode family (AMD, k=100):: |
||
129 | |||
130 | df = amd.compare('DEBXIT', csd_refcodes=True, families=True) |
||
131 | """ |
||
132 | |||
133 | def _default_kwargs(func: Callable) -> dict: |
||
134 | """Get the default keyword arguments from ``func``, if any |
||
135 | arguments are in ``kwargs`` then replace with the value in |
||
136 | ``kwargs`` instead of the default. |
||
137 | """ |
||
138 | return { |
||
139 | k: v.default for k, v in inspect.signature(func).parameters.items() |
||
140 | if v.default is not inspect.Parameter.empty |
||
141 | } |
||
142 | |||
143 | def _unwrap_refcode_list( |
||
144 | refcodes: List[str], **reader_kwargs |
||
145 | ) -> List[PeriodicSet]: |
||
146 | """Given string or list of strings, interpret as CSD refcodes |
||
147 | and return a list of ``PeriodicSet`` objects. |
||
148 | """ |
||
149 | if not all(isinstance(refcode, str) for refcode in refcodes): |
||
150 | raise TypeError( |
||
151 | f'amd.compare(csd_refcodes=True) expects a string or list of ' |
||
152 | 'strings.' |
||
153 | ) |
||
154 | return list(CSDReader(refcodes, **reader_kwargs)) |
||
155 | |||
156 | def _unwrap_pset_list( |
||
157 | psets: List[Union[str, PeriodicSet]], **reader_kwargs |
||
158 | ) -> List[PeriodicSet]: |
||
159 | """Given a list of strings or ``PeriodicSet`` objects, interpret |
||
160 | strings as paths and unwrap all items into one list of |
||
161 | ``PeriodicSet``s. |
||
162 | """ |
||
163 | ret = [] |
||
164 | for item in psets: |
||
165 | if isinstance(item, PeriodicSet): |
||
166 | ret.append(item) |
||
167 | else: |
||
168 | try: |
||
169 | path = Path(item) |
||
170 | except TypeError: |
||
171 | raise ValueError( |
||
172 | 'amd.compare() expects strings or amd.PeriodicSets, ' |
||
173 | f'got {item.__class__.__name__}' |
||
174 | ) |
||
175 | ret.extend(CifReader(path, **reader_kwargs)) |
||
176 | return ret |
||
177 | |||
178 | by = by.upper() |
||
179 | if by not in ('AMD', 'PDD'): |
||
180 | raise ValueError( |
||
181 | "'by' parameter of amd.compare() must be 'AMD' or 'PDD' (passed " |
||
182 | f"'{by}')" |
||
183 | ) |
||
184 | |||
185 | # Sort out keyword arguments |
||
186 | cifreader_kwargs = _default_kwargs(CifReader.__init__) |
||
187 | csdreader_kwargs = _default_kwargs(CSDReader.__init__) |
||
188 | csdreader_kwargs.pop('refcodes', None) |
||
189 | pdd_kwargs = _default_kwargs(PDD) |
||
190 | pdd_kwargs.pop('return_row_groups', None) |
||
191 | compare_amds_kwargs = _default_kwargs(AMD_pdist) |
||
192 | compare_pdds_kwargs = _default_kwargs(PDD_pdist) |
||
193 | nearest_items_kwargs = _default_kwargs(_nearest_items) |
||
194 | nearest_items_kwargs.pop('XB', None) |
||
195 | cifreader_kwargs['verbose'] = verbose |
||
196 | csdreader_kwargs['verbose'] = verbose |
||
197 | compare_pdds_kwargs['verbose'] = verbose |
||
198 | |||
199 | for default_kwargs in ( |
||
200 | cifreader_kwargs, csdreader_kwargs, pdd_kwargs, compare_amds_kwargs, |
||
201 | compare_pdds_kwargs, nearest_items_kwargs |
||
202 | ): |
||
203 | for kw in default_kwargs: |
||
204 | if kw in kwargs: |
||
205 | default_kwargs[kw] = kwargs[kw] |
||
206 | |||
207 | # Get list of periodic sets from first input |
||
208 | if not isinstance(crystals, list): |
||
209 | crystals = [crystals] |
||
210 | if csd_refcodes: |
||
211 | crystals = _unwrap_refcode_list(crystals, **csdreader_kwargs) |
||
212 | else: |
||
213 | crystals = _unwrap_pset_list(crystals, **cifreader_kwargs) |
||
214 | if not crystals: |
||
215 | raise ValueError( |
||
216 | 'First argument passed to amd.compare() contains no valid ' |
||
217 | 'crystals/periodic sets to compare.' |
||
218 | ) |
||
219 | names = [s.name for s in crystals] |
||
220 | if verbose: |
||
221 | crystals = tqdm.tqdm(crystals, desc='Calculating', delay=1) |
||
222 | |||
223 | # Get list of periodic sets from second input if given |
||
224 | if crystals_ is None: |
||
225 | names_ = names |
||
226 | else: |
||
227 | if not isinstance(crystals_, list): |
||
228 | crystals_ = [crystals_] |
||
229 | if csd_refcodes: |
||
230 | crystals_ = _unwrap_refcode_list(crystals_, **csdreader_kwargs) |
||
231 | else: |
||
232 | crystals_ = _unwrap_pset_list(crystals_, **cifreader_kwargs) |
||
233 | if not crystals_: |
||
234 | raise ValueError( |
||
235 | 'Second argument passed to amd.compare() contains no ' |
||
236 | 'valid crystals/periodic sets to compare.' |
||
237 | ) |
||
238 | names_ = [s.name for s in crystals_] |
||
239 | if verbose: |
||
240 | crystals_ = tqdm.tqdm(crystals_, desc='Calculating', delay=1) |
||
241 | |||
242 | if by == 'AMD': |
||
243 | |||
244 | amds = np.empty((len(names), k), dtype=np.float64) |
||
245 | for i, s in enumerate(crystals): |
||
246 | amds[i] = AMD(s, k) |
||
247 | |||
248 | if crystals_ is None: |
||
249 | if n_neighbors is None: |
||
250 | dm = squareform(AMD_pdist(amds, **compare_amds_kwargs)) |
||
251 | return pd.DataFrame(dm, index=names, columns=names_) |
||
252 | else: |
||
253 | nn_dm, inds = _nearest_items( |
||
254 | n_neighbors, amds, **nearest_items_kwargs |
||
255 | ) |
||
256 | return _nearest_neighbors_dataframe(nn_dm, inds, names, names_) |
||
257 | else: |
||
258 | amds_ = np.empty((len(names_), k), dtype=np.float64) |
||
259 | for i, s in enumerate(crystals_): |
||
260 | amds_[i] = AMD(s, k) |
||
261 | |||
262 | if n_neighbors is None: |
||
263 | dm = AMD_cdist(amds, amds_, **compare_amds_kwargs) |
||
264 | return pd.DataFrame(dm, index=names, columns=names_) |
||
265 | else: |
||
266 | nn_dm, inds = _nearest_items( |
||
267 | n_neighbors, amds, amds_, **nearest_items_kwargs |
||
268 | ) |
||
269 | return _nearest_neighbors_dataframe(nn_dm, inds, names, names_) |
||
270 | |||
271 | elif by == 'PDD': |
||
272 | |||
273 | pdds = [PDD(s, k, **pdd_kwargs) for s in crystals] |
||
274 | |||
275 | if crystals_ is None: |
||
276 | dm = PDD_pdist(pdds, **compare_pdds_kwargs) |
||
277 | if n_neighbors is None: |
||
278 | dm = squareform(dm) |
||
279 | else: |
||
280 | pdds_ = [PDD(s, k, **pdd_kwargs) for s in crystals_] |
||
281 | dm = PDD_cdist(pdds, pdds_, **compare_pdds_kwargs) |
||
282 | |||
283 | if n_neighbors is None: |
||
284 | return pd.DataFrame(dm, index=names, columns=names_) |
||
285 | else: |
||
286 | nn_dm, inds = _neighbors_from_distance_matrix(n_neighbors, dm) |
||
287 | return _nearest_neighbors_dataframe(nn_dm, inds, names, names_) |
||
288 | |||
716 |