Completed
Pull Request — master (#81)
by
unknown
08:27 queued 02:48
created

compute_pvals()   A

Complexity

Conditions 2

Size

Total Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 0
loc 17
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
4
"""
5
python %prog study.file population.file gene-association.file
6
7
This program returns P-values for functional enrichment in a cluster of
8
study genes using Fisher's exact test, and corrected for multiple testing
9
(including Bonferroni, Holm, Sidak, and false discovery rate)
10
"""
11
12
from __future__ import absolute_import
13
14
__copyright__ = "Copyright (C) 2010-2017, H Tang et al., All rights reserved."
15
__author__ = "various"
16
17
import sys
18
import collections as cx
19
import datetime
20
from functools import partial
21
import multiprocessing
22
23
24
from goatools.multiple_testing import Methods, Bonferroni, Sidak, HolmBonferroni, FDR, calc_qval
25
from goatools.ratio import get_terms, count_terms, is_ratio_different
26
import goatools.wr_tbl as RPT
27
from goatools.pvalcalc import FisherFactory
28
29
30
31
32
def compute_pvals(allterms, calc_pvalue, go2studyitems, go2popitems,
33
                  study_n, pop_n):
34
35
    import os
36
    print("Pid=", os.getpid())
37
38
    results = {}
39
    for term in allterms:
40
41
        study_items = go2studyitems.get(term, set())
42
        study_count = len(study_items)
43
        pop_items = go2popitems.get(term, set())
44
        pop_count = len(pop_items)
45
46
        results[term] = calc_pvalue(study_count, study_n, pop_count, pop_n)
47
48
    return results
49
50
51
class GOEnrichmentRecord(object):
52
    """Represents one result (from a single GOTerm) in the GOEnrichmentStudy
53
    """
54
    namespace2NS = cx.OrderedDict([
55
        ('biological_process', 'BP'),
56
        ('molecular_function', 'MF'),
57
        ('cellular_component', 'CC')])
58
59
    # Fields seen in every enrichment result
60
    _fldsdefprt = [
61
        "GO",
62
        "NS",
63
        "enrichment",
64
        "name",
65
        "ratio_in_study",
66
        "ratio_in_pop",
67
        "p_uncorrected",
68
        "depth",
69
        "study_count",
70
        "study_items"]
71
    _fldsdeffmt = ["%s"]*3 + ["%-30s"] + ["%d/%d"] * 2 + ["%.3g"] + ["%d"] * 2 + ["%15s"]
72
73
    _flds = set(_fldsdefprt).intersection(
74
        set(['study_items', 'study_count', 'study_n', 'pop_items', 'pop_count', 'pop_n']))
75
76
    def __init__(self, **kwargs):
77
        # Methods seen in current enrichment result
78
        self._methods = []
79
        for k, v in kwargs.items():
80
            setattr(self, k, v)
81
            if k == 'ratio_in_study':
82
                setattr(self, 'study_count', v[0])
83
                setattr(self, 'study_n', v[1])
84
            if k == 'ratio_in_pop':
85
                setattr(self, 'pop_count', v[0])
86
                setattr(self, 'pop_n', v[1])
87
        self._init_enrichment()
88
        self.goterm = None  # the reference to the GOTerm
89
90
    def get_method_name(self):
91
        """Return name of first method in the _methods list."""
92
        return self._methods[0].fieldname
93
94
    def get_pvalue(self):
95
        """Returns pval for 1st method, if it exists. Else returns uncorrected pval."""
96
        if self._methods:
97
            return getattr(self, "p_{m}".format(m=self.get_method_name()))
98
        return getattr(self, "p_uncorrected")
99
100
    def set_corrected_pval(self, nt_method, pvalue):
101
        """Add object attribute based on method name."""
102
        self._methods.append(nt_method)
103
        fieldname = "".join(["p_", nt_method.fieldname])
104
        setattr(self, fieldname, pvalue)
105
106
    def __str__(self, indent=False):
107
        field_data = [getattr(self, f, "n.a.") for f in self._fldsdefprt[:-1]] + \
108
                     [getattr(self, "p_{}".format(m.fieldname)) for m in self._methods] + \
109
                     [", ".join(sorted(getattr(self, self._fldsdefprt[-1], set())))]
110
        fldsdeffmt = self._fldsdeffmt
111
        field_formatter = fldsdeffmt[:-1] + ["%.3g"]*len(self._methods) + [fldsdeffmt[-1]]
112
        self._chk_fields(field_data, field_formatter)
113
114
        # default formatting only works for non-"n.a" data
115
        for i, f in enumerate(field_data):
116
            if f == "n.a.":
117
                field_formatter[i] = "%s"
118
119
        # print dots to show the level of the term
120
        dots = self.get_indent_dots() if indent else ""
121
        prtdata = "\t".join(a % b for (a, b) in zip(field_formatter, field_data))
122
        return "".join([dots, prtdata])
123
124
    def get_indent_dots(self):
125
        """Get a string of dots ("....") representing the level of the GO term."""
126
        return "." * self.goterm.level if self.goterm is not None else ""
127
128
    @staticmethod
129
    def _chk_fields(field_data, field_formatter):
130
        """Check that expected fields are present."""
131
        if len(field_data) == len(field_formatter):
132
            return
133
        len_dat = len(field_data)
134
        len_fmt = len(field_formatter)
135
        msg = [
136
            "FIELD DATA({d}) != FORMATTER({f})".format(d=len_dat, f=len_fmt),
137
            "DAT({N}): {D}".format(N=len_dat, D=field_data),
138
            "FMT({N}): {F}".format(N=len_fmt, F=field_formatter)]
139
        raise Exception("\n".join(msg))
140
141
    def __repr__(self):
142
        return "GOEnrichmentRecord({GO})".format(GO=self.GO)
143
144
    def set_goterm(self, goid):
145
        """Set goterm and copy GOTerm's name and namespace."""
146
        self.goterm = goid.get(self.GO, None)
147
        present = self.goterm is not None
148
        self.name = self.goterm.name if present else "n.a."
149
        self.NS = self.namespace2NS[self.goterm.namespace] if present else "XX"
150
151
    def _init_enrichment(self):
152
        """Mark as 'enriched' or 'purified'."""
153
        self.enrichment = 'e' if ((1.0 * self.study_count / self.study_n) >
154
                                  (1.0 * self.pop_count / self.pop_n)) else 'p'
155
156
    def update_remaining_fldsdefprt(self, min_ratio=None):
157
        """Finish updating self (GOEnrichmentRecord) field, is_ratio_different."""
158
        self.is_ratio_different = is_ratio_different(min_ratio, self.study_count,
159
                                                     self.study_n, self.pop_count, self.pop_n)
160
161
162
    # -------------------------------------------------------------------------------------
163
    # Methods for getting flat namedtuple values from GOEnrichmentRecord object
164
    def get_prtflds_default(self):
165
        """Get default fields."""
166
        return self._fldsdefprt[:-1] + \
167
               ["p_{M}".format(M=m.fieldname) for m in self._methods] + \
168
               [self._fldsdefprt[-1]]
169
170
    def get_prtflds_all(self):
171
        """When converting to a namedtuple, get all possible fields in their original order."""
172
        flds = []
173
        dont_add = set(['_parents', '_methods'])
174
        # Fields: GO NS enrichment name ratio_in_study ratio_in_pop p_uncorrected
175
        #         depth study_count p_sm_bonferroni p_fdr_bh study_items
176
        self._flds_append(flds, self.get_prtflds_default(), dont_add)
177
        # Fields: GO NS goterm
178
        #         ratio_in_pop pop_n pop_count pop_items name
179
        #         ratio_in_study study_n study_count study_items
180
        #         _methods enrichment p_uncorrected p_sm_bonferroni p_fdr_bh
181
        self._flds_append(flds, vars(self).keys(), dont_add)
182
        # Fields: name level is_obsolete namespace id depth parents children _parents alt_ids
183
        self._flds_append(flds, vars(self.goterm).keys(), dont_add)
184
        return flds
185
186
    @staticmethod
187
    def _flds_append(flds, addthese, dont_add):
188
        """Retain order of fields as we add them once to the list."""
189
        for fld in addthese:
190
            if fld not in flds and fld not in dont_add:
191
                flds.append(fld)
192
193
    def get_field_values(self, fldnames, rpt_fmt=True):
194
        """Get flat namedtuple fields for one GOEnrichmentRecord."""
195
        row = []
196
        # Loop through each user field desired
197
        for fld in fldnames:
198
            # 1. Check the GOEnrichmentRecord's attributes
199
            val = getattr(self, fld, None)
200
            if val is not None:
201
                if rpt_fmt:
202
                    val = self._get_rpt_fmt(fld, val)
203
                row.append(val)
204
            else:
205
                # 2. Check the GO object for the field
206
                val = getattr(self.goterm, fld, None)
207
                if rpt_fmt:
208
                    val = self._get_rpt_fmt(fld, val)
209
                if val is not None:
210
                    row.append(val)
211
                else:
212
                    # 3. Field not found, raise Exception
213
                    self._err_fld(fld, fldnames, row)
214
            if rpt_fmt:
215
                assert not isinstance(val, list), \
216
                   "UNEXPECTED LIST: FIELD({F}) VALUE({V}) FMT({P})".format(
217
                       P=rpt_fmt, F=fld, V=val)
218
        return row
219
220
    @staticmethod
221
    def _get_rpt_fmt(fld, val):
222
        """Return values in a format amenable to printing in a table."""
223
        if fld.startswith("ratio_"):
224
            return "{N}/{TOT}".format(N=val[0], TOT=val[1])
225
        elif fld in set(['study_items', 'pop_items', 'alt_ids']):
226
            return ", ".join([str(v) for v in sorted(val)])
227
        return val
228
229
    def _err_fld(self, fld, fldnames):
230
        """Unrecognized field. Print detailed Failure message."""
231
        msg = ['ERROR. UNRECOGNIZED FIELD({F})'.format(F=fld)]
232
        actual_flds = set(self.get_prtflds_default() + self.goterm.__dict__.keys())
233
        bad_flds = set(fldnames).difference(set(actual_flds))
234
        if bad_flds:
235
            msg.append("\nGOEA RESULT FIELDS: {}".format(" ".join(self._fldsdefprt)))
236
            msg.append("GO FIELDS: {}".format(" ".join(self.goterm.__dict__.keys())))
237
            msg.append("\nFATAL: {N} UNEXPECTED FIELDS({F})\n".format(
238
                N=len(bad_flds), F=" ".join(bad_flds)))
239
            msg.append("  {N} User-provided fields:".format(N=len(fldnames)))
240
            for idx, fld in enumerate(fldnames, 1):
241
                mrk = "ERROR -->" if fld in bad_flds else ""
242
                msg.append("  {M:>9} {I:>2}) {F}".format(M=mrk, I=idx, F=fld))
243
        raise Exception("\n".join(msg))
244
245
246
class GOEnrichmentStudy(object):
247
    """Runs Fisher's exact test, as well as multiple corrections
248
    """
249
    # Default Excel table column widths for GOEA results
250
    default_fld2col_widths = {
251
        'NS'        :  3,
252
        'GO'        : 12,
253
        'level'     :  3,
254
        'enrichment':  1,
255
        'name'      : 60,
256
        'ratio_in_study':  8,
257
        'ratio_in_pop'  : 12,
258
        'study_items'   : 15,
259
    }
260
261
    def __init__(self, pop, assoc, obo_dag, propagate_counts=True, alpha=.05, methods=None, **kws):
262
        self.log = kws['log'] if 'log' in kws else sys.stdout
263
        self._run_multitest = {
264
            'local':lambda iargs: self._run_multitest_local(iargs),
265
            'statsmodels':lambda iargs: self._run_multitest_statsmodels(iargs)}
266
        self.pop = pop
267
        self.pop_n = len(pop)
268
        self.assoc = assoc
269
        self.obo_dag = obo_dag
270
        self.alpha = alpha
271
        if methods is None:
272
            methods = ["bonferroni", "sidak", "holm"]
273
        self.methods = Methods(methods)
274
        self.pval_obj = FisherFactory(**kws).pval_obj
275
276
        if propagate_counts:
277
            sys.stderr.write("Propagating term counts to parents ..\n")
278
            obo_dag.update_association(assoc)
279
        self.go2popitems = get_terms("population", pop, assoc, obo_dag, self.log)
280
281
    def run_study(self, study, **kws):
282
        """Run Gene Ontology Enrichment Study (GOEA) on study ids."""
283
        # Key-word arguments:
284
        methods = Methods(kws['methods']) if 'methods' in kws else self.methods
285
        alpha = kws['alpha'] if 'alpha' in kws else self.alpha
286
        log = kws['log'] if 'log' in kws else self.log
287
        # Calculate uncorrected pvalues
288
        results = self._get_pval_uncorr(study, log)
289
        if not results:
290
            return []
291
292
        # Do multipletest corrections on uncorrected pvalues and update results
293
        self._run_multitest_corr(results, methods, alpha, study, log)
294
295
        for rec in results:
296
            # get go term for name and level
297
            rec.set_goterm(self.obo_dag)
298
299
        # 'keep_if' can be used to keep only significant GO terms. Example:
300
        #     >>> keep_if = lambda nt: nt.p_fdr_bh < 0.05 # if results are significant
301
        #     >>> goea_results = goeaobj.run_study(geneids_study, keep_if=keep_if)
302
        if 'keep_if' in kws:
303
            keep_if = kws['keep_if']
304
            results = [r for r in results if keep_if(r)]
305
306
        # Default sort order: First, sort by BP, MF, CC. Second, sort by pval
307
        results.sort(key=lambda r: [r.NS, r.p_uncorrected])
308
309
        if log is not None:
310
            log.write("  {MSG}\n".format(MSG="\n  ".join(self.get_results_msg(results, study))))
311
312
        return results # list of GOEnrichmentRecord objects
313
314
    def run_study_nts(self, study, **kws):
315
        """Run GOEA on study ids. Return results as a list of namedtuples."""
316
        goea_results = self.run_study(study, **kws)
317
        return get_goea_nts_all(goea_results)
318
319
    def get_results_msg(self, results, study):
320
        """Return summary for GOEA results."""
321
        # To convert msg list to string: "\n".join(msg)
322
        msg = []
323
        if results:
324
            stu_items, num_gos_stu = self.get_item_cnt(results, "study_items")
325
            pop_items, num_gos_pop = self.get_item_cnt(results, "pop_items")
326
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} study items".format(
327
                N=len(stu_items), NT=len(set(study)), M=num_gos_stu))
