-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_synth.py
87 lines (80 loc) · 3.04 KB
/
gen_synth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import scipy.io
import os
import glob
import numpy as np
def gen_data():
base_path = "./datasets/data-synth-mat"
save_path = "./datasets/data-synth"
#Iterate through all the files in the folder
for file in glob.glob(os.path.join(base_path, "*.mat")):
file_name = os.path.basename(file)
name = os.path.splitext(file_name)[0]
mat = scipy.io.loadmat(file)
Z_incomplete = mat['Z_incomplete']
complete_matrix_gt = np.matmul(mat['Z_correct'], mat['Z_correct'].T)
Omega = mat["Omega"]
data = []
data_unobs = []
for i in range(0, Z_incomplete.shape[0], 3):
for j in range(0, Z_incomplete.shape[0], 3):
if(Omega[i][j] == 1):
data.append([i, j])
else:
data_unobs.append([i , j])
if not os.path.exists(save_path):
os.makedirs(save_path)
final_path = os.path.join(save_path , name+"_gt.pt")
torch.save(torch.from_numpy(complete_matrix_gt), final_path)
x =[]
y = []
for idx in data:
for i in range(3):
for j in range(3):
x.append(idx[0]+i)
y.append(idx[1]+j)
x = torch.tensor(x)
y = torch.tensor(y)
ys_ = Z_incomplete[x , y]
ys_ = torch.tensor(ys_)
obs_path = os.path.join(save_path , name+"_obs.pt")
torch.save([(x, y), ys_], obs_path)
x_un =[]
y_un = []
for idx in data_unobs:
for i in range(3):
for j in range(3):
x_un.append(idx[0]+i)
y_un.append(idx[1]+j)
x_un = torch.tensor(x_un)
y_un = torch.tensor(y_un)
ys_un = complete_matrix_gt[x_un , y_un]
ys_un = torch.tensor(ys_un)
unobs_path = os.path.join(save_path , name+"_unobs.pt")
torch.save([(x_un, y_un), ys_un], unobs_path)
def gen_config():
base_path = "datasets/data-synth-mat"
data_base_path = "datasets/data-synth"
save_path = "configs/data-synth"
for file in glob.glob(os.path.join(base_path, "*.mat")):
file_name = os.path.basename(file)
name = os.path.splitext(file_name)[0]
data = []
data.append("problem = \"matrix-completion\" \n")
gt_path = os.path.join(data_base_path, name+"_gt.pt")
data.append(f"gt_path = \"{gt_path}\" \n")
x = torch.load(gt_path)
data.append(f"shape = [{x.shape[0]}, {x.shape[0]}] \n")
obs_path = os.path.join(data_base_path, name+"_obs.pt")
data.append(f"obs_path = \"{obs_path}\" \n")
data.append(f"dataset = \"{name}\" \n")
unobs_path = os.path.join(data_base_path, name+"_unobs.pt")
data.append(f"unobs_path = \"{unobs_path}\" \n")
data.append(f"gt_mat = \"{file}\" \n")
file_path = os.path.join(save_path , name+".toml")
file1 = open(file_path,"w")
file1.writelines(data)
file1.close()
if __name__ == "__main__":
gen_data()
gen_config()