-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstep1_feature_ranking.py
160 lines (135 loc) · 6.15 KB
/
step1_feature_ranking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import sys
import torch
import numpy as np
import pickle as pkl
import random
from sklearn.preprocessing import StandardScaler
from configs import id_ood_envs, embeddings_main_path, metadata_main_path
from stylist_utils import get_stylist_wass_distance, get_stylist_kl_distance
from stylist_utils import get_stylist_mean_order, get_stylist_medianranking_order, get_stylist_weightedranking_order
from baselines import get_random_features, get_pca_loadings_features, get_infogain_features, get_mad_features, get_variance_features, get_fisherscore_features, get_dispersion_features
from read_data import get_samples
from configs import DATASETS, FEATURES_TYPES, RANKING_METHODS
from configs import env_aware_baselines
from configs import STANDARDIZE_BEFORE_RANKING
BASELINE_RANKING_METHODS_FUNCTIONS = {
'PCA_loadings': get_pca_loadings_features,
'InfoGain': get_infogain_features,
'FisherScore': get_fisherscore_features,
'MAD': get_mad_features,
'Dispersion': get_dispersion_features,
'Variance': get_variance_features
}
def get_Stylist_features(samples, env_labels, method_out_path, distance_name, ranking_method):
"""Get Stylist feature ranking
Parameters
----------
samples : numpy array
Features matrix for all samples
env_labels : list
List of environment labels for all samples
method_out_path : str
Path to the output directory
distance_name : str
Name of the distance metric (one of ['Wass', 'KL'])
ranking_method : str
Name of the ranking method (one of ['mean', 'medianranking', 'weightedranking'])
"""
unique_env_labels = list(set(env_labels))
per_env_samples = {}
for env in unique_env_labels:
env_samples = [samples[i]
for i, env_ in enumerate(env_labels) if env_ == env]
per_env_samples[env] = np.array(env_samples)
distances_dict = {
'Wass': get_stylist_wass_distance,
'KL': get_stylist_kl_distance
}
distances = distances_dict[distance_name](per_env_samples)
ranking_dict = {
'mean': get_stylist_mean_order,
'medianranking': get_stylist_medianranking_order,
'weightedranking': get_stylist_weightedranking_order
}
indexes = ranking_dict[ranking_method](distances)
out_path = os.path.join(method_out_path, 'ranking.npy')
np.save(out_path, indexes)
print('Feature ranking saved at: ', out_path)
sys.stdout.flush()
def run(dataset_name, features_type, feature_ranking_method, main_out_path):
"""Run Step 1 - Feature Ranking
Parameters
----------
dataset_name : str
Name of the dataset (one of DATASETS from configs.py)
features_type : str
Type of features (one of FEATURES_TYPES from configs.py)
feature_ranking_method : str
Feature ranking method (one of RANKING_METHODS from configs.py)
main_out_path : str
Path to the output directory
"""
ranking_out_path = os.path.join(
main_out_path, '%s_features_%s' % (dataset_name, features_type))
os.makedirs(ranking_out_path, exist_ok=True)
id_envs = id_ood_envs[dataset_name]['id_envs']
embeddings_current_path = os.path.join(
embeddings_main_path[dataset_name], 'embeddings_%s.pt' % features_type)
with open(embeddings_current_path, "rb") as fd:
all_features = torch.load(fd)
with open(metadata_main_path[dataset_name], "rb") as fd:
all_metadata = pkl.load(fd)
samples, _, env_labels = get_samples(dataset_name=dataset_name,
all_metadata=all_metadata, all_features=all_features,
add_normal_samples=True, add_anomaly_samples=False,
selected_splits=['ID'], selected_envs=id_envs)
scaler = StandardScaler(with_mean=True, with_std=True)
scaler.fit(samples)
standardize = STANDARDIZE_BEFORE_RANKING[feature_ranking_method]
if standardize == 1:
samples_std = scaler.transform(samples)
else:
samples_std = samples
scaler_out_path = os.path.join(ranking_out_path,
'scaler.pkl')
with open(scaler_out_path, "wb") as fd:
pkl.dump(scaler, fd)
method_out_path = os.path.join(
ranking_out_path, feature_ranking_method)
os.makedirs(method_out_path, exist_ok=True)
if feature_ranking_method == 'random':
# 115 / 42 / 10 / 0 / 15 / 300
random_seed = 115
get_random_features(samples=samples_std, method_out_path=method_out_path,
seed=random_seed)
elif feature_ranking_method == 'Stylist':
# Wass / KL
distance_name = 'Wass'
# mean / medianranking / weightedranking
ranking_method = 'mean'
get_Stylist_features(samples=samples_std, env_labels=env_labels,
method_out_path=method_out_path, distance_name=distance_name, ranking_method=ranking_method)
elif feature_ranking_method in env_aware_baselines:
BASELINE_RANKING_METHODS_FUNCTIONS[feature_ranking_method](
samples_std, env_labels, method_out_path)
else:
BASELINE_RANKING_METHODS_FUNCTIONS[feature_ranking_method](
samples_std, method_out_path)
# Usage examples:
# python step1_feature_ranking.py COCOShift95 resnet18 Stylist ./results/ranking_methods
# python step1_feature_ranking.py DomainNet resnet18 Stylist ./results/ranking_methods
# python step1_feature_ranking.py FMoW resnet18 Stylist ./results/ranking_methods
if __name__ == '__main__':
# name of the dataset (one of DATASETS from configs.py)
dataset_name = sys.argv[1]
# type of features (one of FEATURES_TYPES from configs.py)
features_type = sys.argv[2]
# feature ranking method (one of RANKING_METHODS from configs.py)
feature_ranking_method = sys.argv[3]
# path to the output directory
main_out_path = sys.argv[4]
assert dataset_name in DATASETS, "invalid dataset name"
assert features_type in FEATURES_TYPES, "invalid features type"
assert feature_ranking_method in RANKING_METHODS, "invalid feature ranking method"
run(dataset_name, features_type, feature_ranking_method, main_out_path)