Completed
Pull Request — master (#83)
by
unknown
01:34
created

compute_pvals()   A

Complexity

Conditions 2

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 0
loc 14
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
4
"""
5
python %prog study.file population.file gene-association.file
6
7
This program returns P-values for functional enrichment in a cluster of
8
study genes using Fisher's exact test, and corrected for multiple testing
9
(including Bonferroni, Holm, Sidak, and false discovery rate)
10
"""
11
12
from __future__ import absolute_import
13
14
__copyright__ = "Copyright (C) 2010-2017, H Tang et al., All rights reserved."
15
__author__ = "various"
16
17
import sys
18
import collections as cx
19
import datetime
20
from functools import partial
21
import multiprocessing
22
23
24
from goatools.multiple_testing import Methods, Bonferroni, Sidak, HolmBonferroni, FDR, calc_qval
25
from goatools.ratio import get_terms, count_terms, is_ratio_different
26
import goatools.wr_tbl as RPT
27
from goatools.pvalcalc import FisherFactory
28
from .multiprocessing_tools import p_map
29
30
31
def compute_pvals(allterms, calc_pvalue, go2studyitems, go2popitems,
32
                  study_n, pop_n):
33
34
    results = {}
35
    for term in allterms:
36
37
        study_items = go2studyitems.get(term, set())
38
        study_count = len(study_items)
39
        pop_items = go2popitems.get(term, set())
40
        pop_count = len(pop_items)
41
42
        results[term] = calc_pvalue(study_count, study_n, pop_count, pop_n)
43
44
    return results
45
46
47
class GOEnrichmentRecord(object):
48
    """Represents one result (from a single GOTerm) in the GOEnrichmentStudy
49
    """
50
    namespace2NS = cx.OrderedDict([
51
        ('biological_process', 'BP'),
52
        ('molecular_function', 'MF'),
53
        ('cellular_component', 'CC')])
54
55
    # Fields seen in every enrichment result
56
    _fldsdefprt = [
57
        "GO",
58
        "NS",
59
        "enrichment",
60
        "name",
61
        "ratio_in_study",
62
        "ratio_in_pop",
63
        "p_uncorrected",
64
        "depth",
65
        "study_count",
66
        "study_items"]
67
    _fldsdeffmt = ["%s"]*3 + ["%-30s"] + ["%d/%d"] * 2 + ["%.3g"] + ["%d"] * 2 + ["%15s"]
68
69
    _flds = set(_fldsdefprt).intersection(
70
        set(['study_items', 'study_count', 'study_n', 'pop_items', 'pop_count', 'pop_n']))
71
72
    def __init__(self, **kwargs):
73
        # Methods seen in current enrichment result
74
        self._methods = []
75
        for k, v in kwargs.items():
76
            setattr(self, k, v)
77
            if k == 'ratio_in_study':
78
                setattr(self, 'study_count', v[0])
79
                setattr(self, 'study_n', v[1])
80
            if k == 'ratio_in_pop':
81
                setattr(self, 'pop_count', v[0])
82
                setattr(self, 'pop_n', v[1])
83
        self._init_enrichment()
84
        self.goterm = None  # the reference to the GOTerm
85
86
    def get_method_name(self):
87
        """Return name of first method in the _methods list."""
88
        return self._methods[0].fieldname
89
90
    def get_pvalue(self):
91
        """Returns pval for 1st method, if it exists. Else returns uncorrected pval."""
92
        if self._methods:
93
            return getattr(self, "p_{m}".format(m=self.get_method_name()))
94
        return getattr(self, "p_uncorrected")
95
96
    def set_corrected_pval(self, nt_method, pvalue):
97
        """Add object attribute based on method name."""
98
        self._methods.append(nt_method)
99
        fieldname = "".join(["p_", nt_method.fieldname])
100
        setattr(self, fieldname, pvalue)
101
102
    def __str__(self, indent=False):
103
        field_data = [getattr(self, f, "n.a.") for f in self._fldsdefprt[:-1]] + \
104
                     [getattr(self, "p_{}".format(m.fieldname)) for m in self._methods] + \
105
                     [", ".join(sorted(getattr(self, self._fldsdefprt[-1], set())))]
106
        fldsdeffmt = self._fldsdeffmt
107
        field_formatter = fldsdeffmt[:-1] + ["%.3g"]*len(self._methods) + [fldsdeffmt[-1]]
108
        self._chk_fields(field_data, field_formatter)
109
110
        # default formatting only works for non-"n.a" data
111
        for i, f in enumerate(field_data):
112
            if f == "n.a.":
113
                field_formatter[i] = "%s"
114
115
        # print dots to show the level of the term
116
        dots = self.get_indent_dots() if indent else ""
117
        prtdata = "\t".join(a % b for (a, b) in zip(field_formatter, field_data))
118
        return "".join([dots, prtdata])
119
120
    def get_indent_dots(self):
121
        """Get a string of dots ("....") representing the level of the GO term."""
122
        return "." * self.goterm.level if self.goterm is not None else ""
123
124
    @staticmethod
125
    def _chk_fields(field_data, field_formatter):
126
        """Check that expected fields are present."""
127
        if len(field_data) == len(field_formatter):
128
            return
129
        len_dat = len(field_data)
130
        len_fmt = len(field_formatter)
131
        msg = [
132
            "FIELD DATA({d}) != FORMATTER({f})".format(d=len_dat, f=len_fmt),
133
            "DAT({N}): {D}".format(N=len_dat, D=field_data),
134
            "FMT({N}): {F}".format(N=len_fmt, F=field_formatter)]
135
        raise Exception("\n".join(msg))
136
137
    def __repr__(self):
138
        return "GOEnrichmentRecord({GO})".format(GO=self.GO)
139
140
    def set_goterm(self, goid):
141
        """Set goterm and copy GOTerm's name and namespace."""
142
        self.goterm = goid.get(self.GO, None)
143
        present = self.goterm is not None
144
        self.name = self.goterm.name if present else "n.a."
145
        self.NS = self.namespace2NS[self.goterm.namespace] if present else "XX"
146
147
    def _init_enrichment(self):
148
        """Mark as 'enriched' or 'purified'."""
149
        self.enrichment = 'e' if ((1.0 * self.study_count / self.study_n) >
150
                                  (1.0 * self.pop_count / self.pop_n)) else 'p'
151
152
    def update_remaining_fldsdefprt(self, min_ratio=None):
153
        """Finish updating self (GOEnrichmentRecord) field, is_ratio_different."""
154
        self.is_ratio_different = is_ratio_different(min_ratio, self.study_count,
155
                                                     self.study_n, self.pop_count, self.pop_n)
156
157
158
    # -------------------------------------------------------------------------------------
159
    # Methods for getting flat namedtuple values from GOEnrichmentRecord object
160
    def get_prtflds_default(self):
161
        """Get default fields."""
162
        return self._fldsdefprt[:-1] + \
163
               ["p_{M}".format(M=m.fieldname) for m in self._methods] + \
164
               [self._fldsdefprt[-1]]
165
166
    def get_prtflds_all(self):
167
        """When converting to a namedtuple, get all possible fields in their original order."""
168
        flds = []
169
        dont_add = set(['_parents', '_methods'])
170
        # Fields: GO NS enrichment name ratio_in_study ratio_in_pop p_uncorrected
171
        #         depth study_count p_sm_bonferroni p_fdr_bh study_items
172
        self._flds_append(flds, self.get_prtflds_default(), dont_add)
173
        # Fields: GO NS goterm
174
        #         ratio_in_pop pop_n pop_count pop_items name
175
        #         ratio_in_study study_n study_count study_items
176
        #         _methods enrichment p_uncorrected p_sm_bonferroni p_fdr_bh
177
        self._flds_append(flds, vars(self).keys(), dont_add)
178
        # Fields: name level is_obsolete namespace id depth parents children _parents alt_ids
179
        self._flds_append(flds, vars(self.goterm).keys(), dont_add)
180
        return flds
181
182
    @staticmethod
183
    def _flds_append(flds, addthese, dont_add):
184
        """Retain order of fields as we add them once to the list."""
185
        for fld in addthese:
186
            if fld not in flds and fld not in dont_add:
187
                flds.append(fld)
188
189
    def get_field_values(self, fldnames, rpt_fmt=True):
190
        """Get flat namedtuple fields for one GOEnrichmentRecord."""
191
        row = []
192
        # Loop through each user field desired
193
        for fld in fldnames:
194
            # 1. Check the GOEnrichmentRecord's attributes
195
            val = getattr(self, fld, None)
196
            if val is not None:
197
                if rpt_fmt:
198
                    val = self._get_rpt_fmt(fld, val)
199
                row.append(val)
200
            else:
201
                # 2. Check the GO object for the field
202
                val = getattr(self.goterm, fld, None)
203
                if rpt_fmt:
204
                    val = self._get_rpt_fmt(fld, val)
205
                if val is not None:
206
                    row.append(val)
207
                else:
208
                    # 3. Field not found, raise Exception
209
                    self._err_fld(fld, fldnames, row)
210
            if rpt_fmt:
211
                assert not isinstance(val, list), \
212
                   "UNEXPECTED LIST: FIELD({F}) VALUE({V}) FMT({P})".format(
213
                       P=rpt_fmt, F=fld, V=val)
214
        return row
215
216
    @staticmethod
217
    def _get_rpt_fmt(fld, val):
218
        """Return values in a format amenable to printing in a table."""
219
        if fld.startswith("ratio_"):
220
            return "{N}/{TOT}".format(N=val[0], TOT=val[1])
221
        elif fld in set(['study_items', 'pop_items', 'alt_ids']):
222
            return ", ".join([str(v) for v in sorted(val)])
223
        return val
224
225
    def _err_fld(self, fld, fldnames):
226
        """Unrecognized field. Print detailed Failure message."""
227
        msg = ['ERROR. UNRECOGNIZED FIELD({F})'.format(F=fld)]
228
        actual_flds = set(self.get_prtflds_default() + self.goterm.__dict__.keys())
229
        bad_flds = set(fldnames).difference(set(actual_flds))
230
        if bad_flds:
231
            msg.append("\nGOEA RESULT FIELDS: {}".format(" ".join(self._fldsdefprt)))
232
            msg.append("GO FIELDS: {}".format(" ".join(self.goterm.__dict__.keys())))
233
            msg.append("\nFATAL: {N} UNEXPECTED FIELDS({F})\n".format(
234
                N=len(bad_flds), F=" ".join(bad_flds)))
235
            msg.append("  {N} User-provided fields:".format(N=len(fldnames)))
236
            for idx, fld in enumerate(fldnames, 1):
237
                mrk = "ERROR -->" if fld in bad_flds else ""
238
                msg.append("  {M:>9} {I:>2}) {F}".format(M=mrk, I=idx, F=fld))
239
        raise Exception("\n".join(msg))
240
241
242
class GOEnrichmentStudy(object):
243
    """Runs Fisher's exact test, as well as multiple corrections
244
    """
245
    # Default Excel table column widths for GOEA results
246
    default_fld2col_widths = {
247
        'NS'        :  3,
248
        'GO'        : 12,
249
        'level'     :  3,
250
        'enrichment':  1,
251
        'name'      : 60,
252
        'ratio_in_study':  8,
253
        'ratio_in_pop'  : 12,
254
        'study_items'   : 15,
255
    }
256
257
    def __init__(self, pop, assoc, obo_dag, propagate_counts=True, alpha=.05, methods=None, **kws):
258
        self.log = kws['log'] if 'log' in kws else sys.stdout
259
        self.n_cores = kws['n_cores'] if 'n_cores' in kws else None
260
        self._run_multitest = {
261
            'local':lambda iargs: self._run_multitest_local(iargs),
262
            'statsmodels':lambda iargs: self._run_multitest_statsmodels(iargs)}
263
        self.pop = pop
264
        self.pop_n = len(pop)
265
        self.assoc = assoc
266
        self.obo_dag = obo_dag
267
        self.alpha = alpha
268
        if methods is None:
269
            methods = ["bonferroni", "sidak", "holm"]
270
        self.methods = Methods(methods)
271
        self.pval_obj = FisherFactory(**kws).pval_obj
272
273
        if propagate_counts:
274
            sys.stderr.write("Propagating term counts to parents ..\n")
275
            obo_dag.update_association(assoc)
276
        self.go2popitems = get_terms("population", pop, assoc, obo_dag, self.log)
277
278
    def run_study(self, study, **kws):
279
        """Run Gene Ontology Enrichment Study (GOEA) on study ids."""
280
        # Key-word arguments:
281
        methods = Methods(kws['methods']) if 'methods' in kws else self.methods
282
        alpha = kws['alpha'] if 'alpha' in kws else self.alpha
283
        log = kws['log'] if 'log' in kws else self.log
284
        # Calculate uncorrected pvalues
285
        results = self._get_pval_uncorr(study, log)
286
        if not results:
287
            return []
288
289
        # Do multipletest corrections on uncorrected pvalues and update results
290
        self._run_multitest_corr(results, methods, alpha, study, log)
291
292
        for rec in results:
293
            # get go term for name and level
294
            rec.set_goterm(self.obo_dag)
295
296
        # 'keep_if' can be used to keep only significant GO terms. Example:
297
        #     >>> keep_if = lambda nt: nt.p_fdr_bh < 0.05 # if results are significant
298
        #     >>> goea_results = goeaobj.run_study(geneids_study, keep_if=keep_if)
299
        if 'keep_if' in kws:
300
            keep_if = kws['keep_if']
301
            results = [r for r in results if keep_if(r)]
302
303
        # Default sort order: First, sort by BP, MF, CC. Second, sort by pval
304
        results.sort(key=lambda r: [r.NS, r.p_uncorrected])
305
306
        if log is not None:
307
            log.write("  {MSG}\n".format(MSG="\n  ".join(self.get_results_msg(results, study))))
308
309
        return results # list of GOEnrichmentRecord objects
310
311
    def run_study_nts(self, study, **kws):
312
        """Run GOEA on study ids. Return results as a list of namedtuples."""
313
        goea_results = self.run_study(study, **kws)
314
        return get_goea_nts_all(goea_results)
315
316
    def get_results_msg(self, results, study):
317
        """Return summary for GOEA results."""
318
        # To convert msg list to string: "\n".join(msg)
319
        msg = []
320
        if results:
321
            stu_items, num_gos_stu = self.get_item_cnt(results, "study_items")
322
            pop_items, num_gos_pop = self.get_item_cnt(results, "pop_items")
323
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} study items".format(
324
                N=len(stu_items), NT=len(set(study)), M=num_gos_stu))
