Completed
Pull Request — master (#83)
by
unknown
01:23
created

compute_pvals()   A

Complexity

Conditions 2

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 0
loc 14
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
4
"""
5
python %prog study.file population.file gene-association.file
6
7
This program returns P-values for functional enrichment in a cluster of
8
study genes using Fisher's exact test, and corrected for multiple testing
9
(including Bonferroni, Holm, Sidak, and false discovery rate)
10
"""
11
12
from __future__ import absolute_import
13
14
__copyright__ = "Copyright (C) 2010-2017, H Tang et al., All rights reserved."
15
__author__ = "various"
16
17
import sys
18
import collections as cx
19
import datetime
20
from functools import partial
21
import multiprocessing
22
23
24
from goatools.multiple_testing import Methods, Bonferroni, Sidak, HolmBonferroni, FDR, calc_qval
25
from goatools.ratio import get_terms, count_terms, is_ratio_different
26
import goatools.wr_tbl as RPT
27
from goatools.pvalcalc import FisherFactory
28
from .multiprocessing_tools import p_map
29
30
31
def compute_pvals(allterms, calc_pvalue, go2studyitems, go2popitems,
32
                  study_n, pop_n):
33
34
    results = {}
35
    for term in allterms:
36
37
        study_items = go2studyitems.get(term, set())
38
        study_count = len(study_items)
39
        pop_items = go2popitems.get(term, set())
40
        pop_count = len(pop_items)
41
42
        results[term] = calc_pvalue(study_count, study_n, pop_count, pop_n)
43
44
    return results
45
46
47
class GOEnrichmentRecord(object):
48
    """Represents one result (from a single GOTerm) in the GOEnrichmentStudy
49
    """
50
    namespace2NS = cx.OrderedDict([
51
        ('biological_process', 'BP'),
52
        ('molecular_function', 'MF'),
53
        ('cellular_component', 'CC')])
54
55
    # Fields seen in every enrichment result
56
    _fldsdefprt = [
57
        "GO",
58
        "NS",
59
        "enrichment",
60
        "name",
61
        "ratio_in_study",
62
        "ratio_in_pop",
63
        "p_uncorrected",
64
        "depth",
65
        "study_count",
66
        "study_items"]
67
    _fldsdeffmt = ["%s"]*3 + ["%-30s"] + ["%d/%d"] * 2 + ["%.3g"] + ["%d"] * 2 + ["%15s"]
68
69
    _flds = set(_fldsdefprt).intersection(
70
        set(['study_items', 'study_count', 'study_n', 'pop_items', 'pop_count', 'pop_n']))
71
72
    def __init__(self, **kwargs):
73
        # Methods seen in current enrichment result
74
        self._methods = []
75
        for k, v in kwargs.items():
76
            setattr(self, k, v)
77
            if k == 'ratio_in_study':
78
                setattr(self, 'study_count', v[0])
79
                setattr(self, 'study_n', v[1])
80
            if k == 'ratio_in_pop':
81
                setattr(self, 'pop_count', v[0])
82
                setattr(self, 'pop_n', v[1])
83
        self._init_enrichment()
84
        self.goterm = None  # the reference to the GOTerm
85
86
    def get_method_name(self):
87
        """Return name of first method in the _methods list."""
88
        return self._methods[0].fieldname
89
90
    def get_pvalue(self):
91
        """Returns pval for 1st method, if it exists. Else returns uncorrected pval."""
92
        if self._methods:
93
            return getattr(self, "p_{m}".format(m=self.get_method_name()))
94
        return getattr(self, "p_uncorrected")
95
96
    def set_corrected_pval(self, nt_method, pvalue):
97
        """Add object attribute based on method name."""
98
        self._methods.append(nt_method)
99
        fieldname = "".join(["p_", nt_method.fieldname])
100
        setattr(self, fieldname, pvalue)
101
102
    def __str__(self, indent=False):
103
        field_data = [getattr(self, f, "n.a.") for f in self._fldsdefprt[:-1]] + \
104
                     [getattr(self, "p_{}".format(m.fieldname)) for m in self._methods] + \
105
                     [", ".join(sorted(getattr(self, self._fldsdefprt[-1], set())))]
106
        fldsdeffmt = self._fldsdeffmt
107
        field_formatter = fldsdeffmt[:-1] + ["%.3g"]*len(self._methods) + [fldsdeffmt[-1]]
108
        self._chk_fields(field_data, field_formatter)
109
110
        # default formatting only works for non-"n.a" data
111
        for i, f in enumerate(field_data):
112
            if f == "n.a.":
113
                field_formatter[i] = "%s"
114
115
        # print dots to show the level of the term
116
        dots = self.get_indent_dots() if indent else ""
117
        prtdata = "\t".join(a % b for (a, b) in zip(field_formatter, field_data))
118
        return "".join([dots, prtdata])
119
120
    def get_indent_dots(self):
121
        """Get a string of dots ("....") representing the level of the GO term."""
122
        return "." * self.goterm.level if self.goterm is not None else ""
123
124
    @staticmethod
125
    def _chk_fields(field_data, field_formatter):
126
        """Check that expected fields are present."""
127
        if len(field_data) == len(field_formatter):
128
            return
129
        len_dat = len(field_data)
130
        len_fmt = len(field_formatter)
131
        msg = [
132
            "FIELD DATA({d}) != FORMATTER({f})".format(d=len_dat, f=len_fmt),
133
            "DAT({N}): {D}".format(N=len_dat, D=field_data),
134
            "FMT({N}): {F}".format(N=len_fmt, F=field_formatter)]
135
        raise Exception("\n".join(msg))
136
137
    def __repr__(self):
138
        return "GOEnrichmentRecord({GO})".format(GO=self.GO)
139
140
    def set_goterm(self, goid):
141
        """Set goterm and copy GOTerm's name and namespace."""
142
        self.goterm = goid.get(self.GO, None)
143
        present = self.goterm is not None
144
        self.name = self.goterm.name if present else "n.a."
145
        self.NS = self.namespace2NS[self.goterm.namespace] if present else "XX"
146
147
    def _init_enrichment(self):
148
        """Mark as 'enriched' or 'purified'."""
149
        self.enrichment = 'e' if ((1.0 * self.study_count / self.study_n) >
150
                                  (1.0 * self.pop_count / self.pop_n)) else 'p'
151
152
    def update_remaining_fldsdefprt(self, min_ratio=None):
153
        """Finish updating self (GOEnrichmentRecord) field, is_ratio_different."""
154
        self.is_ratio_different = is_ratio_different(min_ratio, self.study_count,
155
                                                     self.study_n, self.pop_count, self.pop_n)
156
157
158
    # -------------------------------------------------------------------------------------
159
    # Methods for getting flat namedtuple values from GOEnrichmentRecord object
160
    def get_prtflds_default(self):
161
        """Get default fields."""
162
        return self._fldsdefprt[:-1] + \
163
               ["p_{M}".format(M=m.fieldname) for m in self._methods] + \
164
               [self._fldsdefprt[-1]]
165
166
    def get_prtflds_all(self):
167
        """When converting to a namedtuple, get all possible fields in their original order."""
168
        flds = []
169
        dont_add = set(['_parents', '_methods'])
170
        # Fields: GO NS enrichment name ratio_in_study ratio_in_pop p_uncorrected
171
        #         depth study_count p_sm_bonferroni p_fdr_bh study_items
172
        self._flds_append(flds, self.get_prtflds_default(), dont_add)
173
        # Fields: GO NS goterm
174
        #         ratio_in_pop pop_n pop_count pop_items name
175
        #         ratio_in_study study_n study_count study_items
176
        #         _methods enrichment p_uncorrected p_sm_bonferroni p_fdr_bh
177
        self._flds_append(flds, vars(self).keys(), dont_add)
178
        # Fields: name level is_obsolete namespace id depth parents children _parents alt_ids
179
        self._flds_append(flds, vars(self.goterm).keys(), dont_add)
180
        return flds
181
182
    @staticmethod
183
    def _flds_append(flds, addthese, dont_add):
184
        """Retain order of fields as we add them once to the list."""
185
        for fld in addthese:
186
            if fld not in flds and fld not in dont_add:
187
                flds.append(fld)
188
189
    def get_field_values(self, fldnames, rpt_fmt=True):
190
        """Get flat namedtuple fields for one GOEnrichmentRecord."""
191
        row = []
192
        # Loop through each user field desired
193
        for fld in fldnames:
194
            # 1. Check the GOEnrichmentRecord's attributes
195
            val = getattr(self, fld, None)
196
            if val is not None:
197
                if rpt_fmt:
198
                    val = self._get_rpt_fmt(fld, val)
199
                row.append(val)
200
            else:
201
                # 2. Check the GO object for the field
202
                val = getattr(self.goterm, fld, None)
203
                if rpt_fmt:
204
                    val = self._get_rpt_fmt(fld, val)
205
                if val is not None:
206
                    row.append(val)
207
                else:
208
                    # 3. Field not found, raise Exception
209
                    self._err_fld(fld, fldnames, row)
210
            if rpt_fmt:
211
                assert not isinstance(val, list), \
212
                   "UNEXPECTED LIST: FIELD({F}) VALUE({V}) FMT({P})".format(
213
                       P=rpt_fmt, F=fld, V=val)
214
        return row
215
216
    @staticmethod
217
    def _get_rpt_fmt(fld, val):
218
        """Return values in a format amenable to printing in a table."""
219
        if fld.startswith("ratio_"):
220
            return "{N}/{TOT}".format(N=val[0], TOT=val[1])
221
        elif fld in set(['study_items', 'pop_items', 'alt_ids']):
222
            return ", ".join([str(v) for v in sorted(val)])
223
        return val
224
225
    def _err_fld(self, fld, fldnames):
226
        """Unrecognized field. Print detailed Failure message."""
227
        msg = ['ERROR. UNRECOGNIZED FIELD({F})'.format(F=fld)]
228
        actual_flds = set(self.get_prtflds_default() + self.goterm.__dict__.keys())
229
        bad_flds = set(fldnames).difference(set(actual_flds))
230
        if bad_flds:
231
            msg.append("\nGOEA RESULT FIELDS: {}".format(" ".join(self._fldsdefprt)))
232
            msg.append("GO FIELDS: {}".format(" ".join(self.goterm.__dict__.keys())))
233
            msg.append("\nFATAL: {N} UNEXPECTED FIELDS({F})\n".format(
234
                N=len(bad_flds), F=" ".join(bad_flds)))
235
            msg.append("  {N} User-provided fields:".format(N=len(fldnames)))
236
            for idx, fld in enumerate(fldnames, 1):
237
                mrk = "ERROR -->" if fld in bad_flds else ""
238
                msg.append("  {M:>9} {I:>2}) {F}".format(M=mrk, I=idx, F=fld))
239
        raise Exception("\n".join(msg))
240
241
242
class GOEnrichmentStudy(object):
243
    """Runs Fisher's exact test, as well as multiple corrections
244
    """
245
    # Default Excel table column widths for GOEA results
246
    default_fld2col_widths = {
247
        'NS'        :  3,
248
        'GO'        : 12,
249
        'level'     :  3,
250
        'enrichment':  1,
251
        'name'      : 60,
252
        'ratio_in_study':  8,
253
        'ratio_in_pop'  : 12,
254
        'study_items'   : 15,
255
    }
256
257
    def __init__(self, pop, assoc, obo_dag, propagate_counts=True, alpha=.05, methods=None, **kws):
258
        self.log = kws['log'] if 'log' in kws else sys.stdout
259
        self.n_cores = kws['n_cores'] if 'n_cores' in kws else None
260
        self._run_multitest = {
261
            'local':lambda iargs: self._run_multitest_local(iargs),
262
            'statsmodels':lambda iargs: self._run_multitest_statsmodels(iargs)}
263
        self.pop = pop
264
        self.pop_n = len(pop)
265
        self.assoc = assoc
266
        self.obo_dag = obo_dag
267
        self.alpha = alpha
268
        if methods is None:
269
            methods = ["bonferroni", "sidak", "holm"]
270
        self.methods = Methods(methods)
271
        self.pval_obj = FisherFactory(**kws).pval_obj
272
273
        if propagate_counts:
274
            sys.stderr.write("Propagating term counts to parents ..\n")
275
            obo_dag.update_association(assoc)
276
        self.go2popitems = get_terms("population", pop, assoc, obo_dag, self.log)
277
278
    def run_study(self, study, **kws):
279
        """Run Gene Ontology Enrichment Study (GOEA) on study ids."""
280
        # Key-word arguments:
281
        methods = Methods(kws['methods']) if 'methods' in kws else self.methods
282
        alpha = kws['alpha'] if 'alpha' in kws else self.alpha
283
        log = kws['log'] if 'log' in kws else self.log
284
        # Calculate uncorrected pvalues
285
        results = self._get_pval_uncorr(study, log)
286
        if not results:
287
            return []
288
289
        # Do multipletest corrections on uncorrected pvalues and update results
290
        self._run_multitest_corr(results, methods, alpha, study, log)
291
292
        for rec in results:
293
            # get go term for name and level
294
            rec.set_goterm(self.obo_dag)
295
296
        # 'keep_if' can be used to keep only significant GO terms. Example:
297
        #     >>> keep_if = lambda nt: nt.p_fdr_bh < 0.05 # if results are significant
298
        #     >>> goea_results = goeaobj.run_study(geneids_study, keep_if=keep_if)
299
        if 'keep_if' in kws:
300
            keep_if = kws['keep_if']
301
            results = [r for r in results if keep_if(r)]
302
303
        # Default sort order: First, sort by BP, MF, CC. Second, sort by pval
304
        results.sort(key=lambda r: [r.NS, r.p_uncorrected])
305
306
        if log is not None:
307
            log.write("  {MSG}\n".format(MSG="\n  ".join(self.get_results_msg(results, study))))
308
309
        return results # list of GOEnrichmentRecord objects
310
311
    def run_study_nts(self, study, **kws):
312
        """Run GOEA on study ids. Return results as a list of namedtuples."""
313
        goea_results = self.run_study(study, **kws)
314
        return get_goea_nts_all(goea_results)
315
316
    def get_results_msg(self, results, study):
317
        """Return summary for GOEA results."""
318
        # To convert msg list to string: "\n".join(msg)
319
        msg = []
320
        if results:
321
            stu_items, num_gos_stu = self.get_item_cnt(results, "study_items")
322
            pop_items, num_gos_pop = self.get_item_cnt(results, "pop_items")
323
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} study items".format(
324
                N=len(stu_items), NT=len(set(study)), M=num_gos_stu))
