Passed
Push — master ( c46b9b...2e8441 )
by Matt
01:35
created

_load_from_yaml()   A

Complexity

Conditions 4

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 4
c 1
b 0
f 0
dl 0
loc 15
rs 9.2
1
# MIT License
2
#
3
# Copyright (c) 2017 Matt Boyer
4
#
5
# Permission is hereby granted, free of charge, to any person obtaining a copy
6
# of this software and associated documentation files (the "Software"), to deal
7
# in the Software without restriction, including without limitation the rights
8
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
# copies of the Software, and to permit persons to whom the Software is
10
# furnished to do so, subject to the following conditions:
11
#
12
# The above copyright notice and this permission notice shall be included in
13
# all copies or substantial portions of the Software.
14
#
15
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
# SOFTWARE.
22
23
from . import constants
24
from . import PROJECT_NAME, PROJECT_DESCRIPTION, USER_YAML_PATH, BUILTIN_YAML
25
26
import argparse
27
import collections
28
import csv
29
import logging
30
import os
31
import os.path
32
import pdb
33
import pkg_resources
34
import re
35
import shutil
36
import sqlite3
37
import stat
38
import struct
39
import tempfile
40
import yaml
41
42
43
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
44
_LOGGER = logging.getLogger('SQLite recovery')
45
_LOGGER.setLevel(logging.INFO)
46
47
48
SQLite_header = collections.namedtuple('SQLite_header', (
49
    'magic',
50
    'page_size',
51
    'write_format',
52
    'read_format',
53
    'reserved_length',
54
    'max_payload_fraction',
55
    'min_payload_fraction',
56
    'leaf_payload_fraction',
57
    'file_change_counter',
58
    'size_in_pages',
59
    'first_freelist_trunk',
60
    'freelist_pages',
61
    'schema_cookie',
62
    'schema_format',
63
    'default_page_cache_size',
64
    'largest_btree_page',
65
    'text_encoding',
66
    'user_version',
67
    'incremental_vacuum',
68
    'application_id',
69
    'version_valid',
70
    'sqlite_version',
71
))
72
73
74
SQLite_btree_page_header = collections.namedtuple('SQLite_btree_page_header', (
75
    'page_type',
76
    'first_freeblock_offset',
77
    'num_cells',
78
    'cell_content_offset',
79
    'num_fragmented_free_bytes',
80
    'right_most_page_idx',
81
))
82
83
84
SQLite_ptrmap_info = collections.namedtuple('SQLite_ptrmap_info', (
85
    'page_idx',
86
    'page_type',
87
    'page_ptr',
88
))
89
90
91
SQLite_record_field = collections.namedtuple('SQLite_record_field', (
92
    'col_type',
93
    'col_type_descr',
94
    'field_length',
95
    'field_bytes',
96
))
97
98
99
SQLite_master_record = collections.namedtuple('SQLite_master_record', (
100
    'type',
101
    'name',
102
    'tbl_name',
103
    'rootpage',
104
    'sql',
105
))
106
107
108
type_specs = {
109
    'INTEGER': int,
110
    'TEXT': str,
111
    'VARCHAR': str,
112
    'LONGVARCHAR': str,
113
    'REAL': float,
114
    'FLOAT': float,
115
    'LONG': int,
116
    'BLOB': bytes,
117
}
118
119
120
heuristics = {}
121
signatures = {}
122
123
124
def heuristic_factory(magic, offset):
125
    assert(isinstance(magic, bytes))
126
    assert(isinstance(offset, int))
127
    assert(offset >= 0)
128
129
    # We only need to compile the regex once
130
    magic_re = re.compile(magic)
131
132
    def generic_heuristic(freeblock_bytes):
133
        all_matches = [match for match in magic_re.finditer(freeblock_bytes)]
134
        for magic_match in all_matches[::-1]:
135
            header_start = magic_match.start()-offset
136
            if header_start < 0:
137
                _LOGGER.debug("Header start outside of freeblock!")
138
                break
139
            yield header_start
140
    return generic_heuristic
141
142
143
def load_heuristics():
144
145
    def _load_from_yaml(yaml_string):
146
        if isinstance(yaml_string, bytes):
147
            yaml_string = yaml_string.decode('utf-8')
148
149
        raw_yaml = yaml.load(yaml_string)
150
        for table_grouping, tables in raw_yaml.items():
151
            _LOGGER.info(
152
                "Loading raw_yaml for table grouping \"%s\"",
153
                table_grouping
154
            )
155
            for table_name, table_props in tables.items():
156
                heuristics[table_name] = heuristic_factory(
157
                    table_props['magic'], table_props['offset']
158
                )
159
                _LOGGER.debug("Loaded heuristics for \"%s\"", table_name)
160
161
    with pkg_resources.resource_stream(PROJECT_NAME, BUILTIN_YAML) as builtin:
162
        try:
163
            _load_from_yaml(builtin.read())
164
        except KeyError:
165
            raise SystemError("Malformed builtin magic file")
166
167
    if not os.path.exists(USER_YAML_PATH):
168
        return
169
    with open(USER_YAML_PATH, 'r') as user_yaml:
170
        try:
171
            _load_from_yaml(user_yaml.read())
172
        except KeyError:
173
            raise SystemError("Malformed user magic file")
174
175
176
class IndexDict(dict):
177
    def __iter__(self):
178
        for k in sorted(self.keys()):
179
            yield k
180
181
182
class SQLite_DB(object):
183
    def __init__(self, path):
184
        self._path = path
185
        self._page_types = {}
186
        self._header = self.parse_header()
187
188
        self._page_cache = None
189
        # Actual page objects go here
190
        self._pages = {}
191
        self.build_page_cache()
192
193
        self._ptrmap = {}
194
195
        # TODO Do we need all of these?
196
        self._table_roots = {}
197
        self._page_tables = {}
198
        self._tables = {}
199
        self._table_columns = {}
200
        self._freelist_leaves = []
201
        self._freelist_btree_pages = []
202
203
    @property
204
    def ptrmap(self):
205
        return self._ptrmap
206
207
    @property
208
    def header(self):
209
        return self._header
210
211
    @property
212
    def pages(self):
213
        return self._pages
214
215
    @property
216
    def tables(self):
217
        return self._tables
218
219
    @property
220
    def freelist_leaves(self):
221
        return self._freelist_leaves
222
223
    @property
224
    def table_columns(self):
225
        return self._table_columns
226
227
    def page_bytes(self, page_idx):
228
        try:
229
            return self._page_cache[page_idx]
230
        except KeyError:
231
            raise ValueError("No cache for page %d", page_idx)
232
233
    def map_table_page(self, page_idx, table):
234
        assert isinstance(page_idx, int)
235
        assert isinstance(table, Table)
236
        self._page_tables[page_idx] = table
237
238
    def get_page_table(self, page_idx):
239
        assert isinstance(page_idx, int)
240
        try:
241
            return self._page_tables[page_idx]
242
        except KeyError:
243
            return None
244
245
    def __repr__(self):
246
        return '<SQLite DB, page count: {} | page size: {}>'.format(
247
            self.header.size_in_pages,
248
            self.header.page_size
249
        )
250
251
    def parse_header(self):
252
        header_bytes = None
253
        file_size = None
254
        with open(self._path, 'br') as sqlite:
255
            header_bytes = sqlite.read(100)
256
            file_size = os.fstat(sqlite.fileno())[stat.ST_SIZE]
257
258
        if not header_bytes:
259
            raise ValueError("Couldn't read SQLite header")
260
        assert isinstance(header_bytes, bytes)
261
        # This DB header is always big-endian
262
        fields = SQLite_header(*struct.unpack(
263
            r'>16sHBBBBBBIIIIIIIIIIII20xII',
264
            header_bytes[:100]
265
        ))