325
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} population items".format(
326
                N=len(pop_items), NT=self.pop_n, M=num_gos_pop))
327
        return msg
328
329
    def _get_pval_uncorr(self, study, log=sys.stdout):
330
        """Calculate the uncorrected pvalues for study items."""
331
        if log is not None:
332
            log.write("Calculating uncorrected p-values using {PFNC}\n".format(PFNC=self.pval_obj.name))
333
        go2studyitems = get_terms("study", study, self.assoc, self.obo_dag, log)
334
        pop_n, study_n = self.pop_n, len(study)
335
        allterms = set(go2studyitems.keys()).union(set(self.go2popitems.keys()))
336
337
        # -1 avoids freezing of the machine:
338
        if self.n_cores is None:
339
            n_cores = multiprocessing.cpu_count() - 1
340
        else:
341
            n_cores = self.n_cores
342
343
        log.write("use {} cores for computing pvalues\n".format(n_cores))
344
345
        allterms = list(allterms)
346
        fragments = [allterms[i::n_cores] for i in range(n_cores)]
347
348
        # bind arguments, so that remote_func only depends on fragment of terms to process:
349
        calc_pvalue = self.pval_obj.calc_pvalue
350
        remote_func = partial(compute_pvals, calc_pvalue=calc_pvalue, go2studyitems=go2studyitems,
351
                              go2popitems=self.go2popitems, study_n=study_n, pop_n=pop_n)