325
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} population items".format(
326
                N=len(pop_items), NT=self.pop_n, M=num_gos_pop))
327
        return msg
328
329
    def _get_pval_uncorr(self, study, log=sys.stdout):
330
        """Calculate the uncorrected pvalues for study items."""
331
        if log is not None:
332
            log.write("Calculating uncorrected p-values using {PFNC}\n".format(PFNC=self.pval_obj.name))
333
        go2studyitems = get_terms("study", study, self.assoc, self.obo_dag, log)
334
        pop_n, study_n = self.pop_n, len(study)
335
        allterms = set(go2studyitems.keys()).union(set(self.go2popitems.keys()))
336
337
        # -1 avoids freezing of the machine:
338
        if self.n_cores is None:
339
            n_cores = multiprocessing.cpu_count() - 1
340
        else:
341
            n_cores = self.n_cores
342
343
        allterms = list(allterms)
344
        fragments = [allterms[i::n_cores] for i in range(n_cores)]
345
346
        # bind arguments, so that remote_func only depends on fragment of terms to process:
347
        calc_pvalue = self.pval_obj.calc_pvalue
348
        remote_func = partial(compute_pvals, calc_pvalue=calc_pvalue, go2studyitems=go2studyitems,
349
                              go2popitems=self.go2popitems, study_n=study_n, pop_n=pop_n)