266
        assert fields.page_size in constants.VALID_PAGE_SIZES
267
        db_size = fields.page_size * fields.size_in_pages
268
        assert db_size <= file_size
269
        assert (fields.page_size > 0) and \
270
            (fields.file_change_counter == fields.version_valid)
271
272
        if file_size < 1073741824:
273
            _LOGGER.debug("No lock-byte page in this file!")
274
275
        if fields.first_freelist_trunk > 0:
276
            self._page_types[fields.first_freelist_trunk] = \
277
                constants.FREELIST_TRUNK_PAGE
278
        _LOGGER.debug(fields)
279
        return fields
280
281
    def build_page_cache(self):
282
        # The SQLite docs use a numbering convention for pages where the
283
        # first page (the one that has the header) is page 1, with the next
284
        # ptrmap page being page 2, etc.
285
        page_cache = [None, ]
286
        with open(self._path, 'br') as sqlite:
287
            for page_idx in range(self._header.size_in_pages):
288
                page_offset = page_idx * self._header.page_size
289
                sqlite.seek(page_offset, os.SEEK_SET)
290
                page_cache.append(sqlite.read(self._header.page_size))
291
        self._page_cache = page_cache
292
        for page_idx in range(1, len(self._page_cache)):
293
            # We want these to be temporary objects, to be replaced with
294
            # more specialised objects as parsing progresses
295
            self._pages[page_idx] = Page(page_idx, self)
296
297
    def populate_freelist_pages(self):
298
        if 0 == self._header.first_freelist_trunk:
299
            _LOGGER.debug("This database has no freelist trunk page")
300
            return
301
302
        _LOGGER.info("Parsing freelist pages")
303
        parsed_trunks = 0
304
        parsed_leaves = 0
305
        freelist_trunk_idx = self._header.first_freelist_trunk
306
307
        while freelist_trunk_idx != 0:
308
            _LOGGER.debug(
309
                "Parsing freelist trunk page %d",
310
                freelist_trunk_idx
311
            )
312
            trunk_bytes = bytes(self.pages[freelist_trunk_idx])
313
314
            next_freelist_trunk_page_idx, num_leaf_pages = struct.unpack(
315
                r'>II',
316
                trunk_bytes[:8]
317
            )
318
319
            # Now that we know how long the array of freelist page pointers is,
320
            # let's read it again
321
            trunk_array = struct.unpack(
322
                r'>{count}I'.format(count=2+num_leaf_pages),
323
                trunk_bytes[:(4*(2+num_leaf_pages))]
324
            )
325
326
            # We're skipping the first entries as they are realy the next trunk
327
            # index and the leaf count
328
            # TODO Fix that
329
            leaves_in_trunk = []
330
            for page_idx in trunk_array[2:]:
331
                # Let's prepare a specialised object for this freelist leaf
332
                # page
333
                leaf_page = FreelistLeafPage(
334
                    page_idx, self, freelist_trunk_idx
335
                )
336
                leaves_in_trunk.append(leaf_page)
337
                self._freelist_leaves.append(page_idx)
338
                self._pages[page_idx] = leaf_page
339
340
                self._page_types[page_idx] = constants.FREELIST_LEAF_PAGE
341
342
            trunk_page = FreelistTrunkPage(
343
                freelist_trunk_idx,
344
                self,
345
                leaves_in_trunk
346
            )
347
            self._pages[freelist_trunk_idx] = trunk_page
348
            # We've parsed this trunk page
349
            parsed_trunks += 1
350
            # ...And every leaf in it
351
            parsed_leaves += num_leaf_pages
352
353
            freelist_trunk_idx = next_freelist_trunk_page_idx
354
355
        assert (parsed_trunks + parsed_leaves) == self._header.freelist_pages
356
        _LOGGER.info(
357
            "Freelist summary: %d trunk pages, %d leaf pages",
358
            parsed_trunks,
359
            parsed_leaves
360
        )
361
362
    def populate_overflow_pages(self):
363
        # Knowledge of the overflow pages can come from the pointer map (easy),
364
        # or the parsing of individual cells in table leaf pages (hard)
365
        #
366
        # For now, assume we already have a page type dict populated from the
367
        # ptrmap
368
        _LOGGER.info("Parsing overflow pages")
369
        overflow_count = 0
370
        for page_idx in sorted(self._page_types):
371
            page_type = self._page_types[page_idx]
372
            if page_type not in constants.OVERFLOW_PAGE_TYPES:
373
                continue
374
            overflow_page = OverflowPage(page_idx, self)
375
            self.pages[page_idx] = overflow_page
376
            overflow_count += 1
377
378
        _LOGGER.info("Overflow summary: %d pages", overflow_count)
379
380
    def populate_ptrmap_pages(self):
381
        if self._header.largest_btree_page == 0:
382
            # We don't have ptrmap pages in this DB. That sucks.
383
            _LOGGER.warning("%r does not have ptrmap pages!", self)
384
            for page_idx in range(1, self._header.size_in_pages):
385
                self._page_types[page_idx] = constants.UNKNOWN_PAGE
386
            return
387
388
        _LOGGER.info("Parsing ptrmap pages")
389
390
        ptrmap_page_idx = 2
391
        usable_size = self._header.page_size - self._header.reserved_length
392
        num_ptrmap_entries_in_page = usable_size // 5
393
        ptrmap_page_indices = []
394
395
        ptrmap_page_idx = 2
396
        while ptrmap_page_idx <= self._header.size_in_pages:
397
            page_bytes = self._page_cache[ptrmap_page_idx]
398
            ptrmap_page_indices.append(ptrmap_page_idx)
399
            self._page_types[ptrmap_page_idx] = constants.PTRMAP_PAGE
400
            page_ptrmap_entries = {}
401
402
            ptrmap_bytes = page_bytes[:5 * num_ptrmap_entries_in_page]
403
            for entry_idx in range(num_ptrmap_entries_in_page):
404
                ptr_page_idx = ptrmap_page_idx + entry_idx + 1
405
                page_type, page_ptr = struct.unpack(
406
                    r'>BI',
407
                    ptrmap_bytes[5*entry_idx:5*(entry_idx+1)]
408
                )
409
                if page_type == 0:
410
                    break
411
412
                ptrmap_entry = SQLite_ptrmap_info(
413
                    ptr_page_idx, page_type, page_ptr
414
                )
415
                assert ptrmap_entry.page_type in constants.PTRMAP_PAGE_TYPES
416
                if page_type == constants.BTREE_ROOT_PAGE:
417
                    assert page_ptr == 0
418
                    self._page_types[ptr_page_idx] = page_type
419
420
                elif page_type == constants.FREELIST_PAGE:
421
                    # Freelist pages are assumed to be known already
422
                    assert self._page_types[ptr_page_idx] in \
423
                        constants.FREELIST_PAGE_TYPES
424
                    assert page_ptr == 0
425
426
                elif page_type == constants.FIRST_OFLOW_PAGE:
427
                    assert page_ptr != 0
428
                    self._page_types[ptr_page_idx] = page_type
429
430
                elif page_type == constants.NON_FIRST_OFLOW_PAGE:
431
                    assert page_ptr != 0
432
                    self._page_types[ptr_page_idx] = page_type
433
434
                elif page_type == constants.BTREE_NONROOT_PAGE:
435
                    assert page_ptr != 0
436
                    self._page_types[ptr_page_idx] = page_type
437
438
                # _LOGGER.debug("%r", ptrmap_entry)
439
                self._ptrmap[ptr_page_idx] = ptrmap_entry
440
                page_ptrmap_entries[ptr_page_idx] = ptrmap_entry
441
442
            page = PtrmapPage(ptrmap_page_idx, self, page_ptrmap_entries)
443
            self._pages[ptrmap_page_idx] = page
444
            _LOGGER.debug("%r", page)
