Completed
Pull Request — master (#83)
by
unknown
01:06
created

GOEnrichmentStudy._get_pval_uncorr()   C

Complexity

Conditions 7

Size

Total Lines 65

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 7
c 3
b 0
f 0
dl 0
loc 65
rs 6.0105

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
4
"""
5
python %prog study.file population.file gene-association.file
6
7
This program returns P-values for functional enrichment in a cluster of
8
study genes using Fisher's exact test, and corrected for multiple testing
9
(including Bonferroni, Holm, Sidak, and false discovery rate)
10
"""
11
12
from __future__ import absolute_import
13
14
__copyright__ = "Copyright (C) 2010-2017, H Tang et al., All rights reserved."
15
__author__ = "various"
16
17
import sys
18
import collections as cx
19
import datetime
20
from functools import partial
21
import multiprocessing
22
23
24
from goatools.multiple_testing import Methods, Bonferroni, Sidak, HolmBonferroni, FDR, calc_qval
25
from goatools.ratio import get_terms, count_terms, is_ratio_different
26
import goatools.wr_tbl as RPT
27
from goatools.pvalcalc import FisherFactory
28
from .multiprocessing_tools import p_map
29
30
31
def compute_pvals(allterms, calc_pvalue, go2studyitems, go2popitems,
32
                  study_n, pop_n):
33
34
    results = {}
35
    for term in allterms:
36
37
        study_items = go2studyitems.get(term, set())
38
        study_count = len(study_items)
39
        pop_items = go2popitems.get(term, set())
40
        pop_count = len(pop_items)
41
42
        results[term] = calc_pvalue(study_count, study_n, pop_count, pop_n)
43
44
    return results
45
46
47
class GOEnrichmentRecord(object):
48
    """Represents one result (from a single GOTerm) in the GOEnrichmentStudy
49
    """
50
    namespace2NS = cx.OrderedDict([
51
        ('biological_process', 'BP'),
52
        ('molecular_function', 'MF'),
53
        ('cellular_component', 'CC')])
54
55
    # Fields seen in every enrichment result
56
    _fldsdefprt = [
57
        "GO",
58
        "NS",
59
        "enrichment",
60
        "name",
61
        "ratio_in_study",
62
        "ratio_in_pop",
63
        "p_uncorrected",
64
        "depth",
65
        "study_count",
66
        "study_items"]
67
    _fldsdeffmt = ["%s"]*3 + ["%-30s"] + ["%d/%d"] * 2 + ["%.3g"] + ["%d"] * 2 + ["%15s"]
68
69
    _flds = set(_fldsdefprt).intersection(
70
        set(['study_items', 'study_count', 'study_n', 'pop_items', 'pop_count', 'pop_n']))
71
72
    def __init__(self, **kwargs):
73
        # Methods seen in current enrichment result
74
        self._methods = []
75
        for k, v in kwargs.items():
76
            setattr(self, k, v)
77
            if k == 'ratio_in_study':
78
                setattr(self, 'study_count', v[0])
79
                setattr(self, 'study_n', v[1])
80
            if k == 'ratio_in_pop':
81
                setattr(self, 'pop_count', v[0])
82
                setattr(self, 'pop_n', v[1])
83
        self._init_enrichment()
84
        self.goterm = None  # the reference to the GOTerm
85
86
    def get_method_name(self):
87
        """Return name of first method in the _methods list."""
88
        return self._methods[0].fieldname
89
90
    def get_pvalue(self):
91
        """Returns pval for 1st method, if it exists. Else returns uncorrected pval."""
92
        if self._methods:
93
            return getattr(self, "p_{m}".format(m=self.get_method_name()))
94
        return getattr(self, "p_uncorrected")
95
96
    def set_corrected_pval(self, nt_method, pvalue):
97
        """Add object attribute based on method name."""
98
        self._methods.append(nt_method)
99
        fieldname = "".join(["p_", nt_method.fieldname])
100
        setattr(self, fieldname, pvalue)
101
102
    def __str__(self, indent=False):
103
        field_data = [getattr(self, f, "n.a.") for f in self._fldsdefprt[:-1]] + \
104
                     [getattr(self, "p_{}".format(m.fieldname)) for m in self._methods] + \
105
                     [", ".join(sorted(getattr(self, self._fldsdefprt[-1], set())))]
106
        fldsdeffmt = self._fldsdeffmt
107
        field_formatter = fldsdeffmt[:-1] + ["%.3g"]*len(self._methods) + [fldsdeffmt[-1]]
108
        self._chk_fields(field_data, field_formatter)
109
110
        # default formatting only works for non-"n.a" data
111
        for i, f in enumerate(field_data):
112
            if f == "n.a.":
113
                field_formatter[i] = "%s"
114
115
        # print dots to show the level of the term
116
        dots = self.get_indent_dots() if indent else ""
117
        prtdata = "\t".join(a % b for (a, b) in zip(field_formatter, field_data))
118
        return "".join([dots, prtdata])
119
120
    def get_indent_dots(self):
121
        """Get a string of dots ("....") representing the level of the GO term."""
122
        return "." * self.goterm.level if self.goterm is not None else ""
123
124
    @staticmethod
125
    def _chk_fields(field_data, field_formatter):
126
        """Check that expected fields are present."""
127
        if len(field_data) == len(field_formatter):
128
            return
129
        len_dat = len(field_data)
130
        len_fmt = len(field_formatter)
131
        msg = [
132
            "FIELD DATA({d}) != FORMATTER({f})".format(d=len_dat, f=len_fmt),
133
            "DAT({N}): {D}".format(N=len_dat, D=field_data),
134
            "FMT({N}): {F}".format(N=len_fmt, F=field_formatter)]
135
        raise Exception("\n".join(msg))
136
137
    def __repr__(self):
138
        return "GOEnrichmentRecord({GO})".format(GO=self.GO)
139
140
    def set_goterm(self, go2obj):
141
        """Set goterm and copy GOTerm's name and namespace."""
142
        self.goterm = go2obj.get(self.GO, None)
143
        present = self.goterm is not None
144
        self.name = self.goterm.name if present else "n.a."
145
        self.NS = self.namespace2NS[self.goterm.namespace] if present else "XX"
146
147
    def _init_enrichment(self):
148
        """Mark as 'enriched' or 'purified'."""
149
        self.enrichment = 'e' if ((1.0 * self.study_count / self.study_n) >
150
                                  (1.0 * self.pop_count / self.pop_n)) else 'p'
151
152
    def update_remaining_fldsdefprt(self, min_ratio=None):
153
        """Finish updating self (GOEnrichmentRecord) field, is_ratio_different."""
154
        self.is_ratio_different = is_ratio_different(min_ratio, self.study_count,
155
                                                     self.study_n, self.pop_count, self.pop_n)
156
157
158
    # -------------------------------------------------------------------------------------
159
    # Methods for getting flat namedtuple values from GOEnrichmentRecord object
160
    def get_prtflds_default(self):
161
        """Get default fields."""
162
        return self._fldsdefprt[:-1] + \
163
               ["p_{M}".format(M=m.fieldname) for m in self._methods] + \
164
               [self._fldsdefprt[-1]]
165
166
    def get_prtflds_all(self):
167
        """When converting to a namedtuple, get all possible fields in their original order."""
168
        flds = []
169
        dont_add = set(['_parents', '_methods'])
170
        # Fields: GO NS enrichment name ratio_in_study ratio_in_pop p_uncorrected
171
        #         depth study_count p_sm_bonferroni p_fdr_bh study_items
172
        self._flds_append(flds, self.get_prtflds_default(), dont_add)
173
        # Fields: GO NS goterm
174
        #         ratio_in_pop pop_n pop_count pop_items name
175
        #         ratio_in_study study_n study_count study_items
176
        #         _methods enrichment p_uncorrected p_sm_bonferroni p_fdr_bh
177
        self._flds_append(flds, vars(self).keys(), dont_add)
178
        # Fields: name level is_obsolete namespace id depth parents children _parents alt_ids
179
        self._flds_append(flds, vars(self.goterm).keys(), dont_add)
180
        return flds
181
182
    @staticmethod
183
    def _flds_append(flds, addthese, dont_add):
184
        """Retain order of fields as we add them once to the list."""
185
        for fld in addthese:
186
            if fld not in flds and fld not in dont_add:
187
                flds.append(fld)
188
189
    def get_field_values(self, fldnames, rpt_fmt=True, itemid2name=None):
190
        """Get flat namedtuple fields for one GOEnrichmentRecord."""
191
        row = []
192
        # Loop through each user field desired
193
        for fld in fldnames:
194
            # 1. Check the GOEnrichmentRecord's attributes
195
            val = getattr(self, fld, None)
196
            if val is not None:
197
                if rpt_fmt:
198
                    val = self._get_rpt_fmt(fld, val, itemid2name)
199
                row.append(val)
200
            else:
201
                # 2. Check the GO object for the field
202
                val = getattr(self.goterm, fld, None)
203
                if rpt_fmt:
204
                    val = self._get_rpt_fmt(fld, val, itemid2name)
205
                if val is not None:
206
                    row.append(val)
207
                else:
208
                    # 3. Field not found, raise Exception
209
                    self._err_fld(fld, fldnames, row)
210
            if rpt_fmt:
211
                assert not isinstance(val, list), \
212
                   "UNEXPECTED LIST: FIELD({F}) VALUE({V}) FMT({P})".format(
213
                       P=rpt_fmt, F=fld, V=val)
214
        return row
215
216
    @staticmethod
217
    def _get_rpt_fmt(fld, val, itemid2name=None):
218
        """Return values in a format amenable to printing in a table."""
219
        if fld.startswith("ratio_"):
220
            return "{N}/{TOT}".format(N=val[0], TOT=val[1])
221
        elif fld in set(['study_items', 'pop_items', 'alt_ids']):
222
            if itemid2name is not None:
223
                val = [itemid2name.get(v, v) for v in val]
224
            return ", ".join([str(v) for v in sorted(val)])
225
        return val
226
227
    def _err_fld(self, fld, fldnames):
228
        """Unrecognized field. Print detailed Failure message."""
229
        msg = ['ERROR. UNRECOGNIZED FIELD({F})'.format(F=fld)]
230
        actual_flds = set(self.get_prtflds_default() + self.goterm.__dict__.keys())
231
        bad_flds = set(fldnames).difference(set(actual_flds))
232
        if bad_flds:
233
            msg.append("\nGOEA RESULT FIELDS: {}".format(" ".join(self._fldsdefprt)))
234
            msg.append("GO FIELDS: {}".format(" ".join(self.goterm.__dict__.keys())))
235
            msg.append("\nFATAL: {N} UNEXPECTED FIELDS({F})\n".format(
236
                N=len(bad_flds), F=" ".join(bad_flds)))
237
            msg.append("  {N} User-provided fields:".format(N=len(fldnames)))
238
            for idx, fld in enumerate(fldnames, 1):
239
                mrk = "ERROR -->" if fld in bad_flds else ""
240
                msg.append("  {M:>9} {I:>2}) {F}".format(M=mrk, I=idx, F=fld))
241
        raise Exception("\n".join(msg))
242
243
244
class GOEnrichmentStudy(object):
245
    """Runs Fisher's exact test, as well as multiple corrections
246
    """
247
    # Default Excel table column widths for GOEA results
248
    default_fld2col_widths = {
249
        'NS'        :  3,
250
        'GO'        : 12,
251
        'level'     :  3,
252
        'enrichment':  1,
253
        'name'      : 60,
254
        'ratio_in_study':  8,
255
        'ratio_in_pop'  : 12,
256
        'study_items'   : 15,
257
    }
258
259
    def __init__(self, pop, assoc, obo_dag, propagate_counts=True, alpha=.05, methods=None, **kws):
260
        self.log = kws.get('log', sys.stdout)
261
        # -1 avoids freezing of the machine:
262
        self.n_cores = kws.get('n_cores', multiprocessing.cpu_count() - 1)
263
        self._run_multitest = {
264
            'local':lambda iargs: self._run_multitest_local(iargs),
265
            'statsmodels':lambda iargs: self._run_multitest_statsmodels(iargs)}
266
        self.pop = pop
267
        self.pop_n = len(pop)
268
        self.assoc = assoc
269
        self.obo_dag = obo_dag
270
        self.alpha = alpha
271
        if methods is None:
272
            methods = ["bonferroni", "sidak", "holm"]
273
        self.methods = Methods(methods)
274
        self.pval_obj = FisherFactory(**kws).pval_obj
275
276
        if propagate_counts:
277
            sys.stderr.write("Propagating term counts to parents ..\n")
278
            obo_dag.update_association(assoc)
279
        self.go2popitems = get_terms("population", pop, assoc, obo_dag, self.log)
280
281
    def run_study(self, study, **kws):
282
        """Run Gene Ontology Enrichment Study (GOEA) on study ids."""
283
        # Key-word arguments:
284
        methods = Methods(kws['methods']) if 'methods' in kws else self.methods
285
        alpha = kws['alpha'] if 'alpha' in kws else self.alpha
286
        log = kws['log'] if 'log' in kws else self.log
287
        # Calculate uncorrected pvalues
288
        results = self._get_pval_uncorr(study, log)
289
        if not results:
290
            return []
291
292
        # Do multipletest corrections on uncorrected pvalues and update results
293
        self._run_multitest_corr(results, methods, alpha, study, log)
294
295
        for rec in results:
296
            # get go term for name and level
297
            rec.set_goterm(self.obo_dag)
298
299
        # 'keep_if' can be used to keep only significant GO terms. Example:
300
        #     >>> keep_if = lambda nt: nt.p_fdr_bh < 0.05 # if results are significant
301
        #     >>> goea_results = goeaobj.run_study(geneids_study, keep_if=keep_if)
302
        if 'keep_if' in kws:
303
            keep_if = kws['keep_if']
304
            results = [r for r in results if keep_if(r)]
305
306
        # Default sort order: First, sort by BP, MF, CC. Second, sort by pval
307
        results.sort(key=lambda r: [r.NS, r.p_uncorrected])
308
309
        if log is not None:
310
            log.write("  {MSG}\n".format(MSG="\n  ".join(self.get_results_msg(results, study))))
311
312
        return results # list of GOEnrichmentRecord objects
313
314
    def run_study_nts(self, study, **kws):
315
        """Run GOEA on study ids. Return results as a list of namedtuples."""
316
        goea_results = self.run_study(study, **kws)
317
        return get_goea_nts_all(goea_results)
318
319
    def get_results_msg(self, results, study):
320
        """Return summary for GOEA results."""
321
        # To convert msg list to string: "\n".join(msg)
322
        msg = []
323
        if results:
324
            stu_items, num_gos_stu = self.get_item_cnt(results, "study_items")
325
            pop_items, num_gos_pop = self.get_item_cnt(results, "pop_items")
326
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} study items".format(
327
                N=len(stu_items), NT=len(set(study)), M=num_gos_stu))
