diff --git a/passl/data/preprocess/basic_transforms.py b/passl/data/preprocess/basic_transforms.py index 7be2b26a..374b05b3 100644 --- a/passl/data/preprocess/basic_transforms.py +++ b/passl/data/preprocess/basic_transforms.py @@ -57,6 +57,7 @@ "SimCLRGaussianBlur", "BYOLSolarize", "MAERandCropImage", + "GaussianBlur", ] @@ -941,3 +942,21 @@ def __call__(self, img): else: img = ImageOps.solarize(img) return img + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.], p=1.0): + self.p = p + self.sigma = sigma + + def __call__(self, img): + if random.random() < self.p: + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + sigma = random.uniform(self.sigma[0], self.sigma[1]) + img = img.filter(ImageFilter.GaussianBlur(radius=sigma)) + if isinstance(img, Image.Image): + img = np.asarray(img) + return img \ No newline at end of file diff --git a/passl/engine/loops/contrastive_learning_loop.py b/passl/engine/loops/contrastive_learning_loop.py index 428d4853..80c1c395 100644 --- a/passl/engine/loops/contrastive_learning_loop.py +++ b/passl/engine/loops/contrastive_learning_loop.py @@ -68,6 +68,9 @@ def train_one_step(self, batch): # remove label batch = batch[0] + for idx, value in enumerate(batch): + if isinstance(value,paddle.Tensor): + batch[idx] = batch[idx].cuda() # do forward and backward loss_dict = self.forward_backward(batch) diff --git a/passl/engine/loops/loop.py b/passl/engine/loops/loop.py index 45745484..fedc7531 100644 --- a/passl/engine/loops/loop.py +++ b/passl/engine/loops/loop.py @@ -286,6 +286,7 @@ def train_one_epoch(self): paddle.to_tensor(batch[0]['label']) ] + self.global_step += 1 # do forward and backward diff --git a/passl/models/__init__.py b/passl/models/__init__.py index 73cc75e0..1b988d41 100644 --- a/passl/models/__init__.py +++ b/passl/models/__init__.py @@ -29,7 +29,7 @@ from .mocov3 import * from .swav import * from .simsiam import * - +from .mocov2 import * __all__ = ["build_model"] diff --git a/passl/models/mocov2.py b/passl/models/mocov2.py new file mode 100644 index 00000000..af24d096 --- /dev/null +++ b/passl/models/mocov2.py @@ -0,0 +1,353 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import os +import copy +import numpy as np + +import paddle +import paddle.nn as nn +from passl.nn import init +import paddle.nn.functional as F +from passl.models.base_model import Model +from paddle.nn.initializer import Constant, Normal +from functools import partial, reduce +from passl.models.resnet import ResNet +from paddle.vision.models.resnet import resnet50 +import random +__all__ = [ + 'mocov2_resnet50_linearprobe', + 'mocov2_resnet50_pretrain', +] + +class MoCoV2Projector(nn.Layer): + def __init__(self, with_pool, in_dim, out_dim): + super().__init__() + + self.with_pool = with_pool + if with_pool: + self.avgpool = nn.Sequential( + nn.AdaptiveAvgPool2D((1, 1)), nn.Flatten(start_axis=1)) + + self.mlp = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU()) + + def forward(self, x): + + if self.with_pool: + x = self.avgpool(x) + + x = self.mlp(x) + return x + + +class MoCoClassifier(nn.Layer): + def __init__(self, with_pool, num_features, class_num): + super().__init__() + + self.with_pool = with_pool + if with_pool: + self.avgpool = nn.Sequential( + nn.AdaptiveAvgPool2D((1, 1)), nn.Flatten(start_axis=1)) + + self.fc = nn.Linear(num_features, class_num) + normal_ = Normal(std=0.01) + zeros_ = Constant(value=0.) + + normal_(self.fc.weight) + zeros_(self.fc.bias) + + def save(self,path): + paddle.save(self.fc.state_dict(),path + ".pdparams") + def load(self,path): + self.fc.set_state_dict(paddle.load(path+".pdparams")) + + + def forward(self, x): + + if self.with_pool: + x = self.avgpool(x) + x = self.fc(x) + return x + + +class MoCoV2Pretain(Model): + """ MoCo v1, v2 + + ref: https://github.com/facebookresearch/moco/blob/main/moco/builder.py + ref: https://github.com/PaddlePaddle/PASSL/blob/main/passl/modeling/architectures/moco.py + """ + + def __init__(self, + base_encoder, + base_projector, + base_classifier, + momentum_encoder, + momentum_projector, + momentum_classifier, + dim=128, + K=65536, + m=0.999, + T=0.07, + **kwargs): + super(MoCoV2Pretain, self).__init__() + + self.m = m + self.T = T + self.K = K + + self.base_encoder = nn.Sequential(base_encoder(), base_projector(), + base_classifier()) + self.momentum_encoder = nn.Sequential( + momentum_encoder(), momentum_projector(), momentum_classifier()) + + for param_b, param_m in zip(self.base_encoder.parameters(), + self.momentum_encoder.parameters()): + param_m.copy_(param_b, False) # initialize + param_m.stop_gradient = True # not update by gradient + + # create the queue + self.register_buffer("queue", paddle.randn([dim, K])) + self.queue = F.normalize(self.queue, axis=0) + + self.register_buffer("queue_ptr", paddle.zeros([1], 'int64')) + + self.loss_fuc = nn.CrossEntropyLoss() + + def save(self, path, local_rank=0, rank=0): + paddle.save(self.state_dict(), path + ".pdparams") + + # rename moco pre-trained keys + state_dict = self.state_dict() + for k in list(state_dict.keys()): + # retain only base_encoder up to before the embedding layer + if k.startswith('base_encoder') and not k.startswith( + 'base_encoder.head'): + # remove prefix + state_dict[k[len("base_encoder."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + paddle.save(state_dict, path + "_base_encoder.pdparams") + + @paddle.no_grad() + def _update_momentum_encoder(self): + """Momentum update of the momentum encoder""" + #Note(GuoxiaWang): disable auto cast when use mix_precision + with paddle.amp.auto_cast(False): + for param_b, param_m in zip(self.base_encoder.parameters(), + self.momentum_encoder.parameters()): + paddle.assign((param_m * self.m + param_b * (1. - self.m)), + param_m) + param_m.stop_gradient = True + + # utils + @paddle.no_grad() + def concat_all_gather(self, tensor): + """ + Performs all_gather operation on the provided tensors. + """ + if paddle.distributed.get_world_size() < 2: + return tensor + tensors_gather = [] + paddle.distributed.all_gather(tensors_gather, tensor) + + output = paddle.concat(tensors_gather, axis=0) + return output + + @paddle.no_grad() + def _dequeue_and_enqueue(self, keys): + keys = self.concat_all_gather(keys) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr[0]) + assert self.K % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose([1, 0]) + ptr = (ptr + batch_size) % self.K # move pointer + + self.queue_ptr[0] = ptr + + @paddle.no_grad() + def _batch_shuffle_ddp(self, x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = self.concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = paddle.randperm(batch_size_all) + + # broadcast to all gpus + if paddle.distributed.get_world_size() > 1: + paddle.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = paddle.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = paddle.distributed.get_rank() + idx_this = idx_shuffle.reshape([num_gpus, -1])[gpu_idx] + return paddle.gather(x_gather, idx_this, axis=0), idx_unshuffle + + @paddle.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = self.concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = paddle.distributed.get_rank() + idx_this = idx_unshuffle.reshape([num_gpus, -1])[gpu_idx] + + return paddle.gather(x_gather, idx_this, axis=0) + + def forward(self, inputs): + assert isinstance(inputs, list) + x1 = inputs[0] + x2 = inputs[1] + # compute query features + q = self.base_encoder(x1) # queries: NxC + q = F.normalize(q, axis=1) + + # compute key features + with paddle.no_grad(): # no gradient + self._update_momentum_encoder() # update the momentum encoder + + # shuffle for making use of BN + k, idx_unshuffle = self._batch_shuffle_ddp(x2) + + k = self.momentum_encoder(k) # keys: NxC + k = F.normalize(k, axis=1) + + # undo shuffle + k = self._batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = paddle.sum(q * k, axis=1).unsqueeze(-1) + # negative logits: NxK + l_neg = paddle.matmul(q, self.queue.clone().detach()) + + # logits: Nx(1+K) + logits = paddle.concat((l_pos, l_neg), axis=1) + + # apply temperature + logits /= self.T + + # labels: positive key indicators + labels = paddle.zeros([logits.shape[0]], dtype=paddle.int64) + + # dequeue and enqueue + self._dequeue_and_enqueue(k) + + return self.loss_fuc(logits, labels) + +class MoCoV2LinearProbe(ResNet): + """ MoCo v1, v2 + + ref: https://github.com/facebookresearch/moco/blob/main/moco/builder.py + ref: https://github.com/PaddlePaddle/PASSL/blob/main/passl/modeling/architectures/moco.py + """ + + def __init__(self, + **kwargs): + super().__init__() + # freeze all layers but the last fc + for name, param in self.named_parameters(): + if name not in ['fc.weight', 'fc.bias']: + param.stop_gradient = True + + # optimize only the linear classifier + parameters = list( + filter(lambda p: not p.stop_gradient, self.parameters())) + assert len(parameters) == 2 # weight, bias + + init.normal_(self.fc.weight, mean=0.0, std=0.01) + init.zeros_(self.fc.bias) + self.apply(self._freeze_norm) + + def _freeze_norm(self, layer): + if isinstance(layer, (nn.layer.norm._BatchNormBase)): + layer._use_global_stats = True + + def load_pretrained(self, path, rank=0, finetune=False): + if not os.path.exists(path + '.pdparams'): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + path = path + ".pdparams" + base_encoder_dict = paddle.load(path) + for k in list(base_encoder_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('0.'): + # remove prefix + base_encoder_dict[k[len( + "0."):]] = base_encoder_dict[k] + # delete renamed + del base_encoder_dict[k] + + for name, param in self.state_dict().items(): + if name in base_encoder_dict and param.dtype != base_encoder_dict[ + name].dtype: + base_encoder_dict[name] = base_encoder_dict[name].cast( + param.dtype) + + self.set_state_dict(base_encoder_dict) + +def mocov2_resnet50_linearprobe(**kwargs): + # **kwargs specify numclass + resnet = MoCoV2LinearProbe(with_pool=True,**kwargs) + return resnet +def mocov2_resnet50_pretrain(**kwargs): + # prepare all layer here + base_encoder = partial(resnet50, with_pool=False,num_classes=0) + base_projector = partial(MoCoV2Projector, with_pool=True, in_dim=2048,out_dim=2048) + base_classifier = partial(MoCoClassifier, with_pool=False, num_features=2048, class_num=128) + momentum_encoder = partial(resnet50, with_pool=False, num_classes=0) + momentum_projector = partial(MoCoV2Projector,with_pool=True,in_dim=2048,out_dim=2048) + momentum_classifier = partial(MoCoClassifier,with_pool=False,num_features=2048,class_num=128) + model = MoCoV2Pretain( + base_encoder=base_encoder, + base_projector=base_projector, + base_classifier=base_classifier, + momentum_encoder=momentum_encoder, + momentum_projector=momentum_projector, + momentum_classifier=momentum_classifier, + T=0.2, + **kwargs) + return model + +if __name__ == "__main__": + model = mocov2_resnet50_pretrain() + model.save("./mocov2") + model_lineprobe = mocov2_resnet50_linearprobe() + model_lineprobe.load_pretrained("./mocov2_base_encoder") diff --git a/passl/models/resnet.py b/passl/models/resnet.py index f15f3443..211c2878 100644 --- a/passl/models/resnet.py +++ b/passl/models/resnet.py @@ -52,14 +52,19 @@ class ResNet(PDResNet, Model): def __init__( self, - block, + block=None, depth=50, width=64, class_num=1000, with_pool=True, groups=1, zero_init_residual=True, - ): + ): + if block == None: + if depth <= 34: + block=BasicBlock + else: + block=BottleneckBlock super().__init__(block, depth=depth, width=width, num_classes=class_num, with_pool=with_pool, groups=groups) # Zero-initialize the last BN in each residual branch, diff --git a/passl/scheduler/__init__.py b/passl/scheduler/__init__.py index 002ee755..8a194fd9 100644 --- a/passl/scheduler/__init__.py +++ b/passl/scheduler/__init__.py @@ -15,7 +15,7 @@ from passl.utils import logger -from .lr_scheduler import TimmCosine, ViTLRScheduler, Step, Poly, MultiStepDecay +from .lr_scheduler import TimmCosine, ViTLRScheduler, Step, Poly, MultiStepDecay, CosineDecay from .lr_callable import LRCallable diff --git a/passl/scheduler/lr_scheduler.py b/passl/scheduler/lr_scheduler.py index 2b91405a..46370d96 100644 --- a/passl/scheduler/lr_scheduler.py +++ b/passl/scheduler/lr_scheduler.py @@ -200,7 +200,58 @@ def get_lr(self): return self.base_lr * pow(1 - float(self.last_epoch - self.warmups) / float(self.T_max - self.warmups), 2) +class MultiStepDecay(lr.LRScheduler): + def __init__(self, + learning_rate, + step_each_epoch, + epochs, + milestones, + gamma=0.1, + last_epoch=-1, + verbose=False, + decay_unit='epoch', + **kwargs): + self.milestones = milestones + assert decay_unit in ['step', 'epoch'] + if decay_unit=='step': + milestones = [mile*step_each_epoch for mile in milestones] + self.gamma = gamma + super().__init__(learning_rate, last_epoch, verbose) + def get_lr(self): + for i in range(len(self.milestones)): + if self.last_epoch < self.milestones[i]: + return self.base_lr * (self.gamma**i) + return self.base_lr * (self.gamma ** len(self.milestones)) -class MultiStepDecay(lr.MultiStepDecay): - def __init__(self, learning_rate, milestones, gamma, last_epoch, **kwargs): - super().__init__(learning_rate, milestones, gamma, last_epoch) +class CosineDecay(lr.LRScheduler): + def __init__(self, + learning_rate, + step_each_epoch, + epochs, + decay_unit='epoch', + warmups=0, + verbose=False, + last_epoch=-1, + **kwargs): + + assert decay_unit in ['step', 'epoch'] + self.T_max = epochs if decay_unit == 'epoch' else step_each_epoch * epochs + self.warmups = warmups if decay_unit == 'epoch' else step_each_epoch * warmups + + assert self.warmups < self.T_max + + self.last_epoch = last_epoch + super(CosineDecay, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + + progress = ( + self.last_epoch - self.warmups) / float(self.T_max - self.warmups) + progress = min(1.0, max(0.0, progress)) + + if self.warmups: + lr = lr * min(1.0, self.last_epoch / self.warmups) + else: + lr = 0.5 * self.base_lr * (1.0 + math.cos(math.pi * progress)) + + return lr diff --git a/tasks/ssl/mocov2/README.md b/tasks/ssl/mocov2/README.md new file mode 100644 index 00000000..a104ca53 --- /dev/null +++ b/tasks/ssl/mocov2/README.md @@ -0,0 +1,106 @@ +# MoCov2 +![MoCo](https://user-images.githubusercontent.com/11435359/71603927-0ca98d00-2b14-11ea-9fd8-10d984a2de45.png) + +This is a PaddlePaddle implementation of the +[MoCov2](https://arxiv.org/abs/2003.04297). + + +## Install Preparation + +MoCoV2 requires `PaddlePaddle >= 2.4`. +```shell +git clone https://github.com/PaddlePaddle/PASSL.git +cd /path/to/PASSL +python setup.py install +``` + +All commands are executed in the `tasks/ssl/mocov2/` directory. + + +## Data Preparation + +The imagenet 1k dataset needs to be prepared first and will be organized into the following directory structure. + +```shell +ILSVRC2012 +├── train/ +└── val/ +``` + +Then configure the path. + +```shell +mkdir -p dataset +ln -s /path/to/ILSVRC2012 dataset/ILSVRC2012 +``` + +## Unsupervised Training + +To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, you can run the script: + +### MoCo V2 (Single Node with 8 GPUs) +```shell +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ./configs/mocov2_resnet50_pt_in1k_1n8c.yaml +``` + +## Linear Classification + +When the unsupervised pre-training is complete, or directly download the provided pre-training checkpoint, you can use the following script to train a supervised linear classifier. +### MoCo v2 + +#### Linear Classification Training (Single Node with 8 GPUs) + +```shell +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ./configs/mocov2_resnet50_lp_in1k_1n8c.yaml +``` + + +#### [Optional] Download checkpoint & Modify yaml configure +```shell +mkdir -p pretrained/moco/ +wget -O ./pretrained/moco/mocov2_pt_imagenet2012_resnet50.pdparams https://paddlefleetx.bj.bcebos.com/model/vision/moco/mocov2_pt_imagenet2012_resnet50.pdparams +``` + +#### Linear Classification Training (Single Node with 8 GPUs) + +```shell +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ./configs/mocov2_resnet50_lp_in1k_1n8c.yaml + -o Global.pretrained_model=./pretrained/mocov3/mocov3_vit_base_in1k_300ep_pretrained + +``` +## Other Configurations +We provide more directly runnable configurations, see [MoCoV2 Configurations](./configs/). + +## Models + +| Model | Phase | Epochs | Top1 Acc | Checkpoint | Log | +| ------- | --------------------- | ------ | -------- | ------------------------------------------------------------ | ------------------------------------------------------------ | +| MoCo v2 | Unsupervised Training | 200 | - | [download](https://paddlefleetx.bj.bcebos.com/model/vision/moco/mocov2_pt_imagenet2012_resnet50.pdparams) | [log](https://paddlefleetx.bj.bcebos.com/model/vision/moco/mocov2_pt_imagenet2012_resnet50.log) | +| MoCo v2 | Linear Classification | 100 | 0.676595 | [download](https://paddlefleetx.bj.bcebos.com/model/vision/moco/mocov2_lincls_imagenet2012_resnet50.pdparams) | [log](https://paddlefleetx.bj.bcebos.com/model/vision/moco/mocov2_lincls_imagenet2012_resnet50.log) | + + +## Citations + +``` +@Article{chen2020mocov2, + author = {Xinlei Chen and Haoqi Fan and Ross Girshick and Kaiming He}, + title = {Improved Baselines with Momentum Contrastive Learning}, + journal = {arXiv preprint arXiv:2003.04297}, + year = {2020}, +} +``` diff --git a/tasks/ssl/mocov2/configs/mocov2_resnet50_lp_in1k_1n8c.yaml b/tasks/ssl/mocov2/configs/mocov2_resnet50_lp_in1k_1n8c.yaml new file mode 100644 index 00000000..c39f6c2a --- /dev/null +++ b/tasks/ssl/mocov2/configs/mocov2_resnet50_lp_in1k_1n8c.yaml @@ -0,0 +1,115 @@ +# global configs +Global: + task_type: Classification + train_loop: ClassificationTrainingEpochLoop + validate_loop: ClassificationEvaluationLoop + checkpoint: null + pretrained_model: ./output/mocov2_resnet50_pretrain/latest_base_encoder + output_dir: ./output/ + device: gpu + save_interval: 1 + max_num_latest_checkpoint: 0 + eval_during_train: True + eval_interval: 1 + eval_unit: "epoch" + accum_steps: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + seed: 2022 + +# FP16 setting +FP16: + level: O0 + +DistributedStrategy: + data_parallel: True + +# model architecture +Model: + name: mocov2_resnet50_linearprobe + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + +LRScheduler: + name: MultiStepDecay + decay_unit: epoch + learning_rate: 30.0 + gamma: 0.1 + milestones: [60, 80] + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 0.0 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/train + transform: + - RandomResizedCrop: + size: 224 + - RandFlipImage: + flip_code: 1 + - ToTensor: + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + + sampler: + name: DistributedBatchSampler + batch_size: 32 + drop_last: True + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/val + transform: + - ResizeImage: + resize_short: 256 + interpolation: bilinear + backend: pil + - CenterCropImage: + size: 224 + - ToTensor: + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + + loader: + num_workers: 8 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] + +Export: + export_type: paddle + input_shape: [None, 3, 224, 224] diff --git a/tasks/ssl/mocov2/configs/mocov2_resnet50_pt_in1k_1n8c.yaml b/tasks/ssl/mocov2/configs/mocov2_resnet50_pt_in1k_1n8c.yaml new file mode 100644 index 00000000..27befd50 --- /dev/null +++ b/tasks/ssl/mocov2/configs/mocov2_resnet50_pt_in1k_1n8c.yaml @@ -0,0 +1,97 @@ +# global configs +Global: + task_type: ContrastiveLearning + train_loop: ContrastiveLearningTrainingEpochLoop + validate_loop: None + checkpoint: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + max_num_latest_checkpoint: 0 + eval_during_train: False + eval_interval: 1 + eval_unit: "epoch" + accum_steps: 1 + epochs: 200 + print_batch_step: 10 + use_visualdl: False + seed: 2023 + +DistributedStrategy: + data_parallel: True + +# model architecture +Model: + name: mocov2_resnet50_pretrain + +LRScheduler: + name: CosineDecay + decay_unit: epoch + learning_rate: 0.03 + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 0.0001 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/train + transform: + - TwoViewsTransform: + base_transform1: + - RandomResizedCrop: + size: 224 + scale: [0.2, 1.0] + interpolation: bicubic + - ColorJitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + p: 0.8 + - RandomGrayscale: + p: 0.2 + - GaussianBlur: + sigma: [.1, 2.] + p: 0.5 + - RandFlipImage: + flip_code: 1 + - ToTensor: + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + base_transform2: + - RandomResizedCrop: + size: 224 + scale: [0.2, 1.0] + interpolation: bicubic + - ColorJitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + p: 0.8 + - RandomGrayscale: + p: 0.2 + - GaussianBlur: + sigma: [.1, 2.] + p: 0.5 + - RandFlipImage: + flip_code: 1 + - ToTensor: + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + sampler: + name: DistributedBatchSampler + batch_size: 32 + drop_last: True + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True diff --git a/tasks/ssl/mocov2/linearprobe.sh b/tasks/ssl/mocov2/linearprobe.sh new file mode 100644 index 00000000..f0cb339e --- /dev/null +++ b/tasks/ssl/mocov2/linearprobe.sh @@ -0,0 +1,26 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# unset PADDLE_TRAINER_ENDPOINTS +# export PADDLE_NNODES=1 +# export PADDLE_MASTER="xxx.xxx.xxx.xxx:12538" +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ./configs/mocov2_resnet50_lp_in1k_1n8c.yaml diff --git a/tasks/ssl/mocov2/pretrain.sh b/tasks/ssl/mocov2/pretrain.sh new file mode 100644 index 00000000..aeac93e3 --- /dev/null +++ b/tasks/ssl/mocov2/pretrain.sh @@ -0,0 +1,26 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# unset PADDLE_TRAINER_ENDPOINTS +# export PADDLE_NNODES=1 +# #export PADDLE_MASTER="xxx.xxx.xxx.xxx:12538" +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ./configs/mocov2_resnet50_pt_in1k_1n8c.yaml \ No newline at end of file diff --git a/tests/CI/case.sh b/tests/CI/case.sh index 3d10679a..ec97cb4e 100644 --- a/tests/CI/case.sh +++ b/tests/CI/case.sh @@ -40,9 +40,12 @@ function model_list(){ mocov3_vit_base_patch16_224_lp_in1k_1n8c_dp_fp16o1 simsiam_resnet50_pt_in1k_1n8c_dp_fp32 simsiam_resnet50_lp_in1k_1n8c_dp_fp32 + mocov2_resnet50_pt_in1k_1n8c_dp_fp32 + mocov2_resnet50_lp_in1k_1n8c_dp_fp32 swav_resnet50_224_ft_in1k_1n4c_dp_fp32 swav_resnet50_224_lp_in1k_1n8c_dp_fp32 swav_resnet50_224_pt_in1k_1n8c_dp_fp16o1 + } ############ case start ############ @@ -390,6 +393,24 @@ function simsiam_resnet50_lp_in1k_1n8c_dp_fp32() { echo "=========== $FUNCNAME run end ===========" } + +function simsiam_resnet50_pt_in1k_1n8c_dp_fp32() { + echo "=========== $FUNCNAME run begin ===========" + rm -rf log + bash ./ssl/simsiam/simsiam_resnet50_pt_in1k_1n8c_dp_fp32.sh + + loss=`cat log/workerlog.0 | grep '50/2502' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=`cat log/workerlog.0 | grep 'ips: ' | awk -F 'ips: ' '{print $2}' | awk -F ' images/sec,' '{print $1}'| awk 'NR>1 {print}' | awk '{a+=$1}END{print a/NR}'` + mem=`cat log/workerlog.0 | grep '50/2502' | awk -F 'max mem: ' '{print $2}' | awk -F ' GB,' '{print $1}'` + loss_base=-0.32798 + ips_base=1731.37 + mem_base=10.55 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} + +###### swav ###### + function swav_resnet50_224_ft_in1k_1n4c_dp_fp32() { echo "=========== $FUNCNAME run begin ===========" rm -rf log @@ -405,6 +426,7 @@ function swav_resnet50_224_ft_in1k_1n4c_dp_fp32() { echo "=========== $FUNCNAME run end ===========" } + function swav_resnet50_224_lp_in1k_1n8c_dp_fp32() { echo "=========== $FUNCNAME run begin ===========" rm -rf log @@ -420,7 +442,6 @@ function swav_resnet50_224_lp_in1k_1n8c_dp_fp32() { echo "=========== $FUNCNAME run end ===========" } - function swav_resnet50_224_pt_in1k_1n8c_dp_fp16o1() { echo "=========== $FUNCNAME run begin ===========" rm -rf log @@ -436,6 +457,38 @@ function swav_resnet50_224_pt_in1k_1n8c_dp_fp16o1() { echo "=========== $FUNCNAME run end ===========" } +###### MocoV2 ###### + +function mocov2_resnet50_lp_in1k_1n8c_dp_fp32() { + echo "=========== $FUNCNAME run begin ===========" + rm -rf log + bash ./ssl/mocov2/mocov2_resnet50_lp_in1k_1n8c.sh + + loss=`cat log/workerlog.0 | grep '49/2502' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=`cat log/workerlog.0 | grep 'ips: ' | awk -F 'ips: ' '{print $2}' | awk -F ' images/sec,' '{print $1}'| awk 'NR>1 {print}' | awk '{a+=$1}END{print a/NR}'` + mem=`cat log/workerlog.0 | grep '49/2502' | awk -F 'max mem: ' '{print $2}' | awk -F ' GB,' '{print $1}'` + loss_base=4.12551 + ips_base=6449.01604 + mem_base=0.77 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} + +function mocov2_resnet50_pt_in1k_1n8c_dp_fp32() { + echo "=========== $FUNCNAME run begin ===========" + rm -rf log + bash ./ssl/mocov2/mocov2_resnet50_pt_in1k_1n8c.sh + + loss=`cat log/workerlog.0 | grep '49/2502' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=`cat log/workerlog.0 | grep 'ips: ' | awk -F 'ips: ' '{print $2}' | awk -F ' images/sec,' '{print $1}'| awk 'NR>1 {print}' | awk '{a+=$1}END{print a/NR}'` + mem=`cat log/workerlog.0 | grep '49/2502' | awk -F 'max mem: ' '{print $2}' | awk -F ' GB,' '{print $1}'` + loss_base=10.05231 + ips_base=2045.23616 + mem_base=6.17 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} + function check_result() { if [ $? -ne 0 ];then echo -e "\033 $1 model runs failed! \033" | tee -a $log_path/result.log diff --git a/tests/CI/ssl/mocov2/mocov2_resnet50_lp_in1k_1n8c_dp.sh b/tests/CI/ssl/mocov2/mocov2_resnet50_lp_in1k_1n8c_dp.sh new file mode 100644 index 00000000..f17b52e4 --- /dev/null +++ b/tests/CI/ssl/mocov2/mocov2_resnet50_lp_in1k_1n8c_dp.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# unset PADDLE_TRAINER_ENDPOINTS +# export PADDLE_NNODES=1 +# export PADDLE_MASTER="xxx.xxx.xxx.xxx:12538" +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ../../tasks/ssl/mocov2/configs/mocov2_resnet50_lp_in1k_1n8c.yaml \ + -o Global.print_batch_step=1 \ + -o Global.max_train_step=50 \ + -o Global.flags.FLAGS_cudnn_exhaustive_search=0 \ + -o Global.flags.FLAGS_cudnn_deterministic=1 \ + -o DataLoader.Train.sampler.batch_size=64 \ + -o Global.pretrained_model=./pretrained/mocov2/mocov2_latest_base_encoder + diff --git a/tests/CI/ssl/mocov2/mocov2_resnet50_pt_in1k_1n8c_dp.sh b/tests/CI/ssl/mocov2/mocov2_resnet50_pt_in1k_1n8c_dp.sh new file mode 100644 index 00000000..1e7d0be0 --- /dev/null +++ b/tests/CI/ssl/mocov2/mocov2_resnet50_pt_in1k_1n8c_dp.sh @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# unset PADDLE_TRAINER_ENDPOINTS +# export PADDLE_NNODES=1 +# #export PADDLE_MASTER="xxx.xxx.xxx.xxx:12538" +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ../../tasks/ssl/mocov2/configs/mocov2_resnet50_pt_in1k_1n8c.yaml \ + -o Global.print_batch_step=1 \ + -o Global.max_train_step=50 \ + -o Global.flags.FLAGS_cudnn_exhaustive_search=0 \ + -o Global.flags.FLAGS_cudnn_deterministic=1 \ + -o DataLoader.Train.sampler.batch_size=64 \ No newline at end of file