445
            ptrmap_page_idx += num_ptrmap_entries_in_page + 1
446
447
        _LOGGER.info(
448
            "Ptrmap summary: %d pages, %r",
449
            len(ptrmap_page_indices), ptrmap_page_indices
450
        )
451
452
    def populate_btree_pages(self):
453
        # TODO Should this use table information instead of scanning all pages?
454
        page_idx = 1
455
        while page_idx <= self._header.size_in_pages:
456
            try:
457
                if self._page_types[page_idx] in \
458
                        constants.NON_BTREE_PAGE_TYPES:
459
                    page_idx += 1
460
                    continue
461
            except KeyError:
462
                pass
463
464
            try:
465
                page_obj = BTreePage(page_idx, self)
466
            except ValueError:
467
                # This page isn't a valid btree page. This can happen if we
468
                # don't have a ptrmap to guide us
469
                _LOGGER.warning(
470
                    "Page %d (%s) is not a btree page",
471
                    page_idx,
472
                    self._page_types[page_idx]
473
                )
474
                page_idx += 1
475
                continue
476
477
            page_obj.parse_cells()
478
            self._page_types[page_idx] = page_obj.page_type
479
            self._pages[page_idx] = page_obj
480
            page_idx += 1
481
482
    def _parse_master_leaf_page(self, page):
483
        for cell_idx in page.cells:
484
            _, master_record = page.cells[cell_idx]
485
            assert isinstance(master_record, Record)
486
            fields = [
487
                master_record.fields[idx].value for idx in master_record.fields
488
            ]
489
            master_record = SQLite_master_record(*fields)
490
            if 'table' != master_record.type:
491
                continue
492
493
            self._table_roots[master_record.name] = \
494
                self.pages[master_record.rootpage]
495
496
            # This record describes a table in the schema, which means it
497
            # includes a SQL statement that defines the table's columns
498
            # We need to parse the field names out of that statement
499
            assert master_record.sql.startswith('CREATE TABLE')
500
            columns_re = re.compile(r'^CREATE TABLE (\S+) \((.*)\)$')
501
            match = columns_re.match(master_record.sql)
502
            if match:
503
                assert match.group(1) == master_record.name
504
                column_list = match.group(2)
505
                csl_between_parens_re = re.compile(r'\([^)]+\)')
506
                expunged = csl_between_parens_re.sub('', column_list)
507
508
                cols = [
509
                    statement.strip() for statement in expunged.split(',')
510
                ]
511
                cols = [
512
                    statement for statement in cols if not (
513
                        statement.startswith('PRIMARY') or
514
                        statement.startswith('UNIQUE')
515
                    )
516
                ]
517
                columns = [col.split()[0] for col in cols]
518
                signature = []
519
520
                # Some column definitions lack a type
521
                for col_def in cols:
522
                    def_tokens = col_def.split()
523
                    try:
524
                        col_type = def_tokens[1]
525
                    except IndexError:
526
                        signature.append(object)
527
                        continue
528
529
                    _LOGGER.debug(
530
                        "Column \"%s\" is defined as \"%s\"",
531
                        def_tokens[0], col_type
532
                    )
533
                    try:
534
                        signature.append(type_specs[col_type])
535
                    except KeyError:
536
                        _LOGGER.warning("No native type for \"%s\"", col_def)
537
                        signature.append(object)
538
                _LOGGER.info(
539
                    "Signature for table \"%s\": %r",
540
                    master_record.name, signature
541
                )
542
                signatures[master_record.name] = signature
543
544
                _LOGGER.info(
545
                    "Columns for table \"%s\": %r",
546
                    master_record.name, columns
547
                )
548
                self._table_columns[master_record.name] = columns
549
550
    def map_tables(self):
551
        first_page = self.pages[1]
552
        assert isinstance(first_page, BTreePage)
553
554
        master_table = Table('sqlite_master', self, first_page)
555
        self._table_columns.update(constants.SQLITE_TABLE_COLUMNS)
556
557
        for master_leaf in master_table.leaves:
558
            self._parse_master_leaf_page(master_leaf)
559
560
        assert all(
561
            isinstance(root, BTreePage) for root in self._table_roots.values()
562
        )
563
        assert all(
564
            root.parent is None for root in self._table_roots.values()
565
        )
566
567
        self.map_table_page(1, master_table)
568
        self._table_roots['sqlite_master'] = self.pages[1]
569
570
        for table_name, rootpage in self._table_roots.items():
571
            try:
572
                table_obj = Table(table_name, self, rootpage)
573
            except Exception as ex:  # pylint:disable=W0703
574
                pdb.set_trace()
575
                _LOGGER.warning(
576
                    "Caught %r while instantiating table object for \"%s\"",
577
                    ex, table_name
578
                )
579
            else:
580
                self._tables[table_name] = table_obj
581
582
    def reparent_orphaned_table_leaf_pages(self):
583
        reparented_pages = []
584
        for page in self.pages.values():
585
            if not isinstance(page, BTreePage):
586
                continue
587
            if page.page_type != "Table Leaf":
588
                continue
589
590
            table = page.table
591
            if not table:
592
                parent = page
593
                root_table = None
594
                while parent:
595
                    root_table = parent.table
596
                    parent = parent.parent
597
                if root_table is None:
598
                    self._freelist_btree_pages.append(page)
599
600
                if root_table is None:
601
                    if not page.cells:
602
                        continue
603
604
                    first_record = page.cells[0][1]
605
                    matches = []
606
                    for table_name in signatures:
607
                        # All records within a given page are for the same
608
                        # table
609
                        if self.tables[table_name].check_signature(
610
                                first_record):
611
                            matches.append(self.tables[table_name])
612
                    if not matches:
613
                        _LOGGER.error(
614
                            "Couldn't find a matching table for %r",
615
                            page
616
                        )
617
                        continue
618
                    if len(matches) > 1:
619
                        _LOGGER.error(
620
                            "Multiple matching tables for %r: %r",
621
                            page, matches
622
                        )
623
                        continue
624
                    elif len(matches) == 1:
625
                        root_table = matches[0]
626
627
                _LOGGER.debug(
628
                    "Reparenting %r to table \"%s\"",
629
                    page, root_table.name
630
                )
631
                root_table.add_leaf(page)
632
                self.map_table_page(page.idx, root_table)
633
                reparented_pages.append(page)
634
635
        if reparented_pages:
636
            _LOGGER.info(
637
                "Reparented %d pages: %r",
638
                len(reparented_pages), [p.idx for p in reparented_pages]
639
            )
640
641
    def grep(self, needle):
642
        match_found = False
643
        page_idx = 1
644
        needle_re = re.compile(needle.encode('utf-8'))
645
        while (page_idx <= self.header.size_in_pages):
646
            page = self.pages[page_idx]
647
            page_offsets = []
648
            for match in needle_re.finditer(bytes(page)):
649
                needle_offset = match.start()
650
                page_offsets.append(needle_offset)
651
            if page_offsets:
652
                _LOGGER.info(
653
                    "Found search term in page %r @ offset(s) %s",
654
                    page, ', '.join(str(offset) for offset in page_offsets)
655
                )
656
            page_idx += 1
657
        if not match_found:
658
            _LOGGER.warning(
659
                "Search term not found",
660
            )
661
662
663
class Table(object):
664
    def __init__(self, name, db, rootpage):
665
        self._name = name
666
        self._db = db
667
        assert(isinstance(rootpage, BTreePage))
668
        self._root = rootpage
669
        self._leaves = []
670
        try:
671
            self._columns = self._db.table_columns[self.name]
672
        except KeyError:
673
            self._columns = None
674
675
        # We want this to be a list of leaf-type pages, sorted in the order of
676
        # their smallest rowid
677
        self._populate_pages()
678
679
    @property