350
351
        # if self.pval_obj.log is a file handle, which we can not serialize, we could not transfer
352
        # self.pval_obj.calc_pvalue to another python process with multiprocessing.  therefore we
353
        # "patch" the object which will later be restored again.
354
        old = self.pval_obj.log
355
        self.pval_obj.log = None
356
        p = multiprocessing.Pool(n_cores)
357
        try:
358
            all_p_values = p_map(p, remote_func, fragments)
359
        finally:
360
            # restore patched file handle
361
            self.pval_obj.log = old
362
363
        results = []
364
365
        for p_values_map in all_p_values:
366
367
            for term, p_value in p_values_map.items():
368
369
                study_items = go2studyitems.get(term, set())
370
                study_count = len(study_items)
371
                pop_items = self.go2popitems.get(term, set())
372
                pop_count = len(pop_items)
373
374
                one_record = GOEnrichmentRecord(
375
                    GO=term,
376
                    p_uncorrected=p_value,
377
                    study_items=study_items,
378
                    pop_items=pop_items,
379
                    ratio_in_study=(study_count, study_n),
380
                    ratio_in_pop=(pop_count, pop_n))
381
382
                results.append(one_record)
383
384
        return results
385
386
    def _run_multitest_corr(self, results, usr_methods, alpha, study, log):