328
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} population items".format(
329
                N=len(pop_items), NT=self.pop_n, M=num_gos_pop))
330
        return msg
331
332
    def _get_pval_uncorr(self, study, log=sys.stdout):
333
        """Calculate the uncorrected pvalues for study items."""
334
        if log is not None:
335
            log.write("Calculating uncorrected p-values using {PFNC}\n".format(PFNC=self.pval_obj.name))
336
        go2studyitems = get_terms("study", study, self.assoc, self.obo_dag, log)
337
        pop_n, study_n = self.pop_n, len(study)
338
        allterms = set(go2studyitems.keys()).union(set(self.go2popitems.keys()))
339
340
        n_cores = self.n_cores
341
342
        # for smaller sizes starup/shutdown of a multiprocessing Pool is too expensive
343
        # limit is based on some heursistincs on a recent Mac.
344
        if len(allterms) < 100:
345
            n_cores = 1
346
347
        log.write("use {} cores for computing pvalues\n".format(n_cores))
348
349
        allterms = list(allterms)
350
        fragments = [allterms[i::n_cores] for i in range(n_cores)]
351
352
        # bind arguments, so that remote_func only depends on fragment of terms to process:
353
        calc_pvalue = self.pval_obj.calc_pvalue
354
        remote_func = partial(compute_pvals, calc_pvalue=calc_pvalue, go2studyitems=go2studyitems,
355
                              go2popitems=self.go2popitems, study_n=study_n, pop_n=pop_n)