680
    def name(self):
681
        return self._name
682
683
    def add_leaf(self, leaf_page):
684
        self._leaves.append(leaf_page)
685
686
    @property
687
    def columns(self):
688
        return self._columns
689
690
    def __repr__(self):
691
        return "<SQLite table \"{}\", root: {}, leaves: {}>".format(
692
            self.name, self._root.idx, len(self._leaves)
693
        )
694
695
    def _populate_pages(self):
696
        _LOGGER.info("Page %d is root for %s", self._root.idx, self.name)
697
        table_pages = [self._root]
698
699
        if self._root.btree_header.right_most_page_idx is not None:
700
            rightmost_idx = self._root.btree_header.right_most_page_idx
701
            rightmost_page = self._db.pages[rightmost_idx]
702
            if rightmost_page is not self._root:
703
                _LOGGER.info(
704
                    "Page %d is rightmost for %s",
705
                    rightmost_idx, self.name
706
                )
707
                table_pages.append(rightmost_page)
708
709
        page_queue = list(table_pages)
710
        while page_queue:
711
            table_page = page_queue.pop(0)
712
            # table_pages is initialised with the table's rootpage, which
713
            # may be a leaf page for a very small table
714
            if table_page.page_type != 'Table Interior':
715
                self._leaves.append(table_page)
716
                continue
717
718
            for cell_idx in table_page.cells:
719
                page_ptr, max_row_in_page = table_page.cells[cell_idx]
720
721
                page = self._db.pages[page_ptr]
722
                _LOGGER.debug("B-Tree cell: (%r, %d)", page, max_row_in_page)
723
                table_pages.append(page)
724
                if page.page_type == 'Table Interior':
725
                    page_queue.append(page)
726
                elif page.page_type == 'Table Leaf':
727
                    self._leaves.append(page)
728
729
        assert(all(p.page_type == 'Table Leaf' for p in self._leaves))
730
        for page in table_pages:
731
            self._db.map_table_page(page.idx, self)
732
733
    @property
734
    def leaves(self):
735
        for leaf_page in self._leaves:
736
            yield leaf_page
737
738
    def recover_records(self):
739
        for page in self.leaves:
740
            assert isinstance(page, BTreePage)
741
            if not page.freeblocks:
742
                continue
743
744
            _LOGGER.info("%r", page)
745
            page.recover_freeblock_records()
746
            page.print_recovered_records()
747
748
    def csv_dump(self, out_dir):
749
        csv_path = os.path.join(out_dir, self.name + '.csv')
750
        if os.path.exists(csv_path):
751
            raise ValueError("Output file {} exists!".format(csv_path))
752
753
        _LOGGER.info("Dumping table \"%s\" to CSV", self.name)
754
        with tempfile.TemporaryFile('w+', newline='') as csv_temp:
755
            writer = csv.DictWriter(csv_temp, fieldnames=self._columns)
756
            writer.writeheader()
757
758
            for leaf_page in self.leaves:
759
                for cell_idx in leaf_page.cells:
760
                    rowid, record = leaf_page.cells[cell_idx]
761
                    # assert(self.check_signature(record))
762
763
                    _LOGGER.debug('Record %d: %r', rowid, record.header)
764
                    fields_iter = (
765
                        repr(record.fields[idx]) for idx in record.fields
766
                    )
767
                    _LOGGER.debug(', '.join(fields_iter))
768
769
                    values_iter = (
770
                        record.fields[idx].value for idx in record.fields
771
                    )
772
                    writer.writerow(dict(zip(self._columns, values_iter)))
773
774
                if not leaf_page.recovered_records:
775
                    continue
776
777
                # Recovered records are in an unordered set because their rowid
778
                # has been lost, making sorting impossible
779
                for record in leaf_page.recovered_records:
780
                    values_iter = (
781
                        record.fields[idx].value for idx in record.fields
782
                    )
783
                    writer.writerow(dict(zip(self._columns, values_iter)))
784
785
            if csv_temp.tell() > 0:
786
                csv_temp.seek(0)
787
                with open(csv_path, 'w') as csv_file:
788
                    csv_file.write(csv_temp.read())
789
790
    def build_insert_SQL(self, record):
791
        column_placeholders = (
792
            ':' + col_name for col_name in self._columns
793
        )
794
        insert_statement = 'INSERT INTO {} VALUES ({})'.format(
795
            self.name,
796
            ', '.join(c for c in column_placeholders),
797
        )
798
        value_kwargs = {}
799
        for col_idx, col_name in enumerate(self._columns):
800
            try:
801
                if record.fields[col_idx].value is None:
802
                    value_kwargs[col_name] = None
803
                else:
804
                    value_kwargs[col_name] = record.fields[col_idx].value
805
            except KeyError:
806
                value_kwargs[col_name] = None
807
808
        return insert_statement, value_kwargs
809
810
    def check_signature(self, record):
811
        assert isinstance(record, Record)
812
        try:
813
            sig = signatures[self.name]
814
        except KeyError:
815
            # The sqlite schema tables don't have a signature (or need one)
816
            return True
817
        if len(record.fields) > len(self.columns):
818
            return False
819
820
        # It's OK for a record to have fewer fields than there are columns in
821
        # this table, this is seen when NULLable or default-valued columns are
822
        # added in an ALTER TABLE statement.
823
        for field_idx, field in record.fields.items():
824
            # NULL can be a value for any column type
825
            if field.value is None:
826
                continue
827
            if not isinstance(field.value, sig[field_idx]):
828
                return False
829
        return True
830
831
832
class Page(object):
833
    def __init__(self, page_idx, db):
834
        self._page_idx = page_idx
835
        self._db = db
836
        self._bytes = db.page_bytes(self.idx)
837
838
    @property
839
    def idx(self):
840
        return self._page_idx
841
842
    @property
843
    def usable_size(self):
844
        return self._db.header.page_size - self._db.header.reserved_length
845
846
    def __bytes__(self):
847
        return self._bytes
848
849
    @property
850
    def parent(self):
851
        try:
852
            parent_idx = self._db.ptrmap[self.idx].page_ptr
853
        except KeyError:
854
            return None
855
856
        if 0 == parent_idx:
857
            return None
858
        else:
859
            return self._db.pages[parent_idx]
860
861
    def __repr__(self):
862
        return "<SQLite Page {0}>".format(self.idx)
863
864
865
class FreelistTrunkPage(Page):
866
    # XXX Maybe it would make sense to expect a Page instance as constructor
867
    # argument?
868
    def __init__(self, page_idx, db, leaves):
869
        super().__init__(page_idx, db)
870
        self._leaves = leaves
871
872
    def __repr__(self):
873
        return "<SQLite Freelist Trunk Page {0}: {1} leaves>".format(
874
            self.idx, len(self._leaves)
875
        )
876
877
878
class FreelistLeafPage(Page):
879
    # XXX Maybe it would make sense to expect a Page instance as constructor
880
    # argument?
881
    def __init__(self, page_idx, db, trunk_idx):
882
        super().__init__(page_idx, db)
883
        self._trunk = self._db.pages[trunk_idx]
884
885
    def __repr__(self):
886
        return "<SQLite Freelist Leaf Page {0}. Trunk: {1}>".format(
887
            self.idx, self._trunk.idx
888
        )
889
890
891
class PtrmapPage(Page):
892
    # XXX Maybe it would make sense to expect a Page instance as constructor
893
    # argument?
894
    def __init__(self, page_idx, db, ptr_array):
895
        super().__init__(page_idx, db)
896
        self._pointers = ptr_array
897
898
    @property
899
    def pointers(self):
900
        return self._pointers
901
902
    def __repr__(self):
903
        return "<SQLite Ptrmap Page {0}. {1} pointers>".format(
904
            self.idx, len(self.pointers)
905
        )
