Skip to content

Commit

Permalink
Fix aggregated queries
Browse files Browse the repository at this point in the history
  • Loading branch information
luciano-fiandesio committed Jan 20, 2025
1 parent c04265f commit 8cb7b23
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public Set<String> getCteKeys() {
return cteDefinitions.keySet();
}

public Set<String> getCteKeys(String... exclude) {
public Set<String> getCteKeysExcluding(String... exclude) {
final List<String> toExclude = List.of(exclude);
return cteDefinitions.keySet().stream()
.filter(key -> !toExclude.contains(key))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import static org.hisp.dhis.analytics.common.CteContext.ENROLLMENT_AGGR_BASE;
import static org.hisp.dhis.analytics.common.CteUtils.computeKey;
import static org.hisp.dhis.analytics.event.data.EnrollmentQueryHelper.getHeaderColumns;
import static org.hisp.dhis.analytics.event.data.EnrollmentQueryHelper.getOrgUnitLevelColumns;
import static org.hisp.dhis.analytics.event.data.EnrollmentQueryHelper.getPeriodColumns;
import static org.hisp.dhis.analytics.event.data.OrgUnitTableJoiner.joinOrgUnitTables;
import static org.hisp.dhis.analytics.util.AnalyticsUtils.withExceptionHandling;
import static org.hisp.dhis.common.DataDimensionType.ATTRIBUTE;
Expand Down Expand Up @@ -165,7 +167,7 @@ public void getEnrollments(EventQueryParams params, Grid grid, int maxLimit) {
? buildEnrollmentQueryWithCte(params)
: getAggregatedEnrollmentsSql(params, maxLimit);
}
System.out.println(sql);
System.out.println(sql); // FIXME remove
if (params.analyzeOnly()) {
withExceptionHandling(
() -> executionPlanStore.addExecutionPlan(params.getExplainOrderId(), sql));
Expand Down Expand Up @@ -1118,11 +1120,10 @@ private CteContext getCteDefinitions(EventQueryParams params) {
* @param headers the {@link GridHeader} list defining the query columns
* @return a {@link List} of column names included in the `enrollment_aggr_base` CTE
*/
private List<String> addBaseAggregationCte(
private void addBaseAggregationCte(
CteContext cteContext, EventQueryParams params, List<GridHeader> headers) {
// create base enrollment context
List<String> columns = new ArrayList<>();
List<String> rootColumns;
columns.add("enrollment");

addDimensionSelectColumns(columns, params, true);
Expand All @@ -1131,14 +1132,21 @@ private List<String> addBaseAggregationCte(
for (String column : Sets.newHashSet(columns)) {
sb.addColumn(SqlColumnParser.removeTableAlias(column));
}
List<String> programIndicators =
params.getItems().stream()
.filter(QueryItem::isProgramIndicator)
.map(QueryItem::getItemId)
.toList();
for (String column : headersCols) {
sb.addColumnIfNotExist(quote(SqlColumnParser.removeTableAlias(column)));
String colToAdd = SqlColumnParser.removeTableAlias(column);
if (!programIndicators.contains(colToAdd)) {
sb.addColumnIfNotExist(quote(colToAdd));
}
}

// return the name of the columns that are part
// of the original params, no need to return also
// the columns that are part of the filters
rootColumns = sb.getColumnNames();
sb.from(getFromClause(params));
sb.where(
Condition.and(
Expand All @@ -1155,8 +1163,6 @@ private List<String> addBaseAggregationCte(
// performance reasons
cteContext.addBaseAggregateCte(
sb.build(), SqlAliasReplacer.replaceTableAliases(sb.getWhereClause(), cols));

return rootColumns;
}

/**
Expand Down Expand Up @@ -1296,7 +1302,12 @@ private String buildAggregatedCteSql(
values.put("enrollmentAggrBase", ENROLLMENT_AGGR_BASE);
values.put("programStageUid", item.getProgramStage().getUid());
values.put(
"aggregateWhereClause", baseAggregatedCte.getAggregateWhereClause().replace("%s", "eb"));
"aggregateWhereClause",
baseAggregatedCte
.getAggregateWhereClause()
// Replace the "ax." alias (from subqueries) with empty string
.replace("ax.", "")
.replace("%s", "eb"));

return new StringSubstitutor(values).replace(template);
}
Expand Down Expand Up @@ -1450,7 +1461,7 @@ private String buildAggregatedEnrollmentQueryWithCte(

// add base aggregation CTE
// and retain the columns from the root query
List<String> rootQueryColumns = addBaseAggregationCte(cteContext, params, headers);
addBaseAggregationCte(cteContext, params, headers);

// Add CTE definitions for program indicators, program stages, etc.
getCteDefinitions(params, cteContext);
Expand All @@ -1460,43 +1471,81 @@ private String buildAggregatedEnrollmentQueryWithCte(
// add the CTE with clause based on the CTE definitions accumulated so far
addCteClause(sb, cteContext);

// add select clause
// SELECT columns in following order:
// 1) count(eb.enrollment) as value
// 2) org unit columns (orgColumns)
// 3) period columns (periodColumns)
// 4) header columns (headerColumns)
sb.addColumn("count(eb.enrollment) as value");

// add the columns from the root CTE query
// excluding the enrollment column
// note that the columns are also added to the group by clause
rootQueryColumns.stream()
.filter(col -> !col.equals("enrollment"))
.peek(sb::groupBy)
.forEach(sb::addColumn);

// add the columns from the CTE definitions
cteContext
.getCteKeys(ENROLLMENT_AGGR_BASE)
.forEach(
itemUid -> {
CteDefinition cteDef = cteContext.getDefinitionByItemUid(itemUid);
if (cteDef.isProgramStage()) {
String columnAlias = quote(cteDef.getProgramStageUid() + "." + cteDef.getItemId());
sb.addColumn(cteDef.getAlias() + ".value", "", columnAlias);
sb.groupBy(columnAlias);
}
});
addOrgUnitAggregateColumns(sb, params);
addPeriodAggregateColumns(params, sb);
addHeaderAggregateColumns(headers, cteContext, sb);

// add from
sb.from(ENROLLMENT_AGGR_BASE, "eb");

// Add join statements for each CTE definition
for (String itemUid : cteContext.getCteKeys(ENROLLMENT_AGGR_BASE)) {
for (String itemUid : cteContext.getCteKeysExcluding(ENROLLMENT_AGGR_BASE)) {
CteDefinition cteDef = cteContext.getDefinitionByItemUid(itemUid);
sb.leftJoin(
itemUid, cteDef.getAlias(), tableAlias -> tableAlias + ".enrollment = eb.enrollment");
}

return sb.build();
}

/**
* Add the columns specified in the headers to the SelectBuilder. The columns are added in the
* order specified in the headers and are based on existing CTE definitions.
*
* @param headers List of GridHeader objects
* @param cteContext CteContext object containing all CTE definitions
* @param sb SelectBuilder object to which the columns are added
*/
private void addHeaderAggregateColumns(
List<GridHeader> headers, CteContext cteContext, SelectBuilder sb) {
// Collect all columns from the headers
Set<String> headerColumns = getHeaderColumns(headers, "");
// Collect all CTE definitions for program indicators and program stages
Map<String, CteDefinition> cteDefinitionMap =
cteContext.getCteKeysExcluding(ENROLLMENT_AGGR_BASE).stream()
.map(cteContext::getDefinitionByItemUid)
.filter(def -> def.isProgramStage() || def.isProgramIndicator())
.collect(
Collectors.toMap(
cteDef ->
quote(
cteDef.isProgramIndicator()
? cteDef.getProgramIndicatorUid()
: cteDef.getProgramStageUid() + "." + cteDef.getItemId()),
cteDef -> cteDef));

// Iterate over headerColumns and add the columns to SelectBuilder based on the order specified
// in the original GridHeader list
headerColumns.forEach(
headerColumn -> {
boolean foundMatch = false;
String columnWithoutAlias = SqlColumnParser.removeTableAlias(headerColumn);

// First, check if there's any match in the CTE definitions
// If there is a match, the column is added with the alias from the CTE definition
for (Map.Entry<String, CteDefinition> entry : cteDefinitionMap.entrySet()) {
if (entry.getKey().contains(columnWithoutAlias)) {
CteDefinition cteDef = entry.getValue();
sb.addColumn(cteDef.getAlias() + ".value", "", entry.getKey());
sb.groupBy(entry.getKey());
foundMatch = true;
break;
}
}

if (!foundMatch) {
// Otherwise, add the column as is
sb.addColumn(quote(columnWithoutAlias));
sb.groupBy(quote(columnWithoutAlias));
}
});
}

private String buildEnrollmentQueryWithCte(EventQueryParams params) {

// 1. Create the CTE context (collect all CTE definitions for program indicators, program
Expand Down Expand Up @@ -1609,4 +1658,30 @@ protected String getSortClause(EventQueryParams params) {
}
return "";
}

private void addOrgUnitAggregateColumns(SelectBuilder sb, EventQueryParams params) {
Set<String> orgColumns = getOrgUnitLevelColumns(params);
if (!orgColumns.isEmpty()) {
// Add them *exactly in the old order*, then group by them
for (String orgColumn : orgColumns) {
sb.addColumn(orgColumn.trim());
sb.groupBy(orgColumn.trim());
}
} else {
// The old code always ensures we at least include ORGUNIT_DIM_ID if orgColumns is blank
sb.addColumn(ORGUNIT_DIM_ID);
sb.groupBy(ORGUNIT_DIM_ID);
}
}

private static void addPeriodAggregateColumns(EventQueryParams params, SelectBuilder sb) {
Set<String> periodColumns = getPeriodColumns(params);
if (!periodColumns.isEmpty()) {
for (String periodColumn : periodColumns) {
var col = SqlColumnParser.removeTableAlias(periodColumn.trim());
sb.addColumn(col);
sb.groupBy(col);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,47 @@
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SubSelect;

/**
* Utility class for handling SQL where clause transformations by replacing table aliases with
* standardized placeholders.
*
* <p>The class handles various SQL constructs including:
*
* <ul>
* <li>Basic column references with and without table aliases
* <li>Quoted identifiers (using ", `, or ')
* <li>Complex SQL functions and expressions
* <li>Subqueries with correlated references
* <li>Case expressions and mathematical operations
* </ul>
*
* <p>Example usage:
*
* <pre>
* List<String> columns = Arrays.asList("employee", "country");
* String whereClause = "ax.employee = 10 AND by.country = 'US'";
* String result = SqlAliasReplacer.replaceTableAliases(whereClause, columns);
* // Result: "%s.employee = 10 AND %s.country = 'US'"
* </pre>
*/
@UtilityClass
public class SqlAliasReplacer {

/**
* Replaces table aliases in a SQL where clause with standardized placeholders for specified
* columns. The method performs case-insensitive matching of column names and preserves the
* original SQL structure. For regular column references, it uses the "%s" placeholder. For outer
* table references in subqueries, it uses the "%z" placeholder.
*
* @param whereClause the SQL where clause to process. Must not be null.
* @param columns a list of column names to process. Must not be null.
* @return the processed where clause with replaced table aliases
* @throws IllegalArgumentException if whereClause or columns is null
* @throws RuntimeException if there is an error parsing the SQL where clause
* <p>Example Input: whereClause: "ax.salary > 1000 AND (SELECT COUNT(*) FROM emp WHERE
* emp.dept = ax.dept) > 0" columns: ["salary", "dept"] Output: "%s.salary > 1000 AND (SELECT
* COUNT(*) FROM emp WHERE emp.dept = %z.dept) > 0"
*/
public static String replaceTableAliases(String whereClause, List<String> columns) {
if (whereClause == null || columns == null) {
throw new IllegalArgumentException("Where clause and columns list cannot be null");
Expand All @@ -66,6 +104,8 @@ public static String replaceTableAliases(String whereClause, List<String> column
private static class ColumnReplacementVisitor extends ExpressionVisitorAdapter {
private final Set<String> columns;
private static final Table PLACEHOLDER_TABLE = new Table("%s");
private static final Table OUTER_REFERENCE_TABLE = new Table("%z"); // New constant

private boolean inSubQuery = false;

public ColumnReplacementVisitor(List<String> columns) {
Expand All @@ -80,6 +120,7 @@ public ColumnReplacementVisitor(List<String> columns) {
public void visit(Column column) {
String columnName = column.getColumnName();
String rawColumnName = stripQuotes(columnName);
Table currentTable = column.getTable();

if (columns.contains(rawColumnName.toLowerCase())) {
String quoteType = getQuoteType(columnName);
Expand All @@ -99,17 +140,14 @@ public void visit(SubSelect subSelect) {
boolean wasInSubQuery = inSubQuery;
inSubQuery = true;

if (subSelect.getSelectBody() instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) subSelect.getSelectBody();
if (subSelect.getSelectBody() instanceof PlainSelect plainSelect) {
plainSelect
.getSelectItems()
.forEach(
selectItem -> {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem sei = (SelectExpressionItem) selectItem;
if (selectItem instanceof SelectExpressionItem sei) {
Expression expression = sei.getExpression();
if (expression instanceof Function) {
Function function = (Function) expression;
if (expression instanceof Function function) {
function.accept(this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,8 @@ void testSubqueriesWithConditionalFunctions() {
"%s.salary > (SELECT AVG(CASE WHEN status = 'ACTIVE' THEN salary ELSE 0 END) FROM employees)";
assertEquals(expected, SqlAliasReplacer.replaceTableAliases(input, columns));
}

private String noEof(String sql) {
return sql.replaceAll("\\s+", " ").trim();
}
}

0 comments on commit 8cb7b23

Please sign in to comment.