328
            msg.append("{M:,} GO terms are associated with {N:,} of {NT:,} population items".format(
329
                N=len(pop_items), NT=self.pop_n, M=num_gos_pop))
330
        return msg
331
332
    def _get_pval_uncorr(self, study, log=sys.stdout):
333
        """Calculate the uncorrected pvalues for study items."""
334
        if log is not None:
335
            log.write("Calculating uncorrected p-values using {PFNC}\n".format(PFNC=self.pval_obj.name))
336
        go2studyitems = get_terms("study", study, self.assoc, self.obo_dag, log)
337
        pop_n, study_n = self.pop_n, len(study)
338
        allterms = set(go2studyitems.keys()).union(
339
            set(self.go2popitems.keys()))
340
341
        # if self.pval_obj.log is a file handle, which we can not serialize, so we could
342
        # not transfer self.pval_obj.calc_pvalue to another python process with multiprocessing.
343
        # there fore we "path" the object which will later be restored again.
344
        old = self.pval_obj.log
345
        self.pval_obj.log = None
346
        calc_pvalue = self.pval_obj.calc_pvalue
347
348
        # -1 avoids freezing of the machine:
349
        n_procs = multiprocessing.cpu_count() - 1
350
351
        p = multiprocessing.Pool(n_procs)
352
        n = len(allterms)
