Skip to content

Commit

Permalink
fix typo of latest square
Browse files Browse the repository at this point in the history
  • Loading branch information
qiauil committed Sep 15, 2024
1 parent 33f1113 commit 9050c3f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions conflictfree/grad_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def ConFIG_update(
grads:Union[torch.Tensor,Sequence[torch.Tensor]],
weight_model:WeightModel=EqualWeight(),
length_model:LengthModel=ProjectionLength(),
use_latest_square:bool=True,
use_least_square:bool=True,
losses:Optional[Sequence]=None)-> torch.Tensor:
"""
Performs the standard ConFIG update step.
Expand All @@ -81,7 +81,7 @@ def ConFIG_update(
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
use_least_square (bool, optional): Whether to use the least square method for calculating the best direction.
If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See `torch.linalg.pinv` and `torch.linalg.lstsq` for more details.
Recommended to set to True. Defaults to True.
losses (Optional[Sequence], optional): The losses associated with the gradients.
Expand Down Expand Up @@ -113,7 +113,7 @@ def ConFIG_update(
with torch.no_grad():
weights=weight_model.get_weights(gradients=grads,losses=losses,device=grads.device)
units=torch.nan_to_num((grads/(grads.norm(dim=1)).unsqueeze(1)),0)
if use_latest_square:
if use_least_square:
best_direction=torch.linalg.lstsq(units, weights).solution
else:
best_direction=torch.linalg.pinv(units)@weights
Expand Down Expand Up @@ -185,7 +185,7 @@ class ConFIGOperator(GradientOperator):
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient.
If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True.
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
use_least_square (bool, optional): Whether to use the least square method for calculating the best direction.
If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See `torch.linalg.pinv` and `torch.linalg.lstsq` for more details.
Recommended to set to True. Defaults to True.
Expand Down Expand Up @@ -213,12 +213,12 @@ def __init__(self,
weight_model: WeightModel = EqualWeight(),
length_model: LengthModel = ProjectionLength(),
allow_simplified_model: bool = True,
use_latest_square: bool = True):
use_least_square: bool = True):
super().__init__()
self.weight_model = weight_model
self.length_model = length_model
self.allow_simplified_model = allow_simplified_model
self.use_latest_square = use_latest_square
self.use_least_square = use_least_square

def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)->torch.Tensor:
"""
Expand All @@ -245,7 +245,7 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],
return ConFIG_update(grads,
weight_model=self.weight_model,
length_model=self.length_model,
use_latest_square=self.use_latest_square,
use_least_square=self.use_least_square,
losses=losses)

class PCGradOperator(GradientOperator):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_install_requires():

setuptools.setup(
name="conflictfree",
version="0.1.3",
version="0.1.4",
author="Qiang Liu, Mengyu Chu, Nils Thuerey",
author_email="[email protected]",
description="Official implementation of Conflict-free Inverse Gradients method",
Expand Down

0 comments on commit 9050c3f

Please sign in to comment.