-
Notifications
You must be signed in to change notification settings - Fork 237
/
Copy pathtest_model.py
70 lines (57 loc) · 2.38 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import numpy as np
from pathlib import Path
import cv2
from model import get_model
from noise_model import get_noise_model
def get_args():
parser = argparse.ArgumentParser(description="Test trained model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--image_dir", type=str, required=True,
help="test image dir")
parser.add_argument("--model", type=str, default="srresnet",
help="model architecture ('srresnet' or 'unet')")
parser.add_argument("--weight_file", type=str, required=True,
help="trained weight file")
parser.add_argument("--test_noise_model", type=str, default="gaussian,25,25",
help="noise model for test images")
parser.add_argument("--output_dir", type=str, default=None,
help="if set, save resulting images otherwise show result using imshow")
args = parser.parse_args()
return args
def get_image(image):
image = np.clip(image, 0, 255)
return image.astype(dtype=np.uint8)
def main():
args = get_args()
image_dir = args.image_dir
weight_file = args.weight_file
val_noise_model = get_noise_model(args.test_noise_model)
model = get_model(args.model)
model.load_weights(weight_file)
if args.output_dir:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
image_paths = list(Path(image_dir).glob("*.*"))
for image_path in image_paths:
image = cv2.imread(str(image_path))
h, w, _ = image.shape
image = image[:(h // 16) * 16, :(w // 16) * 16] # for stride (maximum 16)
h, w, _ = image.shape
out_image = np.zeros((h, w * 3, 3), dtype=np.uint8)
noise_image = val_noise_model(image)
pred = model.predict(np.expand_dims(noise_image, 0))
denoised_image = get_image(pred[0])
out_image[:, :w] = image
out_image[:, w:w * 2] = noise_image
out_image[:, w * 2:] = denoised_image
if args.output_dir:
cv2.imwrite(str(output_dir.joinpath(image_path.name))[:-4] + ".png", out_image)
else:
cv2.imshow("result", out_image)
key = cv2.waitKey(-1)
# "q": quit
if key == 113:
return 0
if __name__ == '__main__':
main()