906
907
908
class OverflowPage(Page):
909
    # XXX Maybe it would make sense to expect a Page instance as constructor
910
    # argument?
911
    def __init__(self, page_idx, db):
912
        super().__init__(page_idx, db)
913
        self._parse()
914
915
    def _parse(self):
916
        # TODO We should have parsing here for the next page index in the
917
        # overflow chain
918
        pass
919
920
    def __repr__(self):
921
        return "<SQLite Overflow Page {0}. Continuation of {1}>".format(
922
            self.idx, self.parent.idx
923
        )
924
925
926
class BTreePage(Page):
927
    btree_page_types = {
928
        0x02:   "Index Interior",
929
        0x05:   "Table Interior",
930
        0x0A:   "Index Leaf",
931
        0x0D:   "Table Leaf",
932
    }
933
934
    def __init__(self, page_idx, db):
935
        # XXX We don't know a page's type until we've had a look at the header.
936
        # Or do we?
937
        super().__init__(page_idx, db)
938
        self._header_size = 8
939
        page_header_bytes = self._get_btree_page_header()
940
        self._btree_header = SQLite_btree_page_header(
941
            # Set the right-most page index to None in the 1st pass
942
            *struct.unpack(r'>BHHHB', page_header_bytes), None
943
        )
944
        self._cell_ptr_array = []
945
        self._freeblocks = IndexDict()
946
        self._cells = IndexDict()
947
        self._recovered_records = set()
948
        self._overflow_threshold = self.usable_size - 35
949
950
        if self._btree_header.page_type not in BTreePage.btree_page_types:
951
            # pdb.set_trace()
952
            raise ValueError
953
954
        # We have a twelve-byte header, need to read it again
955
        if self._btree_header.page_type in (0x02, 0x05):
956
            self._header_size = 12
957
            page_header_bytes = self._get_btree_page_header()
958
            self._btree_header = SQLite_btree_page_header(*struct.unpack(
959
                r'>BHHHBI', page_header_bytes
960
            ))
961
962
        # Page 1 (and page 2, but that's the 1st ptrmap page) does not have a
963
        # ptrmap entry.
964
        # The first ptrmap page will contain back pointer information for pages
965
        # 3 through J+2, inclusive.
966
        if self._db.ptrmap:
967
            if self.idx >= 3 and self.idx not in self._db.ptrmap:
968
                _LOGGER.warning(
969
                    "BTree page %d doesn't have ptrmap entry!", self.idx
970
                )
971
972
        if self._btree_header.num_cells > 0:
973
            cell_ptr_bytes = self._get_btree_ptr_array(
974
                self._btree_header.num_cells
975
            )
976
            self._cell_ptr_array = struct.unpack(
977
                r'>{count}H'.format(count=self._btree_header.num_cells),
978
                cell_ptr_bytes
979
            )
980
            smallest_cell_offset = min(self._cell_ptr_array)
981
            if self._btree_header.cell_content_offset != smallest_cell_offset:
982
                _LOGGER.warning(
983
                    (
984
                        "Inconsistent cell ptr array in page %d! Cell content "
985
                        "starts at offset %d, but min cell pointer is %d"
986
                    ),
987
                    self.idx,
988
                    self._btree_header.cell_content_offset,
989
                    smallest_cell_offset
990
                )
991
992
    @property
993
    def btree_header(self):
994
        return self._btree_header
995
996
    @property
997
    def page_type(self):
998
        try:
999
            return self.btree_page_types[self._btree_header.page_type]
1000
        except KeyError:
1001
            pdb.set_trace()
1002
            _LOGGER.warning(
1003
                "Unknown B-Tree page type: %d", self._btree_header.page_type
1004
            )
1005
            raise
1006
1007
    @property
1008
    def freeblocks(self):
1009
        return self._freeblocks
1010
1011
    @property
1012
    def cells(self):
1013
        return self._cells
1014
1015
    def __repr__(self):
1016
        # TODO Include table in repr, where available
1017
        return "<SQLite B-Tree Page {0} ({1}) {2} cells>".format(
1018
            self.idx, self.page_type, len(self._cell_ptr_array)
1019
        )
1020
1021
    @property
1022
    def table(self):
1023
        return self._db.get_page_table(self.idx)
1024
1025
    def _get_btree_page_header(self):
1026
        header_offset = 0
1027
        if self.idx == 1:
1028
            header_offset += 100
1029
        return bytes(self)[header_offset:self._header_size + header_offset]
1030
1031
    def _get_btree_ptr_array(self, num_cells):
1032
        array_offset = self._header_size
1033
        if self.idx == 1:
1034
            array_offset += 100
1035
        return bytes(self)[array_offset:2 * num_cells + array_offset]
1036
1037
    def parse_cells(self):
1038
        if self.btree_header.page_type == 0x05:
1039
            self.parse_table_interior_cells()
1040
        elif self.btree_header.page_type == 0x0D:
1041
            self.parse_table_leaf_cells()
1042
        self.parse_freeblocks()
1043
1044
    def parse_table_interior_cells(self):
1045
        if self.btree_header.page_type != 0x05:
1046
            assert False
1047
1048
        _LOGGER.debug("Parsing cells in table interior cell %d", self.idx)
1049
        for cell_idx, offset in enumerate(self._cell_ptr_array):
1050
            _LOGGER.debug("Parsing cell %d @ offset %d", cell_idx, offset)
1051
            left_ptr_bytes = bytes(self)[offset:offset + 4]
1052
            left_ptr, = struct.unpack(r'>I', left_ptr_bytes)
1053
1054
            offset += 4
1055
            integer_key = Varint(bytes(self)[offset:offset+9])
1056
            self._cells[cell_idx] = (left_ptr, int(integer_key))
1057
1058
    def parse_table_leaf_cells(self):
1059
        if self.btree_header.page_type != 0x0d:
1060
            assert False
1061
1062
        _LOGGER.debug("Parsing cells in table leaf cell %d", self.idx)
1063
        for cell_idx, cell_offset in enumerate(self._cell_ptr_array):
1064
            _LOGGER.debug("Parsing cell %d @ offset %d", cell_idx, cell_offset)
1065
1066
            # This is the total size of the payload, which may include overflow
1067
            offset = cell_offset
1068
            payload_length_varint = Varint(bytes(self)[offset:offset+9])
1069
            total_payload_size = int(payload_length_varint)
1070
1071
            overflow = False
1072
            # Let X be U-35. If the payload size P is less than or equal to X
1073
            # then the entire payload is stored on the b-tree leaf page. Let M
1074
            # be ((U-12)*32/255)-23 and let K be M+((P-M)%(U-4)). If P is
1075
            # greater than X then the number of bytes stored on the table
1076
            # b-tree leaf page is K if K is less or equal to X or M otherwise.
1077
            # The number of bytes stored on the leaf page is never less than M.
1078
            cell_payload_size = 0
1079
            if total_payload_size > self._overflow_threshold:
1080
                m = int(((self.usable_size - 12) * 32/255)-23)
1081
                k = m + ((total_payload_size - m) % (self.usable_size - 4))
1082
                if k <= self._overflow_threshold:
1083
                    cell_payload_size = k
1084
                else:
1085
                    cell_payload_size = m
1086
                overflow = True
1087
            else:
1088
                cell_payload_size = total_payload_size
1089
1090
            offset += len(payload_length_varint)
1091
1092
            integer_key = Varint(bytes(self)[offset:offset+9])
1093
            offset += len(integer_key)
1094
1095
            overflow_bytes = bytes()
1096
            if overflow:
1097
                first_oflow_page_bytes = bytes(self)[
1098
                    offset + cell_payload_size:offset + cell_payload_size + 4
1099
                ]
1100
                if not first_oflow_page_bytes:
1101
                    continue
