-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathema.py
32 lines (26 loc) · 1.21 KB
/
ema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from copy import deepcopy
import torch
import torch.nn as nn
class ModelEMA(nn.Module):
"""
Based on Ross Wightmans ModelEMAV2: https://github.com/rwightman/pytorch-image-models/blob/02aaa785b97af5cbf22295033b4d3cc0137d8553/timm/utils/model_ema.py#L82
"""
def __init__(self, model, decay=0.9999, device=None):
super(ModelEMA, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)