352
353
        # if self.pval_obj.log is a file handle, which we can not serialize, we could not transfer
354
        # self.pval_obj.calc_pvalue to another python process with multiprocessing.  therefore we
355
        # "patch" the object which will later be restored again.
356
        old = self.pval_obj.log
357
        self.pval_obj.log = None
358
        p = multiprocessing.Pool(n_cores)
359
        try:
360
            all_p_values = p_map(p, remote_func, fragments)
361
        finally:
362
            # restore patched file handle
363
            self.pval_obj.log = old
364
365
        results = []
366
367
        for p_values_map in all_p_values:
368
369
            for term, p_value in p_values_map.items():
370
371
                study_items = go2studyitems.get(term, set())
372
                study_count = len(study_items)
373
                pop_items = self.go2popitems.get(term, set())
374
                pop_count = len(pop_items)
375
376
                one_record = GOEnrichmentRecord(
377
                    GO=term,
378
                    p_uncorrected=p_value,
379
                    study_items=study_items,
380
                    pop_items=pop_items,
381
                    ratio_in_study=(study_count, study_n),
382
                    ratio_in_pop=(pop_count, pop_n))
383
384
                results.append(one_record)
385
386
        return results
387
388
    def _run_multitest_corr(self, results, usr_methods, alpha, study, log):