1102
1103
                first_oflow_idx, = struct.unpack(
1104
                    r'>I', first_oflow_page_bytes
1105
                )
1106
                next_oflow_idx = first_oflow_idx
1107
                while next_oflow_idx != 0:
1108
                    oflow_page_bytes = self._db.page_bytes(next_oflow_idx)
1109
1110
                    len_overflow = min(
1111
                        len(oflow_page_bytes) - 4,
1112
                        (
1113
                            total_payload_size - cell_payload_size +
1114
                            len(overflow_bytes)
1115
                        )
1116
                    )
1117
                    overflow_bytes += oflow_page_bytes[4:4 + len_overflow]
1118
1119
                    first_four_bytes = oflow_page_bytes[:4]
1120
                    next_oflow_idx, = struct.unpack(
1121
                        r'>I', first_four_bytes
1122
                    )
1123
1124
            try:
1125
                cell_data = bytes(self)[offset:offset + cell_payload_size]
1126
                if overflow_bytes:
1127
                    cell_data += overflow_bytes
1128
1129
                # All payload bytes should be accounted for
1130
                assert len(cell_data) == total_payload_size
1131
1132
                record_obj = Record(cell_data)
1133
                _LOGGER.debug("Created record: %r", record_obj)
1134
1135
            except TypeError as ex:
1136
                _LOGGER.warning(
1137
                    "Caught %r while instantiating record %d",
1138
                    ex, int(integer_key)
1139
                )
1140
                pdb.set_trace()
1141
                raise
1142
1143
            self._cells[cell_idx] = (int(integer_key), record_obj)
1144
1145
    def parse_freeblocks(self):
1146
        # The first 2 bytes of a freeblock are a big-endian integer which is
1147
        # the offset in the b-tree page of the next freeblock in the chain, or
1148
        # zero if the freeblock is the last on the chain. The third and fourth
1149
        # bytes of each freeblock form a big-endian integer which is the size
1150
        # of the freeblock in bytes, including the 4-byte header. Freeblocks
1151
        # are always connected in order of increasing offset. The second field
1152
        # of the b-tree page header is the offset of the first freeblock, or
1153
        # zero if there are no freeblocks on the page. In a well-formed b-tree
1154
        # page, there will always be at least one cell before the first
1155
        # freeblock.
1156
        #
1157
        # TODO But what about deleted records that exceeded the overflow
1158
        # threshold in the past?
1159
        block_offset = self.btree_header.first_freeblock_offset
1160
        while block_offset != 0:
1161
            freeblock_header = bytes(self)[block_offset:block_offset + 4]
1162
            # Freeblock_size includes the 4-byte header
1163
            next_freeblock_offset, freeblock_size = struct.unpack(
1164
                r'>HH',
1165
                freeblock_header
1166
            )
1167
            freeblock_bytes = bytes(self)[
1168
                block_offset + 4:block_offset + freeblock_size - 4
1169
            ]
1170
            self._freeblocks[block_offset] = freeblock_bytes
1171
            block_offset = next_freeblock_offset
1172
1173
    def print_cells(self):
1174
        for cell_idx in self.cells.keys():
1175
            rowid, record = self.cells[cell_idx]
1176
            _LOGGER.info(
1177
                "Cell %d, rowid: %d, record: %r",
1178
                cell_idx, rowid, record
1179
            )
1180
            record.print_fields(table=self.table)
1181
1182
    def recover_freeblock_records(self):
1183
        # If we're lucky (i.e. if no overwriting has taken place), we should be
1184
        # able to find whole record headers in freeblocks.
1185
        # We need to start from the end of the freeblock and work our way back
1186
        # to the start. That means we don't know where a cell header will
1187
        # start, but I suppose we can take a guess
1188
        table = self.table
1189
        if not table or table.name not in heuristics:
1190
            return
1191
1192
        _LOGGER.info("Attempting to recover records from freeblocks")
1193
        for freeblock_idx, freeblock_offset in enumerate(self._freeblocks):
1194
            freeblock_bytes = self._freeblocks[freeblock_offset]
1195
            if 0 == len(freeblock_bytes):
1196
                continue
1197
            _LOGGER.debug(
1198
                "Freeblock %d/%d in page, offset %d, %d bytes",
1199
                1 + freeblock_idx,
1200
                len(self._freeblocks),
1201
                freeblock_offset,
1202
                len(freeblock_bytes)
1203
            )
1204
1205
            recovered_bytes = 0
1206
            recovered_in_freeblock = 0
1207
1208
            # TODO Maybe we need to guess the record header lengths rather than
1209
            # try and read them from the freeblocks
1210
            for header_start in heuristics[table.name](freeblock_bytes):
1211
                _LOGGER.debug(
1212
                    (
1213
                        "Trying potential record header start at "
1214
                        "freeblock offset %d/%d"
1215
                    ),
1216
                    header_start, len(freeblock_bytes)
1217
                )
1218
                _LOGGER.debug("%r", freeblock_bytes)
1219
                try:
1220
                    # We don't know how to handle overflow in deleted records,
1221
                    # so we'll have to truncate the bytes object used to
1222
                    # instantiate the Record object
1223
                    record_bytes = freeblock_bytes[
1224
                        header_start:header_start+self._overflow_threshold
1225
                    ]
1226
                    record_obj = Record(record_bytes)
1227
                except MalformedRecord:
1228
                    # This isn't a well-formed record, let's move to the next
1229
                    # candidate
1230
                    continue
1231
1232
                field_lengths = sum(
1233
                    len(field_obj) for field_obj in record_obj.fields.values()
1234
                )
1235
                record_obj.truncate(field_lengths + len(record_obj.header))
1236
                self._recovered_records.add(record_obj)
1237
1238
                recovered_bytes += len(bytes(record_obj))
1239
                recovered_in_freeblock += 1
1240
1241
            _LOGGER.info(
1242
                (
1243
                    "Recovered %d record(s): %d bytes out of %d "
1244
                    "freeblock bytes @ offset %d"
1245
                ),
1246
                recovered_in_freeblock,
1247
                recovered_bytes,
1248
                len(freeblock_bytes),
1249
                freeblock_offset,
1250
            )
1251
1252
    @property
1253
    def recovered_records(self):
1254
        return self._recovered_records
1255
1256
    def print_recovered_records(self):
1257
        if not self._recovered_records:
1258
            return
1259
1260
        for record_obj in self._recovered_records:
1261
            _LOGGER.info("Recovered record: %r", record_obj)
1262
            _LOGGER.info("Recovered record header: %s", record_obj.header)
1263
            record_obj.print_fields(table=self.table)
1264
1265
1266
class Record(object):
1267
1268
    column_types = {
1269
        0: (0, "NULL"),
1270
        1: (1, "8-bit twos-complement integer"),
1271
        2: (2, "big-endian 16-bit twos-complement integer"),
1272
        3: (3, "big-endian 24-bit twos-complement integer"),
1273
        4: (4, "big-endian 32-bit twos-complement integer"),
1274
        5: (6, "big-endian 48-bit twos-complement integer"),
1275
        6: (8, "big-endian 64-bit twos-complement integer"),
1276
        7: (8, "Floating point"),
1277
        8: (0, "Integer 0"),
1278
        9: (0, "Integer 1"),
1279
    }
1280
1281
    def __init__(self, record_bytes):
1282
        self._bytes = record_bytes
1283
        self._header_bytes = None
1284
        self._fields = IndexDict()
1285
        self._parse()
1286
1287
    def __bytes__(self):
1288
        return self._bytes
1289
1290
    @property
1291
    def header(self):
1292
        return self._header_bytes
1293
1294
    @property
1295
    def fields(self):
1296
        return self._fields
1297
1298
    def truncate(self, new_length):
1299
        self._bytes = self._bytes[:new_length]
1300
        self._parse()