356
357
358
        if n_cores == 1:
359
            all_p_values = [remote_func(fragments[0])]
360
        else:
361
            # if self.pval_obj.log is a file handle, which we can not serialize, we could not
362
            # transfer self.pval_obj.calc_pvalue to another python process with multiprocessing.
363
            # therefore we "patch" the object which will later be restored again.
364
            old = self.pval_obj.log
365
            self.pval_obj.log = None
366
            p = multiprocessing.Pool(n_cores)
367
            try:
368
                all_p_values = p_map(p, remote_func, fragments)
369
            finally:
370
                # restore patched file handle
371
                self.pval_obj.log = old
372
                # release memory
373
                p.terminate()
374
375
        results = []
376
377
        for p_values_map in all_p_values:
378
379
            for term, p_value in p_values_map.items():
380
381
                study_items = go2studyitems.get(term, set())
382
                study_count = len(study_items)
383
                pop_items = self.go2popitems.get(term, set())
384
                pop_count = len(pop_items)
385
386
                one_record = GOEnrichmentRecord(
387
                    GO=term,
388
                    p_uncorrected=p_value,
389
                    study_items=study_items,
390
                    pop_items=pop_items,
391
                    ratio_in_study=(study_count, study_n),
392
                    ratio_in_pop=(pop_count, pop_n))