353
354
        allterms = list(allterms)
355
        fragments = [allterms[i::n_procs] for i in range(n_procs)]
356
357
        remote_func = partial(compute_pvals, calc_pvalue=calc_pvalue, go2studyitems=go2studyitems,
358
                              go2popitems=self.go2popitems, study_n=study_n, pop_n=pop_n)
359
360
        all_p_values = p.map(remote_func, fragments)
361
362
        # restore patched file handle
363
        self.pval_obj.log = old
364
365
        results = []
366
367
        for p_values in all_p_values:
368
369
            for term, p_value in p_values.items():
370
371
                study_items = go2studyitems.get(term, set())
372
                study_count = len(study_items)
373
                pop_items = self.go2popitems.get(term, set())
374
                pop_count = len(pop_items)
375
376
                one_record = GOEnrichmentRecord(
377
                    GO=term,
378
                    p_uncorrected=p_value,
379
                    study_items=study_items,
380
                    pop_items=pop_items,
381
                    ratio_in_study=(study_count, study_n),
382
                    ratio_in_pop=(pop_count, pop_n))
383
384
                results.append(one_record)
385
386
        return results
387
388
    def _run_multitest_corr(self, results, usr_methods, alpha, study, log):
389
        """Do multiple-test corrections on uncorrected pvalues."""