389
        """Do multiple-test corrections on uncorrected pvalues."""
390
        assert 0 < alpha < 1, "Test-wise alpha must fall between (0, 1)"
391
        pvals = [r.p_uncorrected for r in results]
392
        NtMt = cx.namedtuple("NtMt", "results pvals alpha nt_method study")
393
394
        for nt_method in usr_methods:
395
            ntmt = NtMt(results, pvals, alpha, nt_method, study)
396
            if log is not None:
397
                log.write("Running multitest correction: {MSRC} {METHOD}\n".format(
398
                    MSRC=ntmt.nt_method.source, METHOD=ntmt.nt_method.method))
399
            self._run_multitest[nt_method.source](ntmt)
400
401
    def _run_multitest_statsmodels(self, ntmt):
402
        """Use multitest mthods that have been implemented in statsmodels."""
403
        # Only load statsmodels if it is used
404
        multipletests = self.methods.get_statsmodels_multipletests()
405
        results = multipletests(ntmt.pvals, ntmt.alpha, ntmt.nt_method.method)
406
        pvals_corrected = results[1] # reject_lst, pvals_corrected, alphacSidak, alphacBonf
407
        self._update_pvalcorr(ntmt, pvals_corrected)
408
409
    def _run_multitest_local(self, ntmt):