393
394
                results.append(one_record)
395
396
        return results
397
398
    def _run_multitest_corr(self, results, usr_methods, alpha, study, log):
399
        """Do multiple-test corrections on uncorrected pvalues."""
400
        assert 0 < alpha < 1, "Test-wise alpha must fall between (0, 1)"
401
        pvals = [r.p_uncorrected for r in results]
402
        NtMt = cx.namedtuple("NtMt", "results pvals alpha nt_method study")
403
404
        for nt_method in usr_methods:
405
            ntmt = NtMt(results, pvals, alpha, nt_method, study)
406
            if log is not None:
407
                log.write("Running multitest correction: {MSRC} {METHOD}\n".format(
408
                    MSRC=ntmt.nt_method.source, METHOD=ntmt.nt_method.method))
409
            self._run_multitest[nt_method.source](ntmt)
410
411
    def _run_multitest_statsmodels(self, ntmt):
412
        """Use multitest mthods that have been implemented in statsmodels."""
413
        # Only load statsmodels if it is used
414
        multipletests = self.methods.get_statsmodels_multipletests()
415
        results = multipletests(ntmt.pvals, ntmt.alpha, ntmt.nt_method.method)
416
        pvals_corrected = results[1] # reject_lst, pvals_corrected, alphacSidak, alphacBonf