390
        assert 0 < alpha < 1, "Test-wise alpha must fall between (0, 1)"
391
        pvals = [r.p_uncorrected for r in results]
392
        NtMt = cx.namedtuple("NtMt", "results pvals alpha nt_method study")
393
394
        for nt_method in usr_methods:
395
            ntmt = NtMt(results, pvals, alpha, nt_method, study)
396
            if log is not None:
397
                log.write("Running multitest correction: {MSRC} {METHOD}\n".format(
398
                    MSRC=ntmt.nt_method.source, METHOD=ntmt.nt_method.method))
399
            self._run_multitest[nt_method.source](ntmt)
400
401
    def _run_multitest_statsmodels(self, ntmt):
402
        """Use multitest mthods that have been implemented in statsmodels."""
403
        # Only load statsmodels if it is used
404
        multipletests = self.methods.get_statsmodels_multipletests()
405
        results = multipletests(ntmt.pvals, ntmt.alpha, ntmt.nt_method.method)
406
        pvals_corrected = results[1] # reject_lst, pvals_corrected, alphacSidak, alphacBonf
407
        self._update_pvalcorr(ntmt, pvals_corrected)
408
409
    def _run_multitest_local(self, ntmt):
410
        """Use multitest mthods that have been implemented locally."""
411
        corrected_pvals = None
