diff --git a/.gitignore b/.gitignore
index bc8b4d0..f91c0ef 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ __pycache__/
# C extensions
*.so
+*.out
# Distribution / packaging
.Python
build/
diff --git a/README.md b/README.md
index 23d2c2f..db03234 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,7 @@ An introduction of the project could be found [here (from the authors)](https://
![demo](./docs/demo.gif)
## 1. Features
+
- [x] Ball detection global stage
- [x] Ball detection local stage (refinement)
- [x] Events Spotting detection (Bounce and Net hit)
@@ -23,20 +24,12 @@ An introduction of the project could be found [here (from the authors)](https://
- [x] Smooth labeling for event spotting
- [x] TensorboardX
-- **(Update 2020.06.23)**: Training much faster, achieve _**> 120 FPS**_ in the inference phase on a single
-GPU (GTX1080Ti).
-
-- **(Update 2020.07.03)**: The implementation could achieve comparative results with the reported results in the TTNet paper.
+- **(Update 2024.09.12)**: The implementation could achieve comparative results with the reported results in the TTNet paper. Moreover, I have fixed the wrong implementation in the ball detection module, changed it to the original implementation just like in the paper! I have also provided completely training code in train.sh
-- **(Update 2020.07.06)**: There are several limitations of the TTNet Paper (hints: Loss function, input size, and 2 more). I have implemented the task with a new
-approach and a new model. Now the new model could achieve:
- - `>` **130FPS** inference,
- - **~0.96** IoU score for the segmentation task
- - `<` **4 pixels** (in the full HD resolution *(1920x1080)*) of Root Mean Square Error (RMSE) for the ball detection task
- - **~97%** percentage of correction events **(PCE)** and smooth PCE **(SPCE)**.
+- **(2024.09.12)**: The model can achieve **0.9632** on average iou, rmse global **8.9**, rmse local **2.3** rmse_overall: **54.4**, pce: **0.8918** spce: **0.9808**
-
## 2. Getting Started
+
### Requirement
```shell script
@@ -55,6 +48,7 @@ $ pip install PyTurboJPEG
Other instruction for setting up virtual environments is [here](https://github.com/maudzung/virtual_environment_python3)
### 2.1. Preparing the dataset
+
The instruction for the dataset preparation is [here](./prepare_dataset/README.md)
### 2.2. Model & Input tensors
@@ -70,13 +64,14 @@ The instruction for the dataset preparation is [here](./prepare_dataset/README.m
### 2.3. How to run
#### 2.3.1. Training
+
##### 2.3.1.1. Single machine, single gpu
```shell script
python main.py --gpu_idx 0
```
-By default (as the above command), there are 4 modules in the TTNet model: *global stage, local stage, event spotting, segmentation*.
+By default (as the above command), there are 4 modules in the TTNet model: _global stage, local stage, event spotting, segmentation_.
You can disable one of the modules, except the global stage module.
An important note is if you disable the local stage module, the event spotting module will be also disabled.
@@ -99,7 +94,8 @@ python main.py --gpu_idx 0 --no_local --no_seg --no_event
```
##### 2.3.1.2. Multi-processing Distributed Data Parallel Training
-We should always use the `nccl` backend for multi-processing distributed training since it currently provides the best
+
+We should always use the `nccl` backend for multi-processing distributed training since it currently provides the best
distributed training performance.
- **Single machine (node), multiple GPUs**
@@ -115,6 +111,7 @@ _**First machine**_
```shell script
python main.py --dist-url 'tcp://IP_OF_NODE1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 2 --rank 0
```
+
_**Second machine**_
```shell script
@@ -123,11 +120,11 @@ python main.py --dist-url 'tcp://IP_OF_NODE2:FREEPORT' --dist-backend 'nccl' --m
#### 2.3.2. Training stratergy
-The performance of the TTNet strongly depends on the global stage for ball detection. Hence, It's necessary to train the
+The performance of the TTNet strongly depends on the global stage for ball detection. Hence, It's necessary to train the
`global ball stage module` of the TTNet model first.
- **1st phase**: Train the global and segmentation modules with 30 epochs
-
+
```shell script
./train_1st_phase.sh
```
@@ -145,10 +142,11 @@ the global stage. In this phase, we train and just update weights of the local a
./train_3rd_phase.sh
```
-
#### 2.3.3. Visualizing training progress
+
The Tensorboard was used to save loss values on the training set and the validation set.
Execute the below command on the working terminal:
+
```
cd logs//tensorboard/
tensorboard --logdir=./
@@ -156,7 +154,6 @@ Execute the below command on the working terminal:
Then open the web browser and go to: [http://localhost:6006/](http://localhost:6006/)
-
#### 2.3.4. Evaluation
The thresholds of the segmentation and event spotting tasks could be set in `test.sh` bash shell scripts.
@@ -165,7 +162,7 @@ The thresholds of the segmentation and event spotting tasks could be set in `tes
./test_3rd_phase.sh
```
-#### 2.3.5. Demo:
+#### 2.3.5. Demo
Run a demonstration with an input video:
@@ -192,6 +189,7 @@ If you find any errors or have any suggestions, please contact me. Thank you!
```
## Usage
+
```
usage: main.py [-h] [--seed SEED] [--saved_fn FN] [-a ARCH] [--dropout_p P]
[--multitask_learning] [--no_local] [--no_event] [--no_seg]
@@ -323,7 +321,7 @@ optional arguments:
saved
```
-[python-image]: https://img.shields.io/badge/Python-3.6-ff69b4.svg
+[python-image]: https://img.shields.io/badge/Python-3.9-ff69b4.svg
[python-url]: https://www.python.org/
-[pytorch-image]: https://img.shields.io/badge/PyTorch-1.5-2BAF2B.svg
+[pytorch-image]: https://img.shields.io/badge/PyTorch-2.4-2BAF2B.svg
[pytorch-url]: https://pytorch.org/
diff --git a/prepare_dataset/extract_all_images.py b/prepare_dataset/extract_all_images.py
index f49079b..b3c77a6 100644
--- a/prepare_dataset/extract_all_images.py
+++ b/prepare_dataset/extract_all_images.py
@@ -16,28 +16,38 @@ def extract_images_from_videos(video_path, out_images_dir):
make_folder(sub_images_dir)
video_cap = cv2.VideoCapture(video_path)
- n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
- f_width = video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)
- f_height = video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
- print('video_fn: {}.mp4, number of frames: {}, f_width: {}, f_height: {}'.format(video_fn, n_frames, f_width,
- f_height))
-
- frame_cnt = -1
- while True:
+ if not video_cap.isOpened():
+ print(f"Error: Cannot open video file {video_path}")
+ return
+
+ n_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ f_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ f_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ print(f'Processing video: {video_fn}.mp4')
+ print(f'Number of frames: {n_frames}, Width: {f_width}, Height: {f_height}')
+
+ frame_idx = 0
+ while frame_idx < n_frames:
ret, img = video_cap.read()
- if ret:
- frame_cnt += 1
- image_path = os.path.join(sub_images_dir, 'img_{:06d}.jpg'.format(frame_cnt))
- if os.path.isfile(image_path):
- print('video {} had been already extracted'.format(video_path))
- break
- cv2.imwrite(image_path, img)
- else:
- break
- if cv2.waitKey(10) & 0xFF == ord('q'):
+ if not ret:
+ print(f"Warning: Failed to read frame {frame_idx} from video {video_path}")
break
+
+ image_path = os.path.join(sub_images_dir, f'img_{frame_idx:06d}.jpg')
+ if os.path.isfile(image_path):
+ # Image already exists, skip writing but continue extracting
+ print(f"Frame {frame_idx} already exists. Skipping...")
+ else:
+ success = cv2.imwrite(image_path, img)
+ if not success:
+ print(f"Error: Failed to write frame {frame_idx} to {image_path}")
+ # Optionally, you can choose to break or continue based on your needs
+ # break
+
+ frame_idx += 1
+
video_cap.release()
- print('done extraction: {}'.format(video_path))
+ print(f'Done extracting frames from: {video_path}')
if __name__ == '__main__':
diff --git a/requirement.txt b/requirement.txt
index f4294ef..8796e52 100644
--- a/requirement.txt
+++ b/requirement.txt
@@ -1,9 +1,12 @@
-wget==3.2
-torch==1.5.0
-torchvision==0.6.0
-easydict==1.9
-opencv-python==4.2.0.34
-numpy==1.18.3
-torchsummary==1.5.1
-tensorboard==2.2.1
-scikit-learn==0.22.2
\ No newline at end of file
+wget
+torch
+torchvision
+easydict
+opencv-python
+numpy==2.0.1
+torchsummary
+tensorboard
+scikit-learn
+tqdm
+matplotlib
+PyTurboJPEG
diff --git a/src/bash_slurm_job.sh b/src/bash_slurm_job.sh
new file mode 100644
index 0000000..2b15838
--- /dev/null
+++ b/src/bash_slurm_job.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+#SBATCH --partition=gpu
+#SBATCH --gres=gpu:1
+#SBATCH --job-name=AugustTest
+
+
diff --git a/src/config/config.py b/src/config/config.py
index 5132b1b..ab7edd8 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -51,15 +51,15 @@ def parse_configs():
####################################################################
############## Dataloader and Running configs #######
####################################################################
- parser.add_argument('--working-dir', type=str, default='../../', metavar='PATH',
+ parser.add_argument('--working-dir', type=str, default='../', metavar='PATH',
help='the ROOT working directory')
- parser.add_argument('--no-val', action='store_true',
+ parser.add_argument('--no_val', action='store_true',
help='If true, use all data for training, no validation set')
- parser.add_argument('--no-test', action='store_true',
+ parser.add_argument('--no_test', action='store_true',
help='If true, dont evaluate the model on the test set')
parser.add_argument('--val-size', type=float, default=0.2,
help='The size of validation set')
- parser.add_argument('--smooth-labelling', action='store_true',
+ parser.add_argument('--smooth_labelling', action='store_true',
help='If true, smoothly make the labels of event spotting')
parser.add_argument('--num_samples', type=int, default=None,
help='Take a subset of the dataset to run and debug')
@@ -69,6 +69,8 @@ def parse_configs():
help='mini-batch size (default: 8), this is the total'
'batch size of all GPUs on the current node when using'
'Data Parallel or Distributed Data Parallel')
+ parser.add_argument('--distributed', type=bool, default=False,
+ help="if its trained using multiple gpu")
parser.add_argument('--print_freq', type=int, default=50, metavar='N',
help='print frequency (default: 50)')
parser.add_argument('--checkpoint_freq', type=int, default=2, metavar='N',
@@ -131,19 +133,19 @@ def parse_configs():
####################################################################
############## Distributed Data Parallel ############
####################################################################
- parser.add_argument('--world-size', default=-1, type=int, metavar='N',
+ parser.add_argument('--world_size', default=-1, type=int, metavar='N',
help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int, metavar='N',
help='node rank for distributed training')
- parser.add_argument('--dist-url', default='tcp://127.0.0.1:29500', type=str,
+ parser.add_argument('--dist_url', default='tcp://127.0.0.1:29500', type=str,
help='url used to set up distributed training')
- parser.add_argument('--dist-backend', default='nccl', type=str,
+ parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--gpu_idx', default=None, type=int,
help='GPU index to use.')
parser.add_argument('--no_cuda', action='store_true',
help='If true, cuda is not used.')
- parser.add_argument('--multiprocessing-distributed', action='store_true',
+ parser.add_argument('--multiprocessing_distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
@@ -204,6 +206,7 @@ def parse_configs():
configs.events_weights_loss = (configs.events_weights_loss_dict['bounce'], configs.events_weights_loss_dict['net'])
configs.num_events = len(configs.events_weights_loss_dict) # Just "bounce" and "net hits"
configs.num_frames_sequence = 9
+ configs.interval_between_frames = 5
configs.org_size = (1920, 1080)
configs.input_size = (320, 128)
diff --git a/src/data_process/transformation.py b/src/data_process/transformation.py
index 96169b3..31f5f80 100644
--- a/src/data_process/transformation.py
+++ b/src/data_process/transformation.py
@@ -159,3 +159,75 @@ def __call__(self, imgs, ball_position_xy, seg_img):
ball_position_xy[0] = w - ball_position_xy[0]
return imgs, ball_position_xy, seg_img
+
+
+import random
+import numpy as np
+import cv2
+
+class Random_Ball_Mask:
+ def __init__(self, mask_size=(20, 20), p=0.5, mask_type='mean'):
+ """
+ Args:
+ mask_size (tuple): Height and width of the mask area (blackout area).
+ p (float): Probability of applying the mask.
+ mask_type (str): Type of mask ('zero', 'noise', 'mean').
+ """
+ self.mask_size = mask_size
+ self.p = p
+ self.mask_type = mask_type
+
+ def __call__(self, imgs, ball_position_xy, seg_img):
+ """
+ Args:
+ imgs : Numpy array of shape [H, W, num_frames].
+ ball_position_xy (numpy): (x, y) ball position for the labeled frame.
+ seg_img: Corresponding segmentation mask.
+
+ Returns:
+ Tuple containing:
+ - masked_imgs: Numpy array with masked frames.
+ - ball_position_xy: Updated ball position.
+ - seg_img: Unmodified segmentation image.
+ """
+ H, W, num_frames = imgs.shape # Extract shape from stacked array
+
+ # Ensure the mask size is valid
+ mask_h = random.randint(max(1, self.mask_size[0] - 10), self.mask_size[0] + 10)
+ mask_w = random.randint(max(1, self.mask_size[1] - 10), self.mask_size[1] + 10)
+
+ # Iterate over all frames and apply masking with some probability
+ for i in range(num_frames):
+ if random.random() <= self.p:
+ if i == num_frames - 1:
+ # Use the given ball position for the last frame
+ x, y = int(ball_position_xy[0]), int(ball_position_xy[1])
+ else:
+ # Apply mask at a random position for non-labeled frames
+ x = random.randint(0, W - mask_w)
+ y = random.randint(0, H - mask_h)
+
+ # Ensure the mask is within the image boundaries
+ top = max(0, min(H - mask_h, y - mask_h // 2))
+ left = max(0, min(W - mask_w, x - mask_w // 2))
+
+ # Check if the selected region has valid pixels
+ region = imgs[top:top + mask_h, left:left + mask_w, i]
+ if region.size == 0:
+ print(f"Warning: Empty slice for frame {i}. Skipping mask.")
+ continue
+
+ # Apply the chosen mask type
+ if self.mask_type == 'zero':
+ imgs[top:top + mask_h, left:left + mask_w, i] = 0
+
+ elif self.mask_type == 'noise':
+ noise = np.random.randn(mask_h, mask_w) * 255 # Generate noise
+ imgs[top:top + mask_h, left:left + mask_w, i] = noise.clip(0, 255)
+
+ elif self.mask_type == 'mean':
+ mean_value = np.nanmean(region) # Handle empty slices safely
+ noise = np.random.randn(mask_h, mask_w) * 10 # Small noise
+ imgs[top:top + mask_h, left:left + mask_w, i] = (mean_value + noise).clip(0, 255)
+
+ return imgs, ball_position_xy, seg_img
diff --git a/src/data_process/ttnet_data_utils.py b/src/data_process/ttnet_data_utils.py
index 91c38f0..1a666ce 100644
--- a/src/data_process/ttnet_data_utils.py
+++ b/src/data_process/ttnet_data_utils.py
@@ -60,6 +60,34 @@ def create_target_ball(ball_position_xy, sigma, w, h, thresh_mask, device):
return target_ball_position
+def create_target_ball_right(ball_position_xy, sigma, w, h, thresh_mask, device):
+ """Create target for the ball detection stages
+
+ :param ball_position_xy: Position of the ball (x,y)
+ :param sigma: standard deviation (a hyperparameter)
+ :param w: width of the resize image
+ :param h: height of the resize image
+ :param thresh_mask: if values of 1D Gaussian < thresh_mask --> set to 0 to reduce computation
+ :param device: cuda() or cpu()
+ :return:
+ """
+ w, h = int(w), int(h)
+ target_ball_position_x = torch.zeros(w, device=device)
+ target_ball_position_y = torch.zeros(h, device=device)
+ # Only do the next step if the ball is existed
+ if (w > ball_position_xy[0] > 0) and (h > ball_position_xy[1] > 0):
+ # For x
+ x_pos = torch.arange(0, w, device=device)
+ target_ball_position_x = gaussian_1d(x_pos, ball_position_xy[0], sigma=sigma)
+ # For y
+ y_pos = torch.arange(0, h, device=device)
+ target_ball_position_y = gaussian_1d(y_pos, ball_position_xy[1], sigma=sigma)
+
+ target_ball_position_x[target_ball_position_x < thresh_mask] = 0.
+ target_ball_position_y[target_ball_position_y < thresh_mask] = 0.
+
+ return target_ball_position_x, target_ball_position_y
+
def smooth_event_labelling(event_class, smooth_idx, event_frameidx):
target_events = np.zeros((2,))
@@ -70,6 +98,7 @@ def smooth_event_labelling(event_class, smooth_idx, event_frameidx):
return target_events
+
def get_events_infor(game_list, configs, dataset_type):
"""Get information of sequences of images based on events
@@ -110,13 +139,14 @@ def get_events_infor(game_list, configs, dataset_type):
for sub_smooth_idx in sub_smooth_frame_indices:
img_path = os.path.join(images_dir, game_name, 'img_{:06d}.jpg'.format(sub_smooth_idx))
img_path_list.append(img_path)
+
last_f_idx = smooth_idx + num_frames_from_event
# Get ball position for the last frame in the sequence
if '{}'.format(last_f_idx) not in ball_annos.keys():
print('smooth_idx: {} - no ball position for the frame idx {}'.format(smooth_idx, last_f_idx))
continue
ball_position_xy = ball_annos['{}'.format(last_f_idx)]
- ball_position_xy = np.array([ball_position_xy['x'], ball_position_xy['y']], dtype=np.int)
+ ball_position_xy = np.array([ball_position_xy['x'], ball_position_xy['y']], dtype=int)
# Ignore the event without ball information
if (ball_position_xy[0] < 0) or (ball_position_xy[1] < 0):
continue
@@ -134,9 +164,11 @@ def get_events_infor(game_list, configs, dataset_type):
if (target_events[0] == 0) and (target_events[1] == 0):
event_class = 2
events_labels.append(event_class)
+
return events_infor, events_labels
+
def train_val_data_separation(configs):
"""Seperate data to training and validation sets"""
dataset_type = 'training'
@@ -155,6 +187,23 @@ def train_val_data_separation(configs):
stratify=events_labels)
return train_events_infor, val_events_infor, train_events_labels, val_events_labels
+def train_val_data_separation_detection(configs):
+ """Seperate data to training and validation sets"""
+ dataset_type = 'training'
+ events_infor, events_labels = get_events_infor(configs.train_game_list, configs, dataset_type)
+ if configs.no_val:
+ train_events_infor = events_infor
+ train_events_labels = events_labels
+ val_events_infor = None
+ val_events_labels = None
+ else:
+ train_events_infor, val_events_infor, train_events_labels, val_events_labels = train_test_split(events_infor,
+ events_labels,
+ shuffle=True,
+ test_size=configs.val_size,
+ random_state=configs.seed,
+ stratify=events_labels)
+ return train_events_infor, val_events_infor, train_events_labels, val_events_labels
if __name__ == '__main__':
from config.config import parse_configs
diff --git a/src/data_process/ttnet_dataloader.py b/src/data_process/ttnet_dataloader.py
index a979e38..816c71f 100644
--- a/src/data_process/ttnet_dataloader.py
+++ b/src/data_process/ttnet_dataloader.py
@@ -12,22 +12,23 @@
import sys
import torch
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Subset
sys.path.append('../')
-from data_process.ttnet_dataset import TTNet_Dataset
+from data_process.ttnet_dataset import TTNet_Dataset, Occlusion_Dataset
from data_process.ttnet_data_utils import get_events_infor, train_val_data_separation
-from data_process.transformation import Compose, Random_Crop, Resize, Normalize, Random_Rotate, Random_HFlip
+from data_process.transformation import Compose, Random_Crop, Resize, Normalize, Random_Rotate, Random_HFlip, Random_Ball_Mask
def create_train_val_dataloader(configs):
"""Create dataloader for training and validate"""
train_transform = Compose([
- Random_Crop(max_reduction_percent=0.15, p=0.5),
- Random_HFlip(p=0.5),
- Random_Rotate(rotation_angle_limit=10, p=0.5),
+ # Random_Crop(max_reduction_percent=0.15, p=0.5),
+ # Random_HFlip(p=0.5),
+ # Random_Rotate(rotation_angle_limit=10, p=0.5),
+ Random_Ball_Mask(mask_size=(128//20, 320//20), p=0.25),
], p=1.)
train_events_infor, val_events_infor, *_ = train_val_data_separation(configs)
@@ -41,7 +42,10 @@ def create_train_val_dataloader(configs):
val_dataloader = None
if not configs.no_val:
- val_transform = None
+
+ val_transform = Compose([
+ Random_Ball_Mask(mask_size=(128//20, 320//20), p=0.25),
+ ], p=1.)
val_sampler = None
val_dataset = TTNet_Dataset(val_events_infor, configs.org_size, configs.input_size, transform=val_transform,
num_samples=configs.num_samples)
@@ -56,7 +60,9 @@ def create_train_val_dataloader(configs):
def create_test_dataloader(configs):
"""Create dataloader for testing phase"""
- test_transform = None
+ test_transform = Compose([
+ Random_Ball_Mask(mask_size=(128//20, 320//20), p=1.0),
+ ], p=1.)
dataset_type = 'test'
test_events_infor, test_events_labels = get_events_infor(configs.test_game_list, configs, dataset_type)
test_dataset = TTNet_Dataset(test_events_infor, configs.org_size, configs.input_size, transform=test_transform,
@@ -70,10 +76,92 @@ def create_test_dataloader(configs):
return test_dataloader
+def create_occlusion_train_val_dataloader(configs, subset_size=None):
+ """Create dataloader for training and validation, with an option to use a subset of the data."""
+
+ train_transform = Compose([
+ Resize(new_size=configs.img_size, p=1.0),
+ Random_Ball_Mask(mask_size=(128//20, 320//20), p=0.25),
+ ], p=1.)
+
+ # Load train and validation data information
+ train_events_infor, val_events_infor, train_events_label, val_events_label = train_val_data_separation(configs)
+
+ # Create train dataset
+ train_dataset = Occlusion_Dataset(train_events_infor, train_events_label, transform=train_transform,
+ num_samples=configs.num_samples)
+
+ # If subset_size is provided, create a subset for training
+ if subset_size is not None:
+ train_indices = torch.randperm(len(train_dataset))[:subset_size].tolist()
+ train_dataset = Subset(train_dataset, train_indices)
+
+ # Create train sampler if distributed
+ train_sampler = None
+ if configs.distributed:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+
+ # Create train dataloader
+ train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None),
+ pin_memory=configs.pin_memory, num_workers=configs.num_workers,
+ sampler=train_sampler, drop_last=True)
+
+ # Create validation dataloader (without transformations)
+ val_dataloader = None
+ if not configs.no_val:
+ val_transform = Compose([
+ Resize(new_size=configs.img_size, p=1.0),
+ Random_Ball_Mask(mask_size=(128//5, 320//5), p=0.5),
+ ], p=1.)
+ val_dataset = Occlusion_Dataset(val_events_infor, val_events_label, transform=val_transform,
+ num_samples=configs.num_samples)
+
+ # If subset_size is provided, create a subset for validation
+ if subset_size is not None:
+ val_indices = torch.randperm(len(val_dataset))[:subset_size].tolist()
+ val_dataset = Subset(val_dataset, val_indices)
+
+ # Create validation sampler if distributed
+ val_sampler = None
+ if configs.distributed:
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
+
+ # Create validation dataloader
+ val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
+ pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler, drop_last=True)
+
+ return train_dataloader, val_dataloader, train_sampler
+
+
if __name__ == '__main__':
from config.config import parse_configs
-
configs = parse_configs()
configs.distributed = False # For testing
+
+ # Create dataloaders
train_dataloader, val_dataloader, train_sampler = create_train_val_dataloader(configs)
print('len train_dataloader: {}, val_dataloader: {}'.format(len(train_dataloader), len(val_dataloader)))
+
+ test_dataloader = create_test_dataloader(configs)
+ print(f"len test_loader {len(test_dataloader)}")
+
+ # Get one batch from train_dataloader
+ for batch in train_dataloader:
+ # Assuming batch contains both input data and labels
+ inputs, labels = batch
+ print(f"Train batch data shape: {inputs.shape}")
+ print(f"Train batch labels shape: {labels.shape}")
+ break # Exit after printing the first batch
+
+ # Get one batch from val_dataloader
+ for batch in val_dataloader:
+ inputs, labels = batch
+ print(f"Val batch data shape: {inputs.shape}")
+ print(f"Val batch labels shape: {labels.shape}")
+ break
+
+ # Get one batch from test_dataloader
+ for batch in test_dataloader:
+ # Test dataloader might have only inputs
+ print(f"Test batch data shape: {batch.shape}")
+ break
diff --git a/src/data_process/ttnet_dataset.py b/src/data_process/ttnet_dataset.py
index 6515cd1..e3a15b4 100644
--- a/src/data_process/ttnet_dataset.py
+++ b/src/data_process/ttnet_dataset.py
@@ -75,15 +75,62 @@ def __getitem__(self, index):
# Transpose (H, W, C) to (C, H, W) --> fit input of Pytorch model
resized_imgs = resized_imgs.transpose(2, 0, 1)
- target_seg = seg_img.transpose(2, 0, 1).astype(np.float)
+ target_seg = seg_img.transpose(2, 0, 1).astype(float)
# Segmentation mask should be 0 or 1
target_seg[target_seg < 75] = 0.
target_seg[target_seg >= 75] = 1.
- return resized_imgs, org_ball_pos_xy.astype(np.int), global_ball_pos_xy.astype(np.int), \
+ return resized_imgs, org_ball_pos_xy.astype(int), global_ball_pos_xy.astype(int), \
target_events, target_seg
+class Occlusion_Dataset(Dataset):
+ def __init__(self, events_infor, events_label, transform=None, num_samples=None):
+ self.events_infor = events_infor
+ self.events_label = events_label
+ self.transform = transform
+
+ if num_samples is not None:
+ self.events_infor = self.events_infor[:num_samples]
+
+ def __len__(self):
+ return len(self.events_infor)
+
+ def __getitem__(self, index):
+ img_path_list = self.events_infor[index]
+ ball_xy = self.events_label[index]
+ imgs = []
+ for img_path in img_path_list:
+ img = cv2.imread(img_path)
+
+ if img is None:
+ raise ValueError(f"Image not found or can't be read at path: {img_path}")
+ imgs.append(img)
+ # Apply augmentation
+ if self.transform:
+ imgs, ball_xy= self.transform(imgs, ball_xy)
+
+ converted_imgs = []
+ for img in imgs:
+ # after transform all images will be in shape (H, W, C)
+ img = np.transpose(img, (2, 0, 1)) # Now img is (C, H, W)
+ converted_imgs.append(img)
+ # stack them to form the shape (1,num_frames, C, H, W)
+ # numpy_imgs = np.stack(converted_imgs, axis=0) # Stack along the new axis (N)
+ # convert them into pairs formation
+ # add a padded frame so the number is equal and can be processed with, only when the images is in odd length
+ image_list=[]
+
+ masked_frameid = len(converted_imgs)//2
+ i = 0
+ while i < len(converted_imgs):
+ image_list.append(np.array(converted_imgs[i]))
+ i+=1
+
+ image_list_np = np.array(image_list)
+ return image_list_np, (masked_frameid, np.array(ball_xy.astype(int)))
+
+
if __name__ == '__main__':
import cv2
import matplotlib.pyplot as plt
diff --git a/src/demo.py b/src/demo.py
index 767cd1f..84e6729 100644
--- a/src/demo.py
+++ b/src/demo.py
@@ -16,6 +16,7 @@
import cv2
import numpy as np
import torch
+import time
sys.path.append('./')
@@ -79,8 +80,10 @@ def demo(configs):
ploted_img = cv2.cvtColor(ploted_img, cv2.COLOR_RGB2BGR)
if configs.show_image:
- cv2.imshow('ploted_img', ploted_img)
- cv2.waitKey(10)
+ # cv2.imshow('ploted_img', ploted_img)
+ cv2.imwrite('ploted_img.png', ploted_img)
+ # cv2.waitKey(10)
+ time.sleep(0.01)
if configs.save_demo_output:
cv2.imwrite(os.path.join(configs.frame_dir, '{:06d}.jpg'.format(frame_idx)), ploted_img)
diff --git a/src/demo.sh b/src/demo.sh
index d599cb4..048b0bc 100755
--- a/src/demo.sh
+++ b/src/demo.sh
@@ -9,6 +9,6 @@ python demo.py \
--seg_thresh 0.5 \
--event_thresh 0.5 \
--thresh_ball_pos_mask 0.05 \
- --video_path ../dataset/test/videos/test_6.mp4 \
+ --video_path ../dataset/test/videos/test_1.mp4 \
--show_image \
--save_demo_output
\ No newline at end of file
diff --git a/src/losses/losses.py b/src/losses/losses.py
index ef8b66a..07d176e 100644
--- a/src/losses/losses.py
+++ b/src/losses/losses.py
@@ -21,6 +21,35 @@ def forward(self, pred_ball_position, target_ball_position):
return loss_ball_x + loss_ball_y
+class Ball_Detection_Loss_right(nn.Module):
+ def __init__(self, w, h, epsilon=1e-9):
+ super(Ball_Detection_Loss_right, self).__init__()
+ self.w = w
+ self.h = h
+ self.epsilon = epsilon
+
+ def forward(self, pred_ball_position, target_ball_position):
+ # currently the pred_ball_position and target_ball_position is [8*([320],[128])]
+ loss_total = 0.0 # Initialize total loss
+ batch_size = len(pred_ball_position) # Determine the batch size
+
+ for (pred_ball, target_ball) in zip(pred_ball_position, target_ball_position):
+ x_pred = pred_ball[0]
+ y_pred = pred_ball[1]
+
+ x_target = target_ball[0]
+ y_target = target_ball[1]
+
+
+ loss_ball_x = - torch.mean(x_target * torch.log(x_pred + self.epsilon) + (1 - x_target) * torch.log(1 - x_pred + self.epsilon))
+ loss_ball_y = - torch.mean(y_target * torch.log(y_pred + self.epsilon) + (1 - y_target) * torch.log(1 - y_pred + self.epsilon))
+
+ # Accumulate the loss
+ loss_total += (loss_ball_x + loss_ball_y)
+
+ # Return the average loss over the batch
+ return loss_total / batch_size
+
class Events_Spotting_Loss(nn.Module):
def __init__(self, weights=(1, 3), num_events=2, epsilon=1e-9):
diff --git a/src/main.py b/src/main.py
index 8948696..00d4f24 100644
--- a/src/main.py
+++ b/src/main.py
@@ -14,19 +14,25 @@
sys.path.append('./')
-from data_process.ttnet_dataloader import create_train_val_dataloader, create_test_dataloader
+from data_process.ttnet_dataloader import create_train_val_dataloader, create_test_dataloader, create_occlusion_train_val_dataloader
from models.model_utils import create_model, load_pretrained_model, make_data_parallel, resume_model, get_num_parameters
from models.model_utils import freeze_model
from utils.train_utils import create_optimizer, create_lr_scheduler, get_saved_state, save_checkpoint
from utils.train_utils import reduce_tensor, to_python_float
from utils.misc import AverageMeter, ProgressMeter
from utils.logger import Logger
+from utils.post_processing import get_prediction_ball_pos_right
from config.config import parse_configs
def main():
configs = parse_configs()
+ if torch.cuda.is_available():
+ print(f"Number of GPUs: {torch.cuda.device_count()}")
+ for i in range(torch.cuda.device_count()):
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
+
# Re-produce results
if configs.seed is not None:
random.seed(configs.seed)
@@ -123,6 +129,7 @@ def main_worker(gpu_idx, configs):
logger.info(">>> Loading dataset & getting dataloader...")
# Create dataloader
train_loader, val_loader, train_sampler = create_train_val_dataloader(configs)
+
test_loader = create_test_dataloader(configs)
if logger is not None:
logger.info('number of batches in train set: {}'.format(len(train_loader)))
@@ -197,6 +204,7 @@ def cleanup():
def train_one_epoch(train_loader, model, optimizer, epoch, configs, logger):
+ configs = parse_configs()
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
@@ -209,12 +217,14 @@ def train_one_epoch(train_loader, model, optimizer, epoch, configs, logger):
start_time = time.time()
for batch_idx, (resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg) in enumerate(
tqdm(train_loader)):
+
data_time.update(time.time() - start_time)
batch_size = resized_imgs.size(0)
target_seg = target_seg.to(configs.device, non_blocking=True)
resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float()
pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg)
+
# For torch.nn.DataParallel case
if (not configs.distributed) and (configs.gpu_idx is None):
total_loss = torch.mean(total_loss)
@@ -263,7 +273,6 @@ def evaluate_one_epoch(val_loader, model, epoch, configs, logger):
resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float()
pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg)
-
# For torch.nn.DataParallel case
if (not configs.distributed) and (configs.gpu_idx is None):
total_loss = torch.mean(total_loss)
diff --git a/src/models/TTNet.py b/src/models/TTNet.py
index daae31c..0bacd0e 100644
--- a/src/models/TTNet.py
+++ b/src/models/TTNet.py
@@ -95,7 +95,50 @@ def forward(self, x):
x = self.dropout1d(self.relu(self.fc1(x)))
x = self.dropout1d(self.relu(self.fc2(x)))
out = self.sigmoid(self.fc3(x))
+ return out, features, out_block2, out_block3, out_block4, out_block5
+class BallDetection_right(nn.Module):
+ def __init__(self, num_frames_sequence, dropout_p):
+ super(BallDetection_right, self).__init__()
+ self.conv1 = nn.Conv2d(num_frames_sequence * 3, 64, kernel_size=1, stride=1, padding=0)
+ self.batchnorm = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU()
+ self.convblock1 = ConvBlock(in_channels=64, out_channels=64)
+ self.convblock2 = ConvBlock(in_channels=64, out_channels=64)
+ self.dropout2d = nn.Dropout2d(p=dropout_p)
+ self.convblock3 = ConvBlock(in_channels=64, out_channels=128)
+ self.convblock4 = ConvBlock(in_channels=128, out_channels=128)
+ self.convblock5 = ConvBlock(in_channels=128, out_channels=256)
+ self.convblock6 = ConvBlock(in_channels=256, out_channels=256)
+ self.fc1 = nn.Linear(in_features=2560, out_features=1792)
+ self.fcx1 = nn.Linear(in_features=1792, out_features=640)
+ self.fcy1 = nn.Linear(in_features=1792, out_features=256)
+ self.fcx2 = nn.Linear(in_features=640, out_features=320)
+ self.fcy2 = nn.Linear(in_features=256, out_features=128)
+ self.dropout1d = nn.Dropout(p=dropout_p)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ x = self.relu(self.batchnorm(self.conv1(x)))
+ out_block2 = self.convblock2(self.convblock1(x))
+ x = self.dropout2d(out_block2)
+ out_block3 = self.convblock3(x)
+ out_block4 = self.convblock4(out_block3)
+ x = self.dropout2d(out_block4)
+ out_block5 = self.convblock5(out_block4)
+ features = self.convblock6(out_block5)
+
+ x = self.dropout2d(features)
+ x = x.contiguous().view(x.size(0), -1)
+ # input of fc1 is 2560 output is 1792, now x is a vector in shape 1792
+ feature = self.dropout1d(self.relu(self.fc1(x)))
+ # then parallel mode, makeing 1792 to 640 and 256
+ x = self.dropout1d(self.relu(self.fcx1(feature)))
+ y = self.dropout1d(self.relu(self.fcy1(feature)))
+ # now finally
+ coordx = self.sigmoid(self.fcx2(x))
+ coordy = self.sigmoid(self.fcy2(y))
+ out = (coordx, coordy)
return out, features, out_block2, out_block3, out_block4, out_block5
@@ -166,9 +209,11 @@ def __init__(self, dropout_p, tasks, input_size, thresh_ball_pos_mask, num_frame
super(TTNet, self).__init__()
self.tasks = tasks
self.ball_local_stage, self.events_spotting, self.segmentation = None, None, None
- self.ball_global_stage = BallDetection(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
+ # self.ball_global_stage = BallDetection(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
+ self.ball_global_stage = BallDetection_right(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
if 'local' in tasks:
- self.ball_local_stage = BallDetection(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
+ # self.ball_local_stage = BallDetection(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
+ self.ball_local_stage = BallDetection_right(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
if 'event' in tasks:
self.events_spotting = EventsSpotting(dropout_p=dropout_p)
if 'seg' in tasks:
@@ -188,21 +233,26 @@ def forward(self, resize_batch_input, org_ball_pos_xy):
pred_ball_local, pred_events, pred_seg, local_ball_pos_xy = None, None, None, None
# Normalize the input before compute forward propagation
+ # pred_ball_global, global_features, out_block2, out_block3, out_block4, out_block5 = self.ball_global_stage(
+ # self.__normalize__(resize_batch_input))
pred_ball_global, global_features, out_block2, out_block3, out_block4, out_block5 = self.ball_global_stage(
self.__normalize__(resize_batch_input))
+
if self.ball_local_stage is not None:
# Based on the prediction of the global stage, crop the original images
- input_ball_local, cropped_params = self.__crop_original_batch__(resize_batch_input, pred_ball_global)
+ input_ball_local, cropped_params = self.__crop_original_batch_right__(resize_batch_input, pred_ball_global)
# Get the ground truth of the ball for the local stage
local_ball_pos_xy = self.__get_groundtruth_local_ball_pos__(org_ball_pos_xy, cropped_params)
+
# Normalize the input before compute forward propagation
pred_ball_local, local_features, *_ = self.ball_local_stage(self.__normalize__(input_ball_local))
+
# Only consider the events spotting if the model has the local stage for ball detection
if self.events_spotting is not None:
pred_events = self.events_spotting(global_features, local_features)
if self.segmentation is not None:
pred_seg = self.segmentation(out_block2, out_block3, out_block4, out_block5)
-
+
return pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy
def run_demo(self, resize_batch_input):
@@ -228,6 +278,7 @@ def __normalize__(self, x):
def __get_groundtruth_local_ball_pos__(self, org_ball_pos_xy, cropped_params):
local_ball_pos_xy = torch.zeros_like(org_ball_pos_xy) # no grad for torch.zeros_like output
+
for idx, params in enumerate(cropped_params):
is_ball_detected, x_min, x_max, y_min, y_max, x_pad, y_pad = params
@@ -239,11 +290,13 @@ def __get_groundtruth_local_ball_pos__(self, org_ball_pos_xy, cropped_params):
# If the ball is outside of the cropped image --> set position to -1, -1 --> No ball
if (local_ball_pos_xy[idx, 0] >= self.w_resize) or (local_ball_pos_xy[idx, 1] >= self.h_resize) or (
local_ball_pos_xy[idx, 0] < 0) or (local_ball_pos_xy[idx, 1] < 0):
+
local_ball_pos_xy[idx, 0] = -1
local_ball_pos_xy[idx, 1] = -1
else:
local_ball_pos_xy[idx, 0] = -1
local_ball_pos_xy[idx, 1] = -1
+
return local_ball_pos_xy
def __crop_original_batch__(self, resize_batch_input, pred_ball_global):
@@ -311,6 +364,74 @@ def __get_crop_params__(self, x_center, y_center, w_resize, h_resize, w_original
y_max = min(h_original, y_min + h_resize)
return x_min, x_max, y_min, y_max
+
+
+ def __crop_original_batch_right__(self, resize_batch_input, pred_ball_global):
+ """Get input of the local stage by cropping the original images based on the predicted ball position
+ of the global stage
+ :param resize_batch_input: (batch_size, 27, 128, 320)
+ :param pred_ball_global: ((batch_size, 320),(batch_size, 128))
+ :param org_ball_pos_xy: (batch_size, 2)
+ :return: input_ball_local (batch_size, 27, 128, 320)
+ """
+ # Process input for local stage based on output of the global one
+ # converted_pred_ball_global is in shape [batch_size*([320],[128])]
+ converted_pred_ball_global = [(pred_ball_global[0][i], pred_ball_global[1][i]) for i in range(pred_ball_global[0].shape[0])]
+
+ batch_size = resize_batch_input.size(0)
+ h_original, w_original = 1080, 1920
+ h_ratio = h_original / self.h_resize
+ w_ratio = w_original / self.w_resize
+ for pred_ball_global_mask_coords in converted_pred_ball_global:
+ pred_ball_global_mask_coords = list(pred_ball_global_mask_coords)
+ pred_ball_global_mask_coords_x = pred_ball_global_mask_coords[0].clone().detach()
+ pred_ball_global_mask_coords_y = pred_ball_global_mask_coords[1].clone().detach()
+ pred_ball_global_mask_coords_x[pred_ball_global_mask_coords_x < self.thresh_ball_pos_mask] = 0.
+ pred_ball_global_mask_coords_y[pred_ball_global_mask_coords_y < self.thresh_ball_pos_mask] = 0.
+ pred_ball_global_mask_coords[0] = pred_ball_global_mask_coords_x
+ pred_ball_global_mask_coords[1] = pred_ball_global_mask_coords_y
+
+ # Crop the original images
+ input_ball_local = torch.zeros_like(resize_batch_input) # same shape with resize_batch_input, no grad
+ original_batch_input = F.interpolate(resize_batch_input, (h_original, w_original)) # On GPU
+ cropped_params = []
+ for idx in range(batch_size):
+ pred_ball_global_mask_coords = converted_pred_ball_global[idx]
+ pred_ball_pos_x = pred_ball_global_mask_coords[0]
+ pred_ball_pos_y = pred_ball_global_mask_coords[1]
+
+ # If the ball is not detected, we crop the center of the images, set ball_poss to [-1, -1]
+ if (torch.sum(pred_ball_pos_x) == 0.) or (torch.sum(pred_ball_pos_y) == 0.):
+ # Assume the ball is in the center image
+ x_center = int(self.w_resize / 2)
+ y_center = int(self.h_resize / 2)
+ is_ball_detected = False
+ else:
+ x_center = torch.argmax(pred_ball_pos_x) # Upper part
+ y_center = torch.argmax(pred_ball_pos_y) # Lower part
+ is_ball_detected = True
+
+ # Adjust ball position to the original size
+ x_center = int(x_center * w_ratio)
+ y_center = int(y_center * h_ratio)
+
+ x_min, x_max, y_min, y_max = self.__get_crop_params__(x_center, y_center, self.w_resize, self.h_resize,
+ w_original, h_original)
+ # Put image to the center
+ h_crop = y_max - y_min
+ w_crop = x_max - x_min
+ x_pad = 0
+ y_pad = 0
+ if (h_crop != self.h_resize) or (w_crop != self.w_resize):
+ x_pad = int((self.w_resize - w_crop) / 2)
+ y_pad = int((self.h_resize - h_crop) / 2)
+ input_ball_local[idx, :, y_pad:(y_pad + h_crop), x_pad:(x_pad + w_crop)] = original_batch_input[idx, :,
+ y_min:y_max, x_min: x_max]
+ else:
+ input_ball_local[idx, :, :, :] = original_batch_input[idx, :, y_min:y_max, x_min: x_max]
+ cropped_params.append([is_ball_detected, x_min, x_max, y_min, y_max, x_pad, y_pad])
+
+ return input_ball_local, cropped_params
if __name__ == '__main__':
diff --git a/src/models/unbalanced_loss_model.py b/src/models/unbalanced_loss_model.py
index 6578036..5c7faf1 100644
--- a/src/models/unbalanced_loss_model.py
+++ b/src/models/unbalanced_loss_model.py
@@ -16,9 +16,10 @@
sys.path.append('../')
-from losses.losses import Ball_Detection_Loss, Events_Spotting_Loss, Segmentation_Loss
-from data_process.ttnet_data_utils import create_target_ball
-
+from losses.losses import Ball_Detection_Loss, Events_Spotting_Loss, Segmentation_Loss, Ball_Detection_Loss_right
+from data_process.ttnet_data_utils import create_target_ball, create_target_ball_right
+from utils.post_processing import get_prediction_ball_pos_right
+from config.config import parse_configs
class Unbalance_Loss_Model(nn.Module):
def __init__(self, model, tasks_loss_weight, weights_events, input_size, sigma, thresh_ball_pos_mask, device):
@@ -32,34 +33,62 @@ def __init__(self, model, tasks_loss_weight, weights_events, input_size, sigma,
self.sigma = sigma
self.thresh_ball_pos_mask = thresh_ball_pos_mask
self.device = device
- self.ball_loss_criterion = Ball_Detection_Loss(self.w, self.h)
+ # self.ball_loss_criterion = Ball_Detection_Loss(self.w, self.h)
+ self.ball_loss_criterion = Ball_Detection_Loss_right(self.w, self.h)
self.event_loss_criterion = Events_Spotting_Loss(weights=weights_events, num_events=self.num_events)
self.seg_loss_criterion = Segmentation_Loss()
+ self.configs = parse_configs()
+
def forward(self, resize_batch_input, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg):
pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy = self.model(resize_batch_input,
org_ball_pos_xy)
# Create target for events spotting and ball position (local and global)
- batch_size = pred_ball_global.size(0)
- target_ball_global = torch.zeros_like(pred_ball_global)
+ # batch_size = pred_ball_global.size(0)
+ # target_ball_global = torch.zeros_like(pred_ball_global)
+ # task_idx = 0
+ # for sample_idx in range(batch_size):
+ # target_ball_global[sample_idx] = create_target_ball(global_ball_pos_xy[sample_idx], sigma=self.sigma,
+ # w=self.w, h=self.h,
+ # thresh_mask=self.thresh_ball_pos_mask,
+ # device=self.device)
+
+ # create a tensor which in is a list of lists [batch_size*([320],[128])] where the first is the x second is the y
+
+ converted_pred_ball_global = [(pred_ball_global[0][i], pred_ball_global[1][i]) for i in range(pred_ball_global[0].shape[0])]
+ batch_size = len(converted_pred_ball_global)
+
+ target_ball_global_x = torch.zeros_like(pred_ball_global[0])
+ target_ball_global_y = torch.zeros_like(pred_ball_global[1])
+ # Create a list of tuples for each batch
+ target_ball_global = [(target_ball_global_x[i], target_ball_global_y[i]) for i in range(batch_size)]
+
task_idx = 0
for sample_idx in range(batch_size):
- target_ball_global[sample_idx] = create_target_ball(global_ball_pos_xy[sample_idx], sigma=self.sigma,
+ target_ball_global[sample_idx] = create_target_ball_right(global_ball_pos_xy[sample_idx], sigma=self.sigma,
w=self.w, h=self.h,
thresh_mask=self.thresh_ball_pos_mask,
device=self.device)
- global_ball_loss = self.ball_loss_criterion(pred_ball_global, target_ball_global)
- total_loss = global_ball_loss * self.tasks_loss_weight[task_idx]
+
+ global_ball_loss = self.ball_loss_criterion(converted_pred_ball_global, target_ball_global)
+ total_loss = global_ball_loss * self.tasks_loss_weight[task_idx]
+
if pred_ball_local is not None:
task_idx += 1
- target_ball_local = torch.zeros_like(pred_ball_local)
+ converted_pred_ball_local = [(pred_ball_local[0][i], pred_ball_local[1][i]) for i in range(pred_ball_local[0].shape[0])]
+
+ target_ball_local_x = torch.zeros_like(pred_ball_local[0])
+ target_ball_local_y = torch.zeros_like(pred_ball_local[1])
+ # Create a list of tuples for each batch
+ target_ball_local = [(target_ball_local_x[i], target_ball_local_y[i]) for i in range(batch_size)]
for sample_idx in range(batch_size):
- target_ball_local[sample_idx] = create_target_ball(local_ball_pos_xy[sample_idx], sigma=self.sigma,
+ target_ball_local[sample_idx] = create_target_ball_right(local_ball_pos_xy[sample_idx], sigma=self.sigma,
w=self.w, h=self.h,
thresh_mask=self.thresh_ball_pos_mask,
device=self.device)
- local_ball_loss = self.ball_loss_criterion(pred_ball_local, target_ball_local)
+ local_ball_loss = self.ball_loss_criterion(converted_pred_ball_local, target_ball_local)
+
total_loss += local_ball_loss * self.tasks_loss_weight[task_idx]
if pred_events is not None:
diff --git a/src/ploted_img.png b/src/ploted_img.png
new file mode 100644
index 0000000..cb552ab
Binary files /dev/null and b/src/ploted_img.png differ
diff --git a/src/test.py b/src/test.py
index e6bd03c..2b602af 100644
--- a/src/test.py
+++ b/src/test.py
@@ -9,6 +9,7 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.utils.data.distributed
+import math
from tqdm import tqdm
sys.path.append('./')
@@ -17,7 +18,7 @@
from models.model_utils import create_model, load_pretrained_model, make_data_parallel, get_num_parameters
from utils.misc import AverageMeter
from config.config import parse_configs
-from utils.post_processing import get_prediction_ball_pos, get_prediction_seg, prediction_get_events
+from utils.post_processing import get_prediction_ball_pos, get_prediction_seg, prediction_get_events, get_prediction_ball_pos_right, get_prediction_ball_pos_right_test
from utils.metrics import SPCE, PCE
@@ -104,19 +105,22 @@ def test(test_loader, model, configs):
pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(
resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg)
+ # previsouly the pred_ball_global will be in shape ((b, 320), (b, 128)) convert them to [[b*[[320],[128]]]
+ converted_pred_ball_global = [(pred_ball_global[0][i], pred_ball_global[1][i]) for i in range(pred_ball_global[0].shape[0])]
+ converted_pred_ball_local = [(pred_ball_local[0][i], pred_ball_local[1][i]) for i in range(pred_ball_local[0].shape[0])]
org_ball_pos_xy = org_ball_pos_xy.numpy()
global_ball_pos_xy = global_ball_pos_xy.numpy()
# Transfer output to cpu
target_seg = target_seg.cpu().numpy()
-
for sample_idx in range(batch_size):
# Get target
sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx]
sample_global_ball_pos_xy = global_ball_pos_xy[sample_idx] # Target
# Process the global stage
- sample_pred_ball_global = pred_ball_global[sample_idx]
- sample_prediction_ball_global_xy = get_prediction_ball_pos(sample_pred_ball_global, w,
+ sample_pred_ball_global = converted_pred_ball_global[sample_idx]
+
+ sample_prediction_ball_global_xy = get_prediction_ball_pos_right_test(sample_pred_ball_global,
configs.thresh_ball_pos_mask)
# Calculate the MSE
@@ -136,14 +140,14 @@ def test(test_loader, model, configs):
# Process local ball stage
if pred_ball_local is not None:
# Get target
- local_ball_pos_xy = local_ball_pos_xy.cpu().numpy() # Ground truth of the local stage
+ local_ball_pos_xy = local_ball_pos_xy # Ground truth of the local stage
sample_local_ball_pos_xy = local_ball_pos_xy[sample_idx] # Target
# Process the local stage
- sample_pred_ball_local = pred_ball_local[sample_idx]
- sample_prediction_ball_local_xy = get_prediction_ball_pos(sample_pred_ball_local, w,
+ sample_pred_ball_local = converted_pred_ball_local[sample_idx]
+ sample_prediction_ball_local_xy = get_prediction_ball_pos_right_test(sample_pred_ball_local,
configs.thresh_ball_pos_mask)
-
- # Calculate the MSE
+
+ # Calculate the MSE only if the ball exist
if (sample_local_ball_pos_xy[0] > 0) and (sample_local_ball_pos_xy[1] > 0):
mse = (sample_prediction_ball_local_xy[0] - sample_local_ball_pos_xy[0]) ** 2 + (
sample_prediction_ball_local_xy[1] - sample_local_ball_pos_xy[1]) ** 2
@@ -176,7 +180,7 @@ def test(test_loader, model, configs):
# Process segmentation stage
if pred_seg is not None:
- sample_target_seg = target_seg[sample_idx].transpose(1, 2, 0).astype(np.int)
+ sample_target_seg = target_seg[sample_idx].transpose(1, 2, 0).astype(int)
sample_prediction_seg = get_prediction_seg(pred_seg[sample_idx], configs.seg_thresh)
# Calculate the IoU
@@ -210,14 +214,14 @@ def test(test_loader, model, configs):
if ((batch_idx + 1) % configs.print_freq) == 0:
print(
'batch_idx: {} - Average iou_seg: {:.4f}, mse_global: {:.1f}, mse_local: {:.1f}, mse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}'.format(
- batch_idx, iou_seg.avg, mse_global.avg, mse_local.avg, mse_overall.avg, pce.avg, spce.avg))
+ batch_idx, iou_seg.avg, math.sqrt(mse_global.avg), math.sqrt(mse_local.avg), math.sqrt(mse_overall.avg), pce.avg, spce.avg))
batch_time.update(time.time() - start_time)
start_time = time.time()
print(
- 'Average iou_seg: {:.4f}, mse_global: {:.1f}, mse_local: {:.1f}, mse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}'.format(
- iou_seg.avg, mse_global.avg, mse_local.avg, mse_overall.avg, pce.avg, spce.avg))
+ 'Average iou_seg: {:.4f}, rmse_global: {:.1f}, rmse_local: {:.1f}, rmse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}'.format(
+ iou_seg.avg, math.sqrt(mse_global.avg), math.sqrt(mse_local.avg), math.sqrt(mse_overall.avg), pce.avg, spce.avg))
print('Done testing')
diff --git a/src/test_3rd_phase.sh b/src/test_3rd_phase.sh
index cedf505..f0db93c 100755
--- a/src/test_3rd_phase.sh
+++ b/src/test_3rd_phase.sh
@@ -1,11 +1,30 @@
#!/bin/bash
+#SBATCH --partition=gpu
+#SBATCH --gres=gpu:1
+#SBATCH --job-name=August
+
+# python test.py \
+# --working-dir '../' \
+# --saved_fn 'ttnet_3rd_phase' \
+# --gpu_idx 0 \
+# --batch_size 1 \
+# --pretrained_path ../checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_epoch_30.pth \
+# --seg_thresh 0.5 \
+# --event_thresh 0.5 \
+# --smooth-labelling \
+# --thresh_ball_pos_mask 0.0001
+
+
python test.py \
- --working_dir '../' \
- --saved_fn 'ttnet_3rd_phase' \
+ --working-dir '../' \
+ --saved_fn 'ttnet_3rd_phase_128_320' \
--gpu_idx 0 \
- --batch_size 1 \
- --pretrained_path ../checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_epoch_30.pth \
+ --batch_size 32 \
+ --pretrained_path ../checkpoints/ttnet_3rd_phase_128_320/ttnet_3rd_phase_128_320_best.pth \
--seg_thresh 0.5 \
--event_thresh 0.5 \
- --smooth-labelling
\ No newline at end of file
+ --smooth_labelling \
+ --thresh_ball_pos_mask 0.00001 \
+ --no_seg \
+ --no_event \
diff --git a/src/train.sh b/src/train.sh
index 2aefbfa..8bc9173 100755
--- a/src/train.sh
+++ b/src/train.sh
@@ -1,32 +1,35 @@
#!/bin/bash
+#SBATCH --partition=gpu
+#SBATCH --gres=gpu:1
+#SBATCH --job-name=AugustTT
# The first phase: No local, no event
python main.py \
--working-dir '../' \
- --saved_fn 'ttnet_1st_phase' \
- --no-val \
- --batch_size 8 \
+ --saved_fn 'ttnet_1st_phase_128_320' \
+ --gpu_idx 0 \
+ --num_epochs 50\
+ --batch_size 128 \
--num_workers 4 \
--lr 0.001 \
--lr_type 'step_lr' \
--lr_step_size 10 \
--lr_factor 0.1 \
- --gpu_idx 0 \
--global_weight 5. \
--seg_weight 1. \
--no_local \
--no_event \
- --smooth-labelling
-
+ --no_test \
+ --smooth_labelling \
# The second phase: Freeze the segmentation and the global modules
python main.py \
--working-dir '../' \
- --saved_fn 'ttnet_2nd_phase' \
- --no-val \
- --batch_size 8 \
- --num_workers 4 \
+ --saved_fn 'ttnet_2nd_phase_128_320' \
+ --num_epochs 50\
+ --batch_size 32 \
+ --num_workers 10 \
--lr 0.001 \
--lr_type 'step_lr' \
--lr_step_size 10 \
@@ -36,28 +39,35 @@ python main.py \
--seg_weight 0. \
--event_weight 2. \
--local_weight 1. \
- --pretrained_path ../checkpoints/ttnet_1st_phase/ttnet_1st_phase_epoch_30.pth \
+ --pretrained_path ../checkpoints/ttnet_1st_phase_128_320/ttnet_1st_phase_128_320_best.pth \
--overwrite_global_2_local \
--freeze_seg \
--freeze_global \
- --smooth-labelling
+ --smooth_labelling \
+ --no_event \
+ --no_seg \
+ --no_test \
-# The third phase: Finetune all modules
+# # The third phase: Finetune all modules
python main.py \
--working-dir '../' \
- --saved_fn 'ttnet_3rd_phase' \
- --no-val \
- --batch_size 8 \
- --num_workers 4 \
+ --saved_fn 'ttnet_3rd_phase_128_320' \
+ --num_epochs 30\
+ --batch_size 32 \
+ --num_workers 10 \
--lr 0.0001 \
--lr_type 'step_lr' \
--lr_step_size 10 \
--lr_factor 0.2 \
--gpu_idx 0 \
--global_weight 1. \
+ --no_seg \
--seg_weight 1. \
--event_weight 1. \
--local_weight 1. \
- --pretrained_path ../checkpoints/ttnet_2nd_phase/ttnet_2nd_phase_epoch_30.pth \
- --smooth-labelling
\ No newline at end of file
+ --pretrained_path ../checkpoints/ttnet_2nd_phase_128_320/ttnet_2nd_phase_128_320_best.pth \
+ --smooth_labelling \
+ --no_event \
+ --no_seg \
+ --no_test \
\ No newline at end of file
diff --git a/src/train_1st_phase.sh b/src/train_1st_phase.sh
index f556647..a47840f 100755
--- a/src/train_1st_phase.sh
+++ b/src/train_1st_phase.sh
@@ -1,9 +1,11 @@
#!/bin/bash
+#SBATCH --partition=gpu
+#SBATCH --gres=gpu:1
+#SBATCH --job-name=August
python main.py \
--working-dir '../' \
--saved_fn 'ttnet_1st_phase' \
- --no-val \
--batch_size 8 \
--num_workers 4 \
--lr 0.001 \
diff --git a/src/train_2nd_phase.sh b/src/train_2nd_phase.sh
index 1e33b98..94aa521 100755
--- a/src/train_2nd_phase.sh
+++ b/src/train_2nd_phase.sh
@@ -1,4 +1,7 @@
#!/bin/bash
+#SBATCH --partition=gpu
+#SBATCH --gres=gpu:1
+#SBATCH --job-name=August
python main.py \
--working-dir '../' \
@@ -15,8 +18,9 @@ python main.py \
--seg_weight 0. \
--event_weight 2. \
--local_weight 1. \
- --pretrained_path ../checkpoints/ttnet_1st_phase/ttnet_1st_phase_epoch_30.pth \
+ --pretrained_path ../checkpoints/ttnet_1st_phase/ttnet_1st_phase_best.pth \
--overwrite_global_2_local \
--freeze_seg \
--freeze_global \
- --smooth-labelling
\ No newline at end of file
+ --smooth-labelling \
+ --thresh_ball_pos_mask 0.0001
\ No newline at end of file
diff --git a/src/train_3rd_phase.sh b/src/train_3rd_phase.sh
index 3459809..ba58afd 100755
--- a/src/train_3rd_phase.sh
+++ b/src/train_3rd_phase.sh
@@ -1,4 +1,8 @@
#!/bin/bash
+#SBATCH --partition=gpu
+#SBATCH --gres=gpu:1
+#SBATCH --job-name=August
+
python main.py \
--working-dir '../' \
diff --git a/src/utils/post_processing.py b/src/utils/post_processing.py
index e9ad8fd..804de85 100644
--- a/src/utils/post_processing.py
+++ b/src/utils/post_processing.py
@@ -22,6 +22,40 @@ def get_prediction_ball_pos(pred_ball, w, thresh_ball_pos_prob):
return (prediction_ball_x, prediction_ball_y)
+def get_prediction_ball_pos_right(pred_ball, thresh_ball_pos_prob):
+ # pred_ball is in shape ((b, w),(b,h))
+ # convert them into [b*(w,h)]
+ converted_pred_balls = [(pred_ball[0][i], pred_ball[1][i]) for i in range(pred_ball[0].shape[0])]
+ results = []
+ for converted_pred_ball in converted_pred_balls:
+ pred_ball_coords_x = converted_pred_ball[0].cpu()
+ pred_ball_coords_y = converted_pred_ball[1].cpu()
+ pred_ball_coords_x = pred_ball_coords_x.detach().numpy()
+ pred_ball_coords_y = pred_ball_coords_y.detach().numpy()
+ pred_ball_coords_x [pred_ball_coords_x < thresh_ball_pos_prob] = 0.
+ pred_ball_coords_y [pred_ball_coords_y < thresh_ball_pos_prob] = 0.
+
+ prediction_ball_x = np.argmax(pred_ball_coords_x)
+ prediction_ball_y = np.argmax(pred_ball_coords_y)
+ results.append([prediction_ball_x, prediction_ball_y])
+
+ return results
+
+def get_prediction_ball_pos_right_test(pred_ball, thresh_ball_pos_prob):
+ # pred_ball is in shape (h,w)
+ pred_ball_coords_x = pred_ball[0].cpu()
+ pred_ball_coords_y = pred_ball[1].cpu()
+ pred_ball_coords_x = pred_ball_coords_x.detach().numpy()
+ pred_ball_coords_y = pred_ball_coords_y.detach().numpy()
+ pred_ball_coords_x [pred_ball_coords_x < thresh_ball_pos_prob] = 0.
+ pred_ball_coords_y [pred_ball_coords_y < thresh_ball_pos_prob] = 0.
+
+ prediction_ball_x = np.argmax(pred_ball_coords_x)
+ prediction_ball_y = np.argmax(pred_ball_coords_y)
+
+
+ return (prediction_ball_x, prediction_ball_y)
+
def prediction_get_events(pred_events, event_thresh):
if pred_events.is_cuda:
@@ -36,6 +70,6 @@ def get_prediction_seg(pred_seg, seg_thresh):
if pred_seg.is_cuda:
pred_seg = pred_seg.cpu()
pred_seg = torch.squeeze(pred_seg).numpy().transpose(1, 2, 0)
- prediction_seg = (pred_seg > seg_thresh).astype(np.int)
+ prediction_seg = (pred_seg > seg_thresh).astype(int)
return prediction_seg