410
        """Use multitest mthods that have been implemented locally."""
411
        corrected_pvals = None
412
        method = ntmt.nt_method.method
413
        if method == "bonferroni":
414
            corrected_pvals = Bonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
415
        elif method == "sidak":
416
            corrected_pvals = Sidak(ntmt.pvals, ntmt.alpha).corrected_pvals
417
        elif method == "holm":
418
            corrected_pvals = HolmBonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
419
        elif method == "fdr":
420
            # get the empirical p-value distributions for FDR
421
            term_pop = getattr(self, 'term_pop', None)
422
            if term_pop is None:
423
                term_pop = count_terms(self.pop, self.assoc, self.obo_dag)
424
            p_val_distribution = calc_qval(len(ntmt.study),
425
                                           self.pop_n,
426
                                           self.pop, self.assoc,
427
                                           term_pop, self.obo_dag)
428
            corrected_pvals = FDR(p_val_distribution,
429
                                  ntmt.results, ntmt.alpha).corrected_pvals
430
431
        self._update_pvalcorr(ntmt, corrected_pvals)
432
433
    @staticmethod
434
    def _update_pvalcorr(ntmt, corrected_pvals):
435
        """Add data members to store multiple test corrections."""
436
        if corrected_pvals is None:
437
            return
438
        for rec, val in zip(ntmt.results, corrected_pvals):
439
            rec.set_corrected_pval(ntmt.nt_method, val)
440
441
    # Methods for writing results into tables: text, tab-separated, Excel spreadsheets
442
    def wr_txt(self, fout_txt, goea_results, prtfmt=None, **kws):
443
        """Print GOEA results to text file."""
444
        if not goea_results:
445
            sys.stdout.write("      0 GOEA results. NOT WRITING {FOUT}\n".format(FOUT=fout_txt))
446
            return
447
        with open(fout_txt, 'w') as prt:
448
            data_nts = self.prt_txt(prt, goea_results, prtfmt, **kws)
449
            log = self.log if self.log is not None else sys.stdout