387
        """Do multiple-test corrections on uncorrected pvalues."""
388
        assert 0 < alpha < 1, "Test-wise alpha must fall between (0, 1)"
389
        pvals = [r.p_uncorrected for r in results]
390
        NtMt = cx.namedtuple("NtMt", "results pvals alpha nt_method study")
391
392
        for nt_method in usr_methods:
393
            ntmt = NtMt(results, pvals, alpha, nt_method, study)
394
            if log is not None:
395
                log.write("Running multitest correction: {MSRC} {METHOD}\n".format(
396
                    MSRC=ntmt.nt_method.source, METHOD=ntmt.nt_method.method))
397
            self._run_multitest[nt_method.source](ntmt)
398
399
    def _run_multitest_statsmodels(self, ntmt):
400
        """Use multitest mthods that have been implemented in statsmodels."""
401
        # Only load statsmodels if it is used
402
        multipletests = self.methods.get_statsmodels_multipletests()
403
        results = multipletests(ntmt.pvals, ntmt.alpha, ntmt.nt_method.method)
404
        pvals_corrected = results[1] # reject_lst, pvals_corrected, alphacSidak, alphacBonf
405
        self._update_pvalcorr(ntmt, pvals_corrected)
406
407
    def _run_multitest_local(self, ntmt):