1301
1302
    def _parse(self):
1303
        header_offset = 0
1304
1305
        header_length_varint = Varint(
1306
            # A varint is encoded on *at most* 9 bytes
1307
            bytes(self)[header_offset:9 + header_offset]
1308
        )
1309
1310
        # Let's keep track of how many bytes of the Record header (including
1311
        # the header length itself) we've succesfully parsed
1312
        parsed_header_bytes = len(header_length_varint)
1313
1314
        if len(bytes(self)) < int(header_length_varint):
1315
            raise MalformedRecord(
1316
                "Not enough bytes to fully read the record header!"
1317
            )
1318
1319
        header_offset += len(header_length_varint)
1320
        self._header_bytes = bytes(self)[:int(header_length_varint)]
1321
1322
        col_idx = 0
1323
        field_offset = int(header_length_varint)
1324
        while header_offset < int(header_length_varint):
1325
            serial_type_varint = Varint(
1326
                bytes(self)[header_offset:9 + header_offset]
1327
            )
1328
            serial_type = int(serial_type_varint)
1329
            col_length = None
1330
1331
            try:
1332
                col_length, _ = self.column_types[serial_type]
1333
            except KeyError:
1334
                if serial_type >= 13 and (1 == serial_type % 2):
1335
                    col_length = (serial_type - 13) // 2
1336
                elif serial_type >= 12 and (0 == serial_type % 2):
1337
                    col_length = (serial_type - 12) // 2
1338
                else:
1339
                    raise ValueError(
1340
                        "Unknown serial type {}".format(serial_type)
1341
                    )
1342
1343
            try:
1344
                field_obj = Field(
1345
                    col_idx,
1346
                    serial_type,
1347
                    bytes(self)[field_offset:field_offset + col_length]
1348
                )
1349
            except MalformedField as ex:
1350
                _LOGGER.warning(
1351
                    "Caught %r while instantiating field %d (%d)",
1352
                    ex, col_idx, serial_type
1353
                )
1354
                raise MalformedRecord
1355
            except Exception as ex:
1356
                _LOGGER.warning(
1357
                    "Caught %r while instantiating field %d (%d)",
1358
                    ex, col_idx, serial_type
1359
                )
1360
                pdb.set_trace()
1361
                raise
1362
1363
            self._fields[col_idx] = field_obj
1364
            col_idx += 1
1365
            field_offset += col_length
1366
1367
            parsed_header_bytes += len(serial_type_varint)
1368
            header_offset += len(serial_type_varint)
1369
1370
            if field_offset > len(bytes(self)):
1371
                raise MalformedRecord
1372
1373
        # assert(parsed_header_bytes == int(header_length_varint))
1374
1375
    def print_fields(self, table=None):
1376
        for field_idx in self._fields:
1377
            field_obj = self._fields[field_idx]
1378
            if not table or table.columns is None:
1379
                _LOGGER.info(
1380
                    "\tField %d (%d bytes), type %d: %s",
1381
                    field_obj.index,
1382
                    len(field_obj),
1383
                    field_obj.serial_type,
1384
                    field_obj.value
1385
                )
1386
            else:
1387
                _LOGGER.info(
1388
                    "\t%s: %s",
1389
                    table.columns[field_obj.index],
1390
                    field_obj.value
1391
                )
1392
1393
    def __repr__(self):
1394
        return '<Record {} fields, {} bytes, header: {} bytes>'.format(
1395
            len(self._fields), len(bytes(self)), len(self.header)
1396
        )
1397
1398
1399
class MalformedField(Exception):
1400
    pass
1401
1402
1403
class MalformedRecord(Exception):
1404
    pass
1405
1406
1407
class Field(object):
1408
    def __init__(self, idx, serial_type, serial_bytes):
1409
        self._index = idx
1410
        self._type = serial_type
1411
        self._bytes = serial_bytes
1412
        self._value = None
1413
        self._parse()
1414
1415
    def _check_length(self, expected_length):
1416
        if len(self) != expected_length:
1417
            raise MalformedField
1418
1419
    # TODO Raise a specific exception when bad bytes are encountered for the
1420
    # fields and then use this to weed out bad freeblock records
1421
    def _parse(self):
1422
        if self._type == 0:
1423
            self._value = None
1424
        # Integer types
1425
        elif self._type == 1:
1426
            self._check_length(1)
1427
            self._value = decode_twos_complement(bytes(self)[0:1], 8)
1428
        elif self._type == 2:
1429
            self._check_length(2)
1430
            self._value = decode_twos_complement(bytes(self)[0:2], 16)
1431
        elif self._type == 3:
1432
            self._check_length(3)
1433
            self._value = decode_twos_complement(bytes(self)[0:3], 24)
1434
        elif self._type == 4:
1435
            self._check_length(4)
1436
            self._value = decode_twos_complement(bytes(self)[0:4], 32)
1437
        elif self._type == 5:
1438
            self._check_length(6)
1439
            self._value = decode_twos_complement(bytes(self)[0:6], 48)
1440
        elif self._type == 6:
1441
            self._check_length(8)
1442
            self._value = decode_twos_complement(bytes(self)[0:8], 64)
1443
1444
        elif self._type == 7:
1445
            self._value = struct.unpack(r'>d', bytes(self)[0:8])[0]
1446
        elif self._type == 8:
1447
            self._value = 0
1448
        elif self._type == 9:
1449
            self._value = 1
1450
        elif self._type >= 13 and (1 == self._type % 2):
1451
            try:
1452
                self._value = bytes(self).decode('utf-8')
1453
            except UnicodeDecodeError:
1454
                raise MalformedField
1455
1456
        elif self._type >= 12 and (0 == self._type % 2):
1457
            self._value = bytes(self)
1458
1459
    def __bytes__(self):
1460
        return self._bytes
1461
1462
    def __repr__(self):
1463
        return "<Field {}: {} ({} bytes)>".format(
1464
            self._index, self._value, len(bytes(self))
1465
        )
1466
1467
    def __len__(self):
1468
        return len(bytes(self))
1469
1470
    @property
1471
    def index(self):
1472
        return self._index
1473
1474
    @property
1475
    def value(self):
1476
        return self._value
1477
1478
    @property
1479
    def serial_type(self):
1480
        return self._type
1481
1482
1483
class Varint(object):
1484
    def __init__(self, varint_bytes):
1485
        self._bytes = varint_bytes
1486
        self._len = 0
1487
        self._value = 0
1488
1489
        varint_bits = []
1490
        for b in self._bytes:
1491
            self._len += 1
1492
            if b & 0x80:
1493
                varint_bits.append(b & 0x7F)
1494
            else:
1495
                varint_bits.append(b)
1496
                break
1497
1498
        varint_twos_complement = 0
1499
        for position, b in enumerate(varint_bits[::-1]):
1500
            varint_twos_complement += b * (1 << (7*position))
1501
1502
        self._value = decode_twos_complement(
1503
            int.to_bytes(varint_twos_complement, 4, byteorder='big'), 64
1504
        )
1505
1506
    def __int__(self):
1507
        return self._value
1508
1509
    def __len__(self):
1510
        return self._len
1511
1512
    def __repr__(self):
1513
        return "<Varint {} ({} bytes)>".format(int(self), len(self))
1514
1515
1516
def decode_twos_complement(encoded, bit_length):
1517
    assert(0 == bit_length % 8)
1518
    encoded_int = int.from_bytes(encoded, byteorder='big')
1519
    mask = 2**(bit_length - 1)
1520
    value = -(encoded_int & mask) + (encoded_int & ~mask)
1521
    return value
1522
1523
1524
def gen_output_dir(db_path):
1525
    db_abspath = os.path.abspath(db_path)
1526
    db_dir, db_name = os.path.split(db_abspath)
1527
1528
    munged_name = db_name.replace('.', '_')