450
            log.write("  {N:>5} GOEA results for {CUR:5} study items. WROTE: {F}\n".format(
451
                N=len(data_nts),
452
                CUR=len(get_study_items(goea_results)),
453
                F=fout_txt))
454
455
    def prt_txt(self, prt, goea_results, prtfmt=None, **kws):
456
        """Print GOEA results in text format."""
457
        if prtfmt is None:
458
            prtfmt = "{GO} {NS} {p_uncorrected:5.2e} {study_count:>5} {name}\n"
459
        prtfmt = self.adjust_prtfmt(prtfmt)
460
        prt_flds = RPT.get_fmtflds(prtfmt)
461
        data_nts = get_goea_nts_prt(goea_results, prt_flds, **kws)
462
        RPT.prt_txt(prt, data_nts, prtfmt, prt_flds, **kws)
463
        return data_nts
464
465
    def wr_xlsx(self, fout_xlsx, goea_results, **kws):
466
        """Write a xlsx file."""
467
        # kws: prt_if indent
468
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
469
        xlsx_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
470
        if 'fld2col_widths' not in kws:
471
            kws['fld2col_widths'] = {f:self.default_fld2col_widths.get(f, 8) for f in prt_flds}
472
        RPT.wr_xlsx(fout_xlsx, xlsx_data, **kws)
473
474
    def wr_tsv(self, fout_tsv, goea_results, **kws):
475
        """Write tab-separated table data to file"""
476
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
477
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
478
        RPT.wr_tsv(fout_tsv, tsv_data, **kws)
479
480
    def prt_tsv(self, prt, goea_results, **kws):
481
        """Write tab-separated table data"""
482
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
483
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
484
        RPT.prt_tsv(prt, tsv_data, prt_flds, **kws)
485
486
    @staticmethod
487
    def adjust_prtfmt(prtfmt):
488
        """Adjust format_strings for legal values."""
489
        prtfmt = prtfmt.replace("{p_holm-sidak", "{p_holm_sidak")
490
        prtfmt = prtfmt.replace("{p_simes-hochberg", "{p_simes_hochberg")
491
        return prtfmt
492
493
    @staticmethod
494
    def get_NS2nts(results, fldnames=None, **kws):
495
        """Get namedtuples of GOEA results, split into BP, MF, CC."""
496
        NS2nts = cx.defaultdict(list)
497
        nts = get_goea_nts_all(results, fldnames, **kws)
498
        for nt in nts:
499
            NS2nts[nt.NS].append(nt)
500
        return NS2nts
501
502
    @staticmethod
503
    def get_item_cnt(results, attrname="study_items"):
504
        """Get all study or population items (e.g., geneids)."""
505
        items = set()
506
        go_cnt = 0
507
        for rec in results:
508
            if hasattr(rec, attrname):
509
                items_cur = getattr(rec, attrname)
510
                # Only count GO term if there are items in the set.
511
                if len(items_cur) != 0:
512
                    items |= items_cur
513
                    go_cnt += 1
514
        return items, go_cnt
515
516
    @staticmethod
517
    def get_prtflds_default(results):
518
        """Get default fields names. Used in printing GOEA results.
519
520
           Researchers can control which fields they want to print in the GOEA results
521
           or they can use the default fields.
522
        """
523
        if results:
524
            return results[0].get_prtflds_default()
525
        return []
526
527
    @staticmethod
528
    def print_summary(results, min_ratio=None, indent=False, pval=0.05):
529
        """Print summary."""
530
        from .version import __version__ as version
531
532
        # Header contains provenance and parameters
533
        print("# Generated by GOATOOLS v{0} ({1})".format(version, datetime.date.today()))
534
        print("# min_ratio={0} pval={1}".format(min_ratio, pval))
535
536
        # field names for output
537
        if results:
538
            print("\t".join(GOEnrichmentStudy.get_prtflds_default(results)))
539
540
        for rec in results:
541
            # calculate some additional statistics
542
            # (over_under, is_ratio_different)
543
            rec.update_remaining_fldsdefprt(min_ratio=min_ratio)
544
545
            if pval is not None and rec.p_uncorrected >= pval:
546
                continue
547
548
            if rec.is_ratio_different:
549
                print(rec.__str__(indent=indent))
550
551
    def wr_py_goea_results(self, fout_py, goea_results, **kws):