408
        """Use multitest mthods that have been implemented locally."""
409
        corrected_pvals = None
410
        method = ntmt.nt_method.method
411
        if method == "bonferroni":
412
            corrected_pvals = Bonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
413
        elif method == "sidak":
414
            corrected_pvals = Sidak(ntmt.pvals, ntmt.alpha).corrected_pvals
415
        elif method == "holm":
416
            corrected_pvals = HolmBonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
417
        elif method == "fdr":
418
            # get the empirical p-value distributions for FDR
419
            term_pop = getattr(self, 'term_pop', None)
420
            if term_pop is None:
421
                term_pop = count_terms(self.pop, self.assoc, self.obo_dag)
422
            p_val_distribution = calc_qval(len(ntmt.study),
423
                                           self.pop_n,
424
                                           self.pop, self.assoc,
425
                                           term_pop, self.obo_dag)
426
            corrected_pvals = FDR(p_val_distribution,
427
                                  ntmt.results, ntmt.alpha).corrected_pvals
428
429
        self._update_pvalcorr(ntmt, corrected_pvals)
430
431
    @staticmethod
432
    def _update_pvalcorr(ntmt, corrected_pvals):
433
        """Add data members to store multiple test corrections."""
434
        if corrected_pvals is None:
435
            return
436
        for rec, val in zip(ntmt.results, corrected_pvals):
437
            rec.set_corrected_pval(ntmt.nt_method, val)
438
439
    # Methods for writing results into tables: text, tab-separated, Excel spreadsheets
440
    def wr_txt(self, fout_txt, goea_results, prtfmt=None, **kws):
441
        """Print GOEA results to text file."""
442
        if not goea_results:
443
            sys.stdout.write("      0 GOEA results. NOT WRITING {FOUT}\n".format(FOUT=fout_txt))
444
            return
445
        with open(fout_txt, 'w') as prt:
446
            data_nts = self.prt_txt(prt, goea_results, prtfmt, **kws)
447
            log = self.log if self.log is not None else sys.stdout
448
            log.write("  {N:>5} GOEA results for {CUR:5} study items. WROTE: {F}\n".format(
449
                N=len(data_nts),
450
                CUR=len(get_study_items(goea_results)),
451
                F=fout_txt))