1529
    out_dir = os.path.join(db_dir, munged_name)
1530
    if not os.path.exists(out_dir):
1531
        return out_dir
1532
    suffix = 1
1533
    while suffix <= 10:
1534
        out_dir = os.path.join(db_dir, "{}_{}".format(munged_name, suffix))
1535
        if not os.path.exists(out_dir):
1536
            return out_dir
1537
        suffix += 1
1538
    raise SystemError(
1539
        "Unreasonable number of output directories for {}".format(db_path)
1540
    )
1541
1542
1543
def _load_db(sqlite_path):
1544
    _LOGGER.info("Processing %s", sqlite_path)
1545
1546
    load_heuristics()
1547
1548
    db = SQLite_DB(sqlite_path)
1549
    _LOGGER.info("Database: %r", db)
1550
1551
    db.populate_freelist_pages()
1552
    db.populate_ptrmap_pages()
1553
    db.populate_overflow_pages()
1554
1555
    # Should we aim to instantiate specialised b-tree objects here, or is the
1556
    # use of generic btree page objects acceptable?
1557
    db.populate_btree_pages()
1558
1559
    db.map_tables()
1560
1561
    # We need a first pass to process tables that are disconnected
1562
    # from their table's root page
1563
    db.reparent_orphaned_table_leaf_pages()
1564
1565
    # All pages should now be represented by specialised objects
1566
    assert(all(isinstance(p, Page) for p in db.pages.values()))
1567
    assert(not any(type(p) is Page for p in db.pages.values()))
1568
    return db
1569
1570
1571
def dump_to_csv(args):
1572
    out_dir = args.output_dir or gen_output_dir(args.sqlite_path)
1573
    db = _load_db(args.sqlite_path)
1574
1575
    if os.path.exists(out_dir):
1576
        raise ValueError("Output directory {} exists!".format(out_dir))
1577
    os.mkdir(out_dir)
1578
1579
    for table_name in sorted(db.tables):
1580
        table = db.tables[table_name]
1581
        _LOGGER.info("Table \"%s\"", table)
1582
        table.recover_records()
1583
        table.csv_dump(out_dir)
1584
1585
1586
def undelete(args):
1587
    db_abspath = os.path.abspath(args.sqlite_path)
1588
    db = _load_db(db_abspath)
1589
1590
    output_path = os.path.abspath(args.output_path)
1591
    if os.path.exists(output_path):
1592
        raise ValueError("Output file {} exists!".format(output_path))
1593
1594
    shutil.copyfile(db_abspath, output_path)
1595
    with sqlite3.connect(output_path) as output_db_connection:
1596
        cursor = output_db_connection.cursor()
1597
        for table_name in sorted(db.tables):
1598
            table = db.tables[table_name]
1599
            _LOGGER.info("Table \"%s\"", table)
1600
            table.recover_records()
1601
1602
            failed_inserts = 0
1603
            constraint_violations = 0
1604
            successful_inserts = 0
1605
            for leaf_page in table.leaves:
1606
                if not leaf_page.recovered_records:
1607
                    continue
1608
1609
                for record in leaf_page.recovered_records:
1610
                    insert_statement, values = table.build_insert_SQL(record)
1611
1612
                    try:
1613
                        cursor.execute(insert_statement, values)
1614
                    except sqlite3.IntegrityError:
1615
                        # We gotta soldier on, there's not much we can do if a
1616
                        # constraint is violated by this insert
1617
                        constraint_violations += 1
1618
                    except (
1619
                                sqlite3.ProgrammingError,
1620
                                sqlite3.OperationalError,
1621
                                sqlite3.InterfaceError
1622
                            ) as insert_ex:
1623
                        _LOGGER.warning(
1624
                            (
1625
                                "Caught %r while executing INSERT statement "
1626
                                "in \"%s\""
1627
                            ),
1628
                            insert_ex,
1629
                            table
1630
                        )
1631
                        failed_inserts += 1
1632
                        # pdb.set_trace()
1633
                    else:
1634
                        successful_inserts += 1
1635
            if failed_inserts > 0:
1636
                _LOGGER.warning(
1637
                    "%d failed INSERT statements in \"%s\"",
1638
                    failed_inserts, table
1639
                )
1640
            if constraint_violations > 0:
1641
                _LOGGER.warning(
1642
                    "%d constraint violations statements in \"%s\"",
1643
                    constraint_violations, table
1644
                )
1645
            _LOGGER.info(
1646
                "%d successful INSERT statements in \"%s\"",
1647
                successful_inserts, table
1648
            )
1649
1650
1651
def find_in_db(args):
1652
    db = _load_db(args.sqlite_path)
1653
    db.grep(args.needle)
1654
1655
1656
subcmd_actions = {
1657
    'csv':  dump_to_csv,
1658
    'grep': find_in_db,
1659
    'undelete': undelete,
1660
}
1661
1662
1663
def subcmd_dispatcher(arg_ns):
1664
    return subcmd_actions[arg_ns.subcmd](arg_ns)
1665
1666
1667
def main():
1668
1669
    verbose_parser = argparse.ArgumentParser(add_help=False)
1670
    verbose_parser.add_argument(
1671
        '-v', '--verbose',
1672
        action='count',
1673
        help='Give *A LOT* more output.',
1674
    )
1675
1676
    cli_parser = argparse.ArgumentParser(
1677
        description=PROJECT_DESCRIPTION,
1678
        parents=[verbose_parser],
1679
    )
1680
1681
    subcmd_parsers = cli_parser.add_subparsers(
1682
        title='Subcommands',
1683
        description='%(prog)s implements the following subcommands:',
1684
        dest='subcmd',
1685
    )
1686
1687
    csv_parser = subcmd_parsers.add_parser(
1688
        'csv',
1689
        parents=[verbose_parser],
1690
        help='Dumps visible and recovered records to CSV files',
1691
        description=(
1692
            'Recovers as many records as possible from the database passed as '
1693
            'argument and outputs all visible and recovered records to CSV '
1694
            'files in output_dir'
1695
        ),
1696
    )
1697
    csv_parser.add_argument(
1698
        'sqlite_path',
1699
        help='sqlite3 file path'
1700
    )
1701
    csv_parser.add_argument(
1702
        'output_dir',
1703
        nargs='?',
1704
        default=None,
1705
        help='Output directory'
1706
    )
1707
1708
    grep_parser = subcmd_parsers.add_parser(
1709
        'grep',
1710
        parents=[verbose_parser],
1711
        help='Matches a string in one or more pages of the database',
1712
        description='Bar',
1713
    )
1714
    grep_parser.add_argument(
1715
        'sqlite_path',
1716
        help='sqlite3 file path'
1717
    )
1718
    grep_parser.add_argument(
1719
        'needle',
1720
        help='String to match in the database'
1721
    )
1722
1723
    undelete_parser = subcmd_parsers.add_parser(
1724
        'undelete',
1725
        parents=[verbose_parser],
1726
        help='Inserts recovered records into a copy of the database',
1727
        description=(
1728
            'Recovers as many records as possible from the database passed as '
1729
            'argument and inserts all recovered records into a copy of'
1730
            'the database.'
1731
        ),
1732
    )
1733
    undelete_parser.add_argument(
1734
        'sqlite_path',
1735
        help='sqlite3 file path'
1736
    )
1737
    undelete_parser.add_argument(
1738
        'output_path',
1739
        help='Output database path'
1740
    )
1741
1742
    cli_args = cli_parser.parse_args()
1743
    if cli_args.verbose:
1744
        _LOGGER.setLevel(logging.DEBUG)
1745
1746
    if cli_args.subcmd:
1747
        subcmd_dispatcher(cli_args)
1748
    else:
1749
        # No subcommand specified, print the usage and bail
1750
        cli_parser.print_help()
1751