417
        self._update_pvalcorr(ntmt, pvals_corrected)
418
419
    def _run_multitest_local(self, ntmt):
420
        """Use multitest mthods that have been implemented locally."""
421
        corrected_pvals = None
422
        method = ntmt.nt_method.method
423
        if method == "bonferroni":
424
            corrected_pvals = Bonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
425
        elif method == "sidak":
426
            corrected_pvals = Sidak(ntmt.pvals, ntmt.alpha).corrected_pvals
427
        elif method == "holm":
428
            corrected_pvals = HolmBonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
429
        elif method == "fdr":
430
            # get the empirical p-value distributions for FDR
431
            term_pop = getattr(self, 'term_pop', None)
432
            if term_pop is None:
433
                term_pop = count_terms(self.pop, self.assoc, self.obo_dag)
434
            p_val_distribution = calc_qval(len(ntmt.study),
435
                                           self.pop_n,
436
                                           self.pop, self.assoc,
437
                                           term_pop, self.obo_dag)
438
            corrected_pvals = FDR(p_val_distribution,
439
                                  ntmt.results, ntmt.alpha).corrected_pvals
440
441
        self._update_pvalcorr(ntmt, corrected_pvals)
442
443
    @staticmethod
444
    def _update_pvalcorr(ntmt, corrected_pvals):
445
        """Add data members to store multiple test corrections."""
446
        if corrected_pvals is None:
447
            return
448
        for rec, val in zip(ntmt.results, corrected_pvals):
449
            rec.set_corrected_pval(ntmt.nt_method, val)
450
451
    # Methods for writing results into tables: text, tab-separated, Excel spreadsheets
452
    def wr_txt(self, fout_txt, goea_results, prtfmt=None, **kws):
453
        """Print GOEA results to text file."""
454
        if not goea_results:
455
            sys.stdout.write("      0 GOEA results. NOT WRITING {FOUT}\n".format(FOUT=fout_txt))
456
            return
457
        with open(fout_txt, 'w') as prt:
458
            data_nts = self.prt_txt(prt, goea_results, prtfmt, **kws)
459
            log = self.log if self.log is not None else sys.stdout