552
        """Save GOEA results into Python package containing list of namedtuples."""
553
        var_name = kws.get("var_name", "goea_results")
554
        docstring = kws.get("docstring", "")
555
        sortby = kws.get("sortby", None)
556
        if goea_results:
557
            from goatools.nt_utils import wr_py_nts
558
            nts_goea = goea_results
559
            # If list has GOEnrichmentRecords or verbose namedtuples, exclude some fields.
560
            if hasattr(goea_results[0], "_fldsdefprt") or hasattr(goea_results[0], 'goterm'):
561
                # Exclude some attributes from the namedtuple when saving results
562
                # to a Python file because the information is redundant or verbose.
563
                nts_goea = get_goea_nts_prt(goea_results)
564
            docstring = "\n".join([docstring, "# {VER}\n\n".format(VER=self.obo_dag.version)])
565
            assert hasattr(nts_goea[0], '_fields')
566
            if sortby is None:
567
                sortby = lambda nt: getattr(nt, 'p_uncorrected')
568
            nts_goea = sorted(nts_goea, key=sortby)
569
            wr_py_nts(fout_py, nts_goea, docstring, var_name)
570
571
def get_study_items(goea_results):
572
    """Get all study items (e.g., geneids)."""
573
    study_items = set()
574
    for rec in goea_results:
575
        study_items |= rec.study_items
576
    return study_items
577
578
def get_goea_nts_prt(goea_results, fldnames=None, **usr_kws):
579
    """Return list of namedtuples removing fields which are redundant or verbose."""
580
    kws = usr_kws.copy()
581
    if 'not_fldnames' not in kws:
582
        kws['not_fldnames'] = ['goterm', 'parents', 'children', 'id']
583
    if 'rpt_fmt' not in kws:
584
        kws['rpt_fmt'] = True
585
    return get_goea_nts_all(goea_results, fldnames, **kws)
586
587
def get_goea_nts_all(goea_results, fldnames=None, **kws):
588
    """Get namedtuples containing user-specified (or default) data from GOEA results.
589
590
        Reformats data from GOEnrichmentRecord objects into lists of
591
        namedtuples so the generic table writers may be used.
592
    """
593
    data_nts = [] # A list of namedtuples containing GOEA results
594
    if not goea_results:
595
        return data_nts
596
    keep_if = kws.get('keep_if', None)
597
    rpt_fmt = kws.get('rpt_fmt', False)
598
    indent = kws.get('indent', False)
599
    # I. FIELD (column) NAMES
600
    not_fldnames = kws.get('not_fldnames', None)
601
    if fldnames is None:
602
        fldnames = get_fieldnames(goea_results[0])
603
    # Ia. Explicitly exclude specific fields from named tuple
604
    if not_fldnames is not None:
605
        fldnames = [f for f in fldnames if f not in not_fldnames]
606
    nttyp = cx.namedtuple("NtGoeaResults", " ".join(fldnames))
607
    goid_idx = fldnames.index("GO") if 'GO' in fldnames else None
608
    # II. Loop through GOEA results stored in a GOEnrichmentRecord object
609
    for goerec in goea_results:
610
        vals = get_field_values(goerec, fldnames, rpt_fmt)
611
        if indent:
612
            vals[goid_idx] = "".join([goerec.get_indent_dots(), vals[goid_idx]])
613
        ntobj = nttyp._make(vals)
614
        if keep_if is None or keep_if(ntobj):
615
            data_nts.append(ntobj)
616
    return data_nts
617
618
def get_field_values(item, fldnames, rpt_fmt=None):
619
    """Return fieldnames and values of either a namedtuple or GOEnrichmentRecord."""
620
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
621
        return item.get_field_values(fldnames, rpt_fmt)
622
    if hasattr(item, "_fields"): # Is a namedtuple
623
        return [getattr(item, f) for f in fldnames]
624
625
def get_fieldnames(item):
626
    """Return fieldnames of either a namedtuple or GOEnrichmentRecord."""
627
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
628
        return item.get_prtflds_all()
629
    if hasattr(item, "_fields"): # Is a namedtuple
630
        return item._fields
631
632
# Copyright (C) 2010-2017, H Tang et al., All rights reserved.
633