본문 바로가기

deeplearning

BERT for Coreference Resolution 모델코딩 #pytorch #한국어

데이터 셋

 

총 124개의 서로 다른 주제의 documents

 

문장들, mention cluster 등이 json 형태로 구성

 

train data와 validation data의 비율은 8:2로 임의로 배분

 

 

 

 

mention cluster 형태

 

 

 

 

 

이용 버전

 

torch==1.2.0

transformers==3.0.2

numpy==1.18.1

 

 

코드

 

 

1 . Config

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#library
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch.nn.functional as F
import json
import os
from collections import OrderedDict
from easydict import EasyDict
import math
import re
from pathlib import Path
import random
import tqdm
from collections import Counter
from scipy.optimize import linear_sum_assignment
import time
import _pickle
import copy
import pprint
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler
from transformers import AutoTokenizer, AutoModel, AdamW
 
#config
config = EasyDict({
    "embedding_dim"768,
    "max_span_width"30,
    # max training sentences depends on size of memery 
    "max_training_sentences"10,
    # max seq length
    "max_seq_length"128,
    "bert_max_seq_length"512,
 
    "device""cuda",
    "checkpoint_path""./Coreference_Resolution_dataset/checkpoint/checkpoint",
    "lr"0.0002,
    "weight_decay"0.0005,
    "dropout"0.3,
 
    "report_frequency"200,
    "eval_frequency"200
    
    # dir
    "root_dir""./Coreference_Resolution_dataset",
    "train_file_path""./Coreference_Resolution_dataset/preprocessing/train.json",
    "test_file_path""./Coreference_Resolution_dataset/preprocessing/test.json",
    "val_file_path""./Coreference_Resolution_dataset/preprocessing/val.json",
 
    # max candidate mentions size in first/second stage
    "top_span_ratio"0.4,
    "max_top_antecedents"50,
    # use coarse to fine pruning
    "coarse_to_fine"True,
    # high order coref depth
    "coref_depth"2,
 
    # FFNN config
    "ffnn_depth"1,
    "ffnn_size"3000,
 
    # use span features, such as distance
    "use_features"True,
    "feature_dim"20,
    "model_heads"True,
    "use_metadata": False, # 나의 경우 dataset에 genre나 speaker가 없어서 False를 주었다
    "genres": ["bc""bn""mz""nw""tc""wb","dummy"],
    "extract_spans"True,
    "transformer_model_name"'bert-base-multilingual-cased',
    "transformer_lr"0.00001,
})
 
= config
tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"], do_lower_case=False)
root_path = Path(c["root_dir"])
 
#재현 가능하게 만들기
random_seed = 333
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
#torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
cs

 

 

2. Loading and Preprocessing Data

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def mydata_target_cluster(json_data): #하나의 다큐먼트의 json 데이터를 받아서 문장들의 모음과 antecedent를 가진 클러스터들을 반환한다.
    clusters = []
    raw_text = []
    for entities in json_data['entity']:
        cluster = []
        for entity in entities['mention']:
            sent_id = entity['sent_id']
            cluster.append([sent_id, entity['start_eid'],entity['end_eid']])
        if len(cluster)>1 : #antecedent가 없는 cluster은 가져오지 않는다.
            clusters.append(cluster) #span ids should be in a same cluster 예시로 [[2,110,112], [2,124,124]]는 2번째 문장의 110~112스팬과 124~124 스팬이 하나의 군집을 이룬다는 뜻이다.
    
    for text in json_data['sentence']:
        raw_text.append(text['text'])
        
    return raw_text, clusters
 
 
