-
Notifications
You must be signed in to change notification settings - Fork 0
/
earlystopping.py
82 lines (74 loc) · 2.91 KB
/
earlystopping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import List
import copy
import operator
from enum import Enum, auto
import numpy as np
from torch.nn import Module
class StopVariable(Enum):
LOSS = auto()
ACCURACY = auto()
NONE = auto()
class Best(Enum):
RANKED = auto()
ALL = auto()
stopping_args = dict(
stop_varnames=[StopVariable.ACCURACY],
patience=300, max_epochs=500, remember=Best.RANKED) # for cora
class EarlyStopping:
def __init__(
self, model: Module, # PPNP
stop_varnames: List[StopVariable],
patience: int = 10,
max_epochs: int = 200,
remember: Best = Best.ALL):
self.model = model
self.comp_ops = []
self.stop_vars = []
self.best_vals = []
for stop_varname in stop_varnames:
if stop_varname is StopVariable.LOSS:
self.stop_vars.append('loss')
self.comp_ops.append(operator.le)
self.best_vals.append(np.inf)
elif stop_varname is StopVariable.ACCURACY:
self.stop_vars.append('acc')
self.comp_ops.append(operator.ge)
self.best_vals.append(-np.inf)
self.remember = remember
self.remembered_vals = copy.copy(self.best_vals)
self.max_patience = patience
self.patience = self.max_patience
self.max_epochs = max_epochs
self.best_epoch = None
self.best_state = None
def check(self, values: List[np.floating], epoch: int) -> bool:
checks = [self.comp_ops[i](val, self.best_vals[i])
for i, val in enumerate(values)]
if any(checks):
self.best_vals = np.choose(checks, [self.best_vals, values])
self.patience = self.max_patience
comp_remembered = [
self.comp_ops[i](val, self.remembered_vals[i])
for i, val in enumerate(values)]
if self.remember is Best.ALL:
if all(comp_remembered):
self.best_epoch = epoch
self.remembered_vals = copy.copy(values)
self.best_state = {
key: value.cpu() for key, value
in self.model.state_dict().items()}
elif self.remember is Best.RANKED:
for i, comp in enumerate(comp_remembered):
if comp:
if not(self.remembered_vals[i] == values[i]):
self.best_epoch = epoch
self.remembered_vals = copy.copy(values)
self.best_state = {
key: value.cpu() for key, value
in self.model.state_dict().items()}
break
else:
break
else:
self.patience -= 1
return self.patience == 0