Completed
Push — master ( 3df76a...5af33a )
by Thomas
10:27
created

FlagLS.unpack()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
# encoding: utf-8
2
"""
3
Copyright (c) 2016 Evelio Vila <[email protected]>
4
Copyright (c) 2009-2017 Exa Networks. All rights reserved.
5
License: 3-clause BSD. (See the COPYRIGHT file)
6
"""
7
8
import json
9
import binascii
10
import itertools
11
from struct import unpack
12
13
from exabgp.bgp.message.notification import Notify
14
from exabgp.bgp.message.update.attribute.attribute import Attribute
15
16
17
@Attribute.register()
18
class LinkState(Attribute):
19
    ID = Attribute.CODE.BGP_LS
20
    FLAG = Attribute.Flag.OPTIONAL
21
    TLV = -1
22
23
    # Registered subclasses we know how to decode
24
    registered_lsids = dict()
25
26
    # what this implementation knows as LS attributes
27
    node_lsids = []
28
    link_lsids = []
29
    prefix_lsids = []
30
31
    def __init__(self, ls_attrs):
32
        self.ls_attrs = ls_attrs
33
34
    @classmethod
35
    def register(cls, lsid=None, flag=None):
36
        def register_lsid(klass):
37
            if not hasattr(klass, 'MERGE'):
38
                klass.MERGE = False
39
            scode = klass.TLV if lsid is None else lsid
40
            if scode in cls.registered_lsids:
41
                raise RuntimeError('only one class can be registered per BGP link state attribute type')
42
            cls.registered_lsids[scode] = klass
43
            return klass
44
45
        return register_lsid
46
47
    @classmethod
48
    def klass(cls, code):
49
        return cls.registered_lsids.get(code, GenericLSID)
50
51
    @classmethod
52
    def registered(cls, lsid, flag=None):
53
        return lsid in cls.registered_lsids
54
55
    @classmethod
56
    def unpack(cls, data, negotiated):
57
        ls_attrs = []
58
        while data:
59
            scode, length = unpack('!HH', data[:4])
60
            klass = cls.klass(scode).unpack(data[4: length + 4], length)
61
            klass.TLV = scode
62
            data = data[length + 4:]
63
            if klass.MERGE:
64
                for k in ls_attrs:
65
                    if k.TLV == klass.TLV:
66
                        k.merge(k)
67
                        continue
68
            ls_attrs.append(klass)
69
70
        return cls(ls_attrs=ls_attrs)
71
72
    def json(self, compact=None):
73
        content = ', '.join(d.json() for d in self.ls_attrs)
74
        return '{ %s }' % (content)
75
76
    def __str__(self):
77
        return ', '.join(str(d) for d in self.ls_attrs)
78
79
80
class BaseLS(object):
81
    TLV = -1
82
    TLV = -1
83
    JSON = 'json-name-unset'
84
    REPR = 'repr name unset'
85
    LEN = None
86
87
    def __init__(self, content):
88
        self.content = content
89
90
    def json(self, compact=None):
91
        return '"{}": {}'.format(self.JSON, json.dumps(self.content))
92
93
    def __repr__(self):
94
        return "%s: %s" % (self.REPR, self.content)
95
96
    @classmethod
97
    def check(cls, length):
98
        if cls.LEN is not None and length != cls.LEN:
99
            raise Notify(3, 5, f'Unable to decode attribute, wrong size for {cls.REPR}')
100
101
102
class GenericLSID(BaseLS):
103
    def __init__(self, code, content):
104
        BaseLS.__init__(self, content)
105
        self.code = code
106
107
    def __repr__(self):
108
        return "Attribute with code [ %s ] not implemented" % (self.code)
109
110
    def json(self):
111
        return '"generic-LSID-{}": {}'.format(self.code, json.dumps(self.content))
112
113
    @classmethod
114
    def unpack(cls, scode, data):
115
        return cls(scode, binascii.b2a_uu(data[:]))
116
117
118
class FlagLS(BaseLS):
119
    def __init__(self, flags):
120
        self.flags = flags
121
122
    def __repr__(self):
123
        return "%s: %s" % (self.REPR, self.flags)
124
125
    def json(self, compact=None):
126
        return '"{}": {}'.format(self.JSON, json.dumps(self.flags))
127
128
    @classmethod
129
    def unpack_flags(cls, data):
130
        pad = cls.FLAGS.count('RSV')
131
        repeat = len(cls.FLAGS) - pad
132
        hex_rep = int(binascii.b2a_hex(data), 16)
133
        bits = f'{hex_rep:08b}'
134
        valid_flags = [
135
            ''.join(item) + '0' * pad
136
            for item in itertools.product('01', repeat=repeat)
137
        ]
138
        valid_flags.append('0000')
139
        if bits in valid_flags:
140
            flags = dict(zip(cls.FLAGS, [0, ] * len(cls.FLAGS)))
141
            flags.update(dict((k, int(v)) for k, v in zip(cls.FLAGS, bits)))
142
        else:
143
            raise Notify(3, 5, "Invalid SR flags mask")
144
        return flags
145
146
147
    @classmethod
148
    def unpack(cls, data, length):
149
        cls.check(length)
150
        # We only support IS-IS for now.
151
        return cls(cls.unpack_flags(data[0:1]))
152