def formatting_data(raw_text, clusters, doc_key):
    tokenizer = AutoTokenizer.from_pretrained(config["transformer_model_name"]) #multilingual bert tokenizer 이용    
    formatted_coref_chains = list()
    line_corefs = dict()
    
    
    sentence_words = list()
    for sentence in raw_text: 
        words = sentence.split(" ")  #문장을 띄어쓰기를 기준으로 나눈다
        sentence_words.append(words)
    
    for cluster in clusters:
        formatted_coref_chain = list()
        for coref in cluster:
            line_num = coref[0]
            line_start, line_end = coref[1], coref[2]
            formatted_coref = [line_num, int(line_start), int(line_end) + 1]
            formatted_coref_chain.append(formatted_coref)
            if line_num not in line_corefs:
                line_corefs[line_num] = list()
            line_corefs[line_num].append(formatted_coref) 
 
        formatted_coref_chains.append(formatted_coref_chain) 
 
    sentences = list()
    subtoken_map = list()
    sentence_map = list()
    speaker_ids = list()
    subtoken_index = 0
 
    for line_num, words in enumerate(sentence_words):# sent_id와 띄어쓰기를 기준으로 나누어진 리스트 문장을 가져온다
        tokens = list()
        sentence_subtoken_map = list()
 
        for word in words: #문장에서 띄어쓰기로 나누어진 단어들을 가져온다
           
            tokenized_word = tokenizer.tokenize(word) # 단어를 토크나이징한다.
            subtoken_index += 1 
 
 
            for _ in range(len(tokenized_word)): 
                sentence_subtoken_map.append(subtoken_index) #토큰으로 나뉘더라도 동일한 단어를 이루면 같은 수로 인덱싱한다.
 
            if line_num in line_corefs: # {'line_num' : [line_num, line_start, line_end+1], ...}
                for coref in line_corefs[line_num]:
                    if coref[1> len(tokens): 
                        coref[1+= len(tokenized_word) - 1 
                        coref[2+= len(tokenized_word) - 1 
 
                    elif coref[2> len(tokens):
                        coref[2+= len(tokenized_word) - 1
 
            tokens.extend(tokenized_word)
 
        sentences.append(tokens)
        subtoken_map.append(sentence_subtoken_map)
        sentence_map.append([line_num] * len(tokens))
        speaker_list = ['_']*1000 #내 데이터셋은 따로 speaker가 없으므로 더미로 준다.
        speaker_ids.append([speaker_list[line_num]] * len(tokens)) #speaker는 ['_']로 채워준다.     
 
    output_sentences = list()
    output_subtoken_map = list()
    output_sentence_map = list()
    output_speaker_ids = list()
    tmp_sentences = list()
    tmp_subtoken_map = list()
    tmp_sentence_map = list()
    tmp_speaker_ids = list()
    bias = 0            
    line_num = 0
 
    while line_num < len(sentences):
        sentence = sentences[line_num]
        sentence_token_num = len(sentence)
        if len(tmp_sentences) > 0 and sentence_token_num + len(tmp_sentences) > config.max_seq_length:
 
            output_sentences.append(tmp_sentences)
            output_subtoken_map.append(tmp_subtoken_map)
            output_sentence_map.append(tmp_sentence_map)
            output_speaker_ids.append(tmp_speaker_ids)
            tmp_sentences = list()
            tmp_subtoken_map = list()
            tmp_sentence_map = list()
            tmp_speaker_ids = list()
        else:
 
            tmp_sentences.extend(sentence)
            tmp_subtoken_map.extend(subtoken_map[line_num])
            tmp_sentence_map.extend(sentence_map[line_num])
            tmp_speaker_ids.extend(speaker_ids[line_num])
 
            if line_num in line_corefs:
                for coref in line_corefs[line_num]:
                    coref[1+= bias
                    coref[2+= bias
            bias += sentence_token_num
            line_num += 1
 
    if len(tmp_sentences) > 0:
        output_sentences.append(tmp_sentences)
        output_subtoken_map.append(tmp_subtoken_map)
        output_sentence_map.append(tmp_sentence_map)
        output_speaker_ids.append(tmp_speaker_ids)
 
    output_clusters = list()
    for formatted_coref_chain in formatted_coref_chains:
        if len(formatted_coref_chain) > 1:
            cluster = list()
            for coref in formatted_coref_chain:
                if coref[1!= coref[2]:
                    cluster.append([coref[1], coref[2- 1])
            if len(cluster) > 1:
                output_clusters.append(cluster)
 
    example = {"sentences": output_sentences, "clusters": output_clusters, "speaker_ids": output_speaker_ids, "sentence_map": output_sentence_map, "subtoken_map": output_subtoken_map, 'doc_key':doc_key, 'genre':c['genres'][6]} # 장르는 더미로 준다
    return example
 
def add_bias(src, bias):
    if type(src) is list:
        res = list()
        for element in src:
            new_element = add_bias(element, bias)
            res.append(new_element)
        return res
    else:
        return src + bias
    
def final_formatting_data(example):
    final_example = {"sentences": list(), "clusters": list(), "speaker_ids": list(), "sentence_map": list(), "subtoken_map": list(), "genre"''"doc_key"''}
    clusters_bias = 0
 
   final_example["sentences"].extend(example["sentences"])
   final_example["clusters"].extend(add_bias(example["clusters"], clusters_bias))
   final_example["speaker_ids"].extend(example["speaker_ids"])
   final_example["sentence_map"].extend(example["sentence_map"])
   final_example["subtoken_map"].extend(example["subtoken_map"])
   final_example['genre'= example['genre'#장르는 모두 동일한 더미로 준다.
   final_example['doc_key'= example['doc_key']
   clusters_bias += sum([len(s) for s in example["sentences"]])
 
    
    return final_example
 
 
 
 
 
 
train_filenames = os.listdir('./Coreference_Resolution_dataset/train')
val_filenames = os.listdir('./Coreference_Resolution_dataset/val')
test_filenames = os.listdir('./Coreference_Resolution_dataset/test')
 
train_list = []
val_list = []
test_list = []
 
 
for train_filename in train_filenames:
    if train_filename != 'train.json':
        try:
            with open('./Coreference_Resolution_dataset/train/'+train_filename) as file:
                json_data_train = json.load(file)     
            doc_key = train_filenames[0].split()[1][:7]
            raw_text, clusters = mydata_target_cluster(json_data_train)
            example = formatting_data(raw_text, clusters, doc_key)
           train_list.append(example)
        except:
            print(train_filename)
            continue
 
train_file_fd = open(c["train_file_path"], "w", encoding="utf-8")
 
examples_json_train = ''
for example in train_list:
    examples_json_train += json.dumps(example, ensure_ascii=False+ "\n"
 
train_file_fd.write(examples_json_train)
    
      
    
 
for val_filename in val_filenames:
    if val_filename != 'val.json':
        try:      
            with open('./Coreference_Resolution_dataset/val/'+val_filename) as file:
                json_data_val = json.load(file)
            doc_key = val_filenames[0].split()[1][:7]
            raw_text, clusters = mydata_target_cluster(json_data_val)
            example = formatting_data(raw_text, clusters, doc_key)
           val_list.append(example)
        except:
            print(val_filename)
            continue
    
val_file_fd = open(c["val_file_path"], "w", encoding="utf-8")   
 
examples_json_val = ''
for example in val_list:
    examples_json_val += json.dumps(example, ensure_ascii=False+ "\n"
 
val_file_fd.write(examples_json_val)
 
 
 
train_file_fd.close()
val_file_fd.close()
 
cs

 

 

 

3. Tools

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def print_spans_loc(spans_start, spans_end):
 
    assert len(spans_start) == len(spans_end)
    for i in range(len(spans_start)):
        try:
            print((spans_start[i].item(), spans_end[i].item()))
        except:
            print((spans_start[i], spans_end[i]))
 
 
 
def truncate_example(sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, max_training_sentences):
 
    line_offset = random.randint(0len(sentences_ids) - max_training_sentences)
    truncated_sentences_ids = sentences_ids[line_offset:line_offset + max_training_sentences]
    truncated_sentences_masks = sentences_masks[line_offset:line_offset + max_training_sentences]
    truncated_sentences_valid_masks = sentences_valid_masks[line_offset:line_offset + max_training_sentences]
    truncated_speaker_ids = speaker_ids[line_offset:line_offset + max_training_sentences]
    truncated_sentence_map = sentence_map[line_offset:line_offset + max_training_sentences]
    truncated_subtoken_map = subtoken_map[line_offset:line_offset + max_training_sentences]
 
    token_offset = torch.sum(sentences_valid_masks[:line_offset]).item()
    token_num = torch.sum(truncated_sentences_valid_masks).item()
 
    truncated_clusters = list()
    for cluster in clusters:
        truncated_cluster = list()
        for start_loc, end_loc in cluster:
            if start_loc - token_offset >= 0 and end_loc <= token_offset + token_num:
                truncated_cluster.append([start_loc-token_offset, end_loc-token_offset])
        if len(truncated_cluster) > 0:
            truncated_clusters.append(truncated_cluster)
    return truncated_sentences_ids, truncated_sentences_masks, truncated_sentences_valid_masks, truncated_clusters, truncated_speaker_ids, truncated_sentence_map, truncated_subtoken_map
 
 
def add_bias_to_clusters(clusters, bias, coref_filter):
    for cluster in clusters:
        for coref in cluster:
            if coref_filter(coref) == True:
                coref[0+= bias
                coref[1+= bias
 
 
def tokenize_example(example, tokenizer, c):
    """ tokenize example
    """
 
    sentences = example["sentences"]
    clusters = example["clusters"]
    speaker_ids = example["speaker_ids"]
    sentence_map = example["sentence_map"]
    subtoken_map = example["subtoken_map"]
    example_genre_name = example["genre"]
 
    # genre to idx
    genre = 0
    while genre < len(c["genres"]):
        if example_genre_name == c["genres"][genre]:
            break
        genre += 1
    
    # token to ids
    sentences_ids = list()
    sentences_masks = list()
 
    max_seq_len = max([len(s) for s in sentences]) + 2      #[CLS] [SEP]
 
    for sentence in sentences:
        sentence_ids = tokenizer.convert_tokens_to_ids(sentence)
        sentence_ids = [101+ sentence_ids + [102]
        token_num = len(sentence_ids)
        sentence_masks = [1* token_num
        if token_num > c["bert_max_seq_length"]:
            raise Exception("the length of sentence is out the range of bert_max_seq_length.")
        else:
            sentence_ids += [0* (max_seq_len - token_num)
            sentence_masks += [0* (max_seq_len - token_num)
        sentences_ids.append(sentence_ids)
        sentences_masks.append(sentence_masks)
 
    sentences_ids = torch.LongTensor(sentences_ids)
    sentences_masks = torch.LongTensor(sentences_masks)
 
    # convert speaker_ids to long type
    speaker_dict = dict()
    speaker_index = 0
    speaker_ids_long = list()
    for sentence_speaker_ids in speaker_ids:
        sentence_speaker_ids_long = list()
        for speaker in sentence_speaker_ids:
            if speaker not in speaker_dict:
                speaker_dict[speaker] = speaker_index
                speaker_index += 1
            speaker_id = speaker_dict[speaker]
            sentence_speaker_ids_long.append(speaker_id)
        speaker_ids_long.append(sentence_speaker_ids_long)
    speaker_ids = speaker_ids_long
 
    sentence_tokens_num = torch.sum(sentences_masks, dim=1)
 
    sentences_valid_masks = sentences_masks.clone()
    sentences_valid_masks[:, 0= 0
 
    for i in range(len(sentences_valid_masks)):
        sentences_valid_masks[i][sentence_tokens_num[i] - 1= 0
 
    
    for i in range(len(sentences)):
        if not (len(sentences[i]) == len(speaker_ids[i]) == len(sentence_map[i]) == len(subtoken_map[i])):
            raise Exception("The length of sentence/speaker_ids/sentence_map/subtoken_map is inconsistent.")
 
 
    return sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, genre
 
 
 
cs

 

 

 

4. Model

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class CorefModel(torch.nn.Module):
    def __init__(self, config):
        super(CorefModel, self).__init__()
        self.config = config
        self.span_dim = 2 * self.config["embedding_dim"]
        mm_input_dim = 0
 
        if self.config["use_features"]:
            self.span_dim += self.config["feature_dim"]
            self.span_width_embeddings = torch.nn.Embedding(self.config["max_span_width"], self.config["feature_dim"])
            self.bucket_distance_embeddings = torch.nn.Embedding(10self.config["feature_dim"])
            mm_input_dim += self.config["feature_dim"]
 
        if self.config["model_heads"]:
            self.span_dim += self.config["embedding_dim"]
            self.Sh = torch.nn.Linear(self.config["embedding_dim"], 1)      # token head score
 
        if self.config["use_metadata"]:
            self.genre_embeddings = torch.nn.Embedding(len(self.config["genres"]) + 1self.config["feature_dim"])
            self.same_speaker_emb = torch.nn.Embedding(2self.config["feature_dim"])
            mm_input_dim += 2 * self.config["feature_dim"]
 
        mm_input_dim += 3 * self.span_dim
 
 
        self.Sm = self._create_score_ffnn(self.span_dim)     # mention score        
        self.Smm = self._create_score_ffnn(mm_input_dim)      # pairwise score between spans
        self.c2fP = torch.nn.Linear(self.span_dim, self.span_dim)       # coarse to fine pruning span projection
        self.hoP = torch.nn.Linear(2 * self.span_dim, self.span_dim)    # high order projection
 
 
 
    def _create_ffnn(self, input_size, output_size, ffnn_size, ffnn_depth, dropout=0): # mention score와 pairwise score 계산을 위한 ffnn을 구성해준다.
 
        current_size = input_size
        model_seq = OrderedDict()
        for i in range(ffnn_depth):
            model_seq['fc' + str(i)] = torch.nn.Linear(current_size, ffnn_size)
            model_seq['relu' + str(i)] = torch.nn.ReLU()
            model_seq['dropout' + str(i)] = torch.nn.Dropout(dropout)
            current_size = ffnn_size
        model_seq['output'= torch.nn.Linear(current_size, output_size)
 
        return torch.nn.Sequential(model_seq)
 
 
    def _create_score_ffnn(self, input_size):
 
        return self._create_ffnn(input_size, 1self.config["ffnn_size"], self.config["ffnn_depth"], self.config["dropout"])
 
    
    def bucket_distance(self, distances):
        """
        Places the given values (designed for distances) into 10 semi-logscale buckets:
        [0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32-63, 64+].
        """
        float_distances = distances.float()
        combined_idx = torch.floor(torch.log(float_distances) / math.log(2)) + 3
        use_identity = distances <= 4
        combined_idx[use_identity] = float_distances[use_identity]
        combined_idx = combined_idx.long()
 
        return torch.clamp(combined_idx, 09)
 
 
    def get_span_embed(self, tokens_embed, spans_start, spans_end):
        span_embed_list = list()
        start_embed = tokens_embed[spans_start]        
        end_embed = tokens_embed[spans_end]            
 
        span_embed_list.append(start_embed)
        span_embed_list.append(end_embed)
 
        if self.config["use_features"]:
            spans_width = (spans_end - spans_start).to(device=torch.device(self.config["device"]))
            span_width_embed = self.span_width_embeddings(spans_width)
            span_width_embed = torch.nn.functional.dropout(span_width_embed, p=self.config["dropout"], training=self.training)
            span_embed_list.append(span_width_embed)
 
        if self.config["model_heads"]:
            tokens_score = self.Sh(tokens_embed).view(-1)           # size: num_tokens
            tokens_locs = torch.arange(start=0, end=len(tokens_embed), dtype=torch.long).repeat(len(spans_start), 1)        # size: num_spans * num_tokens
            tokens_masks = (tokens_locs >= spans_start.view(-11)) & (tokens_locs <= spans_end.view(-11))             # size: num_spans * num_tokens
            tokens_weights = torch.nn.functional.softmax(
                (tokens_score + torch.log(tokens_masks.float()).to(device=torch.device(self.config["device"]))), 
                dim=1
            )
            
            span_head_emb = torch.matmul(tokens_weights, tokens_embed)
            span_embed_list.append(span_head_emb)
 
        return torch.cat(span_embed_list, dim=1)
 
 
    def extract_spans(self, spans_score, spans_start, spans_end, m):
        """ top m spans
        """
        sorted_spans_score, indices = torch.sort(spans_score, 0True)
        top_m_spans_index = list()
        top_m_spans_start = torch.zeros(m, dtype=torch.long)
        top_m_spans_end = torch.zeros(m, dtype=torch.long)
        top_m_len = 0
        i = 0
        while top_m_len < m and i < len(sorted_spans_score):
            span_index = indices[i]
            span_start = spans_start[span_index]
            span_end = spans_end[span_index]
 
            res = (((span_start < top_m_spans_start) & (span_end < top_m_spans_end) & (span_end >= top_m_spans_start)) 
                | ((span_start > top_m_spans_start) & (span_start <= top_m_spans_end) & (span_end > top_m_spans_end)))
 
            if torch.sum(res) == 0:
                top_m_spans_index.append(span_index)
                top_m_spans_start[top_m_len] = span_start
                top_m_spans_end[top_m_len] = span_end
                top_m_len += 1
            i += 1
 
        return torch.stack(top_m_spans_index)
 
 
    def coarse_to_fine_pruning(self, k, spans_masks, spans_embed, spans_score):
        """
 
        parameters:
            k: int
            spans_masks: m * m
            spans_embed: m * span_dim 
            spans_score: m * 1
 
        return
            score: FloatTensor m * k
            index: Long m * k
        """
 
        m = len(spans_embed)
        all_score = torch.zeros(m, m).to(device=torch.device(self.config["device"]))
        all_score[~spans_masks] = float("-inf")
 
        # add span score
        all_score += spans_score
 
        antecedents_offset = (torch.arange(0, m).view(-11- torch.arange(0, m).view(1-1))
        antecedents_offset = antecedents_offset.to(device=torch.device(self.config["device"]))
 
        if self.config["coarse_to_fine"== True:
            source_top_span_emb = torch.nn.functional.dropout(self.c2fP(spans_embed), p=self.config["dropout"], training=self.training)
            target_top_span_emb = torch.nn.functional.dropout(spans_embed, p=self.config["dropout"], training=self.training)
            all_score += source_top_span_emb.matmul(target_top_span_emb.t())     # m * m
        else:
            
            k = m
 
        top_antecedents_fast_score, top_antecedents_index = torch.topk(all_score, k)
        top_antecedents_offset = torch.gather(antecedents_offset, dim=1, index=top_antecedents_index)
 
        return top_antecedents_fast_score, top_antecedents_index, top_antecedents_offset
 
 
    def get_spans_similarity_score(self, top_antecedents_index, spans_embed, top_antecedents_offset, speaker_ids, genre_emb):
        ""
 
        parameters:
            top_antecedents_index: Long m * k
            spans_embed: m * span_dim
            top_antecedents_offset: m * k
            speaker_ids: m
            genre_emb: feature_dim
        return:
            score: FloatTensor m * k
        """
        m = len(spans_embed)
        k = top_antecedents_index.shape[1]
 
        span_index = torch.arange(0, m, dtype=torch.long).repeat(k, 1).t()
 
        mm_ffnn_input_list = list()
        mm_ffnn_input_list.append(spans_embed[span_index])
        mm_ffnn_input_list.append(spans_embed[top_antecedents_index])
        mm_ffnn_input_list.append(mm_ffnn_input_list[0* mm_ffnn_input_list[1])
 
        if self.config["use_features"]:
            top_antecedents_distance_bucket = self.bucket_distance(top_antecedents_offset)
            top_antecedents_distance_emb = self.bucket_distance_embeddings(top_antecedents_distance_bucket)
            mm_ffnn_input_list.append(top_antecedents_distance_emb)
 
        if self.config["use_metadata"]:
            same_speaker_ids = (speaker_ids.view(-11== speaker_ids[top_antecedents_index]).long().to(device=torch.device(self.config["device"]))
            speaker_emb = self.same_speaker_emb(same_speaker_ids)
            mm_ffnn_input_list.append(speaker_emb)
            mm_ffnn_input_list.append(genre_emb.repeat(m, k, 1))
 
        mm_ffnn_input = torch.cat(mm_ffnn_input_list, dim=2)
        mm_slow_score = self.Smm(mm_ffnn_input)
 
        return mm_slow_score.squeeze()
 
 
    def forward(self, sentences_ids, sentences_masks, sentences_valid_masks, speaker_ids, sentence_map, subtoken_map, genre, transformer_model):
        ""
        parameters:
            sentences_ids: num_sentence * max_sentence_len
            sentences_masks: num_sentence * max_sentence_len
            sentences_valid_masks: num_sentence * max_sentence_len
            speaker_ids: list[list]
            sentence_map: list[list]
            subtoken_map: list[list]
            genre: genre_id
            transformer_model: AutoModel
        """
 
 
        sentences_embed, _ = transformer_model(sentences_ids.to(device=torch.device(self.config["device"])), sentences_masks.to(device=torch.device(self.config["device"])))      # num_sentence * max_sentence_len * embed_dim
 
        tokens_embed = sentences_embed[sentences_valid_masks.bool()]          # num_tokens * embed_dim
 
        flattened_sentence_indices = list()
        for sm in sentence_map:
            flattened_sentence_indices += sm
        flattened_sentence_indices = torch.LongTensor(flattened_sentence_indices)
 
        candidate_spans_start = torch.arange(0len(tokens_embed)).repeat(self.config["max_span_width"], 1).t()
        candidate_spans_end = candidate_spans_start + torch.arange(0self.config["max_span_width"])
        candidate_spans_start_sentence_indices = flattened_sentence_indices[candidate_spans_start]
        candidate_spans_end_sentence_indices = flattened_sentence_indices[torch.min(candidate_spans_end, torch.tensor(len(tokens_embed) - 1))]
        candidate_spans_mask = (candidate_spans_end < len(tokens_embed)) & (candidate_spans_start_sentence_indices == candidate_spans_end_sentence_indices)
 
        spans_start = candidate_spans_start[candidate_spans_mask]
        spans_end = candidate_spans_end[candidate_spans_mask]
        
        spans_embed = self.get_span_embed(tokens_embed, spans_start, spans_end)       # size: num_spans * span_dim
        spans_score = self.Sm(spans_embed)          # size: num_spans * 1
 
        spans_len = len(spans_score)
        spans_score_mask = torch.zeros(spans_len, dtype=torch.bool)
        m = min(int(self.config["top_span_ratio"* len(tokens_embed)), spans_len)    
 
        if self.config["extract_spans"]:
            span_indices = self.extract_spans(spans_score, spans_start, spans_end, m)
        else:
            _, span_indices = torch.topk(spans_score, m, dim=0, largest=True)
 
        m = len(span_indices)
        spans_score_mask[span_indices] = True
 
        top_m_spans_embed = spans_embed[spans_score_mask]        # size: m * span_dim
        top_m_spans_score = spans_score[spans_score_mask]        # size: m * 1
        top_m_spans_start = spans_start[spans_score_mask]        # size: m
        top_m_spans_end = spans_end[spans_score_mask]            # size: m
 
        # BoolTensor, size: m * m
        top_m_spans_masks = \
            ((top_m_spans_start.repeat(m, 1).t() > top_m_spans_start) \
            | ((top_m_spans_start.repeat(m, 1).t() == top_m_spans_start) \
                & (top_m_spans_end.repeat(m, 1).t() > top_m_spans_end)))
 
        if self.config['use_metadata']:
            flattened_speaker_ids = list()
            for si in speaker_ids:
                flattened_speaker_ids += si
            flattened_speaker_ids = torch.LongTensor(flattened_speaker_ids)
            top_m_spans_speaker_ids = flattened_speaker_ids[top_m_spans_start]
            genre_emb = self.genre_embeddings(torch.LongTensor([genre]).to(device=torch.device(self.config["device"]))).squeeze()
        else:
            top_m_spans_speaker_ids = None
            genre_emb = None
 
 
        k = min(self.config["max_top_antecedents"], m)
 
        top_antecedents_fast_score, top_antecedents_index, top_antecedents_offset = self.coarse_to_fine_pruning(
            k, top_m_spans_masks, top_m_spans_embed, top_m_spans_score
        )
 
        for i in range(self.config["coref_depth"]):
            top_antecedents_slow_score = self.get_spans_similarity_score(top_antecedents_index, top_m_spans_embed, top_antecedents_offset, top_m_spans_speaker_ids, genre_emb)
            top_antecedents_score = top_antecedents_fast_score + top_antecedents_slow_score         # size: m * k
            dummy_score = torch.zeros(m, 1).to(device=torch.device(self.config["device"]))          # add dummy
            top_antecedents_score = torch.cat((top_antecedents_score, dummy_score), dim=1)          # size: m * (k+1)
            top_antecedents_weight = torch.nn.functional.softmax(top_antecedents_score, dim=1)      # size: m * (k+1)
            top_antecedents_emb = torch.cat((top_m_spans_embed[top_antecedents_index], top_m_spans_embed.unsqueeze(1)), dim=1)     # size: m * (k+1) * embed
            attended_spans_emb = torch.sum(top_antecedents_weight.unsqueeze(2* top_antecedents_emb, dim=1)            # size: m * embed
            f = torch.sigmoid(self.hoP(torch.cat([top_m_spans_embed, attended_spans_emb], dim=1)))  # size: m * embed
            top_m_spans_embed = f * attended_spans_emb + (1 - f) * top_m_spans_embed                # size: m * embed
 
        return top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end
    
 
    def evaluate(self, data, transformer_model):
 
        coref_evaluator = CorefEvaluator()
 
        with torch.no_grad():
            for idx, data_i in enumerate(data):
                sentences_ids, sentences_masks, sentences_valid_masks, gold_clusters, speaker_ids, sentence_map, subtoken_map, genre = data_i
                
                if len(sentences_ids) > c["max_training_sentences"]:
                    sentences_ids, sentences_masks, sentences_valid_masks, gold_clusters, speaker_ids, sentence_map, subtoken_map = truncate_example(sentences_ids, sentences_masks, sentences_valid_masks, gold_clusters, speaker_ids, sentence_map, subtoken_map, c["max_training_sentences"])
               
                
                top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end = self.forward(sentences_ids, sentences_masks, sentences_valid_masks, speaker_ids, sentence_map, subtoken_map, genre, transformer_model)
                predicted_antecedents = self.get_predicted_antecedents(top_antecedents_index, top_antecedents_score)
                top_m_spans = list()
                for i in range(len(top_m_spans_start)):
                    top_m_spans.append(tuple([top_m_spans_start[i].item(), top_m_spans_end[i].item()]))
 
                # all spans
                gold_clusters = [tuple(tuple([m[0], m[1]]) for m in gc) for gc in gold_clusters]
                mention_to_gold = {}
                for gc in gold_clusters:
                    for mention in gc:
                        mention_to_gold[tuple(mention)] = gc
                predicted_clusters, mention_to_predicted = self.get_predicted_clusters(top_m_spans, predicted_antecedents)
                coref_evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
 
        return coref_evaluator.get_prf()
 
 
    def get_predicted_antecedents(self, top_antecedents_index, top_antecedents_score):
 
        predicted_antecedents = []
        for i, index in enumerate(torch.argmax(top_antecedents_score, axis=1)):
            if index == len(top_antecedents_score[i]) - 1:
 
                predicted_antecedents.append(-1)
            else:
                predicted_antecedents.append(top_antecedents_index[i][index])
        return predicted_antecedents
 
    def get_predicted_clusters(self, top_m_spans, predicted_antecedents):
 
        idx_to_clusters = {}
        predicted_clusters = []
        for i in range(len(top_m_spans)):
            idx_to_clusters[i] = set([i])
 
        for i, predicted_index in enumerate(predicted_antecedents):
            if predicted_index < 0:
                continue
            else:
                union_cluster = idx_to_clusters[predicted_index.item()] | idx_to_clusters[i]
                for j in union_cluster:
                    idx_to_clusters[j] = union_cluster
        
        tagged_index = set()
        for i in idx_to_clusters:
            if (len(idx_to_clusters[i]) == 1or (i in tagged_index):
                continue
            cluster = idx_to_clusters[i]
            predicted_cluster = list()
            for j in cluster:
                tagged_index.add(j)
                predicted_cluster.append(tuple(top_m_spans[j]))
 
            predicted_clusters.append(tuple(predicted_cluster))
 
        mention_to_predicted = {}
        for pc in predicted_clusters:
            for mention in pc:
                mention_to_predicted[tuple(mention)] = pc
 
        return predicted_clusters, mention_to_predicted
 
cs

 

 

5. Metric

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class Evaluator(object):
    def __init__(self, metric, beta=1):
        self.p_num = 0
        self.p_den = 0
        self.r_num = 0
        self.r_den = 0
        self.metric = metric
        self.beta = beta
 
    def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
        if self.metric == ceafe:
            pn, pd, rn, rd = self.metric(predicted, gold)
        else:
            pn, pd = self.metric(predicted, mention_to_gold)
            rn, rd = self.metric(gold, mention_to_predicted)
        self.p_num += pn
        self.p_den += pd
        self.r_num += rn
        self.r_den += rd
 
    def get_f1(self):
        return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
 
    def get_recall(self):
        return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
 
    def get_precision(self):
        return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
 
    def get_prf(self):
        return self.get_precision(), self.get_recall(), self.get_f1()
 
    def get_counts(self):
        return self.p_num, self.p_den, self.r_num, self.r_den
 
 
def evaluate_documents(documents, metric, beta=1):
    evaluator = Evaluator(metric, beta=beta)
    for document in documents:
        evaluator.update(document)
    return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1()
 
 
def b_cubed(clusters, mention_to_gold):
    num, dem = 00
 
    for c in clusters:
        if len(c) == 1:
            continue
 
        gold_counts = Counter()
        correct = 0
        for m in c:
            if m in mention_to_gold:
                gold_counts[tuple(mention_to_gold[m])] += 1
        for c2, count in gold_counts.items():
            if len(c2) != 1:
                correct += count * count
 
        num += correct / float(len(c))
        dem += len(c)
 
    return num, dem
 
 
def muc(clusters, mention_to_gold):
    tp, p = 00
    for c in clusters:
        p += len(c) - 1
        tp += len(c)
        linked = set()
        for m in c:
            if m in mention_to_gold:
                linked.add(mention_to_gold[m])
            else:
                tp -= 1
        tp -= len(linked)
    return tp, p
 
 
def phi4(c1, c2):
    return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
 
 
def ceafe(clusters, gold_clusters):
    clusters = [c for c in clusters if len(c) != 1]
    scores = np.zeros((len(gold_clusters), len(clusters)))
    for i in range(len(gold_clusters)):
        for j in range(len(clusters)):
            scores[i, j] = phi4(gold_clusters[i], clusters[j])
    matching = np.transpose(np.asarray(linear_sum_assignment(-scores)))
 
    similarity = sum(scores[matching[:, 0], matching[:, 1]])
    return similarity, len(clusters), similarity, len(gold_clusters)
 
 
def lea(clusters, mention_to_gold):
    num, dem = 00
 
    for c in clusters:
        if len(c) == 1:
            continue
 
        common_links = 0
        all_links = len(c) * (len(c) - 1/ 2.0
        for i, m in enumerate(c):
            if m in mention_to_gold:
                for m2 in c[i + 1:]:
                    if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:
                        common_links += 1
 
        num += len(c) * common_links / float(all_links)
        dem += len(c)
 
    return num, dem
def f1(p_num, p_den, r_num, r_den, beta=1):
    p = 0 if p_den == 0 else p_num / float(p_den)
    r = 0 if r_den == 0 else r_num / float(r_den)
    return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
 
 
 
class CorefEvaluator(object):
    def __init__(self):
        self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]
 
    def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
        for e in self.evaluators:
            e.update(predicted, gold, mention_to_predicted, mention_to_gold)
 
    def get_f1(self):
        return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
 
    def get_recall(self):
        return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
 
    def get_precision(self):
        return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
 
    def get_prf(self):
        return self.get_precision(), self.get_recall(), self.get_f1()
 
 
 
 
 
cs

 

 

 

6. Train

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
 
 
 
def find_clusters(span_loc, index_clusters):
 
    if span_loc in index_clusters:
        return index_clusters[span_loc]
    return -1
 
 
def train(config):
 
    c = config
    tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"])
    transformer_model = AutoModel.from_pretrained(c["transformer_model_name"]).to(c["device"])
    transformer_optim = AdamW(transformer_model.parameters(), lr=c["transformer_lr"])
    transformer_model.train()
 
    print("config:")
    pprint.pprint(c)
 
    print("preparing data...")
 
    
 
 
    train_data = list()
    with open(c["train_file_path"], "r", encoding="utf-8"as fd:
        for line in fd:
            item = json.loads(line.strip())
            train_data.append(item)
    val_data = list()
    with open(c["val_file_path"], "r", encoding="utf-8"as fd:
        for line in fd:
            item = json.loads(line.strip())
            val_data.append(item)
 
 
 
    tokenized_train_data = list()
    for data_i in train_data:
        if len(data_i["sentences"]) == 0:
            print("Warning: `sentences` in %s is empty." % data_i["doc_key"])
        else:
            tokenized_train_data.append((tokenize_example(data_i, tokenizer, c)))
 
    tokenized_val_data = list()
    for data_i in val_data:
        if len(data_i["sentences"]) == 0:
            print("Warning: `sentences` in %s is empty." % data_i["doc_key"])
        else:
            tokenized_val_data.append((tokenize_example(data_i, tokenizer, c)))
 
 
    coref_model = CorefModel(c).to(c["device"])
    optimizer = torch.optim.Adam(coref_model.parameters(), lr=c["lr"], weight_decay=c["weight_decay"])
    accumulated_loss = 0.0
    max_f1 = 0.0
 
    print("start training...")
 
    init_time = time.time()
    steps = 0
    while True:
        for idx in RandomSampler(SequentialSampler(tokenized_train_data)):
            steps += 1
            sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, genre = tokenized_train_data[idx]
            if len(sentences_ids) > c["max_training_sentences"]:
                sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map = truncate_example(sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, c["max_training_sentences"])
 
            top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end = coref_model(sentences_ids, sentences_masks, sentences_valid_masks, speaker_ids, sentence_map, subtoken_map, genre, transformer_model)
 
            num_spans = len(top_m_spans_start)
            index_clusters = dict()
            for i, cluster in enumerate(clusters):
                for loc in cluster:
                    index_clusters[tuple(loc)] = i
            top_m_spans_cluster_idx = list()        # size: m
            for i in range(num_spans):
                cluster_idx = find_clusters((top_m_spans_start[i].item(), top_m_spans_end[i].item()), index_clusters)
                top_m_spans_cluster_idx.append(cluster_idx)
            top_m_spans_cluster_idx = torch.LongTensor(top_m_spans_cluster_idx).to(device=torch.device(c["device"]))
            top_m_spans_in_gold_clusters = (top_m_spans_cluster_idx != -1)
            # gold labels, size: m * k
            top_antecedents_label = (top_m_spans_cluster_idx[top_antecedents_index].t() == top_m_spans_cluster_idx).t()
            top_antecedents_label[~top_m_spans_in_gold_clusters] = False
            top_antecedents_label[~torch.gather(top_m_spans_masks.to(device=torch.device(c["device"])), 1, top_antecedents_index)] = False 
            # gold labels with dummy label, size: m * (k+1)
            top_antecedents_label = torch.cat((top_antecedents_label, ~torch.sum(top_antecedents_label, dim=1).bool().view(-1,1)), dim=1)
 
 
            gold_scores = top_antecedents_score + torch.log(top_antecedents_label.float())     # size: m * (k+1)
            marginalized_gold_scores = torch.logsumexp(gold_scores, 1)                  # size: m
            log_norm = torch.logsumexp(top_antecedents_score, 1)                          # size: m
 
            loss = torch.sum(log_norm - marginalized_gold_scores)
 
            optimizer.zero_grad()
            transformer_optim.zero_grad()
            loss.backward()
            optimizer.step()
            transformer_optim.step()
 
            accumulated_loss += loss.item()
 
            if steps % c["report_frequency"== 0:
                total_time = time.time() - init_time
                print("[%d] loss=%.7f, steps/s=%.4f" % (steps, accumulated_loss / c["report_frequency"], steps / total_time))
                accumulated_loss = 0.0
 
            if steps % c["eval_frequency"== 0:
               
                coref_model.eval()
                transformer_model.eval()
                torch.save({"model": coref_model.state_dict(), "optimizer": optimizer.state_dict(), "steps": steps}, c["checkpoint_path"+ "." + str(steps))
                transformer_model.save_pretrained(c["checkpoint_path"+ ".transformer." + str(steps))
                try:
                    p, r, f = coref_model.evaluate(tokenized_val_data, transformer_model)
                    if f >= max_f1:
                        max_f1 = f
                        torch.save({"model": coref_model.state_dict(), "optimizer": optimizer.state_dict(), "steps": steps}, c["checkpoint_path"+ ".max")
                        transformer_model.save_pretrained(c["checkpoint_path"+ ".transformer.max")
                    print("evaluation result:\np:%.4f,r:%.4f,f:%.4f(max f:%.4f)" % (p, r, f, max_f1))
                except Exception as e:
                    print("Error: evaluation error:", e)
                coref_model.train()
                transformer_model.train()
                torch.cuda.empty_cache()
                
 
 
 
train(config)
cs

 

 

 

7. Prediction

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
= config
 
print("load model...")
coref_model = CorefModel(c).eval().to(c["device"])
tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"])
 
transformer_model = AutoModel.from_pretrained(c["checkpoint_path"+ ".transformer.max").eval().to(c["device"])
checkpoint = torch.load(c["checkpoint_path"+ ".max", map_location=c["device"])
coref_model.load_state_dict(checkpoint["model"])
print("success!")
 
 
while True:
 
    sentence = input("input: ")
    res = tokenizer(sentence)
    sentence_ids, sentence_masks = res["input_ids"], res["attention_mask"]
    sentence_ids = torch.LongTensor([sentence_ids])
    sentence_masks = torch.LongTensor([sentence_masks])
 
 
    #[CLS][SEP] 
    sentence_valid_masks = sentence_masks.clone()
    sentence_valid_masks[0][0= 0
    sentence_valid_masks[0][-1= 0
 
    valid_num = torch.sum(sentence_valid_masks).item()
    speaker_ids = [[0* valid_num]
    sentence_map = [[0* valid_num]
    subtoken_map = [list(range(valid_num))]
 
    top_antecedents_score, top_antecedents_index, top_m_units_masks, top_m_units_start, top_m_units_end = coref_model(sentence_ids, sentence_masks, sentence_valid_masks, speaker_ids, sentence_map, subtoken_map, len(c["genres"]), transformer_model)
    predicted_antecedents = coref_model.get_predicted_antecedents(top_antecedents_index, top_antecedents_score)
    top_m_units = list()
    for i in range(len(top_m_units_start)):
        top_m_units.append([top_m_units_start[i], top_m_units_end[i]])
    predicted_clusters, _ = coref_model.get_predicted_clusters(top_m_units, predicted_antecedents)
 
    print("============================")
    print("tokenized context:")
    tokens = tokenizer.convert_ids_to_tokens(sentence_ids[sentence_valid_masks.bool()])
    print(tokens)
    print("predicted clusters:", predicted_clusters)
 
 
mentions = []
for coref_clusters in predicted_clusters:
    temp_token=[]
    for cluster in coref_clusters:
        temp_token.append(tokens[cluster[0].numpy():cluster[1].numpy()+1])                          
    mentions.append(temp_token)
 
cs

 

 

 

 

간단한 학습 결과 표

 

precision, recall, f1 score는 Metrics(MUC, B3, CEAF)의 average 이다.

 

 

 

 

 

 

 

 

 

코드 참조

 

github.com/cheniison/e2e-coref-pytorch

 

cheniison/e2e-coref-pytorch

Bert for End-to-end Neural Coreference Resolution in Pytorch - cheniison/e2e-coref-pytorch

github.com