Skip to content

Commit

Permalink
fix: don't use getAllPrincipals (#19707)
Browse files Browse the repository at this point in the history
* fix: don't use getAllPrincipals

Signed-off-by: Morten Svanaes <[email protected]>

* fix: don't use getAllPrincipals

Signed-off-by: Morten Svanaes <[email protected]>

---------

Signed-off-by: Morten Svanaes <[email protected]>
  • Loading branch information
netroms authored Jan 17, 2025
1 parent 5e8867f commit a106bbc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@
import static com.google.common.base.Preconditions.checkNotNull;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.hisp.dhis.cache.Cache;
import org.hisp.dhis.cache.CacheProvider;
import org.hisp.dhis.organisationunit.OrganisationUnit;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

Expand All @@ -54,15 +51,11 @@ public class CurrentUserService {

private final Cache<CurrentUserGroupInfo> currentUserGroupInfoCache;

private final SessionRegistry sessionRegistry;

public CurrentUserService(
@Lazy UserStore userStore, CacheProvider cacheProvider, SessionRegistry sessionRegistry) {
public CurrentUserService(@Lazy UserStore userStore, CacheProvider cacheProvider) {
checkNotNull(userStore);

this.userStore = userStore;
this.currentUserGroupInfoCache = cacheProvider.createCurrentUserGroupInfoCache();
this.sessionRegistry = sessionRegistry;
}

/**
Expand Down Expand Up @@ -120,20 +113,4 @@ public void invalidateUserGroupCache(String userUID) {
// Ignore if key doesn't exist
}
}

public CurrentUserDetailsImpl getCurrentUserPrincipal(String uid) {
return sessionRegistry.getAllPrincipals().stream()
.map(CurrentUserDetailsImpl.class::cast)
.filter(principal -> principal.getUid().equals(uid))
.findFirst()
.orElse(null);
}

public void invalidateUserSessions(String uid) {
CurrentUserDetailsImpl principal = getCurrentUserPrincipal(uid);
if (principal != null) {
List<SessionInformation> allSessions = sessionRegistry.getAllSessions(principal, false);
allSessions.forEach(SessionInformation::expireNow);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -613,4 +613,11 @@ boolean canCurrentUserCanModify(
* @param activeUsername the username of the user to set as active
*/
void setActiveLinkedAccounts(@Nonnull String actingUser, @Nonnull String activeUsername);

/**
* Invalidate all sessions for the given user.
*
* @param userUid the user uid of the user account.
*/
void invalidateUserSessions(String userUid);
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
import org.hisp.dhis.util.ObjectUtils;
import org.jboss.aerogear.security.otp.api.Base32;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

Expand Down Expand Up @@ -114,6 +116,8 @@ public class DefaultUserService implements UserService {

private final Cache<Integer> twoFaDisableFailedAttemptCache;

private final SessionRegistry sessionRegistry;

public DefaultUserService(
UserStore userStore,
UserGroupService userGroupService,
Expand All @@ -124,7 +128,8 @@ public DefaultUserService(
@Lazy PasswordManager passwordManager,
@Lazy SecurityService securityService,
AclService aclService,
@Lazy OrganisationUnitService organisationUnitService) {
@Lazy OrganisationUnitService organisationUnitService,
SessionRegistry sessionRegistry) {
checkNotNull(userStore);
checkNotNull(userGroupService);
checkNotNull(userRoleStore);
Expand All @@ -133,6 +138,7 @@ public DefaultUserService(
checkNotNull(securityService);
checkNotNull(aclService);
checkNotNull(organisationUnitService);
checkNotNull(sessionRegistry);

this.userStore = userStore;
this.userGroupService = userGroupService;
Expand All @@ -145,6 +151,7 @@ public DefaultUserService(
this.aclService = aclService;
this.organisationUnitService = organisationUnitService;
this.twoFaDisableFailedAttemptCache = cacheProvider.createDisable2FAFailedAttemptCache(0);
this.sessionRegistry = sessionRegistry;
}

@Override
Expand Down Expand Up @@ -817,7 +824,7 @@ public void privilegedTwoFactorDisable(

@Override
public void expireActiveSessions(User user) {
currentUserService.invalidateUserSessions(user.getUid());
invalidateUserSessions(user.getUid());
}

@Override
Expand Down Expand Up @@ -1017,4 +1024,14 @@ public List<User> getUsersWithOrgUnit(
public void setActiveLinkedAccounts(@Nonnull String actingUser, @Nonnull String activeUsername) {
userStore.setActiveLinkedAccounts(actingUser, activeUsername);
}

@Override
public void invalidateUserSessions(String userUid) {
User user = userStore.getByUid(userUid);
CurrentUserDetailsImpl userDetails = createUserDetails(user, true, true);
if (userDetails != null) {
List<SessionInformation> allSessions = sessionRegistry.getAllSessions(userDetails, false);
allSessions.forEach(SessionInformation::expireNow);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public void postUpdate(User persistedUser, ObjectBundle bundle) {
userSettingService.saveUserSettings(persistedUser.getSettings(), persistedUser);

if (Boolean.TRUE.equals(invalidateSessions)) {
currentUserService.invalidateUserSessions(persistedUser.getUid());
userService.invalidateUserSessions(persistedUser.getUid());
}

bundle.removeExtras(persistedUser, PRE_UPDATE_USER_KEY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.hisp.dhis.dxf2.metadata.objectbundle.ObjectBundle;
import org.hisp.dhis.user.CurrentUserService;
import org.hisp.dhis.user.User;
import org.hisp.dhis.user.UserRole;
import org.hisp.dhis.user.UserService;
import org.springframework.stereotype.Component;

/**
Expand All @@ -47,7 +47,7 @@ public class UserRoleBundleHook extends AbstractObjectBundleHook<UserRole> {

public static final String INVALIDATE_SESSION_KEY = "shouldInvalidateUserSessions";

private final CurrentUserService currentUserService;
private final UserService userService;

@Override
public void preUpdate(UserRole update, UserRole existing, ObjectBundle bundle) {
Expand All @@ -68,7 +68,7 @@ public void postUpdate(UserRole updatedUserRole, ObjectBundle bundle) {

if (Boolean.TRUE.equals(invalidateSessions)) {
for (User user : updatedUserRole.getUsers()) {
currentUserService.invalidateUserSessions(user.getUid());
userService.invalidateUserSessions(user.getUid());
}
}

Expand Down

0 comments on commit a106bbc

Please sign in to comment.