452
453
    def prt_txt(self, prt, goea_results, prtfmt=None, **kws):
454
        """Print GOEA results in text format."""
455
        if prtfmt is None:
456
            prtfmt = "{GO} {NS} {p_uncorrected:5.2e} {study_count:>5} {name}\n"
457
        prtfmt = self.adjust_prtfmt(prtfmt)
458
        prt_flds = RPT.get_fmtflds(prtfmt)
459
        data_nts = get_goea_nts_prt(goea_results, prt_flds, **kws)
460
        RPT.prt_txt(prt, data_nts, prtfmt, prt_flds, **kws)
461
        return data_nts
462
463
    def wr_xlsx(self, fout_xlsx, goea_results, **kws):
464
        """Write a xlsx file."""
465
        # kws: prt_if indent
466
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
467
        xlsx_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
468
        if 'fld2col_widths' not in kws:
469
            kws['fld2col_widths'] = {f:self.default_fld2col_widths.get(f, 8) for f in prt_flds}
470
        RPT.wr_xlsx(fout_xlsx, xlsx_data, **kws)
471
472
    def wr_tsv(self, fout_tsv, goea_results, **kws):
473
        """Write tab-separated table data to file"""
474
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
475
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
476
        RPT.wr_tsv(fout_tsv, tsv_data, **kws)
477
478
    def prt_tsv(self, prt, goea_results, **kws):
479
        """Write tab-separated table data"""
480
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
481
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
482
        RPT.prt_tsv(prt, tsv_data, prt_flds, **kws)
483
484
    @staticmethod
485
    def adjust_prtfmt(prtfmt):
486
        """Adjust format_strings for legal values."""
487
        prtfmt = prtfmt.replace("{p_holm-sidak", "{p_holm_sidak")
488
        prtfmt = prtfmt.replace("{p_simes-hochberg", "{p_simes_hochberg")
489
        return prtfmt
490
491
    @staticmethod
492
    def get_NS2nts(results, fldnames=None, **kws):
493
        """Get namedtuples of GOEA results, split into BP, MF, CC."""
494
        NS2nts = cx.defaultdict(list)
495
        nts = get_goea_nts_all(results, fldnames, **kws)
496
        for nt in nts:
497
            NS2nts[nt.NS].append(nt)
498
        return NS2nts
499
500
    @staticmethod
501
    def get_item_cnt(results, attrname="study_items"):
502
        """Get all study or population items (e.g., geneids)."""
503
        items = set()
504
        go_cnt = 0
505
        for rec in results:
506
            if hasattr(rec, attrname):
507
                items_cur = getattr(rec, attrname)
508
                # Only count GO term if there are items in the set.
509
                if len(items_cur) != 0:
510
                    items |= items_cur
511
                    go_cnt += 1
512
        return items, go_cnt
513
514
    @staticmethod
515
    def get_prtflds_default(results):
516
        """Get default fields names. Used in printing GOEA results.
517
518
           Researchers can control which fields they want to print in the GOEA results
519
           or they can use the default fields.
520
        """
521
        if results:
522
            return results[0].get_prtflds_default()
523
        return []
524
525
    @staticmethod
526
    def print_summary(results, min_ratio=None, indent=False, pval=0.05):
527
        """Print summary."""
528
        from .version import __version__ as version
529
530
        # Header contains provenance and parameters
531
        print("# Generated by GOATOOLS v{0} ({1})".format(version, datetime.date.today()))
532
        print("# min_ratio={0} pval={1}".format(min_ratio, pval))
533
534
        # field names for output
535
        if results:
536
            print("\t".join(GOEnrichmentStudy.get_prtflds_default(results)))
537
538
        for rec in results:
539
            # calculate some additional statistics
540
            # (over_under, is_ratio_different)
541
            rec.update_remaining_fldsdefprt(min_ratio=min_ratio)
542
543
            if pval is not None and rec.p_uncorrected >= pval:
544
                continue
545
546
            if rec.is_ratio_different:
547
                print(rec.__str__(indent=indent))
548
549
    def wr_py_goea_results(self, fout_py, goea_results, **kws):
