-
Notifications
You must be signed in to change notification settings - Fork 1
/
NeurocircuitX_wm_mix.py
76 lines (68 loc) · 3.75 KB
/
NeurocircuitX_wm_mix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Author : Ziyuan Ye
@Email : [email protected]
'''
import numpy as np
import os
def stdlize(data: np.ndarray):
data = np.array([(_data - np.min(data)) / (np.max(data) - np.min(data))
for _data in data]).reshape(-1, 1)
return data
def mix_02():
cate_list = ['bk_body', 'bk_faces', 'bk_places', 'bk_tools']
model_list = ['gcn', 'gat', 'stgcn', 'stpgcn']
strategy = ['keep', 'ablation']
for cate in cate_list:
for model_name in model_list:
for stra in strategy:
bk_0_pth = r'.\result_cv\WM_task\diff_stimuli\{}\saliency_result\yeo17\{}\{}{}.csv'.format(model_name,
stra,
'0',
cate)
bk_2_pth = r'.\result_cv\WM_task\diff_stimuli\{}\saliency_result\yeo17\{}\{}{}.csv'.format(model_name,
stra,
'2',
cate)
save_pth = r'.\result_cv\WM_task\diff_stimuli\{}\saliency_result\yeo17\{}\{}.csv'.format(model_name,
stra,
cate[3:])
bk_0_weight = np.loadtxt(bk_0_pth)
bk_2_weight = np.loadtxt(bk_2_pth)
save_array = []
for i in range(379):
mix_weight = bk_0_weight[i] * 0.5 + bk_2_weight[i] * 0.5
save_array.append(mix_weight)
save_array = np.array(save_array)
save_array = stdlize(save_array)
np.savetxt(save_pth,
save_array,
delimiter=",")
print('finish')
def mix_acc():
cate_list = ['body', 'faces', 'places', 'tools']
model_list = ['gcn', 'gat', 'stgcn', 'stpgcn']
for cate in cate_list:
for model_name in model_list:
keep_pth = r'.\result_cv\WM_task\diff_stimuli\{}\saliency_result\yeo17\keep\{}.csv'.format(model_name,
cate)
ablation_pth = r'.\result_cv\WM_task\diff_stimuli\{}\saliency_result\yeo17\ablation\{}.csv'.format(model_name,
cate)
save_pth = r'.\result_cv\WM_task\diff_stimuli\{}\saliency_result\yeo17\mix\{}.csv'.format(model_name,
cate)
bk_0_weight = np.loadtxt(keep_pth, encoding='gbk')
bk_2_weight = np.loadtxt(ablation_pth, encoding='gbk')
save_array = []
for i in range(379):
mix_weight = bk_0_weight[i] * 0.5 + bk_2_weight[i] * 0.5
save_array.append(mix_weight)
save_array = np.array(save_array)
save_array = stdlize(save_array)
np.savetxt(save_pth,
save_array,
delimiter=",")
print('finish')
if __name__ == '__main__':
# mix_02()
mix_acc()