From 693f96f6fea5571492fab25c90490a3263be1db7 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 1 Oct 2023 10:52:15 +0200 Subject: [PATCH] add model and size params --- .../files_from_makeathon/dvc.lock | 26 ++- .../files_from_makeathon/dvc.yaml | 4 + .../files_from_makeathon/params.yaml | 2 + .../pipelines/export/dvc.lock | 4 +- .../files_from_makeathon/poetry.lock | 205 +++++++----------- .../files_from_makeathon/predict.py | 14 +- .../files_from_makeathon/pyproject.toml | 3 +- .../files_from_makeathon/run_ray_tune.py | 14 +- .../files_from_makeathon/train.py | 73 +++++-- 9 files changed, 171 insertions(+), 174 deletions(-) diff --git a/classification_model_training/files_from_makeathon/dvc.lock b/classification_model_training/files_from_makeathon/dvc.lock index 3e9fbd1..70c6532 100644 --- a/classification_model_training/files_from_makeathon/dvc.lock +++ b/classification_model_training/files_from_makeathon/dvc.lock @@ -1,10 +1,11 @@ schema: '2.0' stages: train: - cmd: python train.py --num_classes 4 --pretrained_on_ImageNet --fold 1 --dataset - ./data/train/images --gt ./data/train/gt.csv --outdir models --epochs 30 --seed - 1 --lr 0.001 --dropout_rate 0.3 --drop_connect_rate 0.2 --batch_norm_momentum - 0.99 --batch_norm_epsilon 0.001 --metrics_file_path models/metrics-training.json + cmd: python train.py --model efficientnet-b2 --image_size 224 --num_classes 4 + --pretrained_on_ImageNet --fold 1 --dataset ./data/train/images --gt ./data/train/gt.csv + --outdir models --epochs 30 --seed 1 --lr 0.001 --dropout_rate 0.3 --drop_connect_rate + 0.2 --batch_norm_momentum 0.99 --batch_norm_epsilon 0.001 --metrics_file_path + models/metrics-training.json deps: - path: data/train hash: md5 @@ -13,8 +14,8 @@ stages: nfiles: 362 - path: train.py hash: md5 - md5: d0b7beb3fc65200f6a9c3a51e9307a24 - size: 9750 + md5: cf9e5a2efba47f0da725d7f991d8e424 + size: 10428 params: params.yaml: base: @@ -35,6 +36,8 @@ stages: drop_connect_rate: 0.2 batch_norm_momentum: 0.99 batch_norm_epsilon: 0.001 + model: efficientnet-b2 + image_size: 224 outs: - path: models/metrics-training.json hash: md5 @@ -45,8 +48,9 @@ stages: md5: 43b5c6b74a5b036831c4195bf8679e90 size: 31270857 evaluate: - cmd: python predict.py --num_classes 4 --dataset ./data/test/images --gt ./data/test/gt.csv - --single_model_path models/model_best.pt --metrics_file_path models/metrics-evaluate.json + cmd: python predict.py --model efficientnet-b2 --image_size 224 --num_classes + 4 --dataset ./data/test/images --gt ./data/test/gt.csv --single_model_path models/model_best.pt + --metrics_file_path models/metrics-evaluate.json deps: - path: data/test hash: md5 @@ -59,8 +63,8 @@ stages: size: 31270857 - path: predict.py hash: md5 - md5: 8423985a8b1f20117aecc13fd9e81fa3 - size: 5868 + md5: 43fb163a6a3546d1bd6b47cf92fc6ff5 + size: 6001 params: params.yaml: base: @@ -86,6 +90,8 @@ stages: drop_connect_rate: 0.2 batch_norm_momentum: 0.99 batch_norm_epsilon: 0.001 + model: efficientnet-b2 + image_size: 224 outs: - path: models/metrics-evaluate.json hash: md5 diff --git a/classification_model_training/files_from_makeathon/dvc.yaml b/classification_model_training/files_from_makeathon/dvc.yaml index 88700dc..97ba0b2 100644 --- a/classification_model_training/files_from_makeathon/dvc.yaml +++ b/classification_model_training/files_from_makeathon/dvc.yaml @@ -2,6 +2,8 @@ stages: train: cmd: >- python train.py + --model ${train.model} + --image_size ${train.image_size} --num_classes ${train.num_classes} --pretrained_on_ImageNet --fold ${train.fold} @@ -39,6 +41,8 @@ stages: evaluate: cmd: >- python predict.py + --model ${train.model} + --image_size ${train.image_size} --num_classes ${train.num_classes} --dataset ${evaluate.dataset} --gt ${evaluate.gt} diff --git a/classification_model_training/files_from_makeathon/params.yaml b/classification_model_training/files_from_makeathon/params.yaml index 06b83fa..b91d35b 100644 --- a/classification_model_training/files_from_makeathon/params.yaml +++ b/classification_model_training/files_from_makeathon/params.yaml @@ -19,6 +19,8 @@ train: batch_norm_epsilon: 1e-3 # pretrained_on_ImageNet: # pretrained_own: + model: efficientnet-b2 + image_size: 224 evaluate: # single_model_path: models/model_best.pt diff --git a/classification_model_training/files_from_makeathon/pipelines/export/dvc.lock b/classification_model_training/files_from_makeathon/pipelines/export/dvc.lock index 45af65f..86f39b6 100644 --- a/classification_model_training/files_from_makeathon/pipelines/export/dvc.lock +++ b/classification_model_training/files_from_makeathon/pipelines/export/dvc.lock @@ -6,8 +6,8 @@ stages: deps: - path: ../../export.py hash: md5 - md5: f955cd2a80cff3babb9183dbee02d68d - size: 1428 + md5: 7ac103807a6ff8f1927e64691cf913ff + size: 1467 - path: ../../models/model_best.pt hash: md5 md5: 43b5c6b74a5b036831c4195bf8679e90 diff --git a/classification_model_training/files_from_makeathon/poetry.lock b/classification_model_training/files_from_makeathon/poetry.lock index afe0c93..3676147 100644 --- a/classification_model_training/files_from_makeathon/poetry.lock +++ b/classification_model_training/files_from_makeathon/poetry.lock @@ -767,88 +767,18 @@ files = [ [[package]] name = "charset-normalizer" -version = "3.2.0" +version = "2.1.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.6.0" files = [ - {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, - {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, + {file = "charset-normalizer-2.1.1.tar.gz", hash = "sha256:5a3d016c7c547f69d6f81fb0db9449ce888b418b5b9952cc5e6e66843e9dd845"}, + {file = "charset_normalizer-2.1.1-py3-none-any.whl", hash = "sha256:83e9a75d1911279afd89352c68b45348559d1fc0506b054b346651b5e7fee29f"}, ] +[package.extras] +unicode-backport = ["unicodedata2"] + [[package]] name = "chex" version = "0.1.7" @@ -3456,39 +3386,33 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" [[package]] name = "numpy" -version = "1.23.5" +version = "1.23.0" description = "NumPy is the fundamental package for array computing with Python." optional = false python-versions = ">=3.8" files = [ - {file = "numpy-1.23.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9c88793f78fca17da0145455f0d7826bcb9f37da4764af27ac945488116efe63"}, - {file = "numpy-1.23.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e9f4c4e51567b616be64e05d517c79a8a22f3606499941d97bb76f2ca59f982d"}, - {file = "numpy-1.23.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7903ba8ab592b82014713c491f6c5d3a1cde5b4a3bf116404e08f5b52f6daf43"}, - {file = "numpy-1.23.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e05b1c973a9f858c74367553e236f287e749465f773328c8ef31abe18f691e1"}, - {file = "numpy-1.23.5-cp310-cp310-win32.whl", hash = "sha256:522e26bbf6377e4d76403826ed689c295b0b238f46c28a7251ab94716da0b280"}, - {file = "numpy-1.23.5-cp310-cp310-win_amd64.whl", hash = "sha256:dbee87b469018961d1ad79b1a5d50c0ae850000b639bcb1b694e9981083243b6"}, - {file = "numpy-1.23.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ce571367b6dfe60af04e04a1834ca2dc5f46004ac1cc756fb95319f64c095a96"}, - {file = "numpy-1.23.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:56e454c7833e94ec9769fa0f86e6ff8e42ee38ce0ce1fa4cbb747ea7e06d56aa"}, - {file = "numpy-1.23.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5039f55555e1eab31124a5768898c9e22c25a65c1e0037f4d7c495a45778c9f2"}, - {file = "numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58f545efd1108e647604a1b5aa809591ccd2540f468a880bedb97247e72db387"}, - {file = "numpy-1.23.5-cp311-cp311-win32.whl", hash = "sha256:b2a9ab7c279c91974f756c84c365a669a887efa287365a8e2c418f8b3ba73fb0"}, - {file = "numpy-1.23.5-cp311-cp311-win_amd64.whl", hash = "sha256:0cbe9848fad08baf71de1a39e12d1b6310f1d5b2d0ea4de051058e6e1076852d"}, - {file = "numpy-1.23.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a"}, - {file = "numpy-1.23.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0aaee12d8883552fadfc41e96b4c82ee7d794949e2a7c3b3a7201e968c7ecab9"}, - {file = "numpy-1.23.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92c8c1e89a1f5028a4c6d9e3ccbe311b6ba53694811269b992c0b224269e2398"}, - {file = "numpy-1.23.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d208a0f8729f3fb790ed18a003f3a57895b989b40ea4dce4717e9cf4af62c6bb"}, - {file = "numpy-1.23.5-cp38-cp38-win32.whl", hash = "sha256:06005a2ef6014e9956c09ba07654f9837d9e26696a0470e42beedadb78c11b07"}, - {file = "numpy-1.23.5-cp38-cp38-win_amd64.whl", hash = "sha256:ca51fcfcc5f9354c45f400059e88bc09215fb71a48d3768fb80e357f3b457e1e"}, - {file = "numpy-1.23.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8969bfd28e85c81f3f94eb4a66bc2cf1dbdc5c18efc320af34bffc54d6b1e38f"}, - {file = "numpy-1.23.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7ac231a08bb37f852849bbb387a20a57574a97cfc7b6cabb488a4fc8be176de"}, - {file = "numpy-1.23.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf837dc63ba5c06dc8797c398db1e223a466c7ece27a1f7b5232ba3466aafe3d"}, - {file = "numpy-1.23.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33161613d2269025873025b33e879825ec7b1d831317e68f4f2f0f84ed14c719"}, - {file = "numpy-1.23.5-cp39-cp39-win32.whl", hash = "sha256:af1da88f6bc3d2338ebbf0e22fe487821ea4d8e89053e25fa59d1d79786e7481"}, - {file = "numpy-1.23.5-cp39-cp39-win_amd64.whl", hash = "sha256:09b7847f7e83ca37c6e627682f145856de331049013853f344f37b0c9690e3df"}, - {file = "numpy-1.23.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:abdde9f795cf292fb9651ed48185503a2ff29be87770c3b8e2a14b0cd7aa16f8"}, - {file = "numpy-1.23.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135"}, - {file = "numpy-1.23.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d"}, - {file = "numpy-1.23.5.tar.gz", hash = "sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a"}, + {file = "numpy-1.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58bfd40eb478f54ff7a5710dd61c8097e169bc36cc68333d00a9bcd8def53b38"}, + {file = "numpy-1.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:196cd074c3f97c4121601790955f915187736f9cf458d3ee1f1b46aff2b1ade0"}, + {file = "numpy-1.23.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1d88ef79e0a7fa631bb2c3dda1ea46b32b1fe614e10fedd611d3d5398447f2f"}, + {file = "numpy-1.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d54b3b828d618a19779a84c3ad952e96e2c2311b16384e973e671aa5be1f6187"}, + {file = "numpy-1.23.0-cp310-cp310-win32.whl", hash = "sha256:2b2da66582f3a69c8ce25ed7921dcd8010d05e59ac8d89d126a299be60421171"}, + {file = "numpy-1.23.0-cp310-cp310-win_amd64.whl", hash = "sha256:97a76604d9b0e79f59baeca16593c711fddb44936e40310f78bfef79ee9a835f"}, + {file = "numpy-1.23.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d8cc87bed09de55477dba9da370c1679bd534df9baa171dd01accbb09687dac3"}, + {file = "numpy-1.23.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f0f18804df7370571fb65db9b98bf1378172bd4e962482b857e612d1fec0f53e"}, + {file = "numpy-1.23.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac86f407873b952679f5f9e6c0612687e51547af0e14ddea1eedfcb22466babd"}, + {file = "numpy-1.23.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae8adff4172692ce56233db04b7ce5792186f179c415c37d539c25de7298d25d"}, + {file = "numpy-1.23.0-cp38-cp38-win32.whl", hash = "sha256:fe8b9683eb26d2c4d5db32cd29b38fdcf8381324ab48313b5b69088e0e355379"}, + {file = "numpy-1.23.0-cp38-cp38-win_amd64.whl", hash = "sha256:5043bcd71fcc458dfb8a0fc5509bbc979da0131b9d08e3d5f50fb0bbb36f169a"}, + {file = "numpy-1.23.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1c29b44905af288b3919803aceb6ec7fec77406d8b08aaa2e8b9e63d0fe2f160"}, + {file = "numpy-1.23.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98e8e0d8d69ff4d3fa63e6c61e8cfe2d03c29b16b58dbef1f9baa175bbed7860"}, + {file = "numpy-1.23.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79a506cacf2be3a74ead5467aee97b81fca00c9c4c8b3ba16dbab488cd99ba10"}, + {file = "numpy-1.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:092f5e6025813e64ad6d1b52b519165d08c730d099c114a9247c9bb635a2a450"}, + {file = "numpy-1.23.0-cp39-cp39-win32.whl", hash = "sha256:d6ca8dabe696c2785d0c8c9b0d8a9b6e5fdbe4f922bde70d57fa1a2848134f95"}, + {file = "numpy-1.23.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc431493df245f3c627c0c05c2bd134535e7929dbe2e602b80e42bf52ff760bc"}, + {file = "numpy-1.23.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f9c3fc2adf67762c9fe1849c859942d23f8d3e0bee7b5ed3d4a9c3eeb50a2f07"}, + {file = "numpy-1.23.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0d2094e8f4d760500394d77b383a1b06d3663e8892cdf5df3c592f55f3bff66"}, + {file = "numpy-1.23.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:94b170b4fa0168cd6be4becf37cb5b127bd12a795123984385b8cd4aca9857e5"}, + {file = "numpy-1.23.0.tar.gz", hash = "sha256:bd3fa4fe2e38533d5336e1272fc4e765cabbbde144309ccee8675509d5cd7b05"}, ] [[package]] @@ -3588,24 +3512,24 @@ tensorflow-addons = "*" [[package]] name = "opencv-python-headless" -version = "4.8.0.76" +version = "4.5.5.64" description = "Wrapper package for OpenCV python bindings." optional = false python-versions = ">=3.6" files = [ - {file = "opencv-python-headless-4.8.0.76.tar.gz", hash = "sha256:bc15726187dae26d8a08777faf6bc71d38f20c785c102677f58ba0e935003afb"}, - {file = "opencv_python_headless-4.8.0.76-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:f85d2e3b9d952db35d31f9db8882d073c903921b72b8db1cfed8bbc75e8d3e63"}, - {file = "opencv_python_headless-4.8.0.76-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:8ee3bf1c9086493c340c6a87899f1c7778d729de92bce8560b8c31ab8a9cdf79"}, - {file = "opencv_python_headless-4.8.0.76-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c675b8dec6298ba6a1eec2ce24077a393b4236a043f68dfacb06bf594354ce06"}, - {file = "opencv_python_headless-4.8.0.76-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:220d2e292fa45ef0582aab730460bbc15cfe61f2089208167a372ccf76f01e21"}, - {file = "opencv_python_headless-4.8.0.76-cp37-abi3-win32.whl", hash = "sha256:df0608de207ae9b094ad9eaf1a475cf6e9a069fb12cd289d4a18cefdab2f8aa8"}, - {file = "opencv_python_headless-4.8.0.76-cp37-abi3-win_amd64.whl", hash = "sha256:9c094faf6ec7bd360244647b26ebdf8f54edec1d9292cb9179fff9badcca7be8"}, + {file = "opencv-python-headless-4.5.5.64.tar.gz", hash = "sha256:c3c2dda44d601757a508b07d628537c49f31223ad4edd0f747a70d4c852a7b98"}, + {file = "opencv_python_headless-4.5.5.64-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:62e31878641a8f96e773118d1eea9f34bdda87c9990a0faab04ebaafb5ae015c"}, + {file = "opencv_python_headless-4.5.5.64-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:567a54c1919bcf5b3d20a9830e3c511e57134de8def286ce137c3544a892f98c"}, + {file = "opencv_python_headless-4.5.5.64-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f330468c29882dbbec5af25695c5e575572c6b855cb0f9fe53e14116fd46bfc"}, + {file = "opencv_python_headless-4.5.5.64-cp36-abi3-win32.whl", hash = "sha256:4bdf982574bf2fefc5f82c86df7cb42e56ad627874c7c0f4d94ecf4ae8885304"}, + {file = "opencv_python_headless-4.5.5.64-cp36-abi3-win_amd64.whl", hash = "sha256:a60e9ff48854ec37be391e19dd634883cc26c2f0f814e5325b3deca33420912c"}, + {file = "opencv_python_headless-4.5.5.64-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:ca4f013fa958f60fb2327fe87e6127c1ac0ab536890b1d4b00847f417b7af1ba"}, ] [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, - {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" or python_version > \"3.9\" or python_version >= \"3.9\" and platform_system != \"Darwin\" or python_version >= \"3.9\" and platform_machine != \"arm64\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.19.3", markers = "python_version >= \"3.8\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\" and platform_system != \"Darwin\" or python_version >= \"3.9\" and platform_machine != \"arm64\""}, ] [[package]] @@ -4654,14 +4578,20 @@ aiosignal = "*" click = ">=7.0" filelock = "*" frozenlist = "*" -fsspec = {version = "*", optional = true, markers = "extra == \"tune\""} +fsspec = {version = "*", optional = true, markers = "extra == \"data\" or extra == \"tune\""} jsonschema = "*" msgpack = ">=1.0.0,<2.0.0" -numpy = {version = ">=1.19.3", markers = "python_version >= \"3.9\""} +numpy = [ + {version = ">=1.19.3", optional = true, markers = "extra != \"data\" and python_version >= \"3.9\""}, + {version = ">=1.20", optional = true, markers = "python_version >= \"3.9\" and extra == \"data\""}, +] packaging = "*" -pandas = {version = "*", optional = true, markers = "extra == \"tune\""} +pandas = [ + {version = ">=1.3", optional = true, markers = "extra == \"data\""}, + {version = "*", optional = true, markers = "extra == \"tune\""}, +] protobuf = ">=3.15.3,<3.19.5 || >3.19.5" -pyarrow = {version = ">=6.0.1", optional = true, markers = "extra == \"tune\""} +pyarrow = {version = ">=6.0.1", optional = true, markers = "extra == \"data\" or extra == \"tune\""} pyyaml = "*" requests = "*" tensorboardX = {version = ">=1.9", optional = true, markers = "extra == \"tune\""} @@ -4697,20 +4627,20 @@ rpds-py = ">=0.7.0" [[package]] name = "requests" -version = "2.31.0" +version = "2.28.1" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.7, <4" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, + {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, ] [package.dependencies] certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" +charset-normalizer = ">=2,<3" idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<3" +urllib3 = ">=1.21.1,<1.27" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] @@ -5185,6 +5115,27 @@ shortuuid = ">=0.5.0" dev = ["mock (==5.0.1)", "mypy (==0.971)", "paramiko (==3.2.0)", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-asyncio (==0.18.3)", "pytest-cov (==3.0.0)", "pytest-docker (==0.12.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)", "pytest-test-utils (==0.0.8)", "types-certifi (==2021.10.8.3)", "types-mock (==5.0.0.6)", "types-paramiko (==3.2.0.1)"] tests = ["mock (==5.0.1)", "mypy (==0.971)", "paramiko (==3.2.0)", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-asyncio (==0.18.3)", "pytest-cov (==3.0.0)", "pytest-docker (==0.12.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)", "pytest-test-utils (==0.0.8)", "types-certifi (==2021.10.8.3)", "types-mock (==5.0.0.6)", "types-paramiko (==3.2.0.1)"] +[[package]] +name = "seaborn" +version = "0.13.0" +description = "Statistical data visualization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "seaborn-0.13.0-py3-none-any.whl", hash = "sha256:70d740828c48de0f402bb17234e475eda687e3c65f4383ea25d0cc4728f7772e"}, + {file = "seaborn-0.13.0.tar.gz", hash = "sha256:0e76abd2ec291c655b516703c6a022f0fd5afed26c8e714e8baef48150f73598"}, +] + +[package.dependencies] +matplotlib = ">=3.3,<3.6.1 || >3.6.1" +numpy = ">=1.20,<1.24.0 || >1.24.0" +pandas = ">=1.2" + +[package.extras] +dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest-cov", "pytest-xdist"] +docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] +stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"] + [[package]] name = "send2trash" version = "1.8.2" @@ -6479,4 +6430,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "~3.9" -content-hash = "2e019ed7849699b77691c24af0e9fb7d9826221a0f90d520c0becc09427841cf" +content-hash = "82bc257ee9b5922a049609c820816e90119641a1adea84eae4fa765b73528aa8" diff --git a/classification_model_training/files_from_makeathon/predict.py b/classification_model_training/files_from_makeathon/predict.py index a590d2d..9ad8cdf 100644 --- a/classification_model_training/files_from_makeathon/predict.py +++ b/classification_model_training/files_from_makeathon/predict.py @@ -137,6 +137,10 @@ def predict(model, device, test_loader, ensemble): default=None, type=str, ) + parser.add_argument( + "--model", default="efficientnet-b2", type=str, help="model name" + ) + parser.add_argument("--image_size", default=224, type=int, help="image size") opt = parser.parse_args() @@ -144,7 +148,7 @@ def predict(model, device, test_loader, ensemble): prediction_aug = transforms.Compose( [ - transforms.Resize((224, 224)), + transforms.Resize((opt.image_size, opt.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] @@ -162,9 +166,7 @@ def predict(model, device, test_loader, ensemble): if opt.single_model_path is not None: - model = EfficientNet.from_name( - "efficientnet-b2", in_channels=3, num_classes=1 - ) + model = EfficientNet.from_name(opt.model, in_channels=3, num_classes=1) model.load_state_dict(torch.load(opt.single_model_path)) model.to(opt.device) @@ -179,9 +181,7 @@ def predict(model, device, test_loader, ensemble): if opt.single_model_path is not None: - model = EfficientNet.from_name( - "efficientnet-b2", in_channels=3, num_classes=4 - ) + model = EfficientNet.from_name(opt.model, in_channels=3, num_classes=4) model.load_state_dict(torch.load(opt.single_model_path)) model.to(opt.device) diff --git a/classification_model_training/files_from_makeathon/pyproject.toml b/classification_model_training/files_from_makeathon/pyproject.toml index a7b20d7..48986d2 100644 --- a/classification_model_training/files_from_makeathon/pyproject.toml +++ b/classification_model_training/files_from_makeathon/pyproject.toml @@ -27,8 +27,9 @@ wandb = "<0.15.8" captum = "^0.6.0" tensorflow = ">=2.13,<2.14" protobuf = ">=3.20.3,<4.21.0" -ray = {extras = ["tune"], version = "^2.7.0"} +ray = {extras = ["data", "tune"], version = "^2.7.0"} torchmetrics = "^1.2.0" +seaborn = "^0.13.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.1" diff --git a/classification_model_training/files_from_makeathon/run_ray_tune.py b/classification_model_training/files_from_makeathon/run_ray_tune.py index 98c4808..94dd05f 100644 --- a/classification_model_training/files_from_makeathon/run_ray_tune.py +++ b/classification_model_training/files_from_makeathon/run_ray_tune.py @@ -46,6 +46,7 @@ def setup(self, config): train_bs = config.get("train_bs", 16) val_bs = config.get("val_bs", 16) + image_size = config.get("image_size", 224) gt = config.get("args").gt training_data_path = config.get("args").dataset @@ -54,11 +55,11 @@ def setup(self, config): self.num_classes = config.get("args").num_classes self.train_loader, self.test_loader = get_data_loaders( - gt, fold, training_data_path, train_bs, val_bs + gt, fold, training_data_path, train_bs, val_bs, image_size=image_size ) self.model = EfficientNet.from_pretrained( - "efficientnet-b2", + config.get("model", "efficientnet-b2"), in_channels=3, num_classes=self.num_classes, dropout_rate=config.get("dropout_rate", 0.3), @@ -178,12 +179,13 @@ def load_checkpoint(self, checkpoint_dir): metric="val_f1", mode="max", scheduler=sched, - num_samples=1 if args.smoke_test else 10, + num_samples=1 if args.smoke_test else 1, ), param_space={ "args": args, # "lr": tune.loguniform(1e-4, 1e-3), - "seed": tune.randint(0, 42), + "seed": 1, + # "seed": tune.randint(0, 42), # "lr": tune.quniform(1e-4, 1e-3, 1e-4), # "weight_decay": tune.uniform(0.0, 1e-4), # "dropout_rate": tune.quniform(0.10, 0.4, 0.05), @@ -191,6 +193,10 @@ def load_checkpoint(self, checkpoint_dir): # "batch_norm_momentum": tune.choice([0.9, 0.997, 0.99]), # "batch_norm_epsilon": tune.choice([1e-3, 1e-5, 1e-6]) # "momentum": tune.uniform(0.1, 0.9), + # "model": tune.choice( + # ["efficientnet-b0", "efficientnet-b1", "efficientnet-b2"] + # ), + # "image_size": tune.choice([224, 240, 260]), }, ) results = tuner.fit() diff --git a/classification_model_training/files_from_makeathon/train.py b/classification_model_training/files_from_makeathon/train.py index 625465e..b87b4ff 100644 --- a/classification_model_training/files_from_makeathon/train.py +++ b/classification_model_training/files_from_makeathon/train.py @@ -2,12 +2,14 @@ import torch from torch import nn + import argparse import pandas as pd import numpy as np import json from datetime import datetime import random +import shutil from torchvision import transforms from efficientnet_pytorch import EfficientNet @@ -17,7 +19,12 @@ def get_data_loaders( - gt: str, fold: int, training_data_path: str, train_bs: int, val_bs: int + gt: str, + fold: int, + training_data_path: str, + train_bs: int, + val_bs: int, + image_size: int = 224, ): df = pd.read_csv(gt) @@ -27,7 +34,7 @@ def get_data_loaders( # Set up the train_loader and val_loader train_aug = transforms.Compose( [ - transforms.Resize((224, 224)), + transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(0.5), transforms.RandomRotation((0, 360)), transforms.ToTensor(), @@ -37,7 +44,7 @@ def get_data_loaders( val_aug = transforms.Compose( [ - transforms.Resize((224, 224)), + transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] @@ -70,6 +77,8 @@ def get_data_loaders( def train( + model_name, + image_size, fold, training_data_path, gt, @@ -92,13 +101,13 @@ def train( ): train_loader, val_loader = get_data_loaders( - gt, fold, training_data_path, train_bs, val_bs + gt, fold, training_data_path, train_bs, val_bs, image_size ) if pretrained_on_ImageNet: print("Using on ImageNet pretrained model") model = EfficientNet.from_pretrained( - "efficientnet-b2", + model_name, in_channels=3, num_classes=num_classes, dropout_rate=dropout_rate, @@ -111,7 +120,7 @@ def train( else: print("Using NOT pretrained model") model = EfficientNet.from_name( - "efficientnet-b2", + model_name, in_channels=3, num_classes=num_classes, dropout_rate=dropout_rate, @@ -160,31 +169,43 @@ def train( ) ( - test_running_loss, - test_num_correct, - test_num_total, - test_running_steps, - test_f1, + val_running_loss, + val_num_correct, + val_num_total, + val_running_steps, + val_f1, ) = val_loop(model, num_classes, device, val_loader, loss_function) - print(f"Epoch = {epoch+1}, train_f1 = {train_f1}, val_f1 = {test_f1}") + train_loss = train_running_loss / train_running_steps + # train_acc = train_num_correct / train_num_total + val_loss = val_running_loss / val_running_steps + val_acc = val_num_correct / val_num_total + + print( + f"Epoch = {epoch+1}, train_loss = {train_loss}, train_f1 = {train_f1}, val_loss = {val_loss}, val_f1 = {val_f1}" + ) + + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + scheduler.step(val_f1) + else: + scheduler.step() - scheduler.step(test_f1) + is_best = val_f1 > best_score + best_score = max(val_f1, best_score) - if test_f1 > best_score and epoch >= 3: - best_score = test_f1 - torch.save( - model.state_dict(), - os.path.join(outdir, f"model_fold_{fold}_{epoch}.pt"), + torch.save(model.state_dict(), os.path.join(outdir, "model.pt")) + if is_best: + shutil.copyfile( + os.path.join(outdir, "model.pt"), os.path.join(outdir, "model_best.pt") ) - torch.save(model.state_dict(), os.path.join(outdir, "model_best.pt")) + if opt.metrics_file_path is not None: json.dump( obj={ - "f1_score": test_f1, - "accuracy": test_num_correct / test_num_total, - "train_loss": train_running_loss / train_running_steps, - "val_loss": test_running_loss / test_running_steps, + "f1_score": val_f1, + "accuracy": val_acc, + "train_loss": train_loss, + "val_loss": val_loss, "epoch": epoch + 1, }, fp=open(opt.metrics_file_path, "w"), @@ -273,6 +294,10 @@ def enable_determinism(): default=None, type=str, ) + parser.add_argument( + "--model", default="efficientnet-b2", type=str, help="model name" + ) + parser.add_argument("--image_size", default=224, type=int, help="image size") opt = parser.parse_args() @@ -307,6 +332,8 @@ def enable_determinism(): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" train( + opt.model, + opt.image_size, opt.fold, opt.dataset, opt.gt,