412
        method = ntmt.nt_method.method
413
        if method == "bonferroni":
414
            corrected_pvals = Bonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
415
        elif method == "sidak":
416
            corrected_pvals = Sidak(ntmt.pvals, ntmt.alpha).corrected_pvals
417
        elif method == "holm":
418
            corrected_pvals = HolmBonferroni(ntmt.pvals, ntmt.alpha).corrected_pvals
419
        elif method == "fdr":
420
            # get the empirical p-value distributions for FDR
421
            term_pop = getattr(self, 'term_pop', None)
422
            if term_pop is None:
423
                term_pop = count_terms(self.pop, self.assoc, self.obo_dag)
424
            p_val_distribution = calc_qval(len(ntmt.study),
425
                                           self.pop_n,
426
                                           self.pop, self.assoc,
427
                                           term_pop, self.obo_dag)
428
            corrected_pvals = FDR(p_val_distribution,
429
                                  ntmt.results, ntmt.alpha).corrected_pvals
430
431
        self._update_pvalcorr(ntmt, corrected_pvals)
432
433
    @staticmethod
434
    def _update_pvalcorr(ntmt, corrected_pvals):
435
        """Add data members to store multiple test corrections."""
436
        if corrected_pvals is None:
437
            return
438
        for rec, val in zip(ntmt.results, corrected_pvals):
439
            rec.set_corrected_pval(ntmt.nt_method, val)
440
441
    # Methods for writing results into tables: text, tab-separated, Excel spreadsheets
442
    def wr_txt(self, fout_txt, goea_results, prtfmt=None, **kws):
443
        """Print GOEA results to text file."""
444
        if not goea_results:
445
            sys.stdout.write("      0 GOEA results. NOT WRITING {FOUT}\n".format(FOUT=fout_txt))
446
            return
447
        with open(fout_txt, 'w') as prt:
448
            data_nts = self.prt_txt(prt, goea_results, prtfmt, **kws)
449
            log = self.log if self.log is not None else sys.stdout
450
            log.write("  {N:>5} GOEA results for {CUR:5} study items. WROTE: {F}\n".format(
451
                N=len(data_nts),
452
                CUR=len(get_study_items(goea_results)),
453
                F=fout_txt))
454
455
    def prt_txt(self, prt, goea_results, prtfmt=None, **kws):
