forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
138 lines (119 loc) · 5.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pypinyin import lazy_pinyin, Style
import paddle
from paddlenlp.transformers import is_chinese_char
def read_train_ds(data_path):
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
source, target = line.strip('\n').split('\t')[0:2]
yield {'source': source, 'target': target}
def read_test_ds(data_path):
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
ids, words = line.strip('\n').split('\t')[0:2]
yield {'source': words}
def convert_example(example,
tokenizer,
pinyin_vocab,
max_seq_length=128,
ignore_label=-1,
is_test=False):
source = example["source"]
words = list(source)
if len(words) > max_seq_length - 2:
words = words[:max_seq_length - 2]
length = len(words)
words = ['[CLS]'] + words + ['[SEP]']
input_ids = tokenizer.convert_tokens_to_ids(words)
token_type_ids = [0] * len(input_ids)
# Use pad token in pinyin emb to map word emb [CLS], [SEP]
pinyins = lazy_pinyin(
source, style=Style.TONE3, neutral_tone_with_five=True)
pinyin_ids = [0]
# Align pinyin and chinese char
pinyin_offset = 0
for i, word in enumerate(words[1:-1]):
pinyin = '[UNK]' if word != '[PAD]' else '[PAD]'
if len(word) == 1 and is_chinese_char(ord(word)):
while pinyin_offset < len(pinyins):
current_pinyin = pinyins[pinyin_offset][:-1]
pinyin_offset += 1
if current_pinyin in pinyin_vocab:
pinyin = current_pinyin
break
pinyin_ids.append(pinyin_vocab[pinyin])
pinyin_ids.append(0)
assert len(input_ids) == len(
pinyin_ids), "length of input_ids must be equal to length of pinyin_ids"
if not is_test:
target = example["target"]
correction_labels = list(target)
if len(correction_labels) > max_seq_length - 2:
correction_labels = correction_labels[:max_seq_length - 2]
correction_labels = tokenizer.convert_tokens_to_ids(correction_labels)
correction_labels = [ignore_label] + correction_labels + [ignore_label]
detection_labels = []
for input_id, label in zip(input_ids[1:-1], correction_labels[1:-1]):
detection_label = 0 if input_id == label else 1
detection_labels += [detection_label]
detection_labels = [ignore_label] + detection_labels + [ignore_label]
return input_ids, token_type_ids, pinyin_ids, detection_labels, correction_labels, length
else:
return input_ids, token_type_ids, pinyin_ids, length
def create_dataloader(dataset,
mode='train',
batch_size=1,
batchify_fn=None,
trans_fn=None):
if trans_fn:
dataset = dataset.map(trans_fn)
shuffle = True if mode == 'train' else False
if mode == 'train':
batch_sampler = paddle.io.DistributedBatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
else:
batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
return paddle.io.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=batchify_fn,
return_list=True)
def parse_decode(words, corr_preds, det_preds, lengths, tokenizer,
max_seq_length):
UNK = tokenizer.unk_token
UNK_id = tokenizer.convert_tokens_to_ids(UNK)
corr_pred = corr_preds[1:1 + lengths].tolist()
det_pred = det_preds[1:1 + lengths].tolist()
words = list(words)
rest_words = []
if len(words) > max_seq_length - 2:
rest_words = words[max_seq_length - 2:]
words = words[:max_seq_length - 2]
pred_result = ""
for j, word in enumerate(words):
candidates = tokenizer.convert_ids_to_tokens(corr_pred[j] if corr_pred[
j] < tokenizer.vocab_size else UNK_id)
word_icc = is_chinese_char(ord(word))
cand_icc = is_chinese_char(ord(candidates)) if len(
candidates) == 1 else False
if not word_icc or det_pred[j] == 0\
or candidates in [UNK, '[PAD]']\
or (word_icc and not cand_icc):
pred_result += word
else:
pred_result += candidates.lstrip("##")
pred_result += ''.join(rest_words)
return pred_result