diff --git a/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/common/CteContext.java b/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/common/CteContext.java index c882de5ea8a..2de517c59ac 100644 --- a/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/common/CteContext.java +++ b/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/common/CteContext.java @@ -134,7 +134,7 @@ public Set getCteKeys() { return cteDefinitions.keySet(); } - public Set getCteKeys(String... exclude) { + public Set getCteKeysExcluding(String... exclude) { final List toExclude = List.of(exclude); return cteDefinitions.keySet().stream() .filter(key -> !toExclude.contains(key)) diff --git a/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/event/data/JdbcEnrollmentAnalyticsManager.java b/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/event/data/JdbcEnrollmentAnalyticsManager.java index a695084990b..ddee5aadbc5 100644 --- a/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/event/data/JdbcEnrollmentAnalyticsManager.java +++ b/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/event/data/JdbcEnrollmentAnalyticsManager.java @@ -35,6 +35,8 @@ import static org.hisp.dhis.analytics.common.CteContext.ENROLLMENT_AGGR_BASE; import static org.hisp.dhis.analytics.common.CteUtils.computeKey; import static org.hisp.dhis.analytics.event.data.EnrollmentQueryHelper.getHeaderColumns; +import static org.hisp.dhis.analytics.event.data.EnrollmentQueryHelper.getOrgUnitLevelColumns; +import static org.hisp.dhis.analytics.event.data.EnrollmentQueryHelper.getPeriodColumns; import static org.hisp.dhis.analytics.event.data.OrgUnitTableJoiner.joinOrgUnitTables; import static org.hisp.dhis.analytics.util.AnalyticsUtils.withExceptionHandling; import static org.hisp.dhis.common.DataDimensionType.ATTRIBUTE; @@ -165,7 +167,7 @@ public void getEnrollments(EventQueryParams params, Grid grid, int maxLimit) { ? buildEnrollmentQueryWithCte(params) : getAggregatedEnrollmentsSql(params, maxLimit); } - System.out.println(sql); + System.out.println(sql); // FIXME remove if (params.analyzeOnly()) { withExceptionHandling( () -> executionPlanStore.addExecutionPlan(params.getExplainOrderId(), sql)); @@ -1118,11 +1120,10 @@ private CteContext getCteDefinitions(EventQueryParams params) { * @param headers the {@link GridHeader} list defining the query columns * @return a {@link List} of column names included in the `enrollment_aggr_base` CTE */ - private List addBaseAggregationCte( + private void addBaseAggregationCte( CteContext cteContext, EventQueryParams params, List headers) { // create base enrollment context List columns = new ArrayList<>(); - List rootColumns; columns.add("enrollment"); addDimensionSelectColumns(columns, params, true); @@ -1131,14 +1132,21 @@ private List addBaseAggregationCte( for (String column : Sets.newHashSet(columns)) { sb.addColumn(SqlColumnParser.removeTableAlias(column)); } + List programIndicators = + params.getItems().stream() + .filter(QueryItem::isProgramIndicator) + .map(QueryItem::getItemId) + .toList(); for (String column : headersCols) { - sb.addColumnIfNotExist(quote(SqlColumnParser.removeTableAlias(column))); + String colToAdd = SqlColumnParser.removeTableAlias(column); + if (!programIndicators.contains(colToAdd)) { + sb.addColumnIfNotExist(quote(colToAdd)); + } } // return the name of the columns that are part // of the original params, no need to return also // the columns that are part of the filters - rootColumns = sb.getColumnNames(); sb.from(getFromClause(params)); sb.where( Condition.and( @@ -1155,8 +1163,6 @@ private List addBaseAggregationCte( // performance reasons cteContext.addBaseAggregateCte( sb.build(), SqlAliasReplacer.replaceTableAliases(sb.getWhereClause(), cols)); - - return rootColumns; } /** @@ -1296,7 +1302,12 @@ private String buildAggregatedCteSql( values.put("enrollmentAggrBase", ENROLLMENT_AGGR_BASE); values.put("programStageUid", item.getProgramStage().getUid()); values.put( - "aggregateWhereClause", baseAggregatedCte.getAggregateWhereClause().replace("%s", "eb")); + "aggregateWhereClause", + baseAggregatedCte + .getAggregateWhereClause() + // Replace the "ax." alias (from subqueries) with empty string + .replace("ax.", "") + .replace("%s", "eb")); return new StringSubstitutor(values).replace(template); } @@ -1450,7 +1461,7 @@ private String buildAggregatedEnrollmentQueryWithCte( // add base aggregation CTE // and retain the columns from the root query - List rootQueryColumns = addBaseAggregationCte(cteContext, params, headers); + addBaseAggregationCte(cteContext, params, headers); // Add CTE definitions for program indicators, program stages, etc. getCteDefinitions(params, cteContext); @@ -1460,43 +1471,81 @@ private String buildAggregatedEnrollmentQueryWithCte( // add the CTE with clause based on the CTE definitions accumulated so far addCteClause(sb, cteContext); - // add select clause + // SELECT columns in following order: + // 1) count(eb.enrollment) as value + // 2) org unit columns (orgColumns) + // 3) period columns (periodColumns) + // 4) header columns (headerColumns) sb.addColumn("count(eb.enrollment) as value"); - - // add the columns from the root CTE query - // excluding the enrollment column - // note that the columns are also added to the group by clause - rootQueryColumns.stream() - .filter(col -> !col.equals("enrollment")) - .peek(sb::groupBy) - .forEach(sb::addColumn); - - // add the columns from the CTE definitions - cteContext - .getCteKeys(ENROLLMENT_AGGR_BASE) - .forEach( - itemUid -> { - CteDefinition cteDef = cteContext.getDefinitionByItemUid(itemUid); - if (cteDef.isProgramStage()) { - String columnAlias = quote(cteDef.getProgramStageUid() + "." + cteDef.getItemId()); - sb.addColumn(cteDef.getAlias() + ".value", "", columnAlias); - sb.groupBy(columnAlias); - } - }); + addOrgUnitAggregateColumns(sb, params); + addPeriodAggregateColumns(params, sb); + addHeaderAggregateColumns(headers, cteContext, sb); // add from sb.from(ENROLLMENT_AGGR_BASE, "eb"); // Add join statements for each CTE definition - for (String itemUid : cteContext.getCteKeys(ENROLLMENT_AGGR_BASE)) { + for (String itemUid : cteContext.getCteKeysExcluding(ENROLLMENT_AGGR_BASE)) { CteDefinition cteDef = cteContext.getDefinitionByItemUid(itemUid); sb.leftJoin( itemUid, cteDef.getAlias(), tableAlias -> tableAlias + ".enrollment = eb.enrollment"); } - return sb.build(); } + /** + * Add the columns specified in the headers to the SelectBuilder. The columns are added in the + * order specified in the headers and are based on existing CTE definitions. + * + * @param headers List of GridHeader objects + * @param cteContext CteContext object containing all CTE definitions + * @param sb SelectBuilder object to which the columns are added + */ + private void addHeaderAggregateColumns( + List headers, CteContext cteContext, SelectBuilder sb) { + // Collect all columns from the headers + Set headerColumns = getHeaderColumns(headers, ""); + // Collect all CTE definitions for program indicators and program stages + Map cteDefinitionMap = + cteContext.getCteKeysExcluding(ENROLLMENT_AGGR_BASE).stream() + .map(cteContext::getDefinitionByItemUid) + .filter(def -> def.isProgramStage() || def.isProgramIndicator()) + .collect( + Collectors.toMap( + cteDef -> + quote( + cteDef.isProgramIndicator() + ? cteDef.getProgramIndicatorUid() + : cteDef.getProgramStageUid() + "." + cteDef.getItemId()), + cteDef -> cteDef)); + + // Iterate over headerColumns and add the columns to SelectBuilder based on the order specified + // in the original GridHeader list + headerColumns.forEach( + headerColumn -> { + boolean foundMatch = false; + String columnWithoutAlias = SqlColumnParser.removeTableAlias(headerColumn); + + // First, check if there's any match in the CTE definitions + // If there is a match, the column is added with the alias from the CTE definition + for (Map.Entry entry : cteDefinitionMap.entrySet()) { + if (entry.getKey().contains(columnWithoutAlias)) { + CteDefinition cteDef = entry.getValue(); + sb.addColumn(cteDef.getAlias() + ".value", "", entry.getKey()); + sb.groupBy(entry.getKey()); + foundMatch = true; + break; + } + } + + if (!foundMatch) { + // Otherwise, add the column as is + sb.addColumn(quote(columnWithoutAlias)); + sb.groupBy(quote(columnWithoutAlias)); + } + }); + } + private String buildEnrollmentQueryWithCte(EventQueryParams params) { // 1. Create the CTE context (collect all CTE definitions for program indicators, program @@ -1609,4 +1658,30 @@ protected String getSortClause(EventQueryParams params) { } return ""; } + + private void addOrgUnitAggregateColumns(SelectBuilder sb, EventQueryParams params) { + Set orgColumns = getOrgUnitLevelColumns(params); + if (!orgColumns.isEmpty()) { + // Add them *exactly in the old order*, then group by them + for (String orgColumn : orgColumns) { + sb.addColumn(orgColumn.trim()); + sb.groupBy(orgColumn.trim()); + } + } else { + // The old code always ensures we at least include ORGUNIT_DIM_ID if orgColumns is blank + sb.addColumn(ORGUNIT_DIM_ID); + sb.groupBy(ORGUNIT_DIM_ID); + } + } + + private static void addPeriodAggregateColumns(EventQueryParams params, SelectBuilder sb) { + Set periodColumns = getPeriodColumns(params); + if (!periodColumns.isEmpty()) { + for (String periodColumn : periodColumns) { + var col = SqlColumnParser.removeTableAlias(periodColumn.trim()); + sb.addColumn(col); + sb.groupBy(col); + } + } + } } diff --git a/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacer.java b/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacer.java index 5069f06cbbe..ff00a111778 100644 --- a/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacer.java +++ b/dhis-2/dhis-services/dhis-service-analytics/src/main/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacer.java @@ -41,9 +41,47 @@ import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SubSelect; +/** + * Utility class for handling SQL where clause transformations by replacing table aliases with + * standardized placeholders. + * + *

The class handles various SQL constructs including: + * + *

    + *
  • Basic column references with and without table aliases + *
  • Quoted identifiers (using ", `, or ') + *
  • Complex SQL functions and expressions + *
  • Subqueries with correlated references + *
  • Case expressions and mathematical operations + *
+ * + *

Example usage: + * + *

+ * List columns = Arrays.asList("employee", "country");
+ * String whereClause = "ax.employee = 10 AND by.country = 'US'";
+ * String result = SqlAliasReplacer.replaceTableAliases(whereClause, columns);
+ * // Result: "%s.employee = 10 AND %s.country = 'US'"
+ * 
+ */ @UtilityClass public class SqlAliasReplacer { + /** + * Replaces table aliases in a SQL where clause with standardized placeholders for specified + * columns. The method performs case-insensitive matching of column names and preserves the + * original SQL structure. For regular column references, it uses the "%s" placeholder. For outer + * table references in subqueries, it uses the "%z" placeholder. + * + * @param whereClause the SQL where clause to process. Must not be null. + * @param columns a list of column names to process. Must not be null. + * @return the processed where clause with replaced table aliases + * @throws IllegalArgumentException if whereClause or columns is null + * @throws RuntimeException if there is an error parsing the SQL where clause + *

Example Input: whereClause: "ax.salary > 1000 AND (SELECT COUNT(*) FROM emp WHERE + * emp.dept = ax.dept) > 0" columns: ["salary", "dept"] Output: "%s.salary > 1000 AND (SELECT + * COUNT(*) FROM emp WHERE emp.dept = %z.dept) > 0" + */ public static String replaceTableAliases(String whereClause, List columns) { if (whereClause == null || columns == null) { throw new IllegalArgumentException("Where clause and columns list cannot be null"); @@ -66,6 +104,8 @@ public static String replaceTableAliases(String whereClause, List column private static class ColumnReplacementVisitor extends ExpressionVisitorAdapter { private final Set columns; private static final Table PLACEHOLDER_TABLE = new Table("%s"); + private static final Table OUTER_REFERENCE_TABLE = new Table("%z"); // New constant + private boolean inSubQuery = false; public ColumnReplacementVisitor(List columns) { @@ -80,6 +120,7 @@ public ColumnReplacementVisitor(List columns) { public void visit(Column column) { String columnName = column.getColumnName(); String rawColumnName = stripQuotes(columnName); + Table currentTable = column.getTable(); if (columns.contains(rawColumnName.toLowerCase())) { String quoteType = getQuoteType(columnName); @@ -99,17 +140,14 @@ public void visit(SubSelect subSelect) { boolean wasInSubQuery = inSubQuery; inSubQuery = true; - if (subSelect.getSelectBody() instanceof PlainSelect) { - PlainSelect plainSelect = (PlainSelect) subSelect.getSelectBody(); + if (subSelect.getSelectBody() instanceof PlainSelect plainSelect) { plainSelect .getSelectItems() .forEach( selectItem -> { - if (selectItem instanceof SelectExpressionItem) { - SelectExpressionItem sei = (SelectExpressionItem) selectItem; + if (selectItem instanceof SelectExpressionItem sei) { Expression expression = sei.getExpression(); - if (expression instanceof Function) { - Function function = (Function) expression; + if (expression instanceof Function function) { function.accept(this); } } diff --git a/dhis-2/dhis-services/dhis-service-analytics/src/test/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacerTest.java b/dhis-2/dhis-services/dhis-service-analytics/src/test/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacerTest.java index 2af5bf53fca..12bf9d9d530 100644 --- a/dhis-2/dhis-services/dhis-service-analytics/src/test/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacerTest.java +++ b/dhis-2/dhis-services/dhis-service-analytics/src/test/java/org/hisp/dhis/analytics/util/sql/SqlAliasReplacerTest.java @@ -248,4 +248,8 @@ void testSubqueriesWithConditionalFunctions() { "%s.salary > (SELECT AVG(CASE WHEN status = 'ACTIVE' THEN salary ELSE 0 END) FROM employees)"; assertEquals(expected, SqlAliasReplacer.replaceTableAliases(input, columns)); } + + private String noEof(String sql) { + return sql.replaceAll("\\s+", " ").trim(); + } }