diff --git a/conflictfree/grad_operator.py b/conflictfree/grad_operator.py index 898bec1..acec228 100644 --- a/conflictfree/grad_operator.py +++ b/conflictfree/grad_operator.py @@ -69,7 +69,7 @@ def ConFIG_update( grads:Union[torch.Tensor,Sequence[torch.Tensor]], weight_model:WeightModel=EqualWeight(), length_model:LengthModel=ProjectionLength(), - use_latest_square:bool=True, + use_least_square:bool=True, losses:Optional[Sequence]=None)-> torch.Tensor: """ Performs the standard ConFIG update step. @@ -81,7 +81,7 @@ def ConFIG_update( Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient. length_model (LengthModel, optional): The length model for rescaling the length of the final gradient. Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length. - use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction. + use_least_square (bool, optional): Whether to use the least square method for calculating the best direction. If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See `torch.linalg.pinv` and `torch.linalg.lstsq` for more details. Recommended to set to True. Defaults to True. losses (Optional[Sequence], optional): The losses associated with the gradients. @@ -113,7 +113,7 @@ def ConFIG_update( with torch.no_grad(): weights=weight_model.get_weights(gradients=grads,losses=losses,device=grads.device) units=torch.nan_to_num((grads/(grads.norm(dim=1)).unsqueeze(1)),0) - if use_latest_square: + if use_least_square: best_direction=torch.linalg.lstsq(units, weights).solution else: best_direction=torch.linalg.pinv(units)@weights @@ -185,7 +185,7 @@ class ConFIGOperator(GradientOperator): Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length. allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient. If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True. - use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction. + use_least_square (bool, optional): Whether to use the least square method for calculating the best direction. If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See `torch.linalg.pinv` and `torch.linalg.lstsq` for more details. Recommended to set to True. Defaults to True. @@ -213,12 +213,12 @@ def __init__(self, weight_model: WeightModel = EqualWeight(), length_model: LengthModel = ProjectionLength(), allow_simplified_model: bool = True, - use_latest_square: bool = True): + use_least_square: bool = True): super().__init__() self.weight_model = weight_model self.length_model = length_model self.allow_simplified_model = allow_simplified_model - self.use_latest_square = use_latest_square + self.use_least_square = use_least_square def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)->torch.Tensor: """ @@ -245,7 +245,7 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], return ConFIG_update(grads, weight_model=self.weight_model, length_model=self.length_model, - use_latest_square=self.use_latest_square, + use_least_square=self.use_least_square, losses=losses) class PCGradOperator(GradientOperator): diff --git a/setup.py b/setup.py index 09ee752..f9dbdf8 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_install_requires(): setuptools.setup( name="conflictfree", - version="0.1.3", + version="0.1.4", author="Qiang Liu, Mengyu Chu, Nils Thuerey", author_email="qiangliu.7@outlook.com", description="Official implementation of Conflict-free Inverse Gradients method",