456
        """Print GOEA results in text format."""
457
        if prtfmt is None:
458
            prtfmt = "{GO} {NS} {p_uncorrected:5.2e} {study_count:>5} {name}\n"
459
        prtfmt = self.adjust_prtfmt(prtfmt)
460
        prt_flds = RPT.get_fmtflds(prtfmt)
461
        data_nts = get_goea_nts_prt(goea_results, prt_flds, **kws)
462
        RPT.prt_txt(prt, data_nts, prtfmt, prt_flds, **kws)
463
        return data_nts
464
465
    def wr_xlsx(self, fout_xlsx, goea_results, **kws):
466
        """Write a xlsx file."""
467
        # kws: prt_if indent
468
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
469
        xlsx_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
470
        if 'fld2col_widths' not in kws:
471
            kws['fld2col_widths'] = {f:self.default_fld2col_widths.get(f, 8) for f in prt_flds}
472
        RPT.wr_xlsx(fout_xlsx, xlsx_data, **kws)
473
474
    def wr_tsv(self, fout_tsv, goea_results, **kws):
475
        """Write tab-separated table data to file"""
476
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
477
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
478
        RPT.wr_tsv(fout_tsv, tsv_data, **kws)
479
480
    def prt_tsv(self, prt, goea_results, **kws):
481
        """Write tab-separated table data"""
482
        prt_flds = kws.get('prt_flds', self.get_prtflds_default(goea_results))
483
        tsv_data = get_goea_nts_prt(goea_results, prt_flds, **kws)
484
        RPT.prt_tsv(prt, tsv_data, prt_flds, **kws)
485
486
    @staticmethod
487
    def adjust_prtfmt(prtfmt):
488
        """Adjust format_strings for legal values."""
489
        prtfmt = prtfmt.replace("{p_holm-sidak", "{p_holm_sidak")
490
        prtfmt = prtfmt.replace("{p_simes-hochberg", "{p_simes_hochberg")
491
        return prtfmt
492
493
    @staticmethod
494
    def get_NS2nts(results, fldnames=None, **kws):
495
        """Get namedtuples of GOEA results, split into BP, MF, CC."""
496
        NS2nts = cx.defaultdict(list)
497
        nts = get_goea_nts_all(results, fldnames, **kws)
498
        for nt in nts:
499
            NS2nts[nt.NS].append(nt)
500
        return NS2nts
501
502
    @staticmethod
503
    def get_item_cnt(results, attrname="study_items"):
504
        """Get all study or population items (e.g., geneids)."""
505
        items = set()
506
        go_cnt = 0
507
        for rec in results:
508
            if hasattr(rec, attrname):
509
                items_cur = getattr(rec, attrname)
510
                # Only count GO term if there are items in the set.
511
                if len(items_cur) != 0:
512
                    items |= items_cur
513
                    go_cnt += 1
514
        return items, go_cnt
515
516
    @staticmethod
517
    def get_prtflds_default(results):
518
        """Get default fields names. Used in printing GOEA results.
519
520
           Researchers can control which fields they want to print in the GOEA results
521
           or they can use the default fields.
522
        """
523
        if results:
524
            return results[0].get_prtflds_default()
525
        return []
526
527
    @staticmethod
528
    def print_summary(results, min_ratio=None, indent=False, pval=0.05):
529
        """Print summary."""
530
        from .version import __version__ as version
531
532
        # Header contains provenance and parameters
533
        print("# Generated by GOATOOLS v{0} ({1})".format(version, datetime.date.today()))
534
        print("# min_ratio={0} pval={1}".format(min_ratio, pval))
535
536
        # field names for output
537
        if results:
538
            print("\t".join(GOEnrichmentStudy.get_prtflds_default(results)))
539
540
        for rec in results:
541
            # calculate some additional statistics
542
            # (over_under, is_ratio_different)
543
            rec.update_remaining_fldsdefprt(min_ratio=min_ratio)
544
545
            if pval is not None and rec.p_uncorrected >= pval:
546
                continue
547
548
            if rec.is_ratio_different:
549
                print(rec.__str__(indent=indent))
550
551
    def wr_py_goea_results(self, fout_py, goea_results, **kws):
552
        """Save GOEA results into Python package containing list of namedtuples."""
