-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add-support-for-non-trainable-params #456
add-support-for-non-trainable-params #456
Conversation
gsplat/strategy/ops.py
Outdated
new_param = param_fn(name, param, param.requires_grad) | ||
params[name] = new_param | ||
if name not in optimizers: | ||
assert not param.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a msg to explain. sth like assert not param.requires_grad, f"param {name} does not present in the optimizer, it's requires_grad should be False, but found True"
gsplat/strategy/ops.py
Outdated
@@ -46,7 +46,7 @@ def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Te | |||
|
|||
@torch.no_grad() | |||
def _update_param_with_optimizer( | |||
param_fn: Callable[[str, Tensor], Tensor], | |||
param_fn: Callable[[str, Tensor, bool], Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This breaks the API of this _update_param_with_optimizer
function. Though it is kinda ok because it's an internal function, there seem to be a simple way that wouldn't break it:
Instead of
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
return torch.nn.Parameter(p[sel], requires_grad=requires_grad)
We could do
def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad)
Happy to merge it after resolving these minor comments. And thanks for writting the test! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Adds support for non-trainable parameters in the strategy classes and operations.
Currently there is an assumption that all parameters are trainable in two places -
Assume there is a one to one mapping between parameter keys and optimiser keys
When updating parameters during densification all parameters are replaced with trainable counterparts.
This PR supports non trainable parameters by -
Asserting all trainable parameters have a 1-1 mapping with optimisers (this is inferred via the requires grad flag)
Persisting the retain grad flags when parameters are replaced during densification
Adds some tests to ensure that all parameters are densified and requires grad flag is persisted