Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed ball detection module architecture problem. #45

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/
# C extensions
*.so

*.out
# Distribution / packaging
.Python
build/
Expand Down
38 changes: 18 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ An introduction of the project could be found [here (from the authors)](https://
![demo](./docs/demo.gif)

## 1. Features

- [x] Ball detection global stage
- [x] Ball detection local stage (refinement)
- [x] Events Spotting detection (Bounce and Net hit)
Expand All @@ -23,20 +24,12 @@ An introduction of the project could be found [here (from the authors)](https://
- [x] Smooth labeling for event spotting
- [x] TensorboardX

- **(Update 2020.06.23)**: Training much faster, achieve _**> 120 FPS**_ in the inference phase on a single
GPU (GTX1080Ti). <br>

- **(Update 2020.07.03)**: The implementation could achieve comparative results with the reported results in the TTNet paper. <br>
- **(Update 2024.09.12)**: The implementation could achieve comparative results with the reported results in the TTNet paper. Moreover, I have fixed the wrong implementation in the ball detection module, changed it to the original implementation just like in the paper! I have also provided completely training code in train.sh<br>

- **(Update 2020.07.06)**: There are several limitations of the TTNet Paper (hints: Loss function, input size, and 2 more). I have implemented the task with a new
approach and a new model. Now the new model could achieve:
- `>` **130FPS** inference,
- **~0.96** IoU score for the segmentation task
- `<` **4 pixels** (in the full HD resolution *(1920x1080)*) of Root Mean Square Error (RMSE) for the ball detection task<br>
- **~97%** percentage of correction events **(PCE)** and smooth PCE **(SPCE)**.
- **(2024.09.12)**: The model can achieve **0.9632** on average iou, rmse global **8.9**, rmse local **2.3** rmse_overall: **54.4**, pce: **0.8918** spce: **0.9808**


## 2. Getting Started

### Requirement

```shell script
Expand All @@ -55,6 +48,7 @@ $ pip install PyTurboJPEG
Other instruction for setting up virtual environments is [here](https://github.com/maudzung/virtual_environment_python3)

### 2.1. Preparing the dataset

The instruction for the dataset preparation is [here](./prepare_dataset/README.md)

### 2.2. Model & Input tensors
Expand All @@ -70,13 +64,14 @@ The instruction for the dataset preparation is [here](./prepare_dataset/README.m
### 2.3. How to run

#### 2.3.1. Training

##### 2.3.1.1. Single machine, single gpu

```shell script
python main.py --gpu_idx 0
```

By default (as the above command), there are 4 modules in the TTNet model: *global stage, local stage, event spotting, segmentation*.
By default (as the above command), there are 4 modules in the TTNet model: _global stage, local stage, event spotting, segmentation_.
You can disable one of the modules, except the global stage module.<br>
An important note is if you disable the local stage module, the event spotting module will be also disabled.

Expand All @@ -99,7 +94,8 @@ python main.py --gpu_idx 0 --no_local --no_seg --no_event
```

##### 2.3.1.2. Multi-processing Distributed Data Parallel Training
We should always use the `nccl` backend for multi-processing distributed training since it currently provides the best

We should always use the `nccl` backend for multi-processing distributed training since it currently provides the best
distributed training performance.

- **Single machine (node), multiple GPUs**
Expand All @@ -115,6 +111,7 @@ _**First machine**_
```shell script
python main.py --dist-url 'tcp://IP_OF_NODE1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 0
```

_**Second machine**_

```shell script
Expand All @@ -123,11 +120,11 @@ python main.py --dist-url 'tcp://IP_OF_NODE2:FREEPORT' --dist-backend 'nccl' --m

#### 2.3.2. Training stratergy

The performance of the TTNet strongly depends on the global stage for ball detection. Hence, It's necessary to train the
The performance of the TTNet strongly depends on the global stage for ball detection. Hence, It's necessary to train the
`global ball stage module` of the TTNet model first.

- **1st phase**: Train the global and segmentation modules with 30 epochs

```shell script
./train_1st_phase.sh
```
Expand All @@ -145,18 +142,18 @@ the global stage. In this phase, we train and just update weights of the local a
./train_3rd_phase.sh
```


#### 2.3.3. Visualizing training progress

The Tensorboard was used to save loss values on the training set and the validation set.
Execute the below command on the working terminal:

```
cd logs/<task directory>/tensorboard/
tensorboard --logdir=./
```

Then open the web browser and go to: [http://localhost:6006/](http://localhost:6006/)


#### 2.3.4. Evaluation

The thresholds of the segmentation and event spotting tasks could be set in `test.sh` bash shell scripts.
Expand All @@ -165,7 +162,7 @@ The thresholds of the segmentation and event spotting tasks could be set in `tes
./test_3rd_phase.sh
```

#### 2.3.5. Demo:
#### 2.3.5. Demo

Run a demonstration with an input video:

Expand All @@ -192,6 +189,7 @@ If you find any errors or have any suggestions, please contact me. Thank you!
```

## Usage

```
usage: main.py [-h] [--seed SEED] [--saved_fn FN] [-a ARCH] [--dropout_p P]
[--multitask_learning] [--no_local] [--no_event] [--no_seg]
Expand Down Expand Up @@ -323,7 +321,7 @@ optional arguments:
saved
```

[python-image]: https://img.shields.io/badge/Python-3.6-ff69b4.svg
[python-image]: https://img.shields.io/badge/Python-3.9-ff69b4.svg
[python-url]: https://www.python.org/
[pytorch-image]: https://img.shields.io/badge/PyTorch-1.5-2BAF2B.svg
[pytorch-image]: https://img.shields.io/badge/PyTorch-2.4-2BAF2B.svg
[pytorch-url]: https://pytorch.org/
48 changes: 29 additions & 19 deletions prepare_dataset/extract_all_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,38 @@ def extract_images_from_videos(video_path, out_images_dir):
make_folder(sub_images_dir)

video_cap = cv2.VideoCapture(video_path)
n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
f_width = video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)
f_height = video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
print('video_fn: {}.mp4, number of frames: {}, f_width: {}, f_height: {}'.format(video_fn, n_frames, f_width,
f_height))

frame_cnt = -1
while True:
if not video_cap.isOpened():
print(f"Error: Cannot open video file {video_path}")
return

n_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
f_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
f_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f'Processing video: {video_fn}.mp4')
print(f'Number of frames: {n_frames}, Width: {f_width}, Height: {f_height}')

frame_idx = 0
while frame_idx < n_frames:
ret, img = video_cap.read()
if ret:
frame_cnt += 1
image_path = os.path.join(sub_images_dir, 'img_{:06d}.jpg'.format(frame_cnt))
if os.path.isfile(image_path):
print('video {} had been already extracted'.format(video_path))
break
cv2.imwrite(image_path, img)
else:
break
if cv2.waitKey(10) & 0xFF == ord('q'):
if not ret:
print(f"Warning: Failed to read frame {frame_idx} from video {video_path}")
break

image_path = os.path.join(sub_images_dir, f'img_{frame_idx:06d}.jpg')
if os.path.isfile(image_path):
# Image already exists, skip writing but continue extracting
print(f"Frame {frame_idx} already exists. Skipping...")
else:
success = cv2.imwrite(image_path, img)
if not success:
print(f"Error: Failed to write frame {frame_idx} to {image_path}")
# Optionally, you can choose to break or continue based on your needs
# break

frame_idx += 1

video_cap.release()
print('done extraction: {}'.format(video_path))
print(f'Done extracting frames from: {video_path}')


if __name__ == '__main__':
Expand Down
21 changes: 12 additions & 9 deletions requirement.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
wget==3.2
torch==1.5.0
torchvision==0.6.0
easydict==1.9
opencv-python==4.2.0.34
numpy==1.18.3
torchsummary==1.5.1
tensorboard==2.2.1
scikit-learn==0.22.2
wget
torch
torchvision
easydict
opencv-python
numpy==2.0.1
torchsummary
tensorboard
scikit-learn
tqdm
matplotlib
PyTurboJPEG
6 changes: 6 additions & 0 deletions src/bash_slurm_job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH --job-name=AugustTest


19 changes: 11 additions & 8 deletions src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def parse_configs():
####################################################################
############## Dataloader and Running configs #######
####################################################################
parser.add_argument('--working-dir', type=str, default='../../', metavar='PATH',
parser.add_argument('--working-dir', type=str, default='../', metavar='PATH',
help='the ROOT working directory')
parser.add_argument('--no-val', action='store_true',
parser.add_argument('--no_val', action='store_true',
help='If true, use all data for training, no validation set')
parser.add_argument('--no-test', action='store_true',
parser.add_argument('--no_test', action='store_true',
help='If true, dont evaluate the model on the test set')
parser.add_argument('--val-size', type=float, default=0.2,
help='The size of validation set')
parser.add_argument('--smooth-labelling', action='store_true',
parser.add_argument('--smooth_labelling', action='store_true',
help='If true, smoothly make the labels of event spotting')
parser.add_argument('--num_samples', type=int, default=None,
help='Take a subset of the dataset to run and debug')
Expand All @@ -69,6 +69,8 @@ def parse_configs():
help='mini-batch size (default: 8), this is the total'
'batch size of all GPUs on the current node when using'
'Data Parallel or Distributed Data Parallel')
parser.add_argument('--distributed', type=bool, default=False,
help="if its trained using multiple gpu")
parser.add_argument('--print_freq', type=int, default=50, metavar='N',
help='print frequency (default: 50)')
parser.add_argument('--checkpoint_freq', type=int, default=2, metavar='N',
Expand Down Expand Up @@ -131,19 +133,19 @@ def parse_configs():
####################################################################
############## Distributed Data Parallel ############
####################################################################
parser.add_argument('--world-size', default=-1, type=int, metavar='N',
parser.add_argument('--world_size', default=-1, type=int, metavar='N',
help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int, metavar='N',
help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:29500', type=str,
parser.add_argument('--dist_url', default='tcp://127.0.0.1:29500', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--gpu_idx', default=None, type=int,
help='GPU index to use.')
parser.add_argument('--no_cuda', action='store_true',
help='If true, cuda is not used.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
parser.add_argument('--multiprocessing_distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
Expand Down Expand Up @@ -204,6 +206,7 @@ def parse_configs():
configs.events_weights_loss = (configs.events_weights_loss_dict['bounce'], configs.events_weights_loss_dict['net'])
configs.num_events = len(configs.events_weights_loss_dict) # Just "bounce" and "net hits"
configs.num_frames_sequence = 9
configs.interval_between_frames = 5

configs.org_size = (1920, 1080)
configs.input_size = (320, 128)
Expand Down
72 changes: 72 additions & 0 deletions src/data_process/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,75 @@ def __call__(self, imgs, ball_position_xy, seg_img):
ball_position_xy[0] = w - ball_position_xy[0]

return imgs, ball_position_xy, seg_img


import random
import numpy as np
import cv2

class Random_Ball_Mask:
def __init__(self, mask_size=(20, 20), p=0.5, mask_type='mean'):
"""
Args:
mask_size (tuple): Height and width of the mask area (blackout area).
p (float): Probability of applying the mask.
mask_type (str): Type of mask ('zero', 'noise', 'mean').
"""
self.mask_size = mask_size
self.p = p
self.mask_type = mask_type

def __call__(self, imgs, ball_position_xy, seg_img):
"""
Args:
imgs : Numpy array of shape [H, W, num_frames].
ball_position_xy (numpy): (x, y) ball position for the labeled frame.
seg_img: Corresponding segmentation mask.

Returns:
Tuple containing:
- masked_imgs: Numpy array with masked frames.
- ball_position_xy: Updated ball position.
- seg_img: Unmodified segmentation image.
"""
H, W, num_frames = imgs.shape # Extract shape from stacked array

# Ensure the mask size is valid
mask_h = random.randint(max(1, self.mask_size[0] - 10), self.mask_size[0] + 10)
mask_w = random.randint(max(1, self.mask_size[1] - 10), self.mask_size[1] + 10)

# Iterate over all frames and apply masking with some probability
for i in range(num_frames):
if random.random() <= self.p:
if i == num_frames - 1:
# Use the given ball position for the last frame
x, y = int(ball_position_xy[0]), int(ball_position_xy[1])
else:
# Apply mask at a random position for non-labeled frames
x = random.randint(0, W - mask_w)
y = random.randint(0, H - mask_h)

# Ensure the mask is within the image boundaries
top = max(0, min(H - mask_h, y - mask_h // 2))
left = max(0, min(W - mask_w, x - mask_w // 2))

# Check if the selected region has valid pixels
region = imgs[top:top + mask_h, left:left + mask_w, i]
if region.size == 0:
print(f"Warning: Empty slice for frame {i}. Skipping mask.")
continue

# Apply the chosen mask type
if self.mask_type == 'zero':
imgs[top:top + mask_h, left:left + mask_w, i] = 0

elif self.mask_type == 'noise':
noise = np.random.randn(mask_h, mask_w) * 255 # Generate noise
imgs[top:top + mask_h, left:left + mask_w, i] = noise.clip(0, 255)

elif self.mask_type == 'mean':
mean_value = np.nanmean(region) # Handle empty slices safely
noise = np.random.randn(mask_h, mask_w) * 10 # Small noise
imgs[top:top + mask_h, left:left + mask_w, i] = (mean_value + noise).clip(0, 255)

return imgs, ball_position_xy, seg_img
Loading