-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
executable file
·136 lines (104 loc) · 4.34 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import pandas as pd
import tqdm
from allennlp.training.trainer import Trainer
from allennlp.data.iterators import BasicIterator
from allennlp.data.vocabulary import Vocabulary
import torch
import torch.optim as optim
from citeomatic.models.text_embedding import Text_Embedding
from citeomatic.models.paper_embedding import Paper_Embedding
from citeomatic.models.embeddingmodel import EmbeddingModel
from citeomatic.models.citationranker import CitationRanker
from citeomatic.citeomaticreader import CiteomaticReader, SimpleReader
from citeomatic.neighbors import ANN
from citeomatic.eval_metrics import eval_text_model
def from_pkl(n_files):
df_list = list()
for i in tqdm.tqdm(range(n_files)):
#print(i)
df_list.append(pd.read_pickle("../clean_pickle/df0"+str(i)+".pkl"))
return pd.concat(df_list)
#load data
print("loading data into memory")
df = from_pkl(1)
print("building index dictionary")
idx_to_id_dict = {}
for row in tqdm.tqdm(df.itertuples()):
idx_to_id_dict[row[1]] = row[0]
print("initializing model")
vocab = Vocabulary.from_files("../vocabulary")
text_embedder = Text_Embedding(vocab)
paper_embedder = Paper_Embedding()
embedder = EmbeddingModel(vocab, text_embedder, paper_embedder)
rank_abs_embed = Text_Embedding(vocab)
rank_title_embed = Text_Embedding(vocab)
ranker = CitationRanker(vocab,rank_title_embed,rank_abs_embed)
ann = ANN.load("../bigger_ann",75)
train_frac = 0.995
train_set = int(len(df)*train_frac)
ann_reader = SimpleReader(df)
ann_data = ann_reader.read("")
#val_reader = CiteomaticReader(df,idx_to_id_dict,ann,train_frac=train_frac, validation = True)
#val_reader.set_compute_nnrank_features(True)
simple = SimpleReader(df.iloc[train_set:])
val_data = list(simple.read(""))
iterator = BasicIterator(batch_size=16)
iterator.index_with(vocab)
n_epochs=40
optimizer = optim.Adam(embedder.parameters(), lr=0.001)
if torch.cuda.is_available():
cuda_device = 0
embedder = embedder.cuda(cuda_device)
ranker = ranker.cuda(cuda_device)
else:
cuda_device = -1
#allennlp doesn't have callbacks so we call the trainer one epoch at a time
print("beginning training")
for e_i in range(n_epochs):
#(re)build ann
#print("bulding annoy trees")
#ann = ANN.build(embedder, ann_data, vec_size=text_embedder.get_output_dim(), ann_trees=10)
#check validation metrics
print("Evaluating model performance...")
val_reader = CiteomaticReader(df,idx_to_id_dict,ann,train_frac=train_frac, validation = True)
val_reader.set_compute_nnrank_features(True)
valid = eval_text_model(val_data,val_reader,embedder,ann,ranker,df, idx_to_id_dict)
print(valid)
#make new reader with the new ann
embed_training_reader = CiteomaticReader(df,idx_to_id_dict,ann,train_frac=train_frac, validation = False)
embed_training_reader.set_compute_nnrank_features(False)
embed_training_data = embed_training_reader.read("")
rank_training_reader = CiteomaticReader(df,idx_to_id_dict,ann,train_frac=train_frac, validation = False)
rank_training_reader.set_compute_nnrank_features(True)
rank_training_data = embed_training_reader.read("")
embed_trainer = Trainer(model=embedder,
optimizer=optimizer,
iterator=iterator,
train_dataset=embed_training_data,
#validation_dataset=val_data,
patience=10,
num_epochs=1,
shuffle=False,
cuda_device=cuda_device)
rank_trainer = Trainer(model=ranker,
optimizer=optimizer,
iterator=iterator,
train_dataset=rank_training_data,
#validation_dataset=val_data,
patience=10,
num_epochs=1,
shuffle=False,
cuda_device=cuda_device)
#run for an epoch
print("Beginning Embedder Epoch")
embed_trainer.train()
print("Beginning NNRank Epoch")
rank_trainer.train()
#make checkpoint
print("Making model checkpoint")
with open("/chkpt/embedder_e"+str(e_i)+".th", 'wb') as f:
torch.save(embedder.state_dict(), f)
with open("/chkpt/ranker_e"+str(e_i)+".th", 'wb') as f:
torch.save(ranker.state_dict(), f)