forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
131 lines (112 loc) · 6.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
class ErnieForCSC(nn.Layer):
r"""
ErnieForCSC is a model specified for Chinese Spelling Correction task.
It integrates phonetic features into language model by leveraging the powerful
pre-training and fine-tuning method.
See more details on https://aclanthology.org/2021.findings-acl.198.pdf.
Args:
ernie (ErnieModel):
An instance of `paddlenlp.transformers.ErnieModel`.
pinyin_vocab_size (int):
The vocab size of pinyin vocab.
pad_pinyin_id (int, optional):
The pad token id of pinyin vocab. Defaults to 0.
"""
def __init__(self, ernie, pinyin_vocab_size, pad_pinyin_id=0):
super(ErnieForCSC, self).__init__()
self.ernie = ernie
emb_size = self.ernie.config["hidden_size"]
hidden_size = self.ernie.config["hidden_size"]
vocab_size = self.ernie.config["vocab_size"]
self.pad_token_id = self.ernie.config["pad_token_id"]
self.pinyin_vocab_size = pinyin_vocab_size
self.pad_pinyin_id = pad_pinyin_id
self.pinyin_embeddings = nn.Embedding(
self.pinyin_vocab_size, emb_size, padding_idx=pad_pinyin_id)
self.detection_layer = nn.Linear(hidden_size, 2)
self.correction_layer = nn.Linear(hidden_size, vocab_size)
self.softmax = nn.Softmax()
def forward(self,
input_ids,
pinyin_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
r"""
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
It's data type should be `int64` and has a shape of [batch_size, sequence_length].
pinyin_ids (Tensor):
Indices of pinyin tokens of input sequence in the pinyin vocabulary. They are
numerical representations of tokens that build the pinyin input sequence.
It's data type should be `int64` and has a shape of [batch_size, sequence_length].
token_type_ids (Tensor, optional):
Segment token indices to indicate first and second portions of the inputs.
Indices can be either 0 or 1:
- 0 corresponds to a **sentence A** token,
- 1 corresponds to a **sentence B** token.
It's data type should be `int64` and has a shape of [batch_size, sequence_length].
Defaults to None, which means no segment embeddings is added to token embeddings.
position_ids (Tensor, optional):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`.
attention_mask (Tensor, optional):
Mask to indicate whether to perform attention on each input token or not.
The values should be either 0 or 1. The attention scores will be set
to **-infinity** for any positions in the mask that are **0**, and will be
**unchanged** for positions that are **1**.
- **1** for tokens that are **not masked**,
- **0** for tokens that are **masked**.
It's data type should be `float32` and has a shape of [batch_size, sequence_length].
Defaults to `None`.
Returns:
detection_error_probs (Tensor):
A Tensor of the detection probablity of each tokens.
Shape as `(batch_size, sequence_length, 2)` and dtype as `int`.
correction_logits (Tensor):
A Tensor of the correction logits of each tokens.
Shape as `(batch_size, sequence_length, vocab_size)` and dtype as `int`.
"""
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id
).astype(self.detection_layer.weight.dtype) * -1e9,
axis=[1, 2])
embedding_output = self.ernie.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
pinyin_embedding_output = self.pinyin_embeddings(pinyin_ids)
# Detection module aims to detect whether each Chinese charater has spelling error.
detection_outputs = self.ernie.encoder(embedding_output, attention_mask)
# detection_error_probs shape: [B, T, 2]. It indicates the erroneous probablity of each
# word in the sequence from 0 to 1.
detection_error_probs = self.softmax(
self.detection_layer(detection_outputs))
# Correction module aims to correct each potential wrong charater to right charater.
word_pinyin_embedding_output = detection_error_probs[:, :, 0:1] * embedding_output \
+ detection_error_probs[:,:, 1:2] * pinyin_embedding_output
correction_outputs = self.ernie.encoder(word_pinyin_embedding_output,
attention_mask)
# correction_logits shape: [B, T, V]. It indicates the correct score of each token in vocab
# according to each word in the sequence.
correction_logits = self.correction_layer(correction_outputs)
return detection_error_probs, correction_logits