550
        """Save GOEA results into Python package containing list of namedtuples."""
551
        var_name = kws.get("var_name", "goea_results")
552
        docstring = kws.get("docstring", "")
553
        sortby = kws.get("sortby", None)
554
        if goea_results:
555
            from goatools.nt_utils import wr_py_nts
556
            nts_goea = goea_results
557
            # If list has GOEnrichmentRecords or verbose namedtuples, exclude some fields.
558
            if hasattr(goea_results[0], "_fldsdefprt") or hasattr(goea_results[0], 'goterm'):
559
                # Exclude some attributes from the namedtuple when saving results
560
                # to a Python file because the information is redundant or verbose.
561
                nts_goea = get_goea_nts_prt(goea_results)
562
            docstring = "\n".join([docstring, "# {VER}\n\n".format(VER=self.obo_dag.version)])
563
            assert hasattr(nts_goea[0], '_fields')
564
            if sortby is None:
565
                sortby = lambda nt: getattr(nt, 'p_uncorrected')
566
            nts_goea = sorted(nts_goea, key=sortby)
567
            wr_py_nts(fout_py, nts_goea, docstring, var_name)
568
569
def get_study_items(goea_results):
570
    """Get all study items (e.g., geneids)."""
571
    study_items = set()
572
    for rec in goea_results:
573
        study_items |= rec.study_items
574
    return study_items
575
576
def get_goea_nts_prt(goea_results, fldnames=None, **usr_kws):
577
    """Return list of namedtuples removing fields which are redundant or verbose."""
578
    kws = usr_kws.copy()
579
    if 'not_fldnames' not in kws:
580
        kws['not_fldnames'] = ['goterm', 'parents', 'children', 'id']
581
    if 'rpt_fmt' not in kws:
582
        kws['rpt_fmt'] = True
583
    return get_goea_nts_all(goea_results, fldnames, **kws)
584
585
def get_goea_nts_all(goea_results, fldnames=None, **kws):
586
    """Get namedtuples containing user-specified (or default) data from GOEA results.
587
588
        Reformats data from GOEnrichmentRecord objects into lists of
589
        namedtuples so the generic table writers may be used.
590
    """
591
    data_nts = [] # A list of namedtuples containing GOEA results
592
    if not goea_results:
593
        return data_nts
594
    keep_if = kws.get('keep_if', None)
595
    rpt_fmt = kws.get('rpt_fmt', False)
596
    indent = kws.get('indent', False)
597
    # I. FIELD (column) NAMES
598
    not_fldnames = kws.get('not_fldnames', None)
599
    if fldnames is None:
600
        fldnames = get_fieldnames(goea_results[0])
601
    # Ia. Explicitly exclude specific fields from named tuple
602
    if not_fldnames is not None:
603
        fldnames = [f for f in fldnames if f not in not_fldnames]
604
    nttyp = cx.namedtuple("NtGoeaResults", " ".join(fldnames))
605
    goid_idx = fldnames.index("GO") if 'GO' in fldnames else None
606
    # II. Loop through GOEA results stored in a GOEnrichmentRecord object
607
    for goerec in goea_results:
608
        vals = get_field_values(goerec, fldnames, rpt_fmt)
609
        if indent:
610
            vals[goid_idx] = "".join([goerec.get_indent_dots(), vals[goid_idx]])
611
        ntobj = nttyp._make(vals)
612
        if keep_if is None or keep_if(ntobj):
613
            data_nts.append(ntobj)
614
    return data_nts
615
616
def get_field_values(item, fldnames, rpt_fmt=None):
617
    """Return fieldnames and values of either a namedtuple or GOEnrichmentRecord."""
618
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
619
        return item.get_field_values(fldnames, rpt_fmt)
620
    if hasattr(item, "_fields"): # Is a namedtuple
621
        return [getattr(item, f) for f in fldnames]
622
623
def get_fieldnames(item):
624
    """Return fieldnames of either a namedtuple or GOEnrichmentRecord."""
625
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
626
        return item.get_prtflds_all()
627
    if hasattr(item, "_fields"): # Is a namedtuple
628
        return item._fields
629
630
# Copyright (C) 2010-2017, H Tang et al., All rights reserved.
631