데이터 셋
총 124개의 서로 다른 주제의 documents
문장들, mention cluster 등이 json 형태로 구성
train data와 validation data의 비율은 8:2로 임의로 배분
이용 버전
torch==1.2.0
transformers==3.0.2
numpy==1.18.1
코드
1 . Config
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
#library
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch.nn.functional as F
import json
import os
from collections import OrderedDict
from easydict import EasyDict
import math
import re
from pathlib import Path
import random
import tqdm
from collections import Counter
from scipy.optimize import linear_sum_assignment
import time
import _pickle
import copy
import pprint
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler
from transformers import AutoTokenizer, AutoModel, AdamW
#config
config = EasyDict({
"embedding_dim": 768,
"max_span_width": 30,
# max training sentences depends on size of memery
"max_training_sentences": 10,
# max seq length
"max_seq_length": 128,
"bert_max_seq_length": 512,
"device": "cuda",
"checkpoint_path": "./Coreference_Resolution_dataset/checkpoint/checkpoint",
"lr": 0.0002,
"weight_decay": 0.0005,
"dropout": 0.3,
"report_frequency": 200,
"eval_frequency": 200,
# dir
"root_dir": "./Coreference_Resolution_dataset",
"train_file_path": "./Coreference_Resolution_dataset/preprocessing/train.json",
"test_file_path": "./Coreference_Resolution_dataset/preprocessing/test.json",
"val_file_path": "./Coreference_Resolution_dataset/preprocessing/val.json",
# max candidate mentions size in first/second stage
"top_span_ratio": 0.4,
"max_top_antecedents": 50,
# use coarse to fine pruning
"coarse_to_fine": True,
# high order coref depth
"coref_depth": 2,
# FFNN config
"ffnn_depth": 1,
"ffnn_size": 3000,
# use span features, such as distance
"use_features": True,
"feature_dim": 20,
"model_heads": True,
"use_metadata": False, # 나의 경우 dataset에 genre나 speaker가 없어서 False를 주었다
"genres": ["bc", "bn", "mz", "nw", "tc", "wb","dummy"],
"extract_spans": True,
"transformer_model_name": 'bert-base-multilingual-cased',
"transformer_lr": 0.00001,
})
c = config
tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"], do_lower_case=False)
root_path = Path(c["root_dir"])
#재현 가능하게 만들기
random_seed = 333
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
#torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
|
cs |
2. Loading and Preprocessing Data
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
|
def mydata_target_cluster(json_data): #하나의 다큐먼트의 json 데이터를 받아서 문장들의 모음과 antecedent를 가진 클러스터들을 반환한다.
clusters = []
raw_text = []
for entities in json_data['entity']:
cluster = []
for entity in entities['mention']:
sent_id = entity['sent_id']
cluster.append([sent_id, entity['start_eid'],entity['end_eid']])
if len(cluster)>1 : #antecedent가 없는 cluster은 가져오지 않는다.
clusters.append(cluster) #span ids should be in a same cluster 예시로 [[2,110,112], [2,124,124]]는 2번째 문장의 110~112스팬과 124~124 스팬이 하나의 군집을 이룬다는 뜻이다.
for text in json_data['sentence']:
raw_text.append(text['text'])
return raw_text, clusters
def formatting_data(raw_text, clusters, doc_key):
tokenizer = AutoTokenizer.from_pretrained(config["transformer_model_name"]) #multilingual bert tokenizer 이용
formatted_coref_chains = list()
line_corefs = dict()
sentence_words = list()
for sentence in raw_text:
words = sentence.split(" ") #문장을 띄어쓰기를 기준으로 나눈다
sentence_words.append(words)
for cluster in clusters:
formatted_coref_chain = list()
for coref in cluster:
line_num = coref[0]
line_start, line_end = coref[1], coref[2]
formatted_coref = [line_num, int(line_start), int(line_end) + 1]
formatted_coref_chain.append(formatted_coref)
if line_num not in line_corefs:
line_corefs[line_num] = list()
line_corefs[line_num].append(formatted_coref)
formatted_coref_chains.append(formatted_coref_chain)
sentences = list()
subtoken_map = list()
sentence_map = list()
speaker_ids = list()
subtoken_index = 0
for line_num, words in enumerate(sentence_words):# sent_id와 띄어쓰기를 기준으로 나누어진 리스트 문장을 가져온다
tokens = list()
sentence_subtoken_map = list()
for word in words: #문장에서 띄어쓰기로 나누어진 단어들을 가져온다
tokenized_word = tokenizer.tokenize(word) # 단어를 토크나이징한다.
subtoken_index += 1
for _ in range(len(tokenized_word)):
sentence_subtoken_map.append(subtoken_index) #토큰으로 나뉘더라도 동일한 단어를 이루면 같은 수로 인덱싱한다.
if line_num in line_corefs: # {'line_num' : [line_num, line_start, line_end+1], ...}
for coref in line_corefs[line_num]:
if coref[1] > len(tokens):
coref[1] += len(tokenized_word) - 1
coref[2] += len(tokenized_word) - 1
elif coref[2] > len(tokens):
coref[2] += len(tokenized_word) - 1
tokens.extend(tokenized_word)
sentences.append(tokens)
subtoken_map.append(sentence_subtoken_map)
sentence_map.append([line_num] * len(tokens))
speaker_list = ['_']*1000 #내 데이터셋은 따로 speaker가 없으므로 더미로 준다.
speaker_ids.append([speaker_list[line_num]] * len(tokens)) #speaker는 ['_']로 채워준다.
output_sentences = list()
output_subtoken_map = list()
output_sentence_map = list()
output_speaker_ids = list()
tmp_sentences = list()
tmp_subtoken_map = list()
tmp_sentence_map = list()
tmp_speaker_ids = list()
bias = 0
line_num = 0
while line_num < len(sentences):
sentence = sentences[line_num]
sentence_token_num = len(sentence)
if len(tmp_sentences) > 0 and sentence_token_num + len(tmp_sentences) > config.max_seq_length:
output_sentences.append(tmp_sentences)
output_subtoken_map.append(tmp_subtoken_map)
output_sentence_map.append(tmp_sentence_map)
output_speaker_ids.append(tmp_speaker_ids)
tmp_sentences = list()
tmp_subtoken_map = list()
tmp_sentence_map = list()
tmp_speaker_ids = list()
else:
tmp_sentences.extend(sentence)
tmp_subtoken_map.extend(subtoken_map[line_num])
tmp_sentence_map.extend(sentence_map[line_num])
tmp_speaker_ids.extend(speaker_ids[line_num])
if line_num in line_corefs:
for coref in line_corefs[line_num]:
coref[1] += bias
coref[2] += bias
bias += sentence_token_num
line_num += 1
if len(tmp_sentences) > 0:
output_sentences.append(tmp_sentences)
output_subtoken_map.append(tmp_subtoken_map)
output_sentence_map.append(tmp_sentence_map)
output_speaker_ids.append(tmp_speaker_ids)
output_clusters = list()
for formatted_coref_chain in formatted_coref_chains:
if len(formatted_coref_chain) > 1:
cluster = list()
for coref in formatted_coref_chain:
if coref[1] != coref[2]:
cluster.append([coref[1], coref[2] - 1])
if len(cluster) > 1:
output_clusters.append(cluster)
example = {"sentences": output_sentences, "clusters": output_clusters, "speaker_ids": output_speaker_ids, "sentence_map": output_sentence_map, "subtoken_map": output_subtoken_map, 'doc_key':doc_key, 'genre':c['genres'][6]} # 장르는 더미로 준다
return example
def add_bias(src, bias):
if type(src) is list:
res = list()
for element in src:
new_element = add_bias(element, bias)
res.append(new_element)
return res
else:
return src + bias
def final_formatting_data(example):
final_example = {"sentences": list(), "clusters": list(), "speaker_ids": list(), "sentence_map": list(), "subtoken_map": list(), "genre": '', "doc_key": ''}
clusters_bias = 0
final_example["sentences"].extend(example["sentences"])
final_example["clusters"].extend(add_bias(example["clusters"], clusters_bias))
final_example["speaker_ids"].extend(example["speaker_ids"])
final_example["sentence_map"].extend(example["sentence_map"])
final_example["subtoken_map"].extend(example["subtoken_map"])
final_example['genre'] = example['genre'] #장르는 모두 동일한 더미로 준다.
final_example['doc_key'] = example['doc_key']
clusters_bias += sum([len(s) for s in example["sentences"]])
return final_example
train_filenames = os.listdir('./Coreference_Resolution_dataset/train')
val_filenames = os.listdir('./Coreference_Resolution_dataset/val')
test_filenames = os.listdir('./Coreference_Resolution_dataset/test')
train_list = []
val_list = []
test_list = []
for train_filename in train_filenames:
if train_filename != 'train.json':
try:
with open('./Coreference_Resolution_dataset/train/'+train_filename) as file:
json_data_train = json.load(file)
doc_key = train_filenames[0].split()[1][:7]
raw_text, clusters = mydata_target_cluster(json_data_train)
example = formatting_data(raw_text, clusters, doc_key)
train_list.append(example)
except:
print(train_filename)
continue
train_file_fd = open(c["train_file_path"], "w", encoding="utf-8")
examples_json_train = ''
for example in train_list:
examples_json_train += json.dumps(example, ensure_ascii=False) + "\n"
train_file_fd.write(examples_json_train)
for val_filename in val_filenames:
if val_filename != 'val.json':
try:
with open('./Coreference_Resolution_dataset/val/'+val_filename) as file:
json_data_val = json.load(file)
doc_key = val_filenames[0].split()[1][:7]
raw_text, clusters = mydata_target_cluster(json_data_val)
example = formatting_data(raw_text, clusters, doc_key)
val_list.append(example)
except:
print(val_filename)
continue
val_file_fd = open(c["val_file_path"], "w", encoding="utf-8")
examples_json_val = ''
for example in val_list:
examples_json_val += json.dumps(example, ensure_ascii=False) + "\n"
val_file_fd.write(examples_json_val)
train_file_fd.close()
val_file_fd.close()
|
cs |
3. Tools
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
def print_spans_loc(spans_start, spans_end):
assert len(spans_start) == len(spans_end)
for i in range(len(spans_start)):
try:
print((spans_start[i].item(), spans_end[i].item()))
except:
print((spans_start[i], spans_end[i]))
def truncate_example(sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, max_training_sentences):
line_offset = random.randint(0, len(sentences_ids) - max_training_sentences)
truncated_sentences_ids = sentences_ids[line_offset:line_offset + max_training_sentences]
truncated_sentences_masks = sentences_masks[line_offset:line_offset + max_training_sentences]
truncated_sentences_valid_masks = sentences_valid_masks[line_offset:line_offset + max_training_sentences]
truncated_speaker_ids = speaker_ids[line_offset:line_offset + max_training_sentences]
truncated_sentence_map = sentence_map[line_offset:line_offset + max_training_sentences]
truncated_subtoken_map = subtoken_map[line_offset:line_offset + max_training_sentences]
token_offset = torch.sum(sentences_valid_masks[:line_offset]).item()
token_num = torch.sum(truncated_sentences_valid_masks).item()
truncated_clusters = list()
for cluster in clusters:
truncated_cluster = list()
for start_loc, end_loc in cluster:
if start_loc - token_offset >= 0 and end_loc <= token_offset + token_num:
truncated_cluster.append([start_loc-token_offset, end_loc-token_offset])
if len(truncated_cluster) > 0:
truncated_clusters.append(truncated_cluster)
return truncated_sentences_ids, truncated_sentences_masks, truncated_sentences_valid_masks, truncated_clusters, truncated_speaker_ids, truncated_sentence_map, truncated_subtoken_map
def add_bias_to_clusters(clusters, bias, coref_filter):
for cluster in clusters:
for coref in cluster:
if coref_filter(coref) == True:
coref[0] += bias
coref[1] += bias
def tokenize_example(example, tokenizer, c):
""" tokenize example
"""
sentences = example["sentences"]
clusters = example["clusters"]
speaker_ids = example["speaker_ids"]
sentence_map = example["sentence_map"]
subtoken_map = example["subtoken_map"]
example_genre_name = example["genre"]
# genre to idx
genre = 0
while genre < len(c["genres"]):
if example_genre_name == c["genres"][genre]:
break
genre += 1
# token to ids
sentences_ids = list()
sentences_masks = list()
max_seq_len = max([len(s) for s in sentences]) + 2 #[CLS] [SEP]
for sentence in sentences:
sentence_ids = tokenizer.convert_tokens_to_ids(sentence)
sentence_ids = [101] + sentence_ids + [102]
token_num = len(sentence_ids)
sentence_masks = [1] * token_num
if token_num > c["bert_max_seq_length"]:
raise Exception("the length of sentence is out the range of bert_max_seq_length.")
else:
sentence_ids += [0] * (max_seq_len - token_num)
sentence_masks += [0] * (max_seq_len - token_num)
sentences_ids.append(sentence_ids)
sentences_masks.append(sentence_masks)
sentences_ids = torch.LongTensor(sentences_ids)
sentences_masks = torch.LongTensor(sentences_masks)
# convert speaker_ids to long type
speaker_dict = dict()
speaker_index = 0
speaker_ids_long = list()
for sentence_speaker_ids in speaker_ids:
sentence_speaker_ids_long = list()
for speaker in sentence_speaker_ids:
if speaker not in speaker_dict:
speaker_dict[speaker] = speaker_index
speaker_index += 1
speaker_id = speaker_dict[speaker]
sentence_speaker_ids_long.append(speaker_id)
speaker_ids_long.append(sentence_speaker_ids_long)
speaker_ids = speaker_ids_long
sentence_tokens_num = torch.sum(sentences_masks, dim=1)
sentences_valid_masks = sentences_masks.clone()
sentences_valid_masks[:, 0] = 0
for i in range(len(sentences_valid_masks)):
sentences_valid_masks[i][sentence_tokens_num[i] - 1] = 0
for i in range(len(sentences)):
if not (len(sentences[i]) == len(speaker_ids[i]) == len(sentence_map[i]) == len(subtoken_map[i])):
raise Exception("The length of sentence/speaker_ids/sentence_map/subtoken_map is inconsistent.")
return sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, genre
|
cs |
4. Model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
|
class CorefModel(torch.nn.Module):
def __init__(self, config):
super(CorefModel, self).__init__()
self.config = config
self.span_dim = 2 * self.config["embedding_dim"]
mm_input_dim = 0
if self.config["use_features"]:
self.span_dim += self.config["feature_dim"]
self.span_width_embeddings = torch.nn.Embedding(self.config["max_span_width"], self.config["feature_dim"])
self.bucket_distance_embeddings = torch.nn.Embedding(10, self.config["feature_dim"])
mm_input_dim += self.config["feature_dim"]
if self.config["model_heads"]:
self.span_dim += self.config["embedding_dim"]
self.Sh = torch.nn.Linear(self.config["embedding_dim"], 1) # token head score
if self.config["use_metadata"]:
self.genre_embeddings = torch.nn.Embedding(len(self.config["genres"]) + 1, self.config["feature_dim"])
self.same_speaker_emb = torch.nn.Embedding(2, self.config["feature_dim"])
mm_input_dim += 2 * self.config["feature_dim"]
mm_input_dim += 3 * self.span_dim
self.Sm = self._create_score_ffnn(self.span_dim) # mention score
self.Smm = self._create_score_ffnn(mm_input_dim) # pairwise score between spans
self.c2fP = torch.nn.Linear(self.span_dim, self.span_dim) # coarse to fine pruning span projection
self.hoP = torch.nn.Linear(2 * self.span_dim, self.span_dim) # high order projection
def _create_ffnn(self, input_size, output_size, ffnn_size, ffnn_depth, dropout=0): # mention score와 pairwise score 계산을 위한 ffnn을 구성해준다.
current_size = input_size
model_seq = OrderedDict()
for i in range(ffnn_depth):
model_seq['fc' + str(i)] = torch.nn.Linear(current_size, ffnn_size)
model_seq['relu' + str(i)] = torch.nn.ReLU()
model_seq['dropout' + str(i)] = torch.nn.Dropout(dropout)
current_size = ffnn_size
model_seq['output'] = torch.nn.Linear(current_size, output_size)
return torch.nn.Sequential(model_seq)
def _create_score_ffnn(self, input_size):
return self._create_ffnn(input_size, 1, self.config["ffnn_size"], self.config["ffnn_depth"], self.config["dropout"])
def bucket_distance(self, distances):
"""
Places the given values (designed for distances) into 10 semi-logscale buckets:
[0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32-63, 64+].
"""
float_distances = distances.float()
combined_idx = torch.floor(torch.log(float_distances) / math.log(2)) + 3
use_identity = distances <= 4
combined_idx[use_identity] = float_distances[use_identity]
combined_idx = combined_idx.long()
return torch.clamp(combined_idx, 0, 9)
def get_span_embed(self, tokens_embed, spans_start, spans_end):
span_embed_list = list()
start_embed = tokens_embed[spans_start]
end_embed = tokens_embed[spans_end]
span_embed_list.append(start_embed)
span_embed_list.append(end_embed)
if self.config["use_features"]:
spans_width = (spans_end - spans_start).to(device=torch.device(self.config["device"]))
span_width_embed = self.span_width_embeddings(spans_width)
span_width_embed = torch.nn.functional.dropout(span_width_embed, p=self.config["dropout"], training=self.training)
span_embed_list.append(span_width_embed)
if self.config["model_heads"]:
tokens_score = self.Sh(tokens_embed).view(-1) # size: num_tokens
tokens_locs = torch.arange(start=0, end=len(tokens_embed), dtype=torch.long).repeat(len(spans_start), 1) # size: num_spans * num_tokens
tokens_masks = (tokens_locs >= spans_start.view(-1, 1)) & (tokens_locs <= spans_end.view(-1, 1)) # size: num_spans * num_tokens
tokens_weights = torch.nn.functional.softmax(
(tokens_score + torch.log(tokens_masks.float()).to(device=torch.device(self.config["device"]))),
dim=1
)
span_head_emb = torch.matmul(tokens_weights, tokens_embed)
span_embed_list.append(span_head_emb)
return torch.cat(span_embed_list, dim=1)
def extract_spans(self, spans_score, spans_start, spans_end, m):
""" top m spans
"""
sorted_spans_score, indices = torch.sort(spans_score, 0, True)
top_m_spans_index = list()
top_m_spans_start = torch.zeros(m, dtype=torch.long)
top_m_spans_end = torch.zeros(m, dtype=torch.long)
top_m_len = 0
i = 0
while top_m_len < m and i < len(sorted_spans_score):
span_index = indices[i]
span_start = spans_start[span_index]
span_end = spans_end[span_index]
res = (((span_start < top_m_spans_start) & (span_end < top_m_spans_end) & (span_end >= top_m_spans_start))
| ((span_start > top_m_spans_start) & (span_start <= top_m_spans_end) & (span_end > top_m_spans_end)))
if torch.sum(res) == 0:
top_m_spans_index.append(span_index)
top_m_spans_start[top_m_len] = span_start
top_m_spans_end[top_m_len] = span_end
top_m_len += 1
i += 1
return torch.stack(top_m_spans_index)
def coarse_to_fine_pruning(self, k, spans_masks, spans_embed, spans_score):
"""
parameters:
k: int
spans_masks: m * m
spans_embed: m * span_dim
spans_score: m * 1
return
score: FloatTensor m * k
index: Long m * k
"""
m = len(spans_embed)
all_score = torch.zeros(m, m).to(device=torch.device(self.config["device"]))
all_score[~spans_masks] = float("-inf")
# add span score
all_score += spans_score
antecedents_offset = (torch.arange(0, m).view(-1, 1) - torch.arange(0, m).view(1, -1))
antecedents_offset = antecedents_offset.to(device=torch.device(self.config["device"]))
if self.config["coarse_to_fine"] == True:
source_top_span_emb = torch.nn.functional.dropout(self.c2fP(spans_embed), p=self.config["dropout"], training=self.training)
target_top_span_emb = torch.nn.functional.dropout(spans_embed, p=self.config["dropout"], training=self.training)
all_score += source_top_span_emb.matmul(target_top_span_emb.t()) # m * m
else:
k = m
top_antecedents_fast_score, top_antecedents_index = torch.topk(all_score, k)
top_antecedents_offset = torch.gather(antecedents_offset, dim=1, index=top_antecedents_index)
return top_antecedents_fast_score, top_antecedents_index, top_antecedents_offset
def get_spans_similarity_score(self, top_antecedents_index, spans_embed, top_antecedents_offset, speaker_ids, genre_emb):
"""
parameters:
top_antecedents_index: Long m * k
spans_embed: m * span_dim
top_antecedents_offset: m * k
speaker_ids: m
genre_emb: feature_dim
return:
score: FloatTensor m * k
"""
m = len(spans_embed)
k = top_antecedents_index.shape[1]
span_index = torch.arange(0, m, dtype=torch.long).repeat(k, 1).t()
mm_ffnn_input_list = list()
mm_ffnn_input_list.append(spans_embed[span_index])
mm_ffnn_input_list.append(spans_embed[top_antecedents_index])
mm_ffnn_input_list.append(mm_ffnn_input_list[0] * mm_ffnn_input_list[1])
if self.config["use_features"]:
top_antecedents_distance_bucket = self.bucket_distance(top_antecedents_offset)
top_antecedents_distance_emb = self.bucket_distance_embeddings(top_antecedents_distance_bucket)
mm_ffnn_input_list.append(top_antecedents_distance_emb)
if self.config["use_metadata"]:
same_speaker_ids = (speaker_ids.view(-1, 1) == speaker_ids[top_antecedents_index]).long().to(device=torch.device(self.config["device"]))
speaker_emb = self.same_speaker_emb(same_speaker_ids)
mm_ffnn_input_list.append(speaker_emb)
mm_ffnn_input_list.append(genre_emb.repeat(m, k, 1))
mm_ffnn_input = torch.cat(mm_ffnn_input_list, dim=2)
mm_slow_score = self.Smm(mm_ffnn_input)
return mm_slow_score.squeeze()
def forward(self, sentences_ids, sentences_masks, sentences_valid_masks, speaker_ids, sentence_map, subtoken_map, genre, transformer_model):
"""
parameters:
sentences_ids: num_sentence * max_sentence_len
sentences_masks: num_sentence * max_sentence_len
sentences_valid_masks: num_sentence * max_sentence_len
speaker_ids: list[list]
sentence_map: list[list]
subtoken_map: list[list]
genre: genre_id
transformer_model: AutoModel
"""
sentences_embed, _ = transformer_model(sentences_ids.to(device=torch.device(self.config["device"])), sentences_masks.to(device=torch.device(self.config["device"]))) # num_sentence * max_sentence_len * embed_dim
tokens_embed = sentences_embed[sentences_valid_masks.bool()] # num_tokens * embed_dim
flattened_sentence_indices = list()
for sm in sentence_map:
flattened_sentence_indices += sm
flattened_sentence_indices = torch.LongTensor(flattened_sentence_indices)
candidate_spans_start = torch.arange(0, len(tokens_embed)).repeat(self.config["max_span_width"], 1).t()
candidate_spans_end = candidate_spans_start + torch.arange(0, self.config["max_span_width"])
candidate_spans_start_sentence_indices = flattened_sentence_indices[candidate_spans_start]
candidate_spans_end_sentence_indices = flattened_sentence_indices[torch.min(candidate_spans_end, torch.tensor(len(tokens_embed) - 1))]
candidate_spans_mask = (candidate_spans_end < len(tokens_embed)) & (candidate_spans_start_sentence_indices == candidate_spans_end_sentence_indices)
spans_start = candidate_spans_start[candidate_spans_mask]
spans_end = candidate_spans_end[candidate_spans_mask]
spans_embed = self.get_span_embed(tokens_embed, spans_start, spans_end) # size: num_spans * span_dim
spans_score = self.Sm(spans_embed) # size: num_spans * 1
spans_len = len(spans_score)
spans_score_mask = torch.zeros(spans_len, dtype=torch.bool)
m = min(int(self.config["top_span_ratio"] * len(tokens_embed)), spans_len)
if self.config["extract_spans"]:
span_indices = self.extract_spans(spans_score, spans_start, spans_end, m)
else:
_, span_indices = torch.topk(spans_score, m, dim=0, largest=True)
m = len(span_indices)
spans_score_mask[span_indices] = True
top_m_spans_embed = spans_embed[spans_score_mask] # size: m * span_dim
top_m_spans_score = spans_score[spans_score_mask] # size: m * 1
top_m_spans_start = spans_start[spans_score_mask] # size: m
top_m_spans_end = spans_end[spans_score_mask] # size: m
# BoolTensor, size: m * m
top_m_spans_masks = \
((top_m_spans_start.repeat(m, 1).t() > top_m_spans_start) \
| ((top_m_spans_start.repeat(m, 1).t() == top_m_spans_start) \
& (top_m_spans_end.repeat(m, 1).t() > top_m_spans_end)))
if self.config['use_metadata']:
flattened_speaker_ids = list()
for si in speaker_ids:
flattened_speaker_ids += si
flattened_speaker_ids = torch.LongTensor(flattened_speaker_ids)
top_m_spans_speaker_ids = flattened_speaker_ids[top_m_spans_start]
genre_emb = self.genre_embeddings(torch.LongTensor([genre]).to(device=torch.device(self.config["device"]))).squeeze()
else:
top_m_spans_speaker_ids = None
genre_emb = None
k = min(self.config["max_top_antecedents"], m)
top_antecedents_fast_score, top_antecedents_index, top_antecedents_offset = self.coarse_to_fine_pruning(
k, top_m_spans_masks, top_m_spans_embed, top_m_spans_score
)
for i in range(self.config["coref_depth"]):
top_antecedents_slow_score = self.get_spans_similarity_score(top_antecedents_index, top_m_spans_embed, top_antecedents_offset, top_m_spans_speaker_ids, genre_emb)
top_antecedents_score = top_antecedents_fast_score + top_antecedents_slow_score # size: m * k
dummy_score = torch.zeros(m, 1).to(device=torch.device(self.config["device"])) # add dummy
top_antecedents_score = torch.cat((top_antecedents_score, dummy_score), dim=1) # size: m * (k+1)
top_antecedents_weight = torch.nn.functional.softmax(top_antecedents_score, dim=1) # size: m * (k+1)
top_antecedents_emb = torch.cat((top_m_spans_embed[top_antecedents_index], top_m_spans_embed.unsqueeze(1)), dim=1) # size: m * (k+1) * embed
attended_spans_emb = torch.sum(top_antecedents_weight.unsqueeze(2) * top_antecedents_emb, dim=1) # size: m * embed
f = torch.sigmoid(self.hoP(torch.cat([top_m_spans_embed, attended_spans_emb], dim=1))) # size: m * embed
top_m_spans_embed = f * attended_spans_emb + (1 - f) * top_m_spans_embed # size: m * embed
return top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end
def evaluate(self, data, transformer_model):
coref_evaluator = CorefEvaluator()
with torch.no_grad():
for idx, data_i in enumerate(data):
sentences_ids, sentences_masks, sentences_valid_masks, gold_clusters, speaker_ids, sentence_map, subtoken_map, genre = data_i
if len(sentences_ids) > c["max_training_sentences"]:
sentences_ids, sentences_masks, sentences_valid_masks, gold_clusters, speaker_ids, sentence_map, subtoken_map = truncate_example(sentences_ids, sentences_masks, sentences_valid_masks, gold_clusters, speaker_ids, sentence_map, subtoken_map, c["max_training_sentences"])
top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end = self.forward(sentences_ids, sentences_masks, sentences_valid_masks, speaker_ids, sentence_map, subtoken_map, genre, transformer_model)
predicted_antecedents = self.get_predicted_antecedents(top_antecedents_index, top_antecedents_score)
top_m_spans = list()
for i in range(len(top_m_spans_start)):
top_m_spans.append(tuple([top_m_spans_start[i].item(), top_m_spans_end[i].item()]))
# all spans
gold_clusters = [tuple(tuple([m[0], m[1]]) for m in gc) for gc in gold_clusters]
mention_to_gold = {}
for gc in gold_clusters:
for mention in gc:
mention_to_gold[tuple(mention)] = gc
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(top_m_spans, predicted_antecedents)
coref_evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
return coref_evaluator.get_prf()
def get_predicted_antecedents(self, top_antecedents_index, top_antecedents_score):
predicted_antecedents = []
for i, index in enumerate(torch.argmax(top_antecedents_score, axis=1)):
if index == len(top_antecedents_score[i]) - 1:
predicted_antecedents.append(-1)
else:
predicted_antecedents.append(top_antecedents_index[i][index])
return predicted_antecedents
def get_predicted_clusters(self, top_m_spans, predicted_antecedents):
idx_to_clusters = {}
predicted_clusters = []
for i in range(len(top_m_spans)):
idx_to_clusters[i] = set([i])
for i, predicted_index in enumerate(predicted_antecedents):
if predicted_index < 0:
continue
else:
union_cluster = idx_to_clusters[predicted_index.item()] | idx_to_clusters[i]
for j in union_cluster:
idx_to_clusters[j] = union_cluster
tagged_index = set()
for i in idx_to_clusters:
if (len(idx_to_clusters[i]) == 1) or (i in tagged_index):
continue
cluster = idx_to_clusters[i]
predicted_cluster = list()
for j in cluster:
tagged_index.add(j)
predicted_cluster.append(tuple(top_m_spans[j]))
predicted_clusters.append(tuple(predicted_cluster))
mention_to_predicted = {}
for pc in predicted_clusters:
for mention in pc:
mention_to_predicted[tuple(mention)] = pc
return predicted_clusters, mention_to_predicted
|
cs |
5. Metric
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
class Evaluator(object):
def __init__(self, metric, beta=1):
self.p_num = 0
self.p_den = 0
self.r_num = 0
self.r_den = 0
self.metric = metric
self.beta = beta
def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
if self.metric == ceafe:
pn, pd, rn, rd = self.metric(predicted, gold)
else:
pn, pd = self.metric(predicted, mention_to_gold)
rn, rd = self.metric(gold, mention_to_predicted)
self.p_num += pn
self.p_den += pd
self.r_num += rn
self.r_den += rd
def get_f1(self):
return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
def get_recall(self):
return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
def get_precision(self):
return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
def get_prf(self):
return self.get_precision(), self.get_recall(), self.get_f1()
def get_counts(self):
return self.p_num, self.p_den, self.r_num, self.r_den
def evaluate_documents(documents, metric, beta=1):
evaluator = Evaluator(metric, beta=beta)
for document in documents:
evaluator.update(document)
return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1()
def b_cubed(clusters, mention_to_gold):
num, dem = 0, 0
for c in clusters:
if len(c) == 1:
continue
gold_counts = Counter()
correct = 0
for m in c:
if m in mention_to_gold:
gold_counts[tuple(mention_to_gold[m])] += 1
for c2, count in gold_counts.items():
if len(c2) != 1:
correct += count * count
num += correct / float(len(c))
dem += len(c)
return num, dem
def muc(clusters, mention_to_gold):
tp, p = 0, 0
for c in clusters:
p += len(c) - 1
tp += len(c)
linked = set()
for m in c:
if m in mention_to_gold:
linked.add(mention_to_gold[m])
else:
tp -= 1
tp -= len(linked)
return tp, p
def phi4(c1, c2):
return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
def ceafe(clusters, gold_clusters):
clusters = [c for c in clusters if len(c) != 1]
scores = np.zeros((len(gold_clusters), len(clusters)))
for i in range(len(gold_clusters)):
for j in range(len(clusters)):
scores[i, j] = phi4(gold_clusters[i], clusters[j])
matching = np.transpose(np.asarray(linear_sum_assignment(-scores)))
similarity = sum(scores[matching[:, 0], matching[:, 1]])
return similarity, len(clusters), similarity, len(gold_clusters)
def lea(clusters, mention_to_gold):
num, dem = 0, 0
for c in clusters:
if len(c) == 1:
continue
common_links = 0
all_links = len(c) * (len(c) - 1) / 2.0
for i, m in enumerate(c):
if m in mention_to_gold:
for m2 in c[i + 1:]:
if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:
common_links += 1
num += len(c) * common_links / float(all_links)
dem += len(c)
return num, dem
def f1(p_num, p_den, r_num, r_den, beta=1):
p = 0 if p_den == 0 else p_num / float(p_den)
r = 0 if r_den == 0 else r_num / float(r_den)
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
class CorefEvaluator(object):
def __init__(self):
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]
def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
for e in self.evaluators:
e.update(predicted, gold, mention_to_predicted, mention_to_gold)
def get_f1(self):
return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
def get_recall(self):
return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
def get_precision(self):
return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
def get_prf(self):
return self.get_precision(), self.get_recall(), self.get_f1()
|
cs |
6. Train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
def find_clusters(span_loc, index_clusters):
if span_loc in index_clusters:
return index_clusters[span_loc]
return -1
def train(config):
c = config
tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"])
transformer_model = AutoModel.from_pretrained(c["transformer_model_name"]).to(c["device"])
transformer_optim = AdamW(transformer_model.parameters(), lr=c["transformer_lr"])
transformer_model.train()
print("config:")
pprint.pprint(c)
print("preparing data...")
train_data = list()
with open(c["train_file_path"], "r", encoding="utf-8") as fd:
for line in fd:
item = json.loads(line.strip())
train_data.append(item)
val_data = list()
with open(c["val_file_path"], "r", encoding="utf-8") as fd:
for line in fd:
item = json.loads(line.strip())
val_data.append(item)
tokenized_train_data = list()
for data_i in train_data:
if len(data_i["sentences"]) == 0:
print("Warning: `sentences` in %s is empty." % data_i["doc_key"])
else:
tokenized_train_data.append((tokenize_example(data_i, tokenizer, c)))
tokenized_val_data = list()
for data_i in val_data:
if len(data_i["sentences"]) == 0:
print("Warning: `sentences` in %s is empty." % data_i["doc_key"])
else:
tokenized_val_data.append((tokenize_example(data_i, tokenizer, c)))
coref_model = CorefModel(c).to(c["device"])
optimizer = torch.optim.Adam(coref_model.parameters(), lr=c["lr"], weight_decay=c["weight_decay"])
accumulated_loss = 0.0
max_f1 = 0.0
print("start training...")
init_time = time.time()
steps = 0
while True:
for idx in RandomSampler(SequentialSampler(tokenized_train_data)):
steps += 1
sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, genre = tokenized_train_data[idx]
if len(sentences_ids) > c["max_training_sentences"]:
sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map = truncate_example(sentences_ids, sentences_masks, sentences_valid_masks, clusters, speaker_ids, sentence_map, subtoken_map, c["max_training_sentences"])
top_antecedents_score, top_antecedents_index, top_m_spans_masks, top_m_spans_start, top_m_spans_end = coref_model(sentences_ids, sentences_masks, sentences_valid_masks, speaker_ids, sentence_map, subtoken_map, genre, transformer_model)
num_spans = len(top_m_spans_start)
index_clusters = dict()
for i, cluster in enumerate(clusters):
for loc in cluster:
index_clusters[tuple(loc)] = i
top_m_spans_cluster_idx = list() # size: m
for i in range(num_spans):
cluster_idx = find_clusters((top_m_spans_start[i].item(), top_m_spans_end[i].item()), index_clusters)
top_m_spans_cluster_idx.append(cluster_idx)
top_m_spans_cluster_idx = torch.LongTensor(top_m_spans_cluster_idx).to(device=torch.device(c["device"]))
top_m_spans_in_gold_clusters = (top_m_spans_cluster_idx != -1)
# gold labels, size: m * k
top_antecedents_label = (top_m_spans_cluster_idx[top_antecedents_index].t() == top_m_spans_cluster_idx).t()
top_antecedents_label[~top_m_spans_in_gold_clusters] = False
top_antecedents_label[~torch.gather(top_m_spans_masks.to(device=torch.device(c["device"])), 1, top_antecedents_index)] = False
# gold labels with dummy label, size: m * (k+1)
top_antecedents_label = torch.cat((top_antecedents_label, ~torch.sum(top_antecedents_label, dim=1).bool().view(-1,1)), dim=1)
gold_scores = top_antecedents_score + torch.log(top_antecedents_label.float()) # size: m * (k+1)
marginalized_gold_scores = torch.logsumexp(gold_scores, 1) # size: m
log_norm = torch.logsumexp(top_antecedents_score, 1) # size: m
loss = torch.sum(log_norm - marginalized_gold_scores)
optimizer.zero_grad()
transformer_optim.zero_grad()
loss.backward()
optimizer.step()
transformer_optim.step()
accumulated_loss += loss.item()
if steps % c["report_frequency"] == 0:
total_time = time.time() - init_time
print("[%d] loss=%.7f, steps/s=%.4f" % (steps, accumulated_loss / c["report_frequency"], steps / total_time))
accumulated_loss = 0.0
if steps % c["eval_frequency"] == 0:
coref_model.eval()
transformer_model.eval()
torch.save({"model": coref_model.state_dict(), "optimizer": optimizer.state_dict(), "steps": steps}, c["checkpoint_path"] + "." + str(steps))
transformer_model.save_pretrained(c["checkpoint_path"] + ".transformer." + str(steps))
try:
p, r, f = coref_model.evaluate(tokenized_val_data, transformer_model)
if f >= max_f1:
max_f1 = f
torch.save({"model": coref_model.state_dict(), "optimizer": optimizer.state_dict(), "steps": steps}, c["checkpoint_path"] + ".max")
transformer_model.save_pretrained(c["checkpoint_path"] + ".transformer.max")
print("evaluation result:\np:%.4f,r:%.4f,f:%.4f(max f:%.4f)" % (p, r, f, max_f1))
except Exception as e:
print("Error: evaluation error:", e)
coref_model.train()
transformer_model.train()
torch.cuda.empty_cache()
train(config)
|
cs |
7. Prediction
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
c = config
print("load model...")
coref_model = CorefModel(c).eval().to(c["device"])
tokenizer = AutoTokenizer.from_pretrained(c["transformer_model_name"])
transformer_model = AutoModel.from_pretrained(c["checkpoint_path"] + ".transformer.max").eval().to(c["device"])
checkpoint = torch.load(c["checkpoint_path"] + ".max", map_location=c["device"])
coref_model.load_state_dict(checkpoint["model"])
print("success!")
while True:
sentence = input("input: ")
res = tokenizer(sentence)
sentence_ids, sentence_masks = res["input_ids"], res["attention_mask"]
sentence_ids = torch.LongTensor([sentence_ids])
sentence_masks = torch.LongTensor([sentence_masks])
#[CLS][SEP]
sentence_valid_masks = sentence_masks.clone()
sentence_valid_masks[0][0] = 0
sentence_valid_masks[0][-1] = 0
valid_num = torch.sum(sentence_valid_masks).item()
speaker_ids = [[0] * valid_num]
sentence_map = [[0] * valid_num]
subtoken_map = [list(range(valid_num))]
top_antecedents_score, top_antecedents_index, top_m_units_masks, top_m_units_start, top_m_units_end = coref_model(sentence_ids, sentence_masks, sentence_valid_masks, speaker_ids, sentence_map, subtoken_map, len(c["genres"]), transformer_model)
predicted_antecedents = coref_model.get_predicted_antecedents(top_antecedents_index, top_antecedents_score)
top_m_units = list()
for i in range(len(top_m_units_start)):
top_m_units.append([top_m_units_start[i], top_m_units_end[i]])
predicted_clusters, _ = coref_model.get_predicted_clusters(top_m_units, predicted_antecedents)
print("============================")
print("tokenized context:")
tokens = tokenizer.convert_ids_to_tokens(sentence_ids[sentence_valid_masks.bool()])
print(tokens)
print("predicted clusters:", predicted_clusters)
mentions = []
for coref_clusters in predicted_clusters:
temp_token=[]
for cluster in coref_clusters:
temp_token.append(tokens[cluster[0].numpy():cluster[1].numpy()+1])
mentions.append(temp_token)
|
cs |
간단한 학습 결과 표
precision, recall, f1 score는 Metrics(MUC, B3, CEAF)의 average 이다.
코드 참조
github.com/cheniison/e2e-coref-pytorch
cheniison/e2e-coref-pytorch
Bert for End-to-end Neural Coreference Resolution in Pytorch - cheniison/e2e-coref-pytorch
github.com