460
            log.write("  {N:>5} GOEA results for {CUR:5} study items. WROTE: {F}\n".format(
461
                N=len(data_nts),
462
                CUR=len(get_study_items(goea_results)),
463
                F=fout_txt))
464
465
    def prt_txt(self, prt, goea_results, prtfmt=None, **kws):
466
        """Print GOEA results in text format."""
467
        if prtfmt is None:
468
            prtfmt = "{GO} {NS} {p_uncorrected:5.2e} {study_count:>5} {name}\n"
469
        prtfmt = self.adjust_prtfmt(prtfmt)
470
        prt_flds = RPT.get_fmtflds(prtfmt)
471
        data_nts = get_goea_nts_prt(goea_results, prt_flds, **kws)
472
        RPT.prt_txt(prt, data_nts, prtfmt, prt_flds, **kws)
473
        return data_nts
474
475
    def wr_xlsx(self, fout_xlsx, goea_results, **kws):
476
        """Write a xlsx file."""
477
        # kws: prt_if indent itemid2name(study_items)
478
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
479
        xlsx_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
480
        if 'fld2col_widths' not in kws:
481
            kws['fld2col_widths'] = {f:self.default_fld2col_widths.get(f, 8) for f in prt_flds}
482
        RPT.wr_xlsx(fout_xlsx, xlsx_data, **kws)
483
484
    def wr_tsv(self, fout_tsv, goea_results, **kws):
485
        """Write tab-separated table data to file"""
486
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
487
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
488
        RPT.wr_tsv(fout_tsv, tsv_data, **kws)
489
490
    def prt_tsv(self, prt, goea_results, **kws):
491
        """Write tab-separated table data"""
492
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
493
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
494
        RPT.prt_tsv(prt, tsv_data, prt_flds, **kws)
495
496
    @staticmethod
497
    def adjust_prtfmt(prtfmt):
498
        """Adjust format_strings for legal values."""
499
        prtfmt = prtfmt.replace("{p_holm-sidak", "{p_holm_sidak")
500
        prtfmt = prtfmt.replace("{p_simes-hochberg", "{p_simes_hochberg")
501
        return prtfmt
502
503
    @staticmethod
504
    def get_NS2nts(results, fldnames=None, **kws):
505
        """Get namedtuples of GOEA results, split into BP, MF, CC."""
506
        NS2nts = cx.defaultdict(list)
507
        nts = get_goea_nts_all(results, fldnames, **kws)
508
        for nt in nts:
509
            NS2nts[nt.NS].append(nt)
510
        return NS2nts
511
512
    @staticmethod
513
    def get_item_cnt(results, attrname="study_items"):
514
        """Get all study or population items (e.g., geneids)."""
515
        items = set()
516
        go_cnt = 0
517
        for rec in results:
518
            if hasattr(rec, attrname):
519
                items_cur = getattr(rec, attrname)
520
                # Only count GO term if there are items in the set.
521
                if len(items_cur) != 0:
522
                    items |= items_cur
523
                    go_cnt += 1
524
        return items, go_cnt
525
526
    @staticmethod
527
    def get_prtflds_default(results):
528
        """Get default fields names. Used in printing GOEA results.
529
530
           Researchers can control which fields they want to print in the GOEA results
531
           or they can use the default fields.
532
        """
533
        if results:
534
            return results[0].get_prtflds_default()
535
        return []
536
537
    @staticmethod
538
    def print_summary(results, min_ratio=None, indent=False, pval=0.05):
539
        """Print summary."""
540
        from .version import __version__ as version
541
542
        # Header contains provenance and parameters
543
        print("# Generated by GOATOOLS v{0} ({1})".format(version, datetime.date.today()))
544
        print("# min_ratio={0} pval={1}".format(min_ratio, pval))
545
546
        # field names for output
547
        if results:
548
            print("\t".join(GOEnrichmentStudy.get_prtflds_default(results)))
549
550
        for rec in results:
551
            # calculate some additional statistics
552
            # (over_under, is_ratio_different)
553
            rec.update_remaining_fldsdefprt(min_ratio=min_ratio)
554
555
            if pval is not None and rec.p_uncorrected >= pval:
556
                continue
557
558
            if rec.is_ratio_different:
559
                print(rec.__str__(indent=indent))
560
561
    def wr_py_goea_results(self, fout_py, goea_results, **kws):
