Skip to content

Commit

Permalink
fixed some bugs and update spec
Browse files Browse the repository at this point in the history
- models/GroundingDINO/ms_deform_attn.py: fixed ops error.
- models/GroundingDINO/ops/test.py: do not test for large channles
- main.py: add args to make eval not required training dataset
- requirements.txt: update version
- test_dist.sh: add test with torch.distributed.launch
- README.md: update spec and slurm/dist usage
  • Loading branch information
BIGBALLON committed Oct 20, 2023
1 parent 0ad46ce commit fc11025
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 87 deletions.
74 changes: 17 additions & 57 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
<div align="center">
<img src="figs/cute_dino.png" width="54%">
<img src="figs/cute_dino.png" width="40%">
</div>


# Open GroundingDino

This is the third party implementation of the paper **[Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection](https://arxiv.org/abs/2303.05499)** by [Zuwei Long]() and [Wei Li](https://github.com/bigballon)
This is the third party implementation of the paper **[Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection](https://arxiv.org/abs/2303.05499)** by [Zuwei Long]() and [Wei Li](https://github.com/bigballon).

**You can use this code to fine-tune a model on your own dataset, or start pretraining a model from scratch.**





***
# Supported Features

| | Official release version | The Version We Replicated |
Expand All @@ -27,21 +23,16 @@ This is the third party implementation of the paper **[Grounding DINO: Marrying


# Setup

We test our models under ```python=3.7.11,pytorch=1.11.0,cuda=11.3```. Other versions might be available as well.

1. Clone the GroundingDINO repository from GitHub.

```bash
git clone https://github.com/longzw1997/Open-GroundingDino.git
```

2. Change the current directory to the GroundingDINO folder.

```bash
cd Open-GroundingDino/
git clone https://github.com/longzw1997/Open-GroundingDino.git && cd Open-GroundingDino/
```

3. Install the required dependencies.
2. Install the required dependencies.

```bash
pip install -r requirements.txt
Expand All @@ -52,35 +43,26 @@ python test.py
cd ../../..
```

4. Download pre-trained model weights.
3. Download [pre-trained model](https://github.com/IDEA-Research/GroundingDINO/releases) and [BERT](https://huggingface.co/bert-base-uncased) weights.

```bash
mkdir weights
cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
```

5. Download BERT as the language model.


```bash
mkdir bert_weights
cd bert_weights
wget -q https://drive.google.com/drive/folders/1eM1HYf2K161YPzIcRDDMzE7S4WBGmDLM?usp=share_link
cd ..
```



# Dataset

- Dataset Format -> [data format](data_format.md)
- See [datasets_mixed_odvg.json](config/datasets_mixed_odvg.json) | [coco2odvg.py](./tools/coco2odvg.py) | [grit2odvg](./tools/grit2odvg.py) for more details




<details>
<summary>mixed dataset</summary>
</br>
Expand Down Expand Up @@ -164,38 +146,19 @@ config/datasets_mixed_odvg.json # support mixed dataset for both OD and VG


``` bash
#train on slrum:
bash train_multi_node.sh ${PARTITION} ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}

# e.g. check train_multi_node.sh for more details
# bash train_multi_node.sh v100_32g 32 config/cfg_odvg.py config/datasets_mixed_odvg.json ./logs
# bash train_multi_node.sh v100_32g 8 config/cfg_coco.py config/datasets_od_example.json ./logs
# bash train_multi_node.sh v100_32g 8 config/cfg_odvg.py config/datasets_vg_example.json ./logs



#train on dist:

# train/eval on slrum cluster:
bash train_slrum.sh ${PARTITION} ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}
bash test_slrum.sh ${PARTITION} ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}
# e.g. check train_slrum.sh for more details
# bash train_slrum.sh v100_32g 32 config/cfg_odvg.py config/datasets_mixed_odvg.json ./logs
# bash train_slrum.sh v100_32g 8 config/cfg_coco.py config/datasets_od_example.json ./logs

# train/eval on torch.distributed.launch:
bash train_dist.sh ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}


# e.g. check train_dist.sh for more details
# bash train_dist.sh 8 config/cfg_odvg.py config/datasets_mixed_odvg.json ./logs
# bash train_dist.sh 8 config/cfg_coco.py config/datasets_od_example.json ./logs
# bash train_dist.sh 8 config/cfg_odvg.py config/datasets_vg_example.json ./logs


#eval:
bash test.sh ${PARTITION} ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}

# e.g. check train_multi_node.sh for more details
# bash train_multi_node.sh v100_32g 32 config/cfg_odvg.py config/datasets_mixed_odvg.json ./logs
# bash train_multi_node.sh v100_32g 8 config/cfg_coco.py config/datasets_od_example.json ./logs
# bash train_multi_node.sh v100_32g 8 config/cfg_odvg.py config/datasets_vg_example.json ./logs
bash test_dist.sh ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}
```



# Results and Models

<!-- insert a table -->
Expand Down Expand Up @@ -249,8 +212,8 @@ bash test.sh ${PARTITION} ${GPU_NUM} ${CFG} ${DATASETS} ${OUTPUT_DIR}
</tr>
</tbody>
</table>
GRIT-200K generated by GLIP and spaCy

GRIT-200K generated by [GLIP](https://github.com/microsoft/GLIP) and [spaCy](https://spacy.io/).


# Contact
Expand All @@ -268,14 +231,11 @@ Provided codes were adapted from:
- [IDEA-Research/GroundingDINO](https://github.com/IDEA-Research/GroundingDINO)





# Citation

```
@misc{Open Grounding Dino,
author = {Zuwei Long,Wei Li},
author = {Zuwei Long, Wei Li},
title = {Open Grounding Dino:The third party implementation of the paper Grounding DINO},
howpublished = {\url{https://github.com/longzw1997/Open-GroundingDino}},
year = {2023}
Expand Down
47 changes: 24 additions & 23 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def get_args_parser():
parser.add_argument('--rank', default=0, type=int,
help='number of distributed processes')
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument("--local-rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--amp', action='store_true',
help="Train with mixed precision")

return parser


Expand Down Expand Up @@ -171,34 +171,35 @@ def main(args):
weight_decay=args.weight_decay)

logger.debug("build dataset ... ...")
num_of_dataset_train = len(dataset_meta["train"])
if num_of_dataset_train == 1:
dataset_train = build_dataset(image_set='train', args=args, datasetinfo=dataset_meta["train"][0])
else:
from torch.utils.data import ConcatDataset
dataset_train_list = []
for idx in range(len(dataset_meta["train"])):
dataset_train_list.append(build_dataset(image_set='train', args=args, datasetinfo=dataset_meta["train"][idx]))
dataset_train = ConcatDataset(dataset_train_list)
dataset_val = build_dataset(image_set='val', args=args, datasetinfo=dataset_meta["val"][0])
logger.debug("build dataset, done.")
logger.debug(f'number of training dataset: {num_of_dataset_train}, samples: {len(dataset_train)}')
if not args.eval:
num_of_dataset_train = len(dataset_meta["train"])
if num_of_dataset_train == 1:
dataset_train = build_dataset(image_set='train', args=args, datasetinfo=dataset_meta["train"][0])
else:
from torch.utils.data import ConcatDataset
dataset_train_list = []
for idx in range(len(dataset_meta["train"])):
dataset_train_list.append(build_dataset(image_set='train', args=args, datasetinfo=dataset_meta["train"][idx]))
dataset_train = ConcatDataset(dataset_train_list)
logger.debug("build dataset, done.")
logger.debug(f'number of training dataset: {num_of_dataset_train}, samples: {len(dataset_train)}')

dataset_val = build_dataset(image_set='val', args=args, datasetinfo=dataset_meta["val"][0])

if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
if not args.eval:
sampler_train = DistributedSampler(dataset_train)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)


data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers)

if not args.eval:
sampler_train = torch.utils.data.RandomSampler(dataset_train)

if not args.eval:
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers)

data_loader_val = DataLoader(dataset_val, 4, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
Expand Down
3 changes: 2 additions & 1 deletion models/GroundingDINO/ms_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from torch.nn.init import constant_, xavier_uniform_

try:
from groundingdino import _C
# from groundingdino import _C
import MultiScaleDeformableAttention as _C
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")

Expand Down
5 changes: 1 addition & 4 deletions models/GroundingDINO/ops/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,5 @@ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True
check_forward_equal_with_pytorch_double()
check_forward_equal_with_pytorch_float()

for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
for channels in [30, 32, 64, 71]:
check_gradient_numerical(channels, True, True, True)



3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ submitit
scipy
termcolor
addict
yapf
yapf==0.40.1
timm
torch
torchvision
Expand All @@ -13,5 +13,4 @@ opencv-python
supervision==0.6.0
pycocotools
pyyaml>3.10
h5py>3.0
colorlog
16 changes: 16 additions & 0 deletions test_dist.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
GPU_NUM=$1
CFG=$2
DATASETS=$3
OUTPUT_DIR=$4
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} main.py \
--output_dir ${OUTPUT_DIR} \
--eval \
-c ${CFG} \
--datasets ${DATASETS} \
--pretrain_model_path ./weights/groundingdino_swint_ogc.pth \
--options text_encoder_type=./bert_weights/bert-base-uncased
File renamed without changes.
File renamed without changes.

0 comments on commit fc11025

Please sign in to comment.