553
        var_name = kws.get("var_name", "goea_results")
554
        docstring = kws.get("docstring", "")
555
        sortby = kws.get("sortby", None)
556
        if goea_results:
557
            from goatools.nt_utils import wr_py_nts
558
            nts_goea = goea_results
559
            # If list has GOEnrichmentRecords or verbose namedtuples, exclude some fields.
560
            if hasattr(goea_results[0], "_fldsdefprt") or hasattr(goea_results[0], 'goterm'):
561
                # Exclude some attributes from the namedtuple when saving results
562
                # to a Python file because the information is redundant or verbose.
563
                nts_goea = get_goea_nts_prt(goea_results)
564
            docstring = "\n".join([docstring, "# {VER}\n\n".format(VER=self.obo_dag.version)])
565
            assert hasattr(nts_goea[0], '_fields')
566
            if sortby is None:
567
                sortby = lambda nt: getattr(nt, 'p_uncorrected')
568
            nts_goea = sorted(nts_goea, key=sortby)
569
            wr_py_nts(fout_py, nts_goea, docstring, var_name)
570
571
def get_study_items(goea_results):
572
    """Get all study items (e.g., geneids)."""
573
    study_items = set()
574
    for rec in goea_results:
575
        study_items |= rec.study_items
576
    return study_items
577
578
def get_goea_nts_prt(goea_results, fldnames=None, **usr_kws):
579
    """Return list of namedtuples removing fields which are redundant or verbose."""
580
    kws = usr_kws.copy()
581
    if 'not_fldnames' not in kws:
582
        kws['not_fldnames'] = ['goterm', 'parents', 'children', 'id']
583
    if 'rpt_fmt' not in kws:
584
        kws['rpt_fmt'] = True
585
    return get_goea_nts_all(goea_results, fldnames, **kws)
586
587
def get_goea_nts_all(goea_results, fldnames=None, **kws):
588
    """Get namedtuples containing user-specified (or default) data from GOEA results.
589
590
        Reformats data from GOEnrichmentRecord objects into lists of
591
        namedtuples so the generic table writers may be used.
592
    """
593
    data_nts = [] # A list of namedtuples containing GOEA results
594
    if not goea_results:
595
        return data_nts
596
    keep_if = kws.get('keep_if', None)
597
    rpt_fmt = kws.get('rpt_fmt', False)
598
    indent = kws.get('indent', False)
599
    # I. FIELD (column) NAMES
600
    not_fldnames = kws.get('not_fldnames', None)
601
    if fldnames is None:
602
        fldnames = get_fieldnames(goea_results[0])
603
    # Ia. Explicitly exclude specific fields from named tuple
604
    if not_fldnames is not None:
605
        fldnames = [f for f in fldnames if f not in not_fldnames]
606
    nttyp = cx.namedtuple("NtGoeaResults", " ".join(fldnames))
607
    goid_idx = fldnames.index("GO") if 'GO' in fldnames else None
608
    # II. Loop through GOEA results stored in a GOEnrichmentRecord object
609
    for goerec in goea_results:
610
        vals = get_field_values(goerec, fldnames, rpt_fmt)
611
        if indent:
612
            vals[goid_idx] = "".join([goerec.get_indent_dots(), vals[goid_idx]])
613
        ntobj = nttyp._make(vals)
614
        if keep_if is None or keep_if(ntobj):
615
            data_nts.append(ntobj)
616
    return data_nts
617
618
def get_field_values(item, fldnames, rpt_fmt=None):
619
    """Return fieldnames and values of either a namedtuple or GOEnrichmentRecord."""
620
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
621
        return item.get_field_values(fldnames, rpt_fmt)
622
    if hasattr(item, "_fields"): # Is a namedtuple
623
        return [getattr(item, f) for f in fldnames]
624
625
def get_fieldnames(item):
626
    """Return fieldnames of either a namedtuple or GOEnrichmentRecord."""
627
    if hasattr(item, "_fldsdefprt"): # Is a GOEnrichmentRecord
628
        return item.get_prtflds_all()
629
    if hasattr(item, "_fields"): # Is a namedtuple
630
        return item._fields
631
632
# Copyright (C) 2010-2017, H Tang et al., All rights reserved.
633