e2edutch.bert.load_bert()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 16
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 16
rs 9.65
c 0
b 0
f 0
cc 4
nop 1
1
import numpy as np
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import torch
3
from transformers import RobertaTokenizer, RobertaModel, BertTokenizer
4
from transformers import BertModel, BertForPreTraining, BertConfig
5
6
7
def load_bert(model_name):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
8
    if model_name == 'robbert':
9
        tokenizer = RobertaTokenizer.from_pretrained("pdelobelle/robBERT-base")
10
        model = RobertaModel.from_pretrained("pdelobelle/robBERT-base")
11
    elif model_name == 'bertje':
12
        tokenizer = BertTokenizer.from_pretrained(
13
            "wietsedv/bert-base-dutch-cased")
14
        model = BertModel.from_pretrained("wietsedv/bert-base-dutch-cased")
15
    elif model_name == 'bert-nl':
16
        tokenizer = BertTokenizer.from_pretrained("data/bert-nl")
17
        config = BertConfig.from_json_file("data/bert-nl/config.json")
18
        model = BertForPreTraining(config).bert
19
    else:
20
        raise ValueError('invalid model name')
21
    model.eval()
22
    return tokenizer, model
23
24
25
def encode_sentences(sentences, tokenizer, model):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Comprehensibility introduced by
This function exceeds the maximum number of variables (19/15).
Loading history...
26
    # Use BERT tokenizer
27
    sentences_tokenized = [
28
        [tokenizer.tokenize(word) for word in sentence]
29
        for sentence in sentences]
30
    sentences_tokenized_flat = [
31
        [tok for word in sentence for tok in word]
32
        for sentence in sentences_tokenized]
33
    indices_flat = [[i for i, word in enumerate(
34
        sentence) for tok in word] for sentence in sentences_tokenized]
35
36
    max_nrtokens = max(len(s) for s in sentences_tokenized_flat)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable len does not seem to be defined.
Loading history...
37
    indexed_tokens = np.zeros((len(sentences), max_nrtokens), dtype=int)
38
    for i, sent in enumerate(sentences_tokenized_flat):
39
        idx = tokenizer.convert_tokens_to_ids(sent)
40
        indexed_tokens[i, :len(idx)] = np.array(idx)
41
42
    # Convert inputs to PyTorch tensors
43
    tokens_tensor = torch.tensor(indexed_tokens)
0 ignored issues
show
Bug introduced by
torch.tensor does not seem to be callable.
Loading history...
44
    with torch.no_grad():
45
        # torch tensor of shape
46
        # (nr_sentences, sequence_length, hidden_size=768
47
        bert_output, _ = model(tokens_tensor)
48
49
    # Add up tensors for subtokens coming from same word
50
    max_sentence_length = max(len(s) for s in sentences)
51
    bert_final = torch.tensor(np.zeros((bert_output.shape[0],
0 ignored issues
show
Bug introduced by
torch.tensor does not seem to be callable.
Loading history...
52
                                        max_sentence_length,
53
                                        bert_output.shape[2])))
54
    for sent_id in range(len(sentences)):
55
        for tok_id, word_id in enumerate(indices_flat[sent_id]):
56
            bert_final[sent_id, word_id, :] += bert_output[sent_id, tok_id, :]
57
    bert_final = np.array(bert_final)
58
    # Add extra axis
59
    bert_final = np.expand_dims(bert_final, axis=3)
60
    return bert_final
61