562
        """Save GOEA results into Python package containing list of namedtuples."""
563
        var_name = kws.get("var_name", "goea_results")
564
        docstring = kws.get("docstring", "")
565
        sortby = kws.get("sortby", None)
566
        if goea_results:
567
            from goatools.nt_utils import wr_py_nts
568
            nts_goea = goea_results
569
            # If list has GOEnrichmentRecords or verbose namedtuples, exclude some fields.
570
            if hasattr(goea_results[0], "_fldsdefprt") or hasattr(goea_results[0], 'goterm'):
571
                # Exclude some attributes from the namedtuple when saving results
572
                # to a Python file because the information is redundant or verbose.
573
                nts_goea = get_goea_nts_prt(goea_results)
574
            docstring = "\n".join([docstring, "# {VER}\n\n".format(VER=self.obo_dag.version)])
575
            assert hasattr(nts_goea[0], '_fields')
576
            if sortby is None:
577
                sortby = lambda nt: getattr(nt, 'p_uncorrected')
578
            nts_goea = sorted(nts_goea, key=sortby)
579
            wr_py_nts(fout_py, nts_goea, docstring, var_name)
580
581
def get_study_items(goea_results):
582
    """Get all study items (e.g., geneids)."""
583
    study_items = set()
584
    for rec in goea_results:
585
        study_items |= rec.study_items
586
    return study_items
587
588
def get_goea_nts_prt(goea_results, fldnames=None, **usr_kws):
589
    """Return list of namedtuples removing fields which are redundant or verbose."""
590
    kws = usr_kws.copy()
591
    if 'not_fldnames' not in kws:
592
        kws['not_fldnames'] = ['goterm', 'parents', 'children', 'id']
593
    if 'rpt_fmt' not in kws:
594
        kws['rpt_fmt'] = True
595
    return get_goea_nts_all(goea_results, fldnames, **kws)
596
597
def get_goea_nts_all(goea_results, fldnames=None, **kws):
598
    """Get namedtuples containing user-specified (or default) data from GOEA results.
599
600
        Reformats data from GOEnrichmentRecord objects into lists of
601
        namedtuples so the generic table writers may be used.
602
    """
603
    # kws: prt_if indent itemid2name(study_items)
604
    data_nts = [] # A list of namedtuples containing GOEA results
605
    if not goea_results:
606
        return data_nts
607
    keep_if = kws.get('keep_if', None)
608
    rpt_fmt = kws.get('rpt_fmt', False)
609
    indent = kws.get('indent', False)
610
    # I. FIELD (column) NAMES
611
    not_fldnames = kws.get('not_fldnames', None)
612
    if fldnames is None:
613
        fldnames = get_fieldnames(goea_results[0])
614
    # Ia. Explicitly exclude specific fields from named tuple
615
    if not_fldnames is not None:
616
        fldnames = [f for f in fldnames if f not in not_fldnames]
617
    nttyp = cx.namedtuple("NtGoeaResults", " ".join(fldnames))
618
    goid_idx = fldnames.index("GO") if 'GO' in fldnames else None
619
    # II. Loop through GOEA results stored in a GOEnrichmentRecord object
620
    for goerec in goea_results:
621
        vals = get_field_values(goerec, fldnames, rpt_fmt, kws.get('itemid2name', None))
622
        if indent:
623
            vals[goid_idx] = "".join([goerec.get_indent_dots(), vals[goid_idx]])
624
        ntobj = nttyp._make(vals)
625
        if keep_if is None or keep_if(goerec):
626
            data_nts.append(ntobj)
627
    return data_nts
628
629
def get_field_values(item, fldnames, rpt_fmt=None, itemid2name=None):
630
    """Return fieldnames and values of either a namedtuple or GOEnrichmentRecord."""
631
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
632
        return item.get_field_values(fldnames, rpt_fmt, itemid2name)
633
    if hasattr(item, "_fields"): # Is a namedtuple
634
        return [getattr(item, f) for f in fldnames]
635
636
def get_fieldnames(item):
637
    """Return fieldnames of either a namedtuple or GOEnrichmentRecord."""
638
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
639
        return item.get_prtflds_all()
640
    if hasattr(item, "_fields"): # Is a namedtuple
641
        return item._fields
642
643
# Copyright (C) 2010-2017, H Tang et al., All rights reserved.
644