GODagSmallPlot._get_item_str()   F
last analyzed

Complexity

Conditions 14

Size

Total Lines 23

Duplication

Lines 23
Ratio 100 %

Importance

Changes 0
Metric Value
cc 14
dl 23
loc 23
rs 3.6
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like GODagSmallPlot._get_item_str() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Plot a GODagSmall."""
2
3
__copyright__ = "Copyright (C) 2016-2018, DV Klopfenstein, H Tang, All rights reserved."
4
__author__ = "DV Klopfenstein"
5
6
import sys
7
import os
8
import collections as cx
9
from collections import OrderedDict
10
from goatools.godag_obosm import OboToGoDagSmall
11
12
def plot_gos(fout_png, goids, obo_dag, *args, **kws):
13
    """Given GO ids and the obo_dag, create a plot of paths from GO ids."""
14
    engine = kws['engine'] if 'engine' in kws else 'pydot'
15
    godagsmall = OboToGoDagSmall(goids=goids, obodag=obo_dag).godag
16
    godagplot = GODagSmallPlot(godagsmall, *args, **kws)
17
    godagplot.plt(fout_png, engine)
18
19
def plot_goid2goobj(fout_png, goid2goobj, *args, **kws):
20
    """Given a dict containing GO id and its goobj, create a plot of paths from GO ids."""
21
    engine = kws['engine'] if 'engine' in kws else 'pydot'
22
    godagsmall = OboToGoDagSmall(goid2goobj=goid2goobj).godag
23
    godagplot = GODagSmallPlot(godagsmall, *args, **kws)
24
    godagplot.plt(fout_png, engine)
25
26
def plot_results(fout_png, goea_results, *args, **kws):
27
    """Given a list of GOEA results, plot result GOs up to top."""
28
    if "{NS}" not in fout_png:
29
        plt_goea_results(fout_png, goea_results, *args, **kws)
30
    else:
31
        # Plot separately by NS: BP, MF, CC
32
        ns2goea_results = cx.defaultdict(list)
33
        for rec in goea_results:
34
            ns2goea_results[rec.NS].append(rec)
35
        for ns_name, ns_res in ns2goea_results.items():
36
            png = fout_png.format(NS=ns_name)
37
            plt_goea_results(png, ns_res, *args, **kws)
38
39
def plt_goea_results(fout_png, goea_results, *args, **kws):
40
    """Plot a single page."""
41
    engine = kws['engine'] if 'engine' in kws else 'pydot'
42
    godagsmall = OboToGoDagSmall(goea_results=goea_results).godag
43
    godagplot = GODagSmallPlot(godagsmall, *args, goea_results=goea_results, **kws)
44
    godagplot.plt(fout_png, engine)
45
46
class GODagPltVars(object):
47
    """Holds plotting paramters."""
48
49
    # http://www.graphviz.org/doc/info/colors.html
50
    rel2col = {
51
        'is_a':      'black',
52
        'part_of':   'blue',
53
        'regulates': 'gold',
54
        'positively_regulates': 'green',
55
        'negatively_regulates': 'red',
56
        'occurs_in':            'aquamarine4',
57
        'capable_of':           'dodgerblue',
58
        'capable_of_part_of':   'darkorange',
59
    }
60
61
    alpha2col = OrderedDict([
62
        # GOEA GO terms that are significant
63
        (0.005, 'mistyrose'),
64
        (0.010, 'moccasin'),
65
        (0.050, 'lemonchiffon1'),
66
        # GOEA GO terms that are not significant
67
        (1.000, 'grey95'),
68
    ])
69
70
    key2col = {
71
        'level_01': 'lightcyan',
72
        'go_sources': 'palegreen',
73
    }
74
75
    fmthdr = "{GO} L{level:>02} D{depth:>02}"
76
    fmtres = "{study_count} genes"
77
    # study items per line on GO Terms:
78
    items_p_line = 5
79
80
81
class GODagSmallPlot(object):
82
    """Plot a graph contained in an object of type GODagSmall ."""
83
84
    def __init__(self, godagsmall, *args, **kws):
85
        self.args = args
86
        self.log = kws['log'] if 'log' in kws else sys.stdout
87
        self.title = kws['title'] if 'title' in kws else None
88
        # GOATOOLs results as objects
89
        self.go2res = self._init_go2res(**kws)
90
        # GOATOOLs results as a list of namedtuples
91
        self.pval_name = self._init_pval_name(**kws)
92
        # Gene Symbol names
93
        self.id2symbol = kws['id2symbol'] if 'id2symbol' in kws else {}
94
        self.study_items = kws['study_items'] if 'study_items' in kws else None
95
        self.study_items_max = self._init_study_items_max()
96
        self.alpha_str = kws['alpha_str'] if 'alpha_str' in kws else None
97
        self.pltvars = kws['GODagPltVars'] if 'GODagPltVars' in kws else GODagPltVars()
98
        if 'items_p_line' in kws:
99
            self.pltvars.items_p_line = kws['items_p_line']
100
        self.dpi = kws['dpi'] if 'dpi' in kws else 150
101
        self.godag = godagsmall
102
        self.goid2color = self._init_goid2color()
103
        self.pydot = None
104
105
    def _init_study_items_max(self):
106
        """User can limit the number of genes printed in a GO term."""
107
        if self.study_items is None:
108
            return None
109
        if self.study_items is True:
110
            return None
111
        if isinstance(self.study_items, int):
112
            return self.study_items
113
        return None
114
115
    @staticmethod
116
    def _init_go2res(**kws):
117
        """Initialize GOEA results."""
118
        if 'goea_results' in kws:
119
            return {res.GO:res for res in kws['goea_results']}
120
        if 'go2nt' in kws:
121
            return kws['go2nt']
122
123
    @staticmethod
124
    def _init_pval_name(**kws):
125
        """Initialize pvalue attribute name."""
126
        if 'pval_name' in kws:
127
            return kws['pval_name']
128
        if 'goea_results' in kws:
129
            goea = kws['goea_results']
130
            if goea:
131
                return "p_{M}".format(M=goea[0].method_flds[0].fieldname)
132
133
    def _init_goid2color(self):
134
        """Set colors of GO terms."""
135
        goid2color = {}
136
        # 1. colors based on p-value override colors based on source GO
137 View Code Duplication
        if self.go2res is not None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
138
            alpha2col = self.pltvars.alpha2col
139
            pval_name = self.pval_name
140
            for goid, res in self.go2res.items():
141
                pval = getattr(res, pval_name, None)
142
                if pval is not None:
143
                    for alpha, color in alpha2col.items():
144
                        if pval <= alpha and res.study_count != 0:
145
                            if goid not in goid2color:
146
                                goid2color[goid] = color
147
        # 2. GO source color
148
        color = self.pltvars.key2col['go_sources']
149
        for goid in self.godag.go_sources:
150
            if goid not in goid2color:
151
                goid2color[goid] = color
152
        # 3. Level-01 GO color
153
        color = self.pltvars.key2col['level_01']
154
        for goid, goobj in self.godag.go2obj.items():
155
            if goobj.level == 1:
156
                if goid not in goid2color:
157
                    goid2color[goid] = color
158
        return goid2color
159
160
    def plt(self, fout_img, engine="pydot"):
161
        """Plot using pydot, graphviz, or GML."""
162
        if engine == "pydot":
163
            self._plt_pydot(fout_img)
164
        elif engine == "pygraphviz":
165
            raise Exception("TO BE IMPLEMENTED SOON: ENGINE pygraphvis")
166
        else:
167
            raise Exception("UNKNOWN ENGINE({E})".format(E=engine))
168
169
    # ----------------------------------------------------------------------------------
170
    # pydot
171 View Code Duplication
    def _plt_pydot(self, fout_img):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
172
        """Plot using the pydot graphics engine."""
173
        dag = self._get_pydot_graph()
174
        img_fmt = os.path.splitext(fout_img)[1][1:]
175
        dag.write(fout_img, format=img_fmt)
176
        self.log.write("  {GO_USR:>3} usr {GO_ALL:>3} GOs  WROTE: {F}\n".format(
177
            F=fout_img,
178
            GO_USR=len(self.godag.go_sources),
179
            GO_ALL=len(self.godag.go2obj)))
180
181
    def _get_pydot_graph(self):
182
        """Given a DAG, return a pydot digraph object."""
183
        rel = "is_a"
184
        pydot = self._get_pydot()
185
        # Initialize empty dag
186
        dag = pydot.Dot(label=self.title, graph_type='digraph', dpi="{}".format(self.dpi))
187
        # Initialize nodes
188
        go2node = self._get_go2pydotnode()
189
        # Add nodes to graph
190
        for node in go2node.values():
191
            dag.add_node(node)
192
        # Add edges to graph
193
        rel2col = self.pltvars.rel2col
194
        for src, tgt in self.godag.get_edges():
195
            dag.add_edge(pydot.Edge(
196
                go2node[tgt], go2node[src],
197
                shape="normal",
198
                color=rel2col[rel],
199
                dir="back")) # invert arrow direction for obo dag convention
200
        return dag
201
202
    def _get_go2pydotnode(self):
203
        """Create pydot Nodes."""
204
        go2node = {}
205
        for goid, goobj in self.godag.go2obj.items():
206
            txt = self._get_node_text(goid, goobj)
207
            fillcolor = self.goid2color.get(goid, "white")
208
            node = self.pydot.Node(
209
                txt,
210
                shape="box",
211
                style="rounded, filled",
212
                fillcolor=fillcolor,
213
                color="mediumseagreen")
214
            go2node[goid] = node
215
        return go2node
216
217
    def _get_pydot(self):
218
        """Return pydot package. Load pydot, if necessary."""
219
        if self.pydot:
220
            return self.pydot
221
        self.pydot = __import__("pydot")
222
        return self.pydot
223
224
    # ----------------------------------------------------------------------------------
225
    # Methods for text printed inside GO terms
226
    def _get_node_text(self, goid, goobj):
227
        """Return a string to be printed in a GO term box."""
228
        txt = []
229
        # Header line: "GO:0036464 L04 D06"
230
        txt.append(self.pltvars.fmthdr.format(
231
            GO=goobj.id.replace("GO:", "GO"),
232
            level=goobj.level,
233
            depth=goobj.depth))
234
        # GO name line: "cytoplamic ribonucleoprotein"
235
        name = goobj.name.replace(",", "\n")
236
        txt.append(name)
237
        # study info line: "24 genes"
238
        study_txt = self._get_study_txt(goid)
239
        if study_txt is not None:
240
            txt.append(study_txt)
241
        # return text string
242
        return "\n".join(txt)
243
244
    def _get_study_txt(self, goid):
245
        """Get GO text from GOEA study."""
246
        if self.go2res is not None:
247
            res = self.go2res.get(goid, None)
248
            if res is not None:
249
                if self.study_items is not None:
250
                    return self._get_item_str(res)
251
                else:
252
                    return self.pltvars.fmtres.format(
253
                        study_count=res.study_count)
254
255 View Code Duplication
    def _get_item_str(self, res):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
256
        """Return genes in any of these formats:
257
              1. 19264, 17319, 12520, 12043, 74131, 22163, 12575
258
              2. Ptprc, Mif, Cd81, Bcl2, Sash3, Tnfrsf4, Cdkn1a
259
              3. 7: Ptprc, Mif, Cd81, Bcl2, Sash3...
260
        """
261
        npl = self.pltvars.items_p_line  # Number of items Per Line
262
        prt_items = sorted([self.__get_genestr(itemid) for itemid in res.study_items])
263
        prt_multiline = [prt_items[i:i+npl] for i in range(0, len(prt_items), npl)]
264
        num_items = len(prt_items)
265
        if self.study_items_max is None:
266
            genestr = "\n".join([", ".join(str(e) for e in sublist) for sublist in prt_multiline])
267
            return "{N}) {GENES}".format(N=num_items, GENES=genestr)
268
        else:
269
            if num_items <= self.study_items_max:
270
                strs = [", ".join(str(e) for e in sublist) for sublist in prt_multiline]
271
                genestr = "\n".join([", ".join(str(e) for e in sublist) for sublist in prt_multiline])
272
                return genestr
273
            else:
274
                short_list = prt_items[:self.study_items_max]
275
                short_mult = [short_list[i:i+npl] for i in range(0, len(short_list), npl)]
276
                short_str = "\n".join([", ".join(str(e) for e in sublist) for sublist in short_mult])
277
                return "".join(["{N} genes; ".format(N=num_items), short_str, "..."])
278
279
    def __get_genestr(self, itemid):
280
        """Given a geneid, return the string geneid or a gene symbol."""
281
        if self.id2symbol is not None:
282
            symbol = self.id2symbol.get(itemid, None)
283
            if symbol is not None:
284
                return symbol
285
        if isinstance(itemid, int):
286
            return str(itemid)
287
        return itemid
288
289
# Copyright (C) 2016-2018, DV Klopfenstein, H Tang, All rights reserved.
290