From ba82e1255b301d92eee9e1ad36e44e07afdb3839 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 30 Jul 2024 13:51:35 -0700 Subject: [PATCH 01/12] Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetada method (#2866) * Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetadata method Signed-off-by: Tomoyuki Morita * Add comments Signed-off-by: Tomoyuki Morita * Fix style Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita --- .../model/AsyncQueryRequestContext.java | 6 +- .../dispatcher/SparkQueryDispatcher.java | 2 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 3 +- .../dispatcher/SparkQueryDispatcherTest.java | 72 ++++++++++++------- .../sql/datasource/DataSourceService.java | 5 +- .../sql/datasource/RequestContext.java | 15 ++++ .../sql/analysis/AnalyzerTestBase.java | 4 +- .../service/DataSourceServiceImpl.java | 4 +- .../service/DataSourceServiceImplTest.java | 8 ++- 9 files changed, 84 insertions(+), 35 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/datasource/RequestContext.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java index 56176faefb..d5a478d592 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.asyncquery.model; +import org.opensearch.sql.datasource.RequestContext; + /** Context interface to provide additional request related information */ -public interface AsyncQueryRequestContext { - Object getAttribute(String name); -} +public interface AsyncQueryRequestContext extends RequestContext {} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0e871f9ddc..0061ea7179 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -44,7 +44,7 @@ public DispatchQueryResponse dispatch( AsyncQueryRequestContext asyncQueryRequestContext) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - dispatchQueryRequest.getDatasource()); + dispatchQueryRequest.getDatasource(), asyncQueryRequestContext); if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { String query = dispatchQueryRequest.getQuery(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 99d4cc722e..34ededc74d 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -512,7 +512,8 @@ private void givenFlintIndexMetadataExists(String indexName) { } private void givenValidDataSourceMetadataExist() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(DATASOURCE_NAME)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn( new DataSourceMetadata.Builder() .setName(DATASOURCE_NAME) diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index f9a83ef9f6..a7a79c758e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -180,7 +180,8 @@ void testDispatchSelectQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -223,7 +224,8 @@ void testDispatchSelectQueryWithLakeFormation() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithLakeFormation(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -255,7 +257,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -278,7 +281,8 @@ void testDispatchSelectQueryCreateNewSession() { doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -304,7 +308,8 @@ void testDispatchSelectQueryReuseSession() { when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); when(session.isOperationalForDataSource(any())).thenReturn(true); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -324,7 +329,8 @@ void testDispatchSelectQueryFailedCreateSession() { doReturn(true).when(sessionManager).isEnabled(); doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any()); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); Assertions.assertThrows( @@ -358,7 +364,8 @@ void testDispatchCreateAutoRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -393,7 +400,8 @@ void testDispatchCreateManualRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -426,7 +434,8 @@ void testDispatchWithPPLQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -450,7 +459,8 @@ void testDispatchWithSparkUDFQuery() { "CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'"); for (String query : udfQueries) { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); IllegalArgumentException illegalArgumentException = @@ -489,7 +499,8 @@ void testInvalidSQLQueryDispatchToSpark() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -532,7 +543,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -568,7 +580,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -589,8 +602,7 @@ void testDispatchMaterializedViewQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); String query = - "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" - + " (auto_refresh = true)"; + "CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( @@ -604,7 +616,8 @@ void testDispatchMaterializedViewQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -637,7 +650,8 @@ void testDispatchShowMVQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -670,7 +684,8 @@ void testRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -703,7 +718,8 @@ void testDispatchDescribeIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -739,7 +755,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -762,7 +779,8 @@ void testDispatchAlterToManualRefreshIndexQuery() { "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = false)"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -785,7 +803,8 @@ void testDispatchDropIndexQuery() { String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -808,7 +827,8 @@ void testDispatchVacuumIndexQuery() { String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -824,7 +844,8 @@ void testDispatchVacuumIndexQuery() { @Test void testDispatchWithUnSupportedDataSourceType() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_prometheus", asyncQueryRequestContext)) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; @@ -1018,7 +1039,8 @@ void testGetQueryResponseWithSuccess() { void testDispatchQueryWithExtraSparkSubmitParameters() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); String extraParameters = "--conf spark.dynamicAllocation.enabled=false"; diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 6af5d19e5c..a8caa4719a 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -82,6 +82,9 @@ public interface DataSourceService { * Specifically for addressing use cases in SparkQueryDispatcher. * * @param dataSourceName of the {@link DataSource} + * @param context request context used by the implementation. It is passed by async-query-core. + * refer {@link RequestContext} */ - DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName); + DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext context); } diff --git a/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java b/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java new file mode 100644 index 0000000000..199930d340 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.datasource; + +/** + * Context interface to provide additional request related information. It is introduced to allow + * async-query-core library user to pass request context information to implementations of data + * accessors. + */ +public interface RequestContext { + Object getAttribute(String name); +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index b35cfbb5e1..0bf959a1b7 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -28,6 +28,7 @@ import org.opensearch.sql.config.TestConfig; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; @@ -236,7 +237,8 @@ public Boolean dataSourceExists(String dataSourceName) { } @Override - public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) { + public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext requestContext) { return null; } } diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 61f3c8cd5d..81b6432891 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -11,6 +11,7 @@ import java.util.*; import java.util.stream.Collectors; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceStatus; @@ -122,7 +123,8 @@ public Boolean dataSourceExists(String dataSourceName) { } @Override - public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) { + public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext requestContext) { DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); verifyDataSourceAccess(dataSourceMetadata); return dataSourceMetadata; diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 5a94945e5b..9a1022706f 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -36,6 +36,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceStatus; @@ -52,6 +53,7 @@ class DataSourceServiceImplTest { @Mock private DataSourceFactory dataSourceFactory; @Mock private StorageEngine storageEngine; @Mock private DataSourceMetadataStorage dataSourceMetadataStorage; + @Mock private RequestContext requestContext; @Mock private DataSourceUserAuthorizationHelper dataSourceUserAuthorizationHelper; @@ -461,7 +463,9 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadataWithDisabledData() { DatasourceDisabledException datasourceDisabledException = Assertions.assertThrows( DatasourceDisabledException.class, - () -> dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS")); + () -> + dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "testDS", requestContext)); Assertions.assertEquals( "Datasource testDS is disabled.", datasourceDisabledException.getMessage()); } @@ -484,7 +488,7 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadata() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); DataSourceMetadata dataSourceMetadata1 = - dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS"); + dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS", requestContext); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.uri")); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username")); From 103c4160ae2a129284ff5be70bac40951e0c6a18 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Tue, 30 Jul 2024 15:20:19 -0700 Subject: [PATCH 02/12] Fixed 2.16 integ test failures (#2871) Signed-off-by: Vamsi Manohar --- .../opensearch/sql/datasource/DataSourceEnabledIT.java | 10 ++++++++++ .../sql/legacy/OpenSearchSQLRestTestCase.java | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java index 9c522134a4..b0bc87a0c6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java @@ -38,6 +38,7 @@ public void testDataSourceCreationWithDefaultSettings() { assertDataSourceCount(1); assertSelectFromDataSourceReturnsSuccess(); assertSelectFromDummyIndexInValidDataSourceDataSourceReturnsDoesNotExist(); + deleteSelfDataSourceCreated(); } @Test @@ -52,6 +53,8 @@ public void testAfterPreviousEnable() { assertDataSourceCount(0); assertSelectFromDataSourceReturnsDoesNotExist(); assertAsyncQueryApiDisabled(); + setDataSourcesEnabled("transient", true); + deleteSelfDataSourceCreated(); } @SneakyThrows @@ -142,4 +145,11 @@ private Response performRequest(Request request) { return e.getResponse(); } } + + @SneakyThrows + private void deleteSelfDataSourceCreated() { + Request deleteRequest = getDeleteDataSourceRequest("self"); + Response deleteResponse = client().performRequest(deleteRequest); + Assert.assertEquals(204, deleteResponse.getStatusLine().getStatusCode()); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java index d73e3468d4..ced69d54a0 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java @@ -195,7 +195,9 @@ protected static void wipeAllOpenSearchIndices(RestClient client) throws IOExcep try { // System index, mostly named .opensearch-xxx or .opendistro-xxx, are not allowed to // delete - if (!indexName.startsWith(".opensearch") && !indexName.startsWith(".opendistro")) { + if (!indexName.startsWith(".opensearch") + && !indexName.startsWith(".opendistro") + && !indexName.startsWith(".ql")) { client.performRequest(new Request("DELETE", "/" + indexName)); } } catch (Exception e) { From aa7a6902a8d03647eecf45564c488c936c53ee3f Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 31 Jul 2024 17:05:28 +0800 Subject: [PATCH 03/12] Change the default value of plugins.query.size_limit to MAX_RESULT_WINDOW (10000) (#2860) * Change the default value of plugins.query.size_limit to MAX_RESULT_WINDOW (10000) Signed-off-by: Lantao Jin * fix ut Signed-off-by: Lantao Jin * fix spotless Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- docs/user/admin/settings.rst | 2 +- docs/user/optimization/optimization.rst | 22 +++++++++---------- docs/user/ppl/admin/settings.rst | 4 ++-- docs/user/ppl/interfaces/endpoint.rst | 2 +- .../org/opensearch/sql/legacy/ExplainIT.java | 2 +- .../setting/OpenSearchSettings.java | 3 ++- .../setting/OpenSearchSettingsTest.java | 6 ++--- 7 files changed, 21 insertions(+), 20 deletions(-) diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index 662d882745..6b24e41f87 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -202,7 +202,7 @@ plugins.query.size_limit Description ----------- -The new engine fetches a default size of index from OpenSearch set by this setting, the default value is 200. You can change the value to any value not greater than the max result window value in index level (10000 by default), here is an example:: +The new engine fetches a default size of index from OpenSearch set by this setting, the default value equals to max result window in index level (10000 by default). You can change the value to any value not greater than the max result window value in index level (`index.max_result_window`), here is an example:: >> curl -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings -d '{ "transient" : { diff --git a/docs/user/optimization/optimization.rst b/docs/user/optimization/optimization.rst index 8ab998309d..835fe96eba 100644 --- a/docs/user/optimization/optimization.rst +++ b/docs/user/optimization/optimization.rst @@ -44,7 +44,7 @@ The consecutive Filter operator will be merged as one Filter operator:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, searchDone=false)" }, "children": [] } @@ -71,7 +71,7 @@ The Filter operator should be push down under Sort operator:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)" }, "children": [] } @@ -102,7 +102,7 @@ The Project list will push down to Query DSL to `filter the source `_. +Without sort push down optimization, the sort operator will sort the result from child operator. By default, only 10000 docs will extracted from the source index, `you can change this value by using size_limit setting <../admin/settings.rst#opensearch-query-size-limit>`_. diff --git a/docs/user/ppl/admin/settings.rst b/docs/user/ppl/admin/settings.rst index ad56408693..28e6897d3d 100644 --- a/docs/user/ppl/admin/settings.rst +++ b/docs/user/ppl/admin/settings.rst @@ -125,9 +125,9 @@ plugins.query.size_limit Description ----------- -The size configure the maximum amount of documents to be pull from OpenSearch. The default value is: 200 +The size configure the maximum amount of documents to be pull from OpenSearch. The default value is: 10000 -Notes: This setting will impact the correctness of the aggregation operation, for example, there are 1000 docs in the index, by default, only 200 docs will be extract from index and do aggregation. +Notes: This setting will impact the correctness of the aggregation operation, for example, there are 1000 docs in the index, if you change the value to 200, only 200 docs will be extract from index and do aggregation. Example ------- diff --git a/docs/user/ppl/interfaces/endpoint.rst b/docs/user/ppl/interfaces/endpoint.rst index 793b94eb8d..fb931fb0ba 100644 --- a/docs/user/ppl/interfaces/endpoint.rst +++ b/docs/user/ppl/interfaces/endpoint.rst @@ -91,7 +91,7 @@ The following PPL query demonstrated that where and stats command were pushed do { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" }, "children": [] } diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java index b42e9f84f4..27f8eca3ef 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java @@ -185,7 +185,7 @@ public void orderByOnNestedFieldTest() throws Exception { Assert.assertThat( result.replaceAll("\\s+", ""), equalTo( - "{\"from\":0,\"size\":200,\"sort\":[{\"message.info\":" + "{\"from\":0,\"size\":10000,\"sort\":[{\"message.info\":" + "{\"order\":\"asc\",\"nested\":{\"path\":\"message\"}}}]}")); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index b4ce82a828..475a584623 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -28,6 +28,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.MemorySizeValue; import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.IndexSettings; import org.opensearch.sql.common.setting.LegacySettings; import org.opensearch.sql.common.setting.Settings; @@ -90,7 +91,7 @@ public class OpenSearchSettings extends Settings { public static final Setting QUERY_SIZE_LIMIT_SETTING = Setting.intSetting( Key.QUERY_SIZE_LIMIT.getKeyValue(), - LegacyOpenDistroSettings.QUERY_SIZE_LIMIT_SETTING, + IndexSettings.MAX_RESULT_WINDOW_SETTING, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java index e99e5b360a..84fb705ae0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java @@ -34,6 +34,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.index.IndexSettings; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.sql.common.setting.LegacySettings; import org.opensearch.sql.common.setting.Settings; @@ -132,8 +133,7 @@ void settingsFallback() { org.opensearch.common.settings.Settings.EMPTY)); assertEquals( settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT), - LegacyOpenDistroSettings.QUERY_SIZE_LIMIT_SETTING.get( - org.opensearch.common.settings.Settings.EMPTY)); + IndexSettings.MAX_RESULT_WINDOW_SETTING.get(org.opensearch.common.settings.Settings.EMPTY)); assertEquals( settings.getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW), LegacyOpenDistroSettings.METRICS_ROLLING_WINDOW_SETTING.get( @@ -165,7 +165,7 @@ public void updateLegacySettingsFallback() { assertEquals( QUERY_MEMORY_LIMIT_SETTING.get(settings), new ByteSizeValue((int) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.2))); - assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 100); + assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 10000); assertEquals(METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L); assertEquals(METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L); } From 53bfeba8ffa0a79027c06fbb6157fa740333d5df Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Wed, 31 Jul 2024 15:08:35 -0700 Subject: [PATCH 04/12] Add AsyncQueryRequestContext to QueryIdProvider parameter (#2870) Signed-off-by: Tomoyuki Morita --- .../DatasourceEmbeddedQueryIdProvider.java | 5 ++- .../sql/spark/dispatcher/QueryIdProvider.java | 4 ++- .../dispatcher/SparkQueryDispatcher.java | 12 ++++--- .../asyncquery/AsyncQueryCoreIntegTest.java | 17 ++++----- ...DatasourceEmbeddedQueryIdProviderTest.java | 35 +++++++++++++++++++ 5 files changed, 59 insertions(+), 14 deletions(-) create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java index c170040718..3564fa9552 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.utils.IDUtils; @@ -12,7 +13,9 @@ public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider { @Override - public String getQueryId(DispatchQueryRequest dispatchQueryRequest) { + public String getQueryId( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext) { return IDUtils.encode(dispatchQueryRequest.getDatasource()); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java index 2167eb6b7a..a108ca1209 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java @@ -5,9 +5,11 @@ package org.opensearch.sql.spark.dispatcher; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; /** Interface for extension point to specify queryId. Called when new query is executed. */ public interface QueryIdProvider { - String getQueryId(DispatchQueryRequest dispatchQueryRequest); + String getQueryId( + DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0061ea7179..a424db4c34 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -69,7 +69,8 @@ private DispatchQueryResponse handleFlintExtensionQuery( DataSourceMetadata dataSourceMetadata) { IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + getDefaultDispatchContextBuilder( + dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext) .indexQueryDetails(indexQueryDetails) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); @@ -84,7 +85,8 @@ private DispatchQueryResponse handleDefaultQuery( DataSourceMetadata dataSourceMetadata) { DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + getDefaultDispatchContextBuilder( + dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); @@ -93,11 +95,13 @@ private DispatchQueryResponse handleDefaultQuery( } private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( - DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { + DispatchQueryRequest dispatchQueryRequest, + DataSourceMetadata dataSourceMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return DispatchQueryContext.builder() .dataSourceMetadata(dataSourceMetadata) .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(queryIdProvider.getQueryId(dispatchQueryRequest)); + .queryId(queryIdProvider.getQueryId(dispatchQueryRequest, asyncQueryRequestContext)); } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 34ededc74d..d82d3bdab7 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -185,7 +185,7 @@ public void setUp() { public void createDropIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); givenCancelJobRunSucceed(); @@ -209,7 +209,7 @@ public void createDropIndexQuery() { public void createVacuumIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); @@ -231,7 +231,7 @@ public void createVacuumIndexQuery() { public void createAlterIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); givenCancelJobRunSucceed(); @@ -261,7 +261,7 @@ public void createAlterIndexQuery() { public void createStreamingQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -297,7 +297,7 @@ private void verifyStartJobRunCalled() { public void createCreateIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -321,7 +321,7 @@ public void createCreateIndexQuery() { public void createRefreshQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -344,7 +344,7 @@ public void createInteractiveQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); givenSessionExists(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(sessionIdProvider.getSessionId(any())).thenReturn(SESSION_ID); givenSessionExists(); // called twice when(awsemrServerless.startJobRun(any())) @@ -538,7 +538,8 @@ private void givenGetJobRunReturnJobRunWithState(String state) { } private void verifyGetQueryIdCalled() { - verify(queryIdProvider).getQueryId(dispatchQueryRequestArgumentCaptor.capture()); + verify(queryIdProvider) + .getQueryId(dispatchQueryRequestArgumentCaptor.capture(), eq(asyncQueryRequestContext)); DispatchQueryRequest dispatchQueryRequest = dispatchQueryRequestArgumentCaptor.getValue(); assertEquals(ACCOUNT_ID, dispatchQueryRequest.getAccountId()); assertEquals(APPLICATION_ID, dispatchQueryRequest.getApplicationId()); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java new file mode 100644 index 0000000000..7f1c92dff3 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class DatasourceEmbeddedQueryIdProviderTest { + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + + DatasourceEmbeddedQueryIdProvider datasourceEmbeddedQueryIdProvider = + new DatasourceEmbeddedQueryIdProvider(); + + @Test + public void test() { + String queryId = + datasourceEmbeddedQueryIdProvider.getQueryId( + DispatchQueryRequest.builder().datasource("DATASOURCE").build(), + asyncQueryRequestContext); + + assertNotNull(queryId); + verifyNoInteractions(asyncQueryRequestContext); + } +} From 1b17520d79ee6e90e3994298e3998adc263b02b0 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Wed, 31 Jul 2024 15:59:37 -0700 Subject: [PATCH 05/12] Fixed integ test delete myindex issue and wipe All indices with security enabled domain (#2878) Signed-off-by: Vamsi Manohar --- .../sql/datasource/DataSourceEnabledIT.java | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java index b0bc87a0c6..a53c04d871 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java @@ -7,8 +7,10 @@ import static org.opensearch.sql.legacy.TestUtils.getResponseBody; +import java.io.IOException; import lombok.SneakyThrows; import org.json.JSONObject; +import org.junit.After; import org.junit.Assert; import org.junit.Test; import org.opensearch.client.Request; @@ -18,9 +20,9 @@ public class DataSourceEnabledIT extends PPLIntegTestCase { - @Override - protected boolean preserveClusterUponCompletion() { - return false; + @After + public void cleanUp() throws IOException { + wipeAllClusterSettings(); } @Test @@ -39,6 +41,7 @@ public void testDataSourceCreationWithDefaultSettings() { assertSelectFromDataSourceReturnsSuccess(); assertSelectFromDummyIndexInValidDataSourceDataSourceReturnsDoesNotExist(); deleteSelfDataSourceCreated(); + deleteIndex(); } @Test @@ -55,6 +58,7 @@ public void testAfterPreviousEnable() { assertAsyncQueryApiDisabled(); setDataSourcesEnabled("transient", true); deleteSelfDataSourceCreated(); + deleteIndex(); } @SneakyThrows @@ -98,6 +102,12 @@ private void createIndex() { Assert.assertEquals(200, response.getStatusLine().getStatusCode()); } + private void deleteIndex() { + Request request = new Request("DELETE", "/myindex"); + Response response = performRequest(request); + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + } + private void createOpenSearchDataSource() { Request request = new Request("POST", "/_plugins/_query/_datasources"); request.setJsonEntity( From 3daf64fbce5a8d29e846669689b3a7b12c5c7f07 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 31 Jul 2024 16:10:13 -0700 Subject: [PATCH 06/12] [Feature] Flint query scheduler part1 - integrate job scheduler plugin (#2834) * [Feature] Flint query scheduler part1 - integrate job scheduler plugin Signed-off-by: Louis Chu * Add comments Signed-off-by: Louis Chu * Add unit test Signed-off-by: Louis Chu * Remove test rest API Signed-off-by: Louis Chu * Fix doc test Signed-off-by: Louis Chu * Add more tests Signed-off-by: Louis Chu * Fix IT Signed-off-by: Louis Chu * Fix IT with security Signed-off-by: Louis Chu * Improve test coverage Signed-off-by: Louis Chu * Fix integTest cluster Signed-off-by: Louis Chu * Fix UT Signed-off-by: Louis Chu * Update UT Signed-off-by: Louis Chu * Fix bwc test Signed-off-by: Louis Chu * Resolve comments Signed-off-by: Louis Chu * Fix bwc test Signed-off-by: Louis Chu * clean up doc test Signed-off-by: Louis Chu * Resolve comments Signed-off-by: Louis Chu * Fix UT Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .gitignore | 1 + .../src/main/antlr/SqlBaseParser.g4 | 17 +- async-query/build.gradle | 3 + .../OpenSearchAsyncQueryScheduler.java | 197 ++++++++ ...penSearchRefreshIndexJobRequestParser.java | 71 +++ .../job/OpenSearchRefreshIndexJob.java | 93 ++++ .../OpenSearchRefreshIndexJobRequest.java | 108 +++++ .../async-query-scheduler-index-mapping.yml | 41 ++ .../async-query-scheduler-index-settings.yml | 11 + .../OpenSearchAsyncQuerySchedulerTest.java | 434 ++++++++++++++++++ .../job/OpenSearchRefreshIndexJobTest.java | 145 ++++++ .../OpenSearchRefreshIndexJobRequestTest.java | 81 ++++ build.gradle | 3 +- common/build.gradle | 4 +- core/build.gradle | 2 +- doctest/build.gradle | 53 +++ integ-test/build.gradle | 65 ++- legacy/build.gradle | 2 +- plugin/build.gradle | 11 +- .../org/opensearch/sql/plugin/SQLPlugin.java | 32 +- ...rch.jobscheduler.spi.JobSchedulerExtension | 6 + ppl/build.gradle | 2 +- protocol/build.gradle | 2 +- sql/build.gradle | 2 +- 24 files changed, 1357 insertions(+), 29 deletions(-) create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java create mode 100644 async-query/src/main/resources/async-query-scheduler-index-mapping.yml create mode 100644 async-query/src/main/resources/async-query-scheduler-index-settings.yml create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java create mode 100644 plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension diff --git a/.gitignore b/.gitignore index 1b892036dd..b9775dea04 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,5 @@ gen .worktrees http-client.env.json /doctest/sql-cli/ +/doctest/opensearch-job-scheduler/ .factorypath diff --git a/async-query-core/src/main/antlr/SqlBaseParser.g4 b/async-query-core/src/main/antlr/SqlBaseParser.g4 index a50051715e..c7aa56cf92 100644 --- a/async-query-core/src/main/antlr/SqlBaseParser.g4 +++ b/async-query-core/src/main/antlr/SqlBaseParser.g4 @@ -66,8 +66,8 @@ compoundStatement ; setStatementWithOptionalVarKeyword - : SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword - | SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + : SET variable? assignmentList #setVariableWithOptionalKeyword + | SET variable? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword ; @@ -215,9 +215,9 @@ statement routineCharacteristics RETURN (query | expression) #createUserDefinedFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction - | DECLARE (OR REPLACE)? VARIABLE? + | DECLARE (OR REPLACE)? variable? identifierReference dataType? variableDefaultExpression? #createVariable - | DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable + | DROP TEMPORARY variable (IF EXISTS)? identifierReference #dropVariable | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? (statement|setResetStatement) #explain | SHOW TABLES ((FROM | IN) identifierReference)? @@ -272,8 +272,8 @@ setResetStatement | SET TIME ZONE interval #setTimeZone | SET TIME ZONE timezone #setTimeZone | SET TIME ZONE .*? #setTimeZone - | SET (VARIABLE | VAR) assignmentList #setVariable - | SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + | SET variable assignmentList #setVariable + | SET variable LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ LEFT_PAREN query RIGHT_PAREN #setVariable | SET configKey EQ configValue #setQuotedConfiguration | SET configKey (EQ .*?)? #setConfiguration @@ -438,6 +438,11 @@ namespaces | SCHEMAS ; +variable + : VARIABLE + | VAR + ; + describeFuncName : identifierReference | stringLit diff --git a/async-query/build.gradle b/async-query/build.gradle index 5a4a0d729d..abda6161d3 100644 --- a/async-query/build.gradle +++ b/async-query/build.gradle @@ -16,6 +16,8 @@ repositories { dependencies { + implementation "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + api project(':core') api project(':async-query-core') implementation project(':protocol') @@ -97,6 +99,7 @@ jacocoTestCoverageVerification { // ignore because XContext IOException 'org.opensearch.sql.spark.execution.statestore.StateStore', 'org.opensearch.sql.spark.rest.*', + 'org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser', 'org.opensearch.sql.spark.transport.model.*' ] limit { diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java new file mode 100644 index 0000000000..c7a66fc6be --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -0,0 +1,197 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.sql.spark.scheduler.job.OpenSearchRefreshIndexJob; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +/** Scheduler class for managing asynchronous query jobs. */ +public class OpenSearchAsyncQueryScheduler { + public static final String SCHEDULER_INDEX_NAME = ".async-query-scheduler"; + public static final String SCHEDULER_PLUGIN_JOB_TYPE = "async-query-scheduler"; + private static final String SCHEDULER_INDEX_MAPPING_FILE_NAME = + "async-query-scheduler-index-mapping.yml"; + private static final String SCHEDULER_INDEX_SETTINGS_FILE_NAME = + "async-query-scheduler-index-settings.yml"; + private static final Logger LOG = LogManager.getLogger(); + + private Client client; + private ClusterService clusterService; + + /** Loads job resources, setting up required services and job runner instance. */ + public void loadJobResource(Client client, ClusterService clusterService, ThreadPool threadPool) { + this.client = client; + this.clusterService = clusterService; + OpenSearchRefreshIndexJob openSearchRefreshIndexJob = + OpenSearchRefreshIndexJob.getJobRunnerInstance(); + openSearchRefreshIndexJob.setClusterService(clusterService); + openSearchRefreshIndexJob.setThreadPool(threadPool); + openSearchRefreshIndexJob.setClient(client); + } + + /** Schedules a new job by indexing it into the job index. */ + public void scheduleJob(OpenSearchRefreshIndexJobRequest request) { + if (!this.clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) { + createAsyncQuerySchedulerIndex(); + } + IndexRequest indexRequest = new IndexRequest(SCHEDULER_INDEX_NAME); + indexRequest.id(request.getName()); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + IndexResponse indexResponse; + try { + indexRequest.source(request.toXContent(JsonXContent.contentBuilder(), EMPTY_PARAMS)); + ActionFuture indexResponseActionFuture = client.index(indexRequest); + indexResponse = indexResponseActionFuture.actionGet(); + } catch (VersionConflictEngineException exception) { + throw new IllegalArgumentException("A job already exists with name: " + request.getName()); + } catch (Throwable e) { + LOG.error("Failed to schedule job : {}", request.getName(), e); + throw new RuntimeException(e); + } + + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Job : {} successfully created", request.getName()); + } else { + throw new RuntimeException( + "Schedule job failed with result : " + indexResponse.getResult().getLowercase()); + } + } + + /** Unschedules a job by marking it as disabled and updating its last update time. */ + public void unscheduleJob(String jobId) throws IOException { + assertIndexExists(); + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(jobId) + .enabled(false) + .lastUpdateTime(Instant.now()) + .build(); + updateJob(request); + } + + /** Updates an existing job with new parameters. */ + public void updateJob(OpenSearchRefreshIndexJobRequest request) throws IOException { + assertIndexExists(); + UpdateRequest updateRequest = new UpdateRequest(SCHEDULER_INDEX_NAME, request.getName()); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + updateRequest.doc(request.toXContent(JsonXContent.contentBuilder(), EMPTY_PARAMS)); + UpdateResponse updateResponse; + try { + ActionFuture updateResponseActionFuture = client.update(updateRequest); + updateResponse = updateResponseActionFuture.actionGet(); + } catch (DocumentMissingException exception) { + throw new IllegalArgumentException("Job: " + request.getName() + " doesn't exist"); + } catch (Throwable e) { + LOG.error("Failed to update job : {}", request.getName(), e); + throw new RuntimeException(e); + } + + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED) + || updateResponse.getResult().equals(DocWriteResponse.Result.NOOP)) { + LOG.debug("Job : {} successfully updated", request.getName()); + } else { + throw new RuntimeException( + "Update job failed with result : " + updateResponse.getResult().getLowercase()); + } + } + + /** Removes a job by deleting its document from the index. */ + public void removeJob(String jobId) { + assertIndexExists(); + DeleteRequest deleteRequest = new DeleteRequest(SCHEDULER_INDEX_NAME, jobId); + deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + ActionFuture deleteResponseActionFuture = client.delete(deleteRequest); + DeleteResponse deleteResponse = deleteResponseActionFuture.actionGet(); + + if (deleteResponse.getResult().equals(DocWriteResponse.Result.DELETED)) { + LOG.debug("Job : {} successfully deleted", jobId); + } else if (deleteResponse.getResult().equals(DocWriteResponse.Result.NOT_FOUND)) { + throw new IllegalArgumentException("Job : " + jobId + " doesn't exist"); + } else { + throw new RuntimeException( + "Remove job failed with result : " + deleteResponse.getResult().getLowercase()); + } + } + + /** Creates the async query scheduler index with specified mappings and settings. */ + @VisibleForTesting + void createAsyncQuerySchedulerIndex() { + try { + InputStream mappingFileStream = + OpenSearchAsyncQueryScheduler.class + .getClassLoader() + .getResourceAsStream(SCHEDULER_INDEX_MAPPING_FILE_NAME); + InputStream settingsFileStream = + OpenSearchAsyncQueryScheduler.class + .getClassLoader() + .getResourceAsStream(SCHEDULER_INDEX_SETTINGS_FILE_NAME); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(SCHEDULER_INDEX_NAME); + createIndexRequest.mapping( + IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML); + createIndexRequest.settings( + IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); + ActionFuture createIndexResponseActionFuture = + client.admin().indices().create(createIndexRequest); + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + + if (createIndexResponse.isAcknowledged()) { + LOG.debug("Index: {} creation Acknowledged", SCHEDULER_INDEX_NAME); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + LOG.error("Error creating index: {}", SCHEDULER_INDEX_NAME, e); + throw new RuntimeException( + "Internal server error while creating " + + SCHEDULER_INDEX_NAME + + " index: " + + e.getMessage(), + e); + } + } + + private void assertIndexExists() { + if (!this.clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) { + throw new IllegalStateException("Job index does not exist."); + } + } + + /** Returns the job runner instance for the scheduler. */ + public static ScheduledJobRunner getJobRunner() { + return OpenSearchRefreshIndexJob.getJobRunnerInstance(); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java new file mode 100644 index 0000000000..0422e7c015 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler; + +import java.io.IOException; +import java.time.Instant; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; + +public class OpenSearchRefreshIndexJobRequestParser { + + private static Instant parseInstantValue(XContentParser parser) throws IOException { + if (XContentParser.Token.VALUE_NULL.equals(parser.currentToken())) { + return null; + } + if (parser.currentToken().isValue()) { + return Instant.ofEpochMilli(parser.longValue()); + } + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + return null; + } + + public static ScheduledJobParser getJobParser() { + return (parser, id, jobDocVersion) -> { + OpenSearchRefreshIndexJobRequest.OpenSearchRefreshIndexJobRequestBuilder builder = + OpenSearchRefreshIndexJobRequest.builder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + while (!parser.nextToken().equals(XContentParser.Token.END_OBJECT)) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case OpenSearchRefreshIndexJobRequest.JOB_NAME_FIELD: + builder.jobName(parser.text()); + break; + case OpenSearchRefreshIndexJobRequest.JOB_TYPE_FIELD: + builder.jobType(parser.text()); + break; + case OpenSearchRefreshIndexJobRequest.ENABLED_FIELD: + builder.enabled(parser.booleanValue()); + break; + case OpenSearchRefreshIndexJobRequest.ENABLED_TIME_FIELD: + builder.enabledTime(parseInstantValue(parser)); + break; + case OpenSearchRefreshIndexJobRequest.LAST_UPDATE_TIME_FIELD: + builder.lastUpdateTime(parseInstantValue(parser)); + break; + case OpenSearchRefreshIndexJobRequest.SCHEDULE_FIELD: + builder.schedule(ScheduleParser.parse(parser)); + break; + case OpenSearchRefreshIndexJobRequest.LOCK_DURATION_SECONDS: + builder.lockDurationSeconds(parser.longValue()); + break; + case OpenSearchRefreshIndexJobRequest.JITTER: + builder.jitter(parser.doubleValue()); + break; + default: + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + } + } + return builder.build(); + }; + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java new file mode 100644 index 0000000000..e465a8790f --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.job; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.plugins.Plugin; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +/** + * The job runner class for scheduling refresh index query. + * + *

The job runner should be a singleton class if it uses OpenSearch client or other objects + * passed from OpenSearch. Because when registering the job runner to JobScheduler plugin, + * OpenSearch has not invoked plugins' createComponents() method. That is saying the plugin is not + * completely initialized, and the OpenSearch {@link org.opensearch.client.Client}, {@link + * ClusterService} and other objects are not available to plugin and this job runner. + * + *

So we have to move this job runner initialization to {@link Plugin} createComponents() method, + * and using singleton job runner to ensure we register a usable job runner instance to JobScheduler + * plugin. + */ +public class OpenSearchRefreshIndexJob implements ScheduledJobRunner { + + private static final Logger log = LogManager.getLogger(OpenSearchRefreshIndexJob.class); + + public static OpenSearchRefreshIndexJob INSTANCE = new OpenSearchRefreshIndexJob(); + + public static OpenSearchRefreshIndexJob getJobRunnerInstance() { + return INSTANCE; + } + + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + + private OpenSearchRefreshIndexJob() { + // Singleton class, use getJobRunnerInstance method instead of constructor + } + + public void setClusterService(ClusterService clusterService) { + this.clusterService = clusterService; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + public void setClient(Client client) { + this.client = client; + } + + @Override + public void runJob(ScheduledJobParameter jobParameter, JobExecutionContext context) { + if (!(jobParameter instanceof OpenSearchRefreshIndexJobRequest)) { + throw new IllegalStateException( + "Job parameter is not instance of OpenSearchRefreshIndexJobRequest, type: " + + jobParameter.getClass().getCanonicalName()); + } + + if (this.clusterService == null) { + throw new IllegalStateException("ClusterService is not initialized."); + } + + if (this.threadPool == null) { + throw new IllegalStateException("ThreadPool is not initialized."); + } + + if (this.client == null) { + throw new IllegalStateException("Client is not initialized."); + } + + Runnable runnable = + () -> { + doRefresh(jobParameter.getName()); + }; + threadPool.generic().submit(runnable); + } + + void doRefresh(String refreshIndex) { + // TODO: add logic to refresh index + log.info("Scheduled refresh index job on : " + refreshIndex); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java new file mode 100644 index 0000000000..7eaa4e2d29 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import java.io.IOException; +import java.time.Instant; +import lombok.Builder; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.Schedule; + +/** Represents a job request to refresh index. */ +@Builder +public class OpenSearchRefreshIndexJobRequest implements ScheduledJobParameter { + // Constant fields for JSON serialization + public static final String JOB_NAME_FIELD = "jobName"; + public static final String JOB_TYPE_FIELD = "jobType"; + public static final String LAST_UPDATE_TIME_FIELD = "lastUpdateTime"; + public static final String LAST_UPDATE_TIME_FIELD_READABLE = "last_update_time_field"; + public static final String SCHEDULE_FIELD = "schedule"; + public static final String ENABLED_TIME_FIELD = "enabledTime"; + public static final String ENABLED_TIME_FIELD_READABLE = "enabled_time_field"; + public static final String LOCK_DURATION_SECONDS = "lockDurationSeconds"; + public static final String JITTER = "jitter"; + public static final String ENABLED_FIELD = "enabled"; + + // name is doc id + private final String jobName; + private final String jobType; + private final Schedule schedule; + private final boolean enabled; + private final Instant lastUpdateTime; + private final Instant enabledTime; + private final Long lockDurationSeconds; + private final Double jitter; + + @Override + public String getName() { + return jobName; + } + + public String getJobType() { + return jobType; + } + + @Override + public Schedule getSchedule() { + return schedule; + } + + @Override + public boolean isEnabled() { + return enabled; + } + + @Override + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + @Override + public Instant getEnabledTime() { + return enabledTime; + } + + @Override + public Long getLockDurationSeconds() { + return lockDurationSeconds; + } + + @Override + public Double getJitter() { + return jitter; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) + throws IOException { + builder.startObject(); + builder.field(JOB_NAME_FIELD, getName()).field(ENABLED_FIELD, isEnabled()); + if (getSchedule() != null) { + builder.field(SCHEDULE_FIELD, getSchedule()); + } + if (getJobType() != null) { + builder.field(JOB_TYPE_FIELD, getJobType()); + } + if (getEnabledTime() != null) { + builder.timeField( + ENABLED_TIME_FIELD, ENABLED_TIME_FIELD_READABLE, getEnabledTime().toEpochMilli()); + } + builder.timeField( + LAST_UPDATE_TIME_FIELD, + LAST_UPDATE_TIME_FIELD_READABLE, + getLastUpdateTime().toEpochMilli()); + if (this.lockDurationSeconds != null) { + builder.field(LOCK_DURATION_SECONDS, this.lockDurationSeconds); + } + if (this.jitter != null) { + builder.field(JITTER, this.jitter); + } + builder.endObject(); + return builder; + } +} diff --git a/async-query/src/main/resources/async-query-scheduler-index-mapping.yml b/async-query/src/main/resources/async-query-scheduler-index-mapping.yml new file mode 100644 index 0000000000..36bd1b873e --- /dev/null +++ b/async-query/src/main/resources/async-query-scheduler-index-mapping.yml @@ -0,0 +1,41 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .async-query-scheduler index +# Also "dynamic" is set to "false" so that other fields cannot be added. +dynamic: false +properties: + name: + type: keyword + jobType: + type: keyword + lastUpdateTime: + type: date + format: epoch_millis + enabledTime: + type: date + format: epoch_millis + schedule: + properties: + initialDelay: + type: long + interval: + properties: + start_time: + type: date + format: "strict_date_time||epoch_millis" + period: + type: integer + unit: + type: keyword + enabled: + type: boolean + lockDurationSeconds: + type: long + null_value: -1 + jitter: + type: double + null_value: 0.0 \ No newline at end of file diff --git a/async-query/src/main/resources/async-query-scheduler-index-settings.yml b/async-query/src/main/resources/async-query-scheduler-index-settings.yml new file mode 100644 index 0000000000..386f1f4f34 --- /dev/null +++ b/async-query/src/main/resources/async-query-scheduler-index-settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .async-query-scheduler index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" \ No newline at end of file diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java new file mode 100644 index 0000000000..de86f111f3 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java @@ -0,0 +1,434 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME; + +import java.io.IOException; +import java.time.Instant; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +public class OpenSearchAsyncQuerySchedulerTest { + + private static final String TEST_SCHEDULER_INDEX_NAME = "testQS"; + + private static final String TEST_JOB_ID = "testJob"; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ThreadPool threadPool; + + @Mock private ActionFuture indexResponseActionFuture; + + @Mock private ActionFuture updateResponseActionFuture; + + @Mock private ActionFuture deleteResponseActionFuture; + + @Mock private ActionFuture createIndexResponseActionFuture; + + @Mock private IndexResponse indexResponse; + + @Mock private UpdateResponse updateResponse; + + private OpenSearchAsyncQueryScheduler scheduler; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + scheduler = new OpenSearchAsyncQueryScheduler(); + scheduler.loadJobResource(client, clusterService, threadPool); + } + + @Test + public void testScheduleJob() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) + .thenReturn(Boolean.FALSE); + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, TEST_SCHEDULER_INDEX_NAME)); + when(client.index(any(IndexRequest.class))).thenReturn(indexResponseActionFuture); + + // Test the if case + when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + scheduler.scheduleJob(request); + + // Verify index created + verify(client.admin().indices(), times(1)).create(ArgumentMatchers.any()); + + // Verify doc indexed + ArgumentCaptor captor = ArgumentCaptor.forClass(IndexRequest.class); + verify(client, times(1)).index(captor.capture()); + IndexRequest capturedRequest = captor.getValue(); + assertEquals(request.getName(), capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testScheduleJobWithExistingJob() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) + .thenReturn(Boolean.TRUE); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(client.index(any(IndexRequest.class))).thenThrow(VersionConflictEngineException.class); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.scheduleJob(request); + }); + + verify(client, times(1)).index(ArgumentCaptor.forClass(IndexRequest.class).capture()); + assertEquals("A job already exists with name: testJob", exception.getMessage()); + } + + @Test + public void testScheduleJobWithExceptions() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) + .thenReturn(Boolean.FALSE); + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, TEST_SCHEDULER_INDEX_NAME)); + when(client.index(any(IndexRequest.class))).thenThrow(new RuntimeException("Test exception")); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + assertThrows(RuntimeException.class, () -> scheduler.scheduleJob(request)); + + when(client.index(any(IndexRequest.class))).thenReturn(indexResponseActionFuture); + when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + RuntimeException exception = + assertThrows(RuntimeException.class, () -> scheduler.scheduleJob(request)); + assertEquals("Schedule job failed with result : not_found", exception.getMessage()); + } + + @Test + public void testUnscheduleJob() throws IOException { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(updateResponseActionFuture.actionGet()).thenReturn(updateResponse); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + + when(client.update(any(UpdateRequest.class))).thenReturn(updateResponseActionFuture); + + scheduler.unscheduleJob(TEST_JOB_ID); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(client).update(captor.capture()); + + UpdateRequest capturedRequest = captor.getValue(); + assertEquals(TEST_JOB_ID, capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + + // Reset the captor for the next verification + captor = ArgumentCaptor.forClass(UpdateRequest.class); + + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + scheduler.unscheduleJob(TEST_JOB_ID); + + verify(client, times(2)).update(captor.capture()); + capturedRequest = captor.getValue(); + assertEquals(TEST_JOB_ID, capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testUnscheduleJobWithIndexNotFound() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + assertThrows(IllegalStateException.class, () -> scheduler.unscheduleJob(TEST_JOB_ID)); + } + + @Test + public void testUpdateJob() throws IOException { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(updateResponseActionFuture.actionGet()).thenReturn(updateResponse); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + + when(client.update(any(UpdateRequest.class))).thenReturn(updateResponseActionFuture); + + scheduler.updateJob(request); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(client).update(captor.capture()); + + UpdateRequest capturedRequest = captor.getValue(); + assertEquals(request.getName(), capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testUpdateJobWithIndexNotFound() { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + assertThrows(IllegalStateException.class, () -> scheduler.updateJob(request)); + } + + @Test + public void testUpdateJobWithExceptions() { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + when(client.update(any(UpdateRequest.class))) + .thenThrow(new DocumentMissingException(null, null)); + + IllegalArgumentException exception1 = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.updateJob(request); + }); + + assertEquals("Job: testJob doesn't exist", exception1.getMessage()); + + when(client.update(any(UpdateRequest.class))).thenThrow(new RuntimeException("Test exception")); + + RuntimeException exception2 = + assertThrows( + RuntimeException.class, + () -> { + scheduler.updateJob(request); + }); + + assertEquals("java.lang.RuntimeException: Test exception", exception2.getMessage()); + + when(client.update(any(UpdateRequest.class))).thenReturn(updateResponseActionFuture); + when(updateResponseActionFuture.actionGet()).thenReturn(updateResponse); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + RuntimeException exception = + assertThrows(RuntimeException.class, () -> scheduler.updateJob(request)); + assertEquals("Update job failed with result : not_found", exception.getMessage()); + } + + @Test + public void testRemoveJob() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.DELETED); + + when(client.delete(any(DeleteRequest.class))).thenReturn(deleteResponseActionFuture); + + scheduler.removeJob(TEST_JOB_ID); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteRequest.class); + verify(client).delete(captor.capture()); + + DeleteRequest capturedRequest = captor.getValue(); + assertEquals(TEST_JOB_ID, capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testRemoveJobWithIndexNotFound() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + assertThrows(IllegalStateException.class, () -> scheduler.removeJob(TEST_JOB_ID)); + } + + @Test + public void testCreateAsyncQuerySchedulerIndex() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + CreateIndexResponse createIndexResponse = mock(CreateIndexResponse.class); + when(createIndexResponseActionFuture.actionGet()).thenReturn(createIndexResponse); + when(createIndexResponse.isAcknowledged()).thenReturn(true); + + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + + scheduler.createAsyncQuerySchedulerIndex(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateIndexRequest.class); + verify(client.admin().indices()).create(captor.capture()); + + CreateIndexRequest capturedRequest = captor.getValue(); + assertEquals(SCHEDULER_INDEX_NAME, capturedRequest.index()); + } + + @Test + public void testCreateAsyncQuerySchedulerIndexFailure() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenThrow(new RuntimeException("Error creating index")); + + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> { + scheduler.createAsyncQuerySchedulerIndex(); + }); + + assertEquals( + "Internal server error while creating .async-query-scheduler index: Error creating index", + exception.getMessage()); + + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(false, false, SCHEDULER_INDEX_NAME)); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + RuntimeException runtimeException = + Assertions.assertThrows(RuntimeException.class, () -> scheduler.scheduleJob(request)); + Assertions.assertEquals( + "Internal server error while creating .async-query-scheduler index: Index creation is not" + + " acknowledged.", + runtimeException.getMessage()); + } + + @Test + public void testUpdateJobNotFound() { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(client.update(any(UpdateRequest.class))) + .thenThrow(new DocumentMissingException(null, null)); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.updateJob(request); + }); + + assertEquals("Job: testJob doesn't exist", exception.getMessage()); + } + + @Test + public void testRemoveJobNotFound() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + when(client.delete(any(DeleteRequest.class))).thenReturn(deleteResponseActionFuture); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.removeJob(TEST_JOB_ID); + }); + + assertEquals("Job : testJob doesn't exist", exception.getMessage()); + } + + @Test + public void testRemoveJobWithExceptions() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(client.delete(any(DeleteRequest.class))).thenThrow(new RuntimeException("Test exception")); + + assertThrows(RuntimeException.class, () -> scheduler.removeJob(TEST_JOB_ID)); + + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(client.delete(any(DeleteRequest.class))).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + + RuntimeException runtimeException = + Assertions.assertThrows(RuntimeException.class, () -> scheduler.removeJob(TEST_JOB_ID)); + Assertions.assertEquals("Remove job failed with result : noop", runtimeException.getMessage()); + } + + @Test + public void testGetJobRunner() { + ScheduledJobRunner jobRunner = OpenSearchAsyncQueryScheduler.getJobRunner(); + assertNotNull(jobRunner); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java new file mode 100644 index 0000000000..cbf137997e --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.job; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +public class OpenSearchRefreshIndexJobTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ThreadPool threadPool; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock private JobExecutionContext context; + + private OpenSearchRefreshIndexJob jobRunner; + + private OpenSearchRefreshIndexJob spyJobRunner; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + jobRunner = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + jobRunner.setClient(null); + jobRunner.setClusterService(null); + jobRunner.setThreadPool(null); + } + + @Test + public void testRunJobWithCorrectParameter() { + spyJobRunner = spy(jobRunner); + spyJobRunner.setClusterService(clusterService); + spyJobRunner.setThreadPool(threadPool); + spyJobRunner.setClient(client); + + OpenSearchRefreshIndexJobRequest jobParameter = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .lastUpdateTime(Instant.now()) + .lockDurationSeconds(10L) + .build(); + + spyJobRunner.runJob(jobParameter, context); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(threadPool.generic()).submit(captor.capture()); + + Runnable runnable = captor.getValue(); + runnable.run(); + + verify(spyJobRunner).doRefresh(eq(jobParameter.getName())); + } + + @Test + public void testRunJobWithIncorrectParameter() { + jobRunner = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + jobRunner.setClusterService(clusterService); + jobRunner.setThreadPool(threadPool); + jobRunner.setClient(client); + + ScheduledJobParameter wrongParameter = mock(ScheduledJobParameter.class); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(wrongParameter, context), + "Expected IllegalStateException but no exception was thrown"); + + assertEquals( + "Job parameter is not instance of OpenSearchRefreshIndexJobRequest, type: " + + wrongParameter.getClass().getCanonicalName(), + exception.getMessage()); + } + + @Test + public void testRunJobWithUninitializedServices() { + OpenSearchRefreshIndexJobRequest jobParameter = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .lastUpdateTime(Instant.now()) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("ClusterService is not initialized.", exception.getMessage()); + + jobRunner.setClusterService(clusterService); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("ThreadPool is not initialized.", exception.getMessage()); + + jobRunner.setThreadPool(threadPool); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("Client is not initialized.", exception.getMessage()); + } + + @Test + public void testGetJobRunnerInstanceMultipleCalls() { + OpenSearchRefreshIndexJob instance1 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + OpenSearchRefreshIndexJob instance2 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + OpenSearchRefreshIndexJob instance3 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + + assertSame(instance1, instance2); + assertSame(instance2, instance3); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java new file mode 100644 index 0000000000..108f1acfd5 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; + +public class OpenSearchRefreshIndexJobRequestTest { + + @Test + public void testBuilderAndGetterMethods() { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + OpenSearchRefreshIndexJobRequest jobRequest = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .jobType("testType") + .schedule(schedule) + .enabled(true) + .lastUpdateTime(now) + .enabledTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + assertEquals("testJob", jobRequest.getName()); + assertEquals("testType", jobRequest.getJobType()); + assertEquals(schedule, jobRequest.getSchedule()); + assertTrue(jobRequest.isEnabled()); + assertEquals(now, jobRequest.getLastUpdateTime()); + assertEquals(now, jobRequest.getEnabledTime()); + assertEquals(60L, jobRequest.getLockDurationSeconds()); + assertEquals(0.1, jobRequest.getJitter()); + } + + @Test + public void testToXContent() throws IOException { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + OpenSearchRefreshIndexJobRequest jobRequest = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .jobType("testType") + .schedule(schedule) + .enabled(true) + .lastUpdateTime(now) + .enabledTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + XContentBuilder builder = XContentFactory.jsonBuilder().prettyPrint(); + jobRequest.toXContent(builder, EMPTY_PARAMS); + String jsonString = builder.toString(); + + assertTrue(jsonString.contains("\"jobName\" : \"testJob\"")); + assertTrue(jsonString.contains("\"jobType\" : \"testType\"")); + assertTrue(jsonString.contains("\"start_time\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"period\" : 1")); + assertTrue(jsonString.contains("\"unit\" : \"Minutes\"")); + assertTrue(jsonString.contains("\"enabled\" : true")); + assertTrue(jsonString.contains("\"lastUpdateTime\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"enabledTime\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"lockDurationSeconds\" : 60")); + assertTrue(jsonString.contains("\"jitter\" : 0.1")); + } +} diff --git a/build.gradle b/build.gradle index b3e09d7b50..702d6f478a 100644 --- a/build.gradle +++ b/build.gradle @@ -50,6 +50,7 @@ buildscript { return "https://github.com/prometheus/prometheus/releases/download/v${prometheus_binary_version}/prometheus-${prometheus_binary_version}."+ getOSFamilyType() + "-" + getArchType() + ".tar.gz" } aws_java_sdk_version = "1.12.651" + guava_version = "32.1.3-jre" } repositories { @@ -192,7 +193,7 @@ configurations.all { exclude group: "commons-logging", module: "commons-logging" // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' - resolutionStrategy.force 'com.google.guava:guava:32.0.1-jre' + resolutionStrategy.force "com.google.guava:guava:${guava_version}" } // updateVersion: Task to auto increment to the next development iteration diff --git a/common/build.gradle b/common/build.gradle index b4ee98a5b7..15c48dd6b3 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -34,7 +34,7 @@ repositories { dependencies { api "org.antlr:antlr4-runtime:4.7.1" - api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + api group: 'com.google.guava', name: 'guava', version: "${guava_version}" api group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' @@ -46,7 +46,7 @@ dependencies { testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.assertj', name: 'assertj-core', version: '3.9.1' - testImplementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + testImplementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' testImplementation('org.junit.jupiter:junit-jupiter:5.9.3') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' diff --git a/core/build.gradle b/core/build.gradle index 655e7d92c2..f36777030c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -46,7 +46,7 @@ pitest { } dependencies { - api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + api group: 'com.google.guava', name: 'guava', version: "${guava_version}" api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' api group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' diff --git a/doctest/build.gradle b/doctest/build.gradle index ec5a26b52b..a125a4f336 100644 --- a/doctest/build.gradle +++ b/doctest/build.gradle @@ -5,6 +5,8 @@ import org.opensearch.gradle.testclusters.RunTask +import java.util.concurrent.Callable + plugins { id 'base' id 'com.wiredforcode.spawn' @@ -109,6 +111,10 @@ if (version_tokens.length > 1) { String mlCommonsRemoteFile = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + opensearch_no_snapshot + '/latest/linux/x64/tar/builds/opensearch/plugins/opensearch-ml-' + opensearch_build + '.zip' String mlCommonsPlugin = 'opensearch-ml' +String bwcOpenSearchJSDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + opensearch_no_snapshot + '/latest/linux/x64/tar/builds/' + + 'opensearch/plugins/opensearch-job-scheduler-' + opensearch_build + '.zip' +String jsPlugin = 'opensearch-job-scheduler' + testClusters { docTestCluster { // Disable loading of `ML-commons` plugin, because it might be unavailable (not released yet). @@ -133,6 +139,7 @@ testClusters { } })) */ + plugin(getJobSchedulerPlugin(jsPlugin, bwcOpenSearchJSDownload)) plugin ':opensearch-sql-plugin' testDistribution = 'archive' } @@ -159,3 +166,49 @@ spotless { googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format') } } + +def getJobSchedulerPlugin(String jsPlugin, String bwcOpenSearchJSDownload) { + return provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + // Use absolute paths + String basePath = new File('.').getCanonicalPath() + File dir = new File(basePath + File.separator + 'doctest' + File.separator + jsPlugin) + + // Log the directory path for debugging + println("Creating directory: " + dir.getAbsolutePath()) + + // Create directory if it doesn't exist + if (!dir.exists()) { + if (!dir.mkdirs()) { + throw new IOException("Failed to create directory: " + dir.getAbsolutePath()) + } + } + + // Define the file path + File f = new File(dir, jsPlugin + '-' + opensearch_build + '.zip') + + // Download file if it doesn't exist + if (!f.exists()) { + println("Downloading file from: " + bwcOpenSearchJSDownload) + println("Saving to file: " + f.getAbsolutePath()) + + new URL(bwcOpenSearchJSDownload).withInputStream { ins -> + f.withOutputStream { it << ins } + } + } + + // Check if the file was created successfully + if (!f.exists()) { + throw new FileNotFoundException("File was not created: " + f.getAbsolutePath()) + } + + return fileTree(f.getParent()).matching { include f.getName() }.singleFile + } + } + } + }) +} diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 93153cf737..1acacdb4a5 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -80,7 +80,6 @@ ext { var projectAbsPath = projectDir.getAbsolutePath() File downloadedSecurityPlugin = Paths.get(projectAbsPath, 'bin', 'opensearch-security-snapshot.zip').toFile() - configureSecurityPlugin = { OpenSearchCluster cluster -> cluster.getNodes().forEach { node -> @@ -138,6 +137,10 @@ ext { cluster.plugin provider((Callable) (() -> (RegularFile) (() -> downloadedSecurityPlugin))) } + + bwcOpenSearchJSDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + baseVersion + '/latest/linux/x64/tar/builds/' + + 'opensearch/plugins/opensearch-job-scheduler-' + bwcVersion + '.zip' + bwcJobSchedulerPath = bwcFilePath + "job-scheduler/" } tasks.withType(licenseHeaders.class) { @@ -153,7 +156,6 @@ configurations.all { resolutionStrategy.force "commons-logging:commons-logging:1.2" // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' - resolutionStrategy.force 'com.google.guava:guava:32.0.1-jre' resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${versions.jackson}" @@ -166,6 +168,7 @@ configurations.all { resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" resolutionStrategy.force "com.amazonaws:aws-java-sdk-core:${aws_java_sdk_version}" + resolutionStrategy.force "com.google.guava:guava:${guava_version}" } configurations { @@ -191,6 +194,7 @@ dependencies { testCompileOnly 'org.apiguardian:apiguardian-api:1.1.2' // Needed for BWC tests + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${bwcVersion}-SNAPSHOT" } @@ -219,22 +223,42 @@ testClusters.all { } } +def getJobSchedulerPlugin() { + provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.matching { + include '**/opensearch-job-scheduler*' + }.singleFile + } + } + } + }) +} + testClusters { integTest { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" setting "plugins.query.datasources.encryption.masterkey", "1234567812345678" } remoteCluster { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" } integTestWithSecurity { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" } remoteIntegTestWithSecurity { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" } } @@ -502,6 +526,24 @@ task comparisonTest(type: RestIntegTestTask) { testDistribution = "ARCHIVE" versions = [baseVersion, opensearch_version] numberOfNodes = 3 + plugin(provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + if (new File("$project.rootDir/$bwcFilePath/job-scheduler/$bwcVersion").exists()) { + project.delete(files("$project.rootDir/$bwcFilePath/job-scheduler/$bwcVersion")) + } + project.mkdir bwcJobSchedulerPath + bwcVersion + ant.get(src: bwcOpenSearchJSDownload, + dest: bwcJobSchedulerPath + bwcVersion, + httpusecaches: false) + return fileTree(bwcJobSchedulerPath + bwcVersion).getSingleFile() + } + } + } + })) plugin(provider(new Callable(){ @Override RegularFile call() throws Exception { @@ -522,17 +564,18 @@ task comparisonTest(type: RestIntegTestTask) { } List> plugins = [ - provider(new Callable() { - @Override - RegularFile call() throws Exception { - return new RegularFile() { - @Override - File getAsFile() { - return fileTree(bwcFilePath + project.version).getSingleFile() + getJobSchedulerPlugin(), + provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return fileTree(bwcFilePath + project.version).getSingleFile() + } } } - } - }) + }) ] // Creates 2 test clusters with 3 nodes of the old version. diff --git a/legacy/build.gradle b/legacy/build.gradle index 0467db183d..e3ddf27066 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -107,7 +107,7 @@ dependencies { because 'https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379' } } - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" implementation group: 'org.json', name: 'json', version:'20231013' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' diff --git a/plugin/build.gradle b/plugin/build.gradle index 710d81ed0a..7ebd0ad2d9 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -48,6 +48,7 @@ opensearchplugin { name 'opensearch-sql' description 'OpenSearch SQL' classname 'org.opensearch.sql.plugin.SQLPlugin' + extendedPlugins = ['opensearch-job-scheduler'] licenseFile rootProject.file("LICENSE.txt") noticeFile rootProject.file("NOTICE") } @@ -98,7 +99,8 @@ configurations.all { resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' - resolutionStrategy.force 'com.google.guava:guava:32.0.1-jre' + resolutionStrategy.force "com.google.guava:guava:${guava_version}" + resolutionStrategy.force 'com.google.guava:failureaccess:1.0.2' resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" @@ -139,6 +141,10 @@ spotless { } dependencies { + compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + compileOnly "com.google.guava:guava:${guava_version}" + compileOnly 'com.google.guava:failureaccess:1.0.2' + api "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" @@ -204,11 +210,10 @@ dependencyLicenses.enabled = false // enable testingConventions check will cause errors like: "Classes ending with [Tests] must subclass [LuceneTestCase]" testingConventions.enabled = false -// TODO: need to verify the thirdPartyAudi +// TODO: need to verify the thirdPartyAudit // currently it complains missing classes like ibatis, mysql etc, should not be a problem thirdPartyAudit.enabled = false - apply plugin: 'com.netflix.nebula.ospackage' validateNebulaPom.enabled = false diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index b86ab9218a..a1b1e32955 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -42,6 +42,9 @@ import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.indices.SystemIndexDescriptor; +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.ScriptPlugin; @@ -91,6 +94,9 @@ import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; +import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser; +import org.opensearch.sql.spark.scheduler.job.OpenSearchRefreshIndexJob; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; @@ -105,7 +111,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin, SystemIndexPlugin { +public class SQLPlugin extends Plugin + implements ActionPlugin, ScriptPlugin, SystemIndexPlugin, JobSchedulerExtension { private static final Logger LOGGER = LogManager.getLogger(SQLPlugin.class); @@ -116,6 +123,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin, Sys private NodeClient client; private DataSourceServiceImpl dataSourceService; + private OpenSearchAsyncQueryScheduler asyncQueryScheduler; private Injector injector; public String name() { @@ -208,6 +216,8 @@ public Collection createComponents( this.client = (NodeClient) client; this.dataSourceService = createDataSourceService(); dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); + this.asyncQueryScheduler = new OpenSearchAsyncQueryScheduler(); + this.asyncQueryScheduler.loadJobResource(client, clusterService, threadPool); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); LocalClusterState.state().setClient(client); @@ -243,6 +253,26 @@ public Collection createComponents( pluginSettings); } + @Override + public String getJobType() { + return OpenSearchAsyncQueryScheduler.SCHEDULER_PLUGIN_JOB_TYPE; + } + + @Override + public String getJobIndex() { + return OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME; + } + + @Override + public ScheduledJobRunner getJobRunner() { + return OpenSearchRefreshIndexJob.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + return OpenSearchRefreshIndexJobRequestParser.getJobParser(); + } + @Override public List> getExecutorBuilders(Settings settings) { return singletonList( diff --git a/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension b/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension new file mode 100644 index 0000000000..5337857c15 --- /dev/null +++ b/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension @@ -0,0 +1,6 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# + +org.opensearch.sql.plugin.SQLPlugin \ No newline at end of file diff --git a/ppl/build.gradle b/ppl/build.gradle index d58882d5e8..2a3d6bdbf9 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -48,7 +48,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' implementation "org.antlr:antlr4-runtime:4.7.1" - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" api group: 'org.json', name: 'json', version: '20231013' implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api project(':common') diff --git a/protocol/build.gradle b/protocol/build.gradle index 5bbff68e51..b5d7929041 100644 --- a/protocol/build.gradle +++ b/protocol/build.gradle @@ -30,7 +30,7 @@ plugins { } dependencies { - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" diff --git a/sql/build.gradle b/sql/build.gradle index 81872e6035..10bb4b24bb 100644 --- a/sql/build.gradle +++ b/sql/build.gradle @@ -46,7 +46,7 @@ dependencies { antlr "org.antlr:antlr4:4.7.1" implementation "org.antlr:antlr4-runtime:4.7.1" - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" implementation group: 'org.json', name: 'json', version:'20231013' implementation project(':common') implementation project(':core') From 82ef68e2b25c7c10740e74968bbe960c000c1cee Mon Sep 17 00:00:00 2001 From: panguixin Date: Thu, 1 Aug 2024 23:10:13 +0800 Subject: [PATCH 07/12] Support common format geo point (#2801) --------- Signed-off-by: panguixin --- .../sql/legacy/SQLIntegTestCase.java | 8 +- .../org/opensearch/sql/legacy/TestUtils.java | 5 ++ .../opensearch/sql/legacy/TestsConstants.java | 1 + .../opensearch/sql/sql/GeopointFormatsIT.java | 60 +++++++++++++ integ-test/src/test/resources/geopoints.json | 12 +++ .../geopoint_index_mapping.json | 9 ++ .../data/utils/OpenSearchJsonContent.java | 50 ++++------- .../value/OpenSearchExprValueFactory.java | 59 ++++++++++-- .../data/utils/OpenSearchJsonContentTest.java | 31 +++++++ .../value/OpenSearchExprValueFactoryTest.java | 89 +++++++++++++------ 10 files changed, 256 insertions(+), 68 deletions(-) create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java create mode 100644 integ-test/src/test/resources/geopoints.json create mode 100644 integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index 63c44bf831..c6d15a305d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -20,6 +20,7 @@ import static org.opensearch.sql.legacy.TestUtils.getDogs3IndexMapping; import static org.opensearch.sql.legacy.TestUtils.getEmployeeNestedTypeIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getGameOfThronesIndexMapping; +import static org.opensearch.sql.legacy.TestUtils.getGeopointIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getJoinTypeIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getLocationIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getMappingFile; @@ -724,7 +725,12 @@ public enum Index { TestsConstants.TEST_INDEX_NESTED_WITH_NULLS, "multi_nested", getNestedTypeIndexMapping(), - "src/test/resources/nested_with_nulls.json"); + "src/test/resources/nested_with_nulls.json"), + GEOPOINTS( + TestsConstants.TEST_INDEX_GEOPOINT, + "dates", + getGeopointIndexMapping(), + "src/test/resources/geopoints.json"); private final String name; private final String type; diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java b/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java index 65cacf16d2..195dda0cbd 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java @@ -245,6 +245,11 @@ public static String getDataTypeNonnumericIndexMapping() { return getMappingFile(mappingFile); } + public static String getGeopointIndexMapping() { + String mappingFile = "geopoint_index_mapping.json"; + return getMappingFile(mappingFile); + } + public static void loadBulk(Client client, String jsonPath, String defaultIndex) throws Exception { System.out.println(String.format("Loading file %s into opensearch cluster", jsonPath)); diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java b/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java index 29bc9813fa..73838feb4f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java @@ -57,6 +57,7 @@ public class TestsConstants { public static final String TEST_INDEX_WILDCARD = TEST_INDEX + "_wildcard"; public static final String TEST_INDEX_MULTI_NESTED_TYPE = TEST_INDEX + "_multi_nested"; public static final String TEST_INDEX_NESTED_WITH_NULLS = TEST_INDEX + "_nested_with_nulls"; + public static final String TEST_INDEX_GEOPOINT = TEST_INDEX + "_geopoint"; public static final String DATASOURCES = ".ql-datasources"; public static final String DATE_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java new file mode 100644 index 0000000000..f25eeec241 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class GeopointFormatsIT extends SQLIntegTestCase { + + @Override + public void init() throws Exception { + loadIndex(Index.GEOPOINTS); + } + + @Test + public void testReadingGeopoints() throws IOException { + String query = String.format("SELECT point FROM %s LIMIT 5", Index.GEOPOINTS.getName()); + JSONObject result = executeJdbcRequest(query); + verifySchema(result, schema("point", null, "geo_point")); + verifyDataRows( + result, + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71))); + } + + private static final double TOLERANCE = 1E-5; + + public void testReadingGeoHash() throws IOException { + String query = String.format("SELECT point FROM %s WHERE _id='6'", Index.GEOPOINTS.getName()); + JSONObject result = executeJdbcRequest(query); + verifySchema(result, schema("point", null, "geo_point")); + Pair point = getGeoValue(result); + assertEquals(40.71, point.getLeft(), TOLERANCE); + assertEquals(74, point.getRight(), TOLERANCE); + } + + private Pair getGeoValue(JSONObject result) { + JSONObject geoRaw = + (JSONObject) ((JSONArray) ((JSONArray) result.get("datarows")).get(0)).get(0); + double lat = geoRaw.getDouble("lat"); + double lon = geoRaw.getDouble("lon"); + return Pair.of(lat, lon); + } +} diff --git a/integ-test/src/test/resources/geopoints.json b/integ-test/src/test/resources/geopoints.json new file mode 100644 index 0000000000..95900fe811 --- /dev/null +++ b/integ-test/src/test/resources/geopoints.json @@ -0,0 +1,12 @@ +{"index": {"_id": "1"}} +{"point": {"lat": 40.71, "lon": 74.00}} +{"index": {"_id": "2"}} +{"point": "40.71,74.00"} +{"index": {"_id": "3"}} +{"point": [74.00, 40.71]} +{"index": {"_id": "4"}} +{"point": "POINT (74.00 40.71)"} +{"index": {"_id": "5"}} +{"point": {"type": "Point", "coordinates": [74.00, 40.71]}} +{"index": {"_id": "6"}} +{"point": "txhxegj0uyp3"} diff --git a/integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json b/integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json new file mode 100644 index 0000000000..61340530d8 --- /dev/null +++ b/integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json @@ -0,0 +1,9 @@ +{ + "mappings": { + "properties": { + "point": { + "type": "geo_point" + } + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java index bdb15428e1..4446c1f979 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java @@ -7,11 +7,19 @@ import com.fasterxml.jackson.databind.JsonNode; import com.google.common.collect.Iterators; +import java.io.IOException; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.geo.GeoPoint; +import org.opensearch.common.geo.GeoUtils; +import org.opensearch.common.xcontent.json.JsonXContentParser; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; /** The Implementation of Content to represent {@link JsonNode}. */ @RequiredArgsConstructor @@ -122,25 +130,17 @@ public Object objectValue() { @Override public Pair geoValue() { final JsonNode value = value(); - if (value.has("lat") && value.has("lon")) { - Double lat = 0d; - Double lon = 0d; - try { - lat = extractDoubleValue(value.get("lat")); - } catch (Exception exception) { - throw new IllegalStateException( - "latitude must be number value, but got value: " + value.get("lat")); - } - try { - lon = extractDoubleValue(value.get("lon")); - } catch (Exception exception) { - throw new IllegalStateException( - "longitude must be number value, but got value: " + value.get("lon")); - } - return Pair.of(lat, lon); - } else { - throw new IllegalStateException( - "geo point must in format of {\"lat\": number, \"lon\": number}"); + try (XContentParser parser = + new JsonXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + value.traverse())) { + parser.nextToken(); + GeoPoint point = new GeoPoint(); + GeoUtils.parseGeoPoint(parser, point, true); + return Pair.of(point.getLat(), point.getLon()); + } catch (IOException ex) { + throw new OpenSearchParseException("error parsing geo point", ex); } } @@ -148,16 +148,4 @@ public Pair geoValue() { private JsonNode value() { return value; } - - /** Get doubleValue from JsonNode if possible. */ - private Double extractDoubleValue(JsonNode node) { - if (node.isTextual()) { - return Double.valueOf(node.textValue()); - } - if (node.isNumber()) { - return node.doubleValue(); - } else { - throw new IllegalStateException("node must be a number"); - } - } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index 3cb182de5b..417aaddaee 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -39,6 +39,7 @@ import java.util.function.BiFunction; import lombok.Getter; import lombok.Setter; +import org.opensearch.OpenSearchParseException; import org.opensearch.common.time.DateFormatter; import org.opensearch.common.time.DateFormatters; import org.opensearch.common.time.FormatNames; @@ -62,7 +63,6 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchBinaryType; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchDateType; -import org.opensearch.sql.opensearch.data.type.OpenSearchGeoPointType; import org.opensearch.sql.opensearch.data.type.OpenSearchIpType; import org.opensearch.sql.opensearch.data.utils.Content; import org.opensearch.sql.opensearch.data.utils.ObjectContent; @@ -134,10 +134,6 @@ public void extendTypeMapping(Map typeMapping) { .put( OpenSearchDataType.of(OpenSearchDataType.MappingType.Ip), (c, dt) -> new OpenSearchExprIpValue(c.stringValue())) - .put( - OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint), - (c, dt) -> - new OpenSearchExprGeoPointValue(c.geoValue().getLeft(), c.geoValue().getRight())) .put( OpenSearchDataType.of(OpenSearchDataType.MappingType.Binary), (c, dt) -> new OpenSearchExprBinaryValue(c.stringValue())) @@ -193,8 +189,11 @@ private ExprValue parse( return ExprNullValue.of(); } - ExprType type = fieldType.get(); - if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) + final ExprType type = fieldType.get(); + + if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint))) { + return parseGeoPoint(content, supportArrays); + } else if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) || content.isArray()) { return parseArray(content, field, type, supportArrays); } else if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Object)) @@ -362,6 +361,49 @@ private ExprValue parseArray( return new ExprCollectionValue(result); } + /** + * Parse geo point content. + * + * @param content Content to parse. + * @param supportArrays Parsing the whole array or not + * @return Geo point value parsed from content. + */ + private ExprValue parseGeoPoint(Content content, boolean supportArrays) { + // there is only one point in doc. + if (content.isArray() == false) { + final var pair = content.geoValue(); + return new OpenSearchExprGeoPointValue(pair.getLeft(), pair.getRight()); + } + + var elements = content.array(); + var first = elements.next(); + // an array in the [longitude, latitude] format. + if (first.isNumber()) { + double lon = first.doubleValue(); + var second = elements.next(); + if (second.isNumber() == false) { + throw new OpenSearchParseException("lat must be a number, got " + second.objectValue()); + } + return new OpenSearchExprGeoPointValue(second.doubleValue(), lon); + } + + // there are multi points in doc + var pair = first.geoValue(); + var firstPoint = new OpenSearchExprGeoPointValue(pair.getLeft(), pair.getRight()); + if (supportArrays) { + List result = new ArrayList<>(); + result.add(firstPoint); + elements.forEachRemaining( + e -> { + var p = e.geoValue(); + result.add(new OpenSearchExprGeoPointValue(p.getLeft(), p.getRight())); + }); + return new ExprCollectionValue(result); + } else { + return firstPoint; + } + } + /** * Parse inner array value. Can be object type and recurse continues. * @@ -375,8 +417,7 @@ private ExprValue parseInnerArrayValue( Content content, String prefix, ExprType type, boolean supportArrays) { if (type instanceof OpenSearchIpType || type instanceof OpenSearchBinaryType - || type instanceof OpenSearchDateType - || type instanceof OpenSearchGeoPointType) { + || type instanceof OpenSearchDateType) { return parse(content, prefix, Optional.of(type), supportArrays); } else if (content.isString()) { return parse(content, prefix, Optional.of(OpenSearchDataType.of(STRING)), supportArrays); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java new file mode 100644 index 0000000000..c2cf0328bd --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.data.utils; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import org.junit.jupiter.api.Test; +import org.opensearch.OpenSearchParseException; + +public class OpenSearchJsonContentTest { + @Test + public void testGetValueWithIOException() throws IOException { + JsonNode jsonNode = mock(JsonNode.class); + JsonParser jsonParser = mock(JsonParser.class); + when(jsonNode.traverse()).thenReturn(jsonParser); + when(jsonParser.nextToken()).thenThrow(new IOException()); + OpenSearchJsonContent content = new OpenSearchJsonContent(jsonNode); + OpenSearchParseException exception = + assertThrows(OpenSearchParseException.class, content::geoValue); + assertTrue(exception.getMessage().contains("error parsing geo point")); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java index 83e26f85e4..6b4d825ab1 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java @@ -47,6 +47,8 @@ import lombok.EqualsAndHashCode; import lombok.ToString; import org.junit.jupiter.api.Test; +import org.opensearch.OpenSearchParseException; +import org.opensearch.geometry.utils.Geohash; import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprTimeValue; @@ -597,6 +599,18 @@ public void constructArrayOfGeoPoints() { .get("geoV")); } + @Test + public void constructArrayOfGeoPointsReturnsFirstIndex() { + assertEquals( + new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), + tupleValue( + "{\"geoV\":[" + + "{\"lat\":42.60355556,\"lon\":-97.25263889}," + + "{\"lat\":-33.6123556,\"lon\":66.287449}" + + "]}") + .get("geoV")); + } + @Test public void constructArrayOfIPsReturnsFirstIndex() { assertEquals( @@ -671,14 +685,50 @@ public void constructIP() { tupleValue("{\"ipV\":\"192.168.0.1\"}").get("ipV")); } + private static final double TOLERANCE = 1E-5; + @Test public void constructGeoPoint() { + final double lat = 42.60355556; + final double lon = -97.25263889; + final var expectedGeoPointValue = new OpenSearchExprGeoPointValue(lat, lon); + // An object with a latitude and longitude. assertEquals( - new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), - tupleValue("{\"geoV\":{\"lat\":42.60355556,\"lon\":-97.25263889}}").get("geoV")); + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":{\"lat\":%.8f,\"lon\":%.8f}}", lat, lon)).get("geoV")); + + // A string in the “latitude,longitude” format. assertEquals( - new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), - tupleValue("{\"geoV\":{\"lat\":\"42.60355556\",\"lon\":\"-97.25263889\"}}").get("geoV")); + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":\"%.8f,%.8f\"}", lat, lon)).get("geoV")); + + // A geohash. + var point = + (OpenSearchExprGeoPointValue.GeoPoint) + tupleValue(String.format("{\"geoV\":\"%s\"}", Geohash.stringEncode(lon, lat))) + .get("geoV") + .value(); + assertEquals(lat, point.getLat(), TOLERANCE); + assertEquals(lon, point.getLon(), TOLERANCE); + + // An array in the [longitude, latitude] format. + assertEquals( + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":[%.8f, %.8f]}", lon, lat)).get("geoV")); + + // A Well-Known Text POINT in the “POINT(longitude latitude)” format. + assertEquals( + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":\"POINT (%.8f %.8f)\"}", lon, lat)).get("geoV")); + + // GeoJSON format, where the coordinates are in the [longitude, latitude] format + assertEquals( + expectedGeoPointValue, + tupleValue( + String.format( + "{\"geoV\":{\"type\":\"Point\",\"coordinates\":[%.8f,%.8f]}}", lon, lat)) + .get("geoV")); + assertEquals( new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), constructFromObject("geoV", "42.60355556,-97.25263889")); @@ -686,38 +736,23 @@ public void constructGeoPoint() { @Test public void constructGeoPointFromUnsupportedFormatShouldThrowException() { - IllegalStateException exception = + OpenSearchParseException exception = assertThrows( - IllegalStateException.class, - () -> tupleValue("{\"geoV\":[42.60355556,-97.25263889]}").get("geoV")); - assertEquals( - "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); + OpenSearchParseException.class, + () -> tupleValue("{\"geoV\": [42.60355556, false]}").get("geoV")); + assertEquals("lat must be a number, got false", exception.getMessage()); exception = assertThrows( - IllegalStateException.class, + OpenSearchParseException.class, () -> tupleValue("{\"geoV\":{\"lon\":-97.25263889}}").get("geoV")); - assertEquals( - "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); - - exception = - assertThrows( - IllegalStateException.class, - () -> tupleValue("{\"geoV\":{\"lat\":-97.25263889}}").get("geoV")); - assertEquals( - "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); + assertEquals("field [lat] missing", exception.getMessage()); exception = assertThrows( - IllegalStateException.class, + OpenSearchParseException.class, () -> tupleValue("{\"geoV\":{\"lat\":true,\"lon\":-97.25263889}}").get("geoV")); - assertEquals("latitude must be number value, but got value: true", exception.getMessage()); - - exception = - assertThrows( - IllegalStateException.class, - () -> tupleValue("{\"geoV\":{\"lat\":42.60355556,\"lon\":false}}").get("geoV")); - assertEquals("longitude must be number value, but got value: false", exception.getMessage()); + assertEquals("lat must be a number", exception.getMessage()); } @Test From 14a80a95fb5fa36781b46f28ba52d406927e21c0 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 1 Aug 2024 09:48:33 -0700 Subject: [PATCH 08/12] Add AsyncQueryRequestContext to FlintIndexMetadataService/FlintIndexStateModelService (#2879) Signed-off-by: Tomoyuki Morita --- .../asyncquery/AsyncQueryExecutorService.java | 2 +- .../AsyncQueryExecutorServiceImpl.java | 4 +- .../spark/dispatcher/AsyncQueryHandler.java | 5 +- .../spark/dispatcher/BatchQueryHandler.java | 5 +- .../sql/spark/dispatcher/IndexDMLHandler.java | 16 +++- .../dispatcher/InteractiveQueryHandler.java | 5 +- .../spark/dispatcher/RefreshQueryHandler.java | 10 +- .../dispatcher/SparkQueryDispatcher.java | 6 +- .../dispatcher/StreamingQueryHandler.java | 5 +- .../flint/FlintIndexMetadataService.java | 11 ++- .../flint/FlintIndexStateModelService.java | 46 ++++++++- .../spark/flint/operation/FlintIndexOp.java | 47 ++++++---- .../flint/operation/FlintIndexOpAlter.java | 8 +- .../flint/operation/FlintIndexOpCancel.java | 6 +- .../flint/operation/FlintIndexOpDrop.java | 6 +- .../flint/operation/FlintIndexOpVacuum.java | 6 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 19 ++-- .../AsyncQueryExecutorServiceImplTest.java | 8 +- .../spark/dispatcher/IndexDMLHandlerTest.java | 11 ++- .../dispatcher/SparkQueryDispatcherTest.java | 15 ++- .../flint/operation/FlintIndexOpTest.java | 37 ++++++-- .../operation/FlintIndexOpVacuumTest.java | 94 ++++++++++++++----- .../FlintStreamingJobHouseKeeperTask.java | 13 ++- .../flint/FlintIndexMetadataServiceImpl.java | 9 +- ...OpenSearchFlintIndexStateModelService.java | 13 ++- ...ransportCancelAsyncQueryRequestAction.java | 5 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 6 +- .../spark/asyncquery/IndexQuerySpecTest.java | 23 +++-- .../asyncquery/model/MockFlintSparkJob.java | 11 ++- .../FlintStreamingJobHouseKeeperTaskTest.java | 8 +- .../FlintIndexMetadataServiceImplTest.java | 27 +++++- ...SearchFlintIndexStateModelServiceTest.java | 13 ++- ...portCancelAsyncQueryRequestActionTest.java | 16 +++- 33 files changed, 384 insertions(+), 132 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index d38c8554ae..b0c339e93d 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -39,5 +39,5 @@ CreateAsyncQueryResponse createAsyncQuery( * @param queryId queryId. * @return {@link String} cancelledQueryId. */ - String cancelQuery(String queryId); + String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 6d3d5b6765..d304766465 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -106,11 +106,11 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { } @Override - public String cancelQuery(String queryId) { + public String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional asyncQueryJobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (asyncQueryJobMetadata.isPresent()) { - return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get()); + return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext); } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java index d61ac17aa3..2bafd88b85 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -12,6 +12,7 @@ import com.amazonaws.services.emrserverless.model.JobRunState; import org.json.JSONObject; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -54,7 +55,9 @@ protected abstract JSONObject getResponseFromResultIndex( protected abstract JSONObject getResponseFromExecutor( AsyncQueryJobMetadata asyncQueryJobMetadata); - public abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); + public abstract String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext); public abstract DispatchQueryResponse submit( DispatchQueryRequest request, DispatchQueryContext context); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 2654f83aad..661ebe27fc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -16,6 +16,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -61,7 +62,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { emrServerlessClient.cancelJobRun( asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId(), false); return asyncQueryJobMetadata.getQueryId(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index e8413f469c..f8217142c3 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -62,9 +62,11 @@ public DispatchQueryResponse submit( long startTime = System.currentTimeMillis(); try { IndexQueryDetails indexDetails = context.getIndexQueryDetails(); - FlintIndexMetadata indexMetadata = getFlintIndexMetadata(indexDetails); + FlintIndexMetadata indexMetadata = + getFlintIndexMetadata(indexDetails, context.getAsyncQueryRequestContext()); - getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata); + getIndexOp(dispatchQueryRequest, indexDetails) + .apply(indexMetadata, context.getAsyncQueryRequestContext()); String asyncQueryId = storeIndexDMLResult( @@ -146,9 +148,11 @@ private FlintIndexOp getIndexOp( } } - private FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexDetails) { + private FlintIndexMetadata getFlintIndexMetadata( + IndexQueryDetails indexDetails, AsyncQueryRequestContext asyncQueryRequestContext) { Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexDetails.openSearchIndexName(), asyncQueryRequestContext); if (!indexMetadataMap.containsKey(indexDetails.openSearchIndexName())) { throw new IllegalStateException( String.format( @@ -174,7 +178,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { throw new IllegalArgumentException("can't cancel index DML query"); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index ec43bccf11..9a9baedde2 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -16,6 +16,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -71,7 +72,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String queryId = asyncQueryJobMetadata.getQueryId(); getStatementByQueryId( asyncQueryJobMetadata.getSessionId(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 99984ecc46..38145a143e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -8,6 +8,7 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -51,10 +52,13 @@ public RefreshQueryHandler( } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String datasourceName = asyncQueryJobMetadata.getDatasourceName(); Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(asyncQueryJobMetadata.getIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + asyncQueryJobMetadata.getIndexName(), asyncQueryRequestContext); if (!indexMetadataMap.containsKey(asyncQueryJobMetadata.getIndexName())) { throw new IllegalStateException( String.format( @@ -62,7 +66,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { } FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName); - jobCancelOp.apply(indexMetadata); + jobCancelOp.apply(indexMetadata, asyncQueryRequestContext); return asyncQueryJobMetadata.getQueryId(); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index a424db4c34..a6fdd3f102 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -162,9 +162,11 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) .getQueryResponse(asyncQueryJobMetadata); } - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) - .cancelJob(asyncQueryJobMetadata); + .cancelJob(asyncQueryJobMetadata, asyncQueryRequestContext); } private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 2fbf2466da..80d4be27cf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,6 +12,7 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -46,7 +47,9 @@ public StreamingQueryHandler( } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { throw new IllegalArgumentException( "can't cancel index DML query, using ALTER auto_refresh=off statement to stop job, using" + " VACUUM statement to stop job and delete data"); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java index ad274e429e..ece14c2a7b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.flint; import java.util.Map; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; /** Interface for FlintIndexMetadataReader */ @@ -15,16 +16,22 @@ public interface FlintIndexMetadataService { * Retrieves a map of {@link FlintIndexMetadata} instances matching the specified index pattern. * * @param indexPattern indexPattern. + * @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService * @return A map of {@link FlintIndexMetadata} instances against indexName, each providing * metadata access for a matched index. Returns an empty list if no indices match the pattern. */ - Map getFlintIndexMetadata(String indexPattern); + Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext); /** * Performs validation and updates flint index to manual refresh. * * @param indexName indexName. * @param flintIndexOptions flintIndexOptions. + * @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService */ - void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions); + void updateIndexToManualRefresh( + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java index 94647f4e07..3872f2d5a0 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java @@ -6,20 +6,58 @@ package org.opensearch.sql.spark.flint; import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; /** * Abstraction over flint index state storage. Flint index state will maintain the status of each * flint index. */ public interface FlintIndexStateModelService { - FlintIndexStateModel createFlintIndexStateModel(FlintIndexStateModel flintIndexStateModel); - Optional getFlintIndexStateModel(String id, String datasourceName); + /** + * Create Flint index state record + * + * @param flintIndexStateModel the model to be saved + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return saved model + */ + FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, AsyncQueryRequestContext asyncQueryRequestContext); + /** + * Get Flint index state record + * + * @param id ID(latestId) of the Flint index state record + * @param datasourceName datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return retrieved model + */ + Optional getFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); + + /** + * Update Flint index state record + * + * @param flintIndexStateModel the model to be updated + * @param flintIndexState new state + * @param datasourceName Datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return Updated model + */ FlintIndexStateModel updateFlintIndexState( FlintIndexStateModel flintIndexStateModel, FlintIndexState flintIndexState, - String datasourceName); + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext); - boolean deleteFlintIndexStateModel(String id, String datasourceName); + /** + * Delete Flint index state record + * + * @param id ID(latestId) of the Flint index state record + * @param datasourceName datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return true if deleted, otherwise false + */ + boolean deleteFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 244f4aee11..78d217b8dc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -16,6 +16,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.NotNull; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -33,30 +34,33 @@ public abstract class FlintIndexOp { private final EMRServerlessClientFactory emrServerlessClientFactory; /** Apply operation on {@link FlintIndexMetadata} */ - public void apply(FlintIndexMetadata metadata) { + public void apply( + FlintIndexMetadata metadata, AsyncQueryRequestContext asyncQueryRequestContext) { // todo, remove this logic after IndexState feature is enabled in Flint. Optional latestId = metadata.getLatestId(); if (latestId.isEmpty()) { - takeActionWithoutOCC(metadata); + takeActionWithoutOCC(metadata, asyncQueryRequestContext); } else { - FlintIndexStateModel initialFlintIndexStateModel = getFlintIndexStateModel(latestId.get()); + FlintIndexStateModel initialFlintIndexStateModel = + getFlintIndexStateModel(latestId.get(), asyncQueryRequestContext); // 1.validate state. validFlintIndexInitialState(initialFlintIndexStateModel); // 2.begin, move to transitioning state FlintIndexStateModel transitionedFlintIndexStateModel = - moveToTransitioningState(initialFlintIndexStateModel); + moveToTransitioningState(initialFlintIndexStateModel, asyncQueryRequestContext); // 3.runOp try { - runOp(metadata, transitionedFlintIndexStateModel); - commit(transitionedFlintIndexStateModel); + runOp(metadata, transitionedFlintIndexStateModel, asyncQueryRequestContext); + commit(transitionedFlintIndexStateModel, asyncQueryRequestContext); } catch (Throwable e) { LOG.error("Rolling back transient log due to transaction operation failure", e); try { flintIndexStateModelService.updateFlintIndexState( transitionedFlintIndexStateModel, initialFlintIndexStateModel.getIndexState(), - datasourceName); + datasourceName, + asyncQueryRequestContext); } catch (Exception ex) { LOG.error("Failed to rollback transient log", ex); } @@ -66,9 +70,11 @@ public void apply(FlintIndexMetadata metadata) { } @NotNull - private FlintIndexStateModel getFlintIndexStateModel(String latestId) { + private FlintIndexStateModel getFlintIndexStateModel( + String latestId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional flintIndexOptional = - flintIndexStateModelService.getFlintIndexStateModel(latestId, datasourceName); + flintIndexStateModelService.getFlintIndexStateModel( + latestId, datasourceName, asyncQueryRequestContext); if (flintIndexOptional.isEmpty()) { String errorMsg = String.format(Locale.ROOT, "no state found. docId: %s", latestId); LOG.error(errorMsg); @@ -77,7 +83,8 @@ private FlintIndexStateModel getFlintIndexStateModel(String latestId) { return flintIndexOptional.get(); } - private void takeActionWithoutOCC(FlintIndexMetadata metadata) { + private void takeActionWithoutOCC( + FlintIndexMetadata metadata, AsyncQueryRequestContext asyncQueryRequestContext) { // take action without occ. FlintIndexStateModel fakeModel = FlintIndexStateModel.builder() @@ -89,7 +96,7 @@ private void takeActionWithoutOCC(FlintIndexMetadata metadata) { .lastUpdateTime(System.currentTimeMillis()) .error("") .build(); - runOp(metadata, fakeModel); + runOp(metadata, fakeModel, asyncQueryRequestContext); } private void validFlintIndexInitialState(FlintIndexStateModel flintIndex) { @@ -103,13 +110,14 @@ private void validFlintIndexInitialState(FlintIndexStateModel flintIndex) { } } - private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flintIndex) { + private FlintIndexStateModel moveToTransitioningState( + FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug("Moving to transitioning state before committing."); FlintIndexState transitioningState = transitioningState(); try { flintIndex = flintIndexStateModelService.updateFlintIndexState( - flintIndex, transitioningState(), datasourceName); + flintIndex, transitioningState(), datasourceName, asyncQueryRequestContext); } catch (Exception e) { String errorMsg = String.format(Locale.ROOT, "Moving to transition state:%s failed.", transitioningState); @@ -119,16 +127,18 @@ private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flint return flintIndex; } - private void commit(FlintIndexStateModel flintIndex) { + private void commit( + FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug("Committing the transaction and moving to stable state."); FlintIndexState stableState = stableState(); try { if (stableState == FlintIndexState.NONE) { LOG.info("Deleting index state with docId: " + flintIndex.getLatestId()); flintIndexStateModelService.deleteFlintIndexStateModel( - flintIndex.getLatestId(), datasourceName); + flintIndex.getLatestId(), datasourceName, asyncQueryRequestContext); } else { - flintIndexStateModelService.updateFlintIndexState(flintIndex, stableState, datasourceName); + flintIndexStateModelService.updateFlintIndexState( + flintIndex, stableState, datasourceName, asyncQueryRequestContext); } } catch (Exception e) { String errorMsg = @@ -192,7 +202,10 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel) /** get transitioningState */ abstract FlintIndexState transitioningState(); - abstract void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex); + abstract void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext); /** get stableState */ abstract FlintIndexState stableState(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 9955320253..4a00195ebf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -48,11 +49,14 @@ FlintIndexState transitioningState() { @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Running alter index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); this.flintIndexMetadataService.updateIndexToManualRefresh( - flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions); + flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions, asyncQueryRequestContext); cancelStreamingJob(flintIndexStateModel); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java index 02c8e39c66..504a8f93c9 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -38,7 +39,10 @@ FlintIndexState transitioningState() { /** cancel EMR-S job, wait cancelled state upto 15s. */ @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index 6613c29870..fc9b644fc7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -40,7 +41,10 @@ FlintIndexState transitioningState() { /** cancel EMR-S job, wait cancelled state upto 15s. */ @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index a0ef955adf..06aaf8ef9f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -7,6 +7,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -42,7 +43,10 @@ FlintIndexState transitioningState() { } @Override - public void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex) { + public void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.info("Vacuuming Flint index {}", flintIndexMetadata.getOpensearchIndexName()); flintIndexClient.deleteIndex(flintIndexMetadata.getOpensearchIndexName()); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index d82d3bdab7..ff92762a7c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -249,7 +249,8 @@ public void createAlterIndexQuery() { assertNull(response.getSessionId()); verifyGetQueryIdCalled(); verify(flintIndexMetadataService) - .updateIndexToManualRefresh(eq(indexName), flintIndexOptionsArgumentCaptor.capture()); + .updateIndexToManualRefresh( + eq(indexName), flintIndexOptionsArgumentCaptor.capture(), eq(asyncQueryRequestContext)); FlintIndexOptions flintIndexOptions = flintIndexOptionsArgumentCaptor.getValue(); assertFalse(flintIndexOptions.autoRefresh()); verifyCancelJobRunCalled(); @@ -430,7 +431,7 @@ public void cancelInteractiveQuery() { when(statementStorageService.updateStatementState(statementModel, StatementState.CANCELLED)) .thenReturn(canceledStatementModel); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verify(statementStorageService).updateStatementState(statementModel, StatementState.CANCELLED); @@ -441,14 +442,15 @@ public void cancelIndexDMLQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobId(DROP_INDEX_JOB_ID)); assertThrows( - IllegalArgumentException.class, () -> asyncQueryExecutorService.cancelQuery(QUERY_ID)); + IllegalArgumentException.class, + () -> asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext)); } @Test public void cancelRefreshQuery() { givenJobMetadataExists( getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME)); - when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME)) + when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( INDEX_NAME, @@ -463,7 +465,7 @@ public void cancelRefreshQuery() { new GetJobRunResult() .withJobRun(new JobRun().withJobRunId(JOB_ID).withState("Cancelled"))); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verifyCancelJobRunCalled(); @@ -475,7 +477,8 @@ public void cancelStreamingQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.STREAMING)); assertThrows( - IllegalArgumentException.class, () -> asyncQueryExecutorService.cancelQuery(QUERY_ID)); + IllegalArgumentException.class, + () -> asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext)); } @Test @@ -483,7 +486,7 @@ public void cancelBatchQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobId(JOB_ID)); givenCancelJobRunSucceed(); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verifyCancelJobRunCalled(); @@ -500,7 +503,7 @@ private void givenSparkExecutionEngineConfigIsSupplied() { } private void givenFlintIndexMetadataExists(String indexName) { - when(flintIndexMetadataService.getFlintIndexMetadata(indexName)) + when(flintIndexMetadataService.getFlintIndexMetadata(indexName, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( indexName, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index dbc51bb0ad..5d8d9a3b63 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -206,7 +206,8 @@ void testCancelJobWithJobNotFound() { AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( - AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID)); + AsyncQueryNotFoundException.class, + () -> jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext)); Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); @@ -218,9 +219,10 @@ void testCancelJobWithJobNotFound() { void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.of(getAsyncQueryJobMetadata())); - when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata())).thenReturn(EMR_JOB_ID); + when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata(), asyncQueryRequestContext)) + .thenReturn(EMR_JOB_ID); - String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); + String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext); Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(sparkExecutionEngineConfigSupplier); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 877d6ec32b..9a3c4e663e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.sql.datasource.model.DataSourceStatus.ACTIVE; @@ -27,6 +28,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -50,6 +52,7 @@ class IndexDMLHandlerTest { @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @InjectMocks IndexDMLHandler indexDMLHandler; @@ -82,8 +85,10 @@ public void testWhenIndexDetailsAreNotFound() { .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - Mockito.when(flintIndexMetadataService.getFlintIndexMetadata(any())) + Mockito.when( + flintIndexMetadataService.getFlintIndexMetadata(any(), eq(asyncQueryRequestContext))) .thenReturn(new HashMap<>()); DispatchQueryResponse dispatchQueryResponse = @@ -107,10 +112,12 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); HashMap flintMetadataMap = new HashMap<>(); flintMetadataMap.put(indexQueryDetails.openSearchIndexName(), flintIndexMetadata); - when(flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName())) + when(flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext)) .thenReturn(flintMetadataMap); indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a7a79c758e..592309cb75 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -871,7 +871,8 @@ void testCancelJob() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, queryId); } @@ -884,7 +885,8 @@ void testCancelQueryWithSession() { String queryId = sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID), + asyncQueryRequestContext); verifyNoInteractions(emrServerlessClient); verify(statement, times(1)).cancel(); @@ -900,7 +902,8 @@ void testCancelQueryWithInvalidSession() { IllegalArgumentException.class, () -> sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(session); @@ -916,7 +919,8 @@ void testCancelQueryWithInvalidStatementId() { IllegalArgumentException.class, () -> sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(statement); @@ -933,7 +937,8 @@ void testCancelQueryWithNoSessionId() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, queryId); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index 0c82733ae6..8105629822 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.xcontent.XContentSerializerUtil; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -28,21 +29,26 @@ public class FlintIndexOpTest { @Mock private FlintIndexStateModelService flintIndexStateModelService; @Mock private EMRServerlessClientFactory mockEmrServerlessClientFactory; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Test public void testApplyWithTransitioningStateFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenThrow(new RuntimeException("Transitioning state failed")); FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "Moving to transition state:DELETING failed.", illegalStateException.getMessage()); @@ -53,9 +59,11 @@ public void testApplyWithCommitFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenReturn( FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) @@ -65,7 +73,9 @@ public void testApplyWithCommitFailure() { new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); @@ -76,9 +86,11 @@ public void testApplyWithRollBackFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenReturn( FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) @@ -87,7 +99,9 @@ public void testApplyWithRollBackFailure() { new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); @@ -125,7 +139,10 @@ FlintIndexState transitioningState() { } @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex) {} + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext) {} @Override FlintIndexState stableState() { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java index 60fa13dc93..26858c18fe 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -38,6 +39,7 @@ class FlintIndexOpVacuumTest { @Mock EMRServerlessClientFactory emrServerlessClientFactory; @Mock FlintIndexStateModel flintIndexStateModel; @Mock FlintIndexStateModel transitionedFlintIndexStateModel; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; RuntimeException testException = new RuntimeException("Test Exception"); @@ -55,110 +57,154 @@ public void setUp() { @Test public void testApplyWithEmptyLatestId() { - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID); + flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } @Test public void testApplyWithFlintIndexStateNotFound() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.empty()); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithNotDeletedState() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.ACTIVE); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithUpdateFlintIndexStateThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenThrow(testException); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithRunOpThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); assertThrows( - Exception.class, () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + Exception.class, + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); verify(flintIndexStateModelService) .updateFlintIndexState( - transitionedFlintIndexStateModel, FlintIndexState.DELETED, DATASOURCE_NAME); + transitionedFlintIndexStateModel, + FlintIndexState.DELETED, + DATASOURCE_NAME, + asyncQueryRequestContext); } @Test public void testApplyWithRunOpThrowAndRollbackThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); when(flintIndexStateModelService.updateFlintIndexState( - transitionedFlintIndexStateModel, FlintIndexState.DELETED, DATASOURCE_NAME)) + transitionedFlintIndexStateModel, + FlintIndexState.DELETED, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenThrow(testException); assertThrows( - Exception.class, () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + Exception.class, + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithDeleteFlintIndexStateModelThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); - when(flintIndexStateModelService.deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.deleteFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenThrow(testException); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyHappyPath() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); when(transitionedFlintIndexStateModel.getLatestId()).thenReturn(LATEST_ID); - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID); + flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext); - verify(flintIndexStateModelService).deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME); + verify(flintIndexStateModelService) + .deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java b/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java index 31b1ecb49c..2dd0a4a7cf 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java @@ -17,6 +17,7 @@ import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -29,6 +30,8 @@ public class FlintStreamingJobHouseKeeperTask implements Runnable { private final DataSourceService dataSourceService; private final FlintIndexMetadataService flintIndexMetadataService; private final FlintIndexOpFactory flintIndexOpFactory; + private final NullAsyncQueryRequestContext nullAsyncQueryRequestContext = + new NullAsyncQueryRequestContext(); private static final Logger LOGGER = LogManager.getLogger(FlintStreamingJobHouseKeeperTask.class); protected static final AtomicBoolean isRunning = new AtomicBoolean(false); @@ -91,7 +94,9 @@ private void dropAutoRefreshIndex( String autoRefreshIndex, FlintIndexMetadata flintIndexMetadata, String datasourceName) { // When the datasource is deleted. Possibly Replace with VACUUM Operation. LOGGER.info("Attempting to drop auto refresh index: {}", autoRefreshIndex); - flintIndexOpFactory.getDrop(datasourceName).apply(flintIndexMetadata); + flintIndexOpFactory + .getDrop(datasourceName) + .apply(flintIndexMetadata, nullAsyncQueryRequestContext); LOGGER.info("Successfully dropped index: {}", autoRefreshIndex); } @@ -100,7 +105,9 @@ private void alterAutoRefreshIndex( LOGGER.info("Attempting to alter index: {}", autoRefreshIndex); FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); flintIndexOptions.setOption(FlintIndexOptions.AUTO_REFRESH, "false"); - flintIndexOpFactory.getAlter(flintIndexOptions, datasourceName).apply(flintIndexMetadata); + flintIndexOpFactory + .getAlter(flintIndexOptions, datasourceName) + .apply(flintIndexMetadata, nullAsyncQueryRequestContext); LOGGER.info("Successfully altered index: {}", autoRefreshIndex); } @@ -119,7 +126,7 @@ private String getDataSourceName(FlintIndexMetadata flintIndexMetadata) { private Map getAllAutoRefreshIndices() { Map flintIndexMetadataHashMap = - flintIndexMetadataService.getFlintIndexMetadata("flint_*"); + flintIndexMetadataService.getFlintIndexMetadata("flint_*", nullAsyncQueryRequestContext); return flintIndexMetadataHashMap.entrySet().stream() .filter(entry -> entry.getValue().getFlintIndexOptions().autoRefresh()) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java index 893b33b39d..b8352d15b2 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java @@ -33,6 +33,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.client.Client; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; /** Implementation of {@link FlintIndexMetadataService} */ @@ -49,7 +50,8 @@ public class FlintIndexMetadataServiceImpl implements FlintIndexMetadataService Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH, WATERMARK_DELAY, CHECKPOINT_LOCATION)); @Override - public Map getFlintIndexMetadata(String indexPattern) { + public Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext) { GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings().setIndices(indexPattern).get(); Map indexMetadataMap = new HashMap<>(); @@ -73,7 +75,10 @@ public Map getFlintIndexMetadata(String indexPattern } @Override - public void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions) { + public void updateIndexToManualRefresh( + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext) { GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings().setIndices(indexName).get(); Map flintMetadataMap = diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java index 5781c3e44b..eba338e912 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java @@ -7,6 +7,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; @@ -20,7 +21,8 @@ public class OpenSearchFlintIndexStateModelService implements FlintIndexStateMod public FlintIndexStateModel updateFlintIndexState( FlintIndexStateModel flintIndexStateModel, FlintIndexState flintIndexState, - String datasourceName) { + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.updateState( flintIndexStateModel, flintIndexState, @@ -29,14 +31,16 @@ public FlintIndexStateModel updateFlintIndexState( } @Override - public Optional getFlintIndexStateModel(String id, String datasourceName) { + public Optional getFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.get( id, serializer::fromXContent, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } @Override public FlintIndexStateModel createFlintIndexStateModel( - FlintIndexStateModel flintIndexStateModel) { + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.create( flintIndexStateModel.getId(), flintIndexStateModel, @@ -45,7 +49,8 @@ public FlintIndexStateModel createFlintIndexStateModel( } @Override - public boolean deleteFlintIndexStateModel(String id, String datasourceName) { + public boolean deleteFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.delete(id, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 232a280db5..ce80351f70 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -13,6 +13,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -41,7 +42,9 @@ protected void doExecute( CancelAsyncQueryActionRequest request, ActionListener listener) { try { - String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); + String jobId = + asyncQueryExecutorService.cancelQuery( + request.getQueryId(), new NullAsyncQueryRequestContext()); listener.onResponse( new CancelAsyncQueryActionResponse( String.format("Deleted async query with id: %s", jobId))); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 3ff806bf50..ede8a348b4 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -71,7 +71,8 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { emrsClient.getJobRunResultCalled(1); // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelQueryId = + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext); assertEquals(response.getQueryId(), cancelQueryId); emrsClient.cancelJobRunCalled(1); } @@ -163,7 +164,8 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelQueryId = + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext); assertEquals(response.getQueryId(), cancelQueryId); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 2eed7b13a0..29c42446b3 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -152,7 +152,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals("can't cancel index DML query", exception.getMessage()); }); } @@ -326,7 +328,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals("can't cancel index DML query", exception.getMessage()); }); } @@ -901,7 +905,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals( "can't cancel index DML query, using ALTER auto_refresh=off statement to stop" + " job, using VACUUM statement to stop job and delete data", @@ -944,7 +950,9 @@ public GetJobRunResult getJobRunResult( flintIndexJob.refreshing(); // 2. Cancel query - String cancelResponse = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelResponse = + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext); assertNotNull(cancelResponse); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -992,7 +1000,9 @@ public GetJobRunResult getJobRunResult( IllegalStateException illegalStateException = Assertions.assertThrows( IllegalStateException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); Assertions.assertEquals( "Transaction failed as flint index is not in a valid state.", illegalStateException.getMessage()); @@ -1038,6 +1048,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. Cancel query Assertions.assertThrows( IllegalStateException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext)); } } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java index 6c82188ee6..0dc8f02820 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java @@ -18,6 +18,7 @@ public class MockFlintSparkJob { private FlintIndexStateModel stateModel; private FlintIndexStateModelService flintIndexStateModelService; private String datasource; + private AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); public MockFlintSparkJob( FlintIndexStateModelService flintIndexStateModelService, String latestId, String datasource) { @@ -34,12 +35,15 @@ public MockFlintSparkJob( .lastUpdateTime(System.currentTimeMillis()) .error("") .build(); - stateModel = flintIndexStateModelService.createFlintIndexStateModel(stateModel); + stateModel = + flintIndexStateModelService.createFlintIndexStateModel( + stateModel, asyncQueryRequestContext); } public void transition(FlintIndexState newState) { stateModel = - flintIndexStateModelService.updateFlintIndexState(stateModel, newState, datasource); + flintIndexStateModelService.updateFlintIndexState( + stateModel, newState, datasource, asyncQueryRequestContext); } public void refreshing() { @@ -68,7 +72,8 @@ public void deleted() { public void assertState(FlintIndexState expected) { Optional stateModelOpt = - flintIndexStateModelService.getFlintIndexStateModel(stateModel.getId(), datasource); + flintIndexStateModelService.getFlintIndexStateModel( + stateModel.getId(), datasource, asyncQueryRequestContext); assertTrue(stateModelOpt.isPresent()); assertEquals(expected, stateModelOpt.get().getIndexState()); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index c5964a61e3..0a3a180932 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -20,6 +20,7 @@ import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceSpec; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.asyncquery.model.MockFlintIndex; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; @@ -393,13 +394,16 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataService() { @Override - public Map getFlintIndexMetadata(String indexPattern) { + public Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext) { throw new RuntimeException("Couldn't fetch details from ElasticSearch"); } @Override public void updateIndexToManualRefresh( - String indexName, FlintIndexOptions flintIndexOptions) {} + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext) {} }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java index f6baa82dd2..b1321cc132 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java @@ -29,6 +29,7 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; @@ -39,6 +40,8 @@ public class FlintIndexMetadataServiceImplTest { @Mock(answer = RETURNS_DEEP_STUBS) private Client client; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; + @SneakyThrows @Test void testGetJobIdFromFlintSkippingIndexMetadata() { @@ -56,8 +59,11 @@ void testGetJobIdFromFlintSkippingIndexMetadata() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertEquals( "00fhelvq7peuao0", indexMetadataMap.get(indexQueryDetails.openSearchIndexName()).getJobId()); @@ -80,8 +86,11 @@ void testGetJobIdFromFlintSkippingIndexMetadataWithIndexState() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + FlintIndexMetadata metadata = indexMetadataMap.get(indexQueryDetails.openSearchIndexName()); Assertions.assertEquals("00fhelvq7peuao0", metadata.getJobId()); } @@ -103,8 +112,11 @@ void testGetJobIdFromFlintCoveringIndexMetadata() { .indexType(FlintIndexType.COVERING) .build(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertEquals( "00fdmvv9hp8u0o0q", indexMetadataMap.get(indexQueryDetails.openSearchIndexName()).getJobId()); @@ -126,8 +138,11 @@ void testGetJobIDWithNPEException() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.COVERING) .build(); + Map flintIndexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertFalse( flintIndexMetadataMap.containsKey("flint_mys3_default_http_logs_cv1_index")); } @@ -148,8 +163,10 @@ void testGetJobIDWithNPEExceptionForMultipleIndices() { indexMappingsMap.put(indexName, mappings); mockNodeClientIndicesMappings("flint_mys3*", indexMappingsMap); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + Map flintIndexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata("flint_mys3*"); + flintIndexMetadataService.getFlintIndexMetadata("flint_mys3*", asyncQueryRequestContext); + Assertions.assertFalse( flintIndexMetadataMap.containsKey("flint_mys3_default_http_logs_cv1_index")); Assertions.assertTrue( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java index 977f77b397..4faff41fe6 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java @@ -16,6 +16,7 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; @@ -30,6 +31,7 @@ public class OpenSearchFlintIndexStateModelServiceTest { @Mock FlintIndexState flintIndexState; @Mock FlintIndexStateModel responseFlintIndexStateModel; @Mock FlintIndexStateModelXContentSerializer flintIndexStateModelXContentSerializer; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; @InjectMocks OpenSearchFlintIndexStateModelService openSearchFlintIndexStateModelService; @@ -40,7 +42,7 @@ void updateFlintIndexState() { FlintIndexStateModel result = openSearchFlintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, flintIndexState, DATASOURCE); + flintIndexStateModel, flintIndexState, DATASOURCE, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result); } @@ -51,7 +53,8 @@ void getFlintIndexStateModel() { .thenReturn(Optional.of(responseFlintIndexStateModel)); Optional result = - openSearchFlintIndexStateModelService.getFlintIndexStateModel("ID", DATASOURCE); + openSearchFlintIndexStateModelService.getFlintIndexStateModel( + "ID", DATASOURCE, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result.get()); } @@ -63,7 +66,8 @@ void createFlintIndexStateModel() { when(flintIndexStateModel.getDatasourceName()).thenReturn(DATASOURCE); FlintIndexStateModel result = - openSearchFlintIndexStateModelService.createFlintIndexStateModel(flintIndexStateModel); + openSearchFlintIndexStateModelService.createFlintIndexStateModel( + flintIndexStateModel, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result); } @@ -73,7 +77,8 @@ void deleteFlintIndexStateModel() { when(mockStateStore.delete(any(), any())).thenReturn(true); boolean result = - openSearchFlintIndexStateModelService.deleteFlintIndexStateModel(ID, DATASOURCE); + openSearchFlintIndexStateModelService.deleteFlintIndexStateModel( + ID, DATASOURCE, asyncQueryRequestContext); assertTrue(result); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..a2581fdea2 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -7,6 +7,8 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -24,6 +26,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -36,7 +39,6 @@ public class TransportCancelAsyncQueryRequestActionTest { @Mock private TransportCancelAsyncQueryRequestAction action; @Mock private Task task; @Mock private ActionListener actionListener; - @Mock private AsyncQueryExecutorServiceImpl asyncQueryExecutorService; @Captor @@ -54,8 +56,12 @@ public void setUp() { @Test public void testDoExecute() { CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); - when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + when(asyncQueryExecutorService.cancelQuery( + eq(EMR_JOB_ID), any(NullAsyncQueryRequestContext.class))) + .thenReturn(EMR_JOB_ID); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); @@ -66,8 +72,12 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); - doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); + doThrow(new RuntimeException("Error")) + .when(asyncQueryExecutorService) + .cancelQuery(eq(EMR_JOB_ID), any(NullAsyncQueryRequestContext.class)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); Exception exception = exceptionArgumentCaptor.getValue(); Assertions.assertTrue(exception instanceof RuntimeException); From 0e70a502bbc64779c428dd0dcc2d9d65f6cfb591 Mon Sep 17 00:00:00 2001 From: qianheng Date: Mon, 5 Aug 2024 23:25:47 +0800 Subject: [PATCH 09/12] add TakeOrderedOperator (#2863) --------- Signed-off-by: Heng Qian --- .../org/opensearch/sql/executor/Explain.java | 14 + .../sql/planner/DefaultImplementor.java | 9 +- .../sql/planner/physical/PhysicalPlanDSL.java | 5 + .../physical/PhysicalPlanNodeVisitor.java | 4 + .../sql/planner/physical/SortHelper.java | 70 ++ .../sql/planner/physical/SortOperator.java | 43 +- .../planner/physical/TakeOrderedOperator.java | 88 +++ .../opensearch/sql/executor/ExplainTest.java | 21 + .../sql/planner/DefaultImplementorTest.java | 26 + .../physical/PhysicalPlanNodeVisitorTest.java | 4 + .../physical/TakeOrderedOperatorTest.java | 607 ++++++++++++++++++ docs/user/optimization/optimization.rst | 29 +- .../OpenSearchExecutionProtector.java | 12 + .../OpenSearchExecutionProtectorTest.java | 11 + 14 files changed, 883 insertions(+), 60 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/SortHelper.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/TakeOrderedOperator.java create mode 100644 core/src/test/java/org/opensearch/sql/planner/physical/TakeOrderedOperatorTest.java diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index 0f05b99383..fffbe6f693 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -30,6 +30,7 @@ import org.opensearch.sql.planner.physical.RemoveOperator; import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; +import org.opensearch.sql.planner.physical.TakeOrderedOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.TableScanOperator; @@ -73,6 +74,19 @@ public ExplainResponseNode visitSort(SortOperator node, Object context) { ImmutableMap.of("sortList", describeSortList(node.getSortList())))); } + @Override + public ExplainResponseNode visitTakeOrdered(TakeOrderedOperator node, Object context) { + return explain( + node, + context, + explainNode -> + explainNode.setDescription( + ImmutableMap.of( + "limit", node.getLimit(), + "offset", node.getOffset(), + "sortList", describeSortList(node.getSortList())))); + } + @Override public ExplainResponseNode visitTableScan(TableScanOperator node, Object context) { return explain( diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index b53d17b38f..f962c3e4bf 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -38,6 +38,7 @@ import org.opensearch.sql.planner.physical.RemoveOperator; import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; +import org.opensearch.sql.planner.physical.TakeOrderedOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -129,7 +130,13 @@ public PhysicalPlan visitValues(LogicalValues node, C context) { @Override public PhysicalPlan visitLimit(LogicalLimit node, C context) { - return new LimitOperator(visitChild(node, context), node.getLimit(), node.getOffset()); + PhysicalPlan child = visitChild(node, context); + // Optimize sort + limit to take ordered operator + if (child instanceof SortOperator sortChild) { + return new TakeOrderedOperator( + sortChild.getInput(), node.getLimit(), node.getOffset(), sortChild.getSortList()); + } + return new LimitOperator(child, node.getLimit(), node.getOffset()); } @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java index 147f0e08dc..0c2764112d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java @@ -64,6 +64,11 @@ public static SortOperator sort(PhysicalPlan input, Pair return new SortOperator(input, Arrays.asList(sorts)); } + public static TakeOrderedOperator takeOrdered( + PhysicalPlan input, Integer limit, Integer offset, Pair... sorts) { + return new TakeOrderedOperator(input, limit, offset, Arrays.asList(sorts)); + } + public static DedupeOperator dedupe(PhysicalPlan input, Expression... expressions) { return new DedupeOperator(input, Arrays.asList(expressions)); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 99b5cc8020..67d7a05135 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -72,6 +72,10 @@ public R visitSort(SortOperator node, C context) { return visitNode(node, context); } + public R visitTakeOrdered(TakeOrderedOperator node, C context) { + return visitNode(node, context); + } + public R visitRareTopN(RareTopNOperator node, C context) { return visitNode(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/SortHelper.java b/core/src/main/java/org/opensearch/sql/planner/physical/SortHelper.java new file mode 100644 index 0000000000..ea117ee6df --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/SortHelper.java @@ -0,0 +1,70 @@ +package org.opensearch.sql.planner.physical; + +import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; +import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; + +import com.google.common.collect.Ordering; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.utils.ExprValueOrdering; +import org.opensearch.sql.expression.Expression; + +public interface SortHelper { + + /** + * Construct an expr comparator for sorting on ExprValue. + * + * @param sortList list of sort fields and their related sort options. + * @return A comparator for ExprValue + */ + static Comparator constructExprComparator( + List> sortList) { + return (o1, o2) -> compareWithExpressions(o1, o2, constructComparator(sortList)); + } + + /** + * Construct an expr ordering for efficiently taking the top-k elements on ExprValue. + * + * @param sortList list of sort fields and their related sort options. + * @return An guava ordering for ExprValue + */ + static Ordering constructExprOrdering(List> sortList) { + return Ordering.from(constructExprComparator(sortList)); + } + + private static List>> constructComparator( + List> sortList) { + List>> comparators = new ArrayList<>(); + for (Pair pair : sortList) { + SortOption option = pair.getLeft(); + ExprValueOrdering ordering = + ASC.equals(option.getSortOrder()) + ? ExprValueOrdering.natural() + : ExprValueOrdering.natural().reverse(); + ordering = + NULL_FIRST.equals(option.getNullOrder()) ? ordering.nullsFirst() : ordering.nullsLast(); + comparators.add(Pair.of(pair.getRight(), ordering)); + } + return comparators; + } + + private static int compareWithExpressions( + ExprValue o1, ExprValue o2, List>> comparators) { + for (Pair> comparator : comparators) { + Expression expression = comparator.getKey(); + int result = + comparator + .getValue() + .compare( + expression.valueOf(o1.bindingTuples()), expression.valueOf(o2.bindingTuples())); + if (result != 0) { + return result; + } + } + return 0; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/SortOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/SortOperator.java index e3116baedf..b635f01d18 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/SortOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/SortOperator.java @@ -5,25 +5,18 @@ package org.opensearch.sql.planner.physical; -import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; -import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; - import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.PriorityQueue; -import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.Singular; import lombok.ToString; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.utils.ExprValueOrdering; import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.planner.physical.SortOperator.Sorter.SorterBuilder; /** * Sort Operator.The input data is sorted by the sort fields in the {@link SortOperator#sortList}. @@ -36,7 +29,7 @@ public class SortOperator extends PhysicalPlan { @Getter private final PhysicalPlan input; @Getter private final List> sortList; - @EqualsAndHashCode.Exclude private final Sorter sorter; + @EqualsAndHashCode.Exclude private final Comparator sorter; @EqualsAndHashCode.Exclude private Iterator iterator; /** @@ -49,18 +42,7 @@ public class SortOperator extends PhysicalPlan { public SortOperator(PhysicalPlan input, List> sortList) { this.input = input; this.sortList = sortList; - SorterBuilder sorterBuilder = Sorter.builder(); - for (Pair pair : sortList) { - SortOption option = pair.getLeft(); - ExprValueOrdering ordering = - ASC.equals(option.getSortOrder()) - ? ExprValueOrdering.natural() - : ExprValueOrdering.natural().reverse(); - ordering = - NULL_FIRST.equals(option.getNullOrder()) ? ordering.nullsFirst() : ordering.nullsLast(); - sorterBuilder.comparator(Pair.of(pair.getRight(), ordering)); - } - this.sorter = sorterBuilder.build(); + this.sorter = SortHelper.constructExprComparator(sortList); } @Override @@ -94,27 +76,6 @@ public ExprValue next() { return iterator.next(); } - @Builder - public static class Sorter implements Comparator { - @Singular private final List>> comparators; - - @Override - public int compare(ExprValue o1, ExprValue o2) { - for (Pair> comparator : comparators) { - Expression expression = comparator.getKey(); - int result = - comparator - .getValue() - .compare( - expression.valueOf(o1.bindingTuples()), expression.valueOf(o2.bindingTuples())); - if (result != 0) { - return result; - } - } - return 0; - } - } - private Iterator iterator(PriorityQueue result) { return new Iterator() { @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TakeOrderedOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TakeOrderedOperator.java new file mode 100644 index 0000000000..a6e0f968e6 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TakeOrderedOperator.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import com.google.common.collect.Ordering; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.Expression; + +/** + * TakeOrdered Operator. This operator will sort input data as the order of {@link this#sortList} + * specifies and return {@link this#limit} rows from the {@link this#offset} index. + * + *

Functionally, this operator is a combination of {@link SortOperator} and {@link + * LimitOperator}. But it can reduce the time complexity from O(nlogn) to O(n), and memory from O(n) + * to O(k) due to use guava {@link com.google.common.collect.Ordering}. + * + *

Overall, it's an optimization to replace `Limit(Sort)` in physical plan level since it's all + * about execution. Because most execution engine may not support this operator, it doesn't have a + * related logical operator. + */ +@ToString +@EqualsAndHashCode(callSuper = false) +public class TakeOrderedOperator extends PhysicalPlan { + @Getter private final PhysicalPlan input; + + @Getter private final List> sortList; + @Getter private final Integer limit; + @Getter private final Integer offset; + @EqualsAndHashCode.Exclude private final Ordering ordering; + @EqualsAndHashCode.Exclude private Iterator iterator; + + /** + * TakeOrdered Operator Constructor. + * + * @param input input {@link PhysicalPlan} + * @param limit the limit value from LimitOperator + * @param offset the offset value from LimitOperator + * @param sortList list of sort field from SortOperator + */ + public TakeOrderedOperator( + PhysicalPlan input, + Integer limit, + Integer offset, + List> sortList) { + this.input = input; + this.sortList = sortList; + this.limit = limit; + this.offset = offset; + this.ordering = SortHelper.constructExprOrdering(sortList); + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitTakeOrdered(this, context); + } + + @Override + public void open() { + super.open(); + iterator = ordering.leastOf(input, offset + limit).stream().skip(offset).iterator(); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index 897347f22d..eaeae07242 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -27,6 +27,7 @@ import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.remove; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.rename; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.sort; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.takeOrdered; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window; @@ -220,6 +221,26 @@ void can_explain_limit() { explain.apply(plan)); } + @Test + void can_explain_takeOrdered() { + Pair sort = + ImmutablePair.of(Sort.SortOption.DEFAULT_ASC, ref("a", INTEGER)); + PhysicalPlan plan = takeOrdered(tableScan, 10, 5, sort); + assertEquals( + new ExplainResponse( + new ExplainResponseNode( + "TakeOrderedOperator", + Map.of( + "limit", + 10, + "offset", + 5, + "sortList", + Map.of("a", Map.of("sortOrder", "ASC", "nullOrder", "NULL_FIRST"))), + singletonList(tableScan.explainNode()))), + explain.apply(plan)); + } + @Test void can_explain_nested() { Set nestedOperatorArgs = Set.of("message.info", "message"); diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 45d8f6c03c..8e71fc2bec 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -278,4 +278,30 @@ public void visitPaginate_should_remove_it_from_tree() { new ProjectOperator(new ValuesOperator(List.of(List.of())), List.of(), List.of()); assertEquals(physicalPlanTree, logicalPlanTree.accept(implementor, null)); } + + @Test + public void visitLimit_support_return_takeOrdered() { + // replace SortOperator + LimitOperator with TakeOrderedOperator + Pair sort = + ImmutablePair.of(Sort.SortOption.DEFAULT_ASC, ref("a", INTEGER)); + var logicalValues = values(emptyList()); + var logicalSort = sort(logicalValues, sort); + var logicalLimit = limit(logicalSort, 10, 5); + PhysicalPlan physicalPlanTree = + PhysicalPlanDSL.takeOrdered(PhysicalPlanDSL.values(emptyList()), 10, 5, sort); + assertEquals(physicalPlanTree, logicalLimit.accept(implementor, null)); + + // don't replace if LimitOperator's child is not SortOperator + Pair newEvalField = + ImmutablePair.of(ref("name1", STRING), ref("name", STRING)); + var logicalEval = eval(logicalSort, newEvalField); + logicalLimit = limit(logicalEval, 10, 5); + physicalPlanTree = + PhysicalPlanDSL.limit( + PhysicalPlanDSL.eval( + PhysicalPlanDSL.sort(PhysicalPlanDSL.values(emptyList()), sort), newEvalField), + 10, + 5); + assertEquals(physicalPlanTree, logicalLimit.accept(implementor, null)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index c91ae8787c..17fb128ace 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -22,6 +22,7 @@ import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.remove; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.rename; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.sort; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.takeOrdered; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window; @@ -117,6 +118,8 @@ public static Stream getPhysicalPlanForTest() { PhysicalPlan sort = sort(plan, Pair.of(SortOption.DEFAULT_ASC, ref)); + PhysicalPlan takeOrdered = takeOrdered(plan, 1, 1, Pair.of(SortOption.DEFAULT_ASC, ref)); + PhysicalPlan dedupe = dedupe(plan, ref); PhysicalPlan values = values(emptyList()); @@ -140,6 +143,7 @@ public static Stream getPhysicalPlanForTest() { Arguments.of(remove, "remove"), Arguments.of(eval, "eval"), Arguments.of(sort, "sort"), + Arguments.of(takeOrdered, "takeOrdered"), Arguments.of(dedupe, "dedupe"), Arguments.of(values, "values"), Arguments.of(rareTopN, "rareTopN"), diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TakeOrderedOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TakeOrderedOperatorTest.java new file mode 100644 index 0000000000..f2fcb84910 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TakeOrderedOperatorTest.java @@ -0,0 +1,607 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.limit; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.sort; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.takeOrdered; + +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import lombok.Getter; +import lombok.Setter; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.Expression; + +/** + * To make sure {@link TakeOrderedOperator} can replace {@link SortOperator} + {@link + * LimitOperator}, this UT will replica all tests in {@link SortOperatorTest} and add more test + * cases on different limit and offset. + */ +@ExtendWith(MockitoExtension.class) +class TakeOrderedOperatorTest extends PhysicalPlanTestBase { + private static PhysicalPlan inputPlan; + + @Getter + @Setter + private static class Wrapper { + Iterator iterator = Collections.emptyIterator(); + } + + private static final Wrapper wrapper = new Wrapper(); + + @BeforeAll + public static void setUp() { + inputPlan = Mockito.mock(PhysicalPlan.class); + when(inputPlan.hasNext()) + .thenAnswer((InvocationOnMock invocation) -> wrapper.iterator.hasNext()); + when(inputPlan.next()).thenAnswer((InvocationOnMock invocation) -> wrapper.iterator.next()); + } + + /** + * construct the map which contain null value, because {@link ImmutableMap} doesn't support null + * value. + */ + private static final Map NULL_MAP = + new HashMap<>() { + { + put("size", 399); + put("response", null); + } + }; + + @Test + public void sort_one_field_asc() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 2, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 2, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_with_duplication() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 2, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 2, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_asc_with_null_value() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(NULL_MAP)); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_asc_with_missing_value() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399))); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 399)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 399)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_desc() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_DESC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit( + inputList, + 2, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 2, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_desc_with_null_value() { + List inputList = + Arrays.asList( + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_DESC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(NULL_MAP)); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(NULL_MAP)); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_with_duplicate_value() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + List> sortList = + List.of(Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 3, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_two_fields_both_asc() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(NULL_MAP)); + + List> sortList = + List.of( + Pair.of(SortOption.DEFAULT_ASC, ref("size", INTEGER)), + Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 5, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 1, + sortList, + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_two_fields_both_desc() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(NULL_MAP)); + + List> sortList = + List.of( + Pair.of(SortOption.DEFAULT_DESC, ref("size", INTEGER)), + Pair.of(SortOption.DEFAULT_DESC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 5, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(NULL_MAP)); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_two_fields_asc_and_desc() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(NULL_MAP)); + + List> sortList = + List.of( + Pair.of(SortOption.DEFAULT_ASC, ref("size", INTEGER)), + Pair.of(SortOption.DEFAULT_DESC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 5, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(NULL_MAP)); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 1, + sortList, + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 499, "response", 404))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_two_fields_desc_and_asc() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(NULL_MAP)); + + List> sortList = + List.of( + Pair.of(SortOption.DEFAULT_DESC, ref("size", INTEGER)), + Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + + test_takeOrdered_with_sort_limit( + inputList, + 5, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 0, + sortList, + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503))); + + test_takeOrdered_with_sort_limit( + inputList, + 4, + 1, + sortList, + tupleValue(NULL_MAP), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(ImmutableMap.of("size", 320, "response", 200))); + + test_takeOrdered_with_sort_limit(inputList, 0, 1, sortList); + } + + @Test + public void sort_one_field_without_input() { + wrapper.setIterator(Collections.emptyIterator()); + assertEquals( + 0, + execute( + takeOrdered( + inputPlan, 1, 0, Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER)))) + .size()); + } + + @Test + public void offset_exceeds_row_number() { + List inputList = + Arrays.asList( + tupleValue(ImmutableMap.of("size", 499, "response", 404)), + tupleValue(ImmutableMap.of("size", 320, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 200)), + tupleValue(ImmutableMap.of("size", 399, "response", 503)), + tupleValue(NULL_MAP)); + + wrapper.setIterator(inputList.iterator()); + PhysicalPlan plan = + takeOrdered(inputPlan, 1, 6, Pair.of(SortOption.DEFAULT_ASC, ref("response", INTEGER))); + List result = execute(plan); + assertEquals(0, result.size()); + } + + private void test_takeOrdered_with_sort_limit( + List inputList, + int limit, + int offset, + List> sortList, + ExprValue... expected) { + wrapper.setIterator(inputList.iterator()); + List compareResult = + execute(limit(sort(inputPlan, sortList.toArray(Pair[]::new)), limit, offset)); + wrapper.setIterator(inputList.iterator()); + List testResult = + execute(takeOrdered(inputPlan, limit, offset, sortList.toArray(Pair[]::new))); + assertEquals(compareResult, testResult); + if (expected.length == 0) { + assertEquals(0, testResult.size()); + } else { + assertThat(testResult, contains(expected)); + } + } +} diff --git a/docs/user/optimization/optimization.rst b/docs/user/optimization/optimization.rst index 835fe96eba..454c9ec066 100644 --- a/docs/user/optimization/optimization.rst +++ b/docs/user/optimization/optimization.rst @@ -237,31 +237,24 @@ If sort that includes expression, which cannot be merged into query DSL, also ex }, "children": [ { - "name": "LimitOperator", + "name": "TakeOrderedOperator", "description": { "limit": 10, - "offset": 0 + "offset": 0, + "sortList": { + "abs(age)": { + "sortOrder": "ASC", + "nullOrder": "NULL_FIRST" + } + } }, "children": [ { - "name": "SortOperator", + "name": "OpenSearchIndexScan", "description": { - "sortList": { - "abs(age)": { - "sortOrder": "ASC", - "nullOrder": "NULL_FIRST" - } - } + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\"}, searchDone=false)" }, - "children": [ - { - "name": "OpenSearchIndexScan", - "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\"}, searchDone=false)" - }, - "children": [] - } - ] + "children": [] } ] } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 0905c2f4b4..28827b0a54 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -23,6 +23,7 @@ import org.opensearch.sql.planner.physical.RemoveOperator; import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; +import org.opensearch.sql.planner.physical.TakeOrderedOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.TableScanOperator; @@ -130,6 +131,17 @@ public PhysicalPlan visitSort(SortOperator node, Object context) { return doProtect(new SortOperator(visitInput(node.getInput(), context), node.getSortList())); } + /** Decorate with {@link ResourceMonitorPlan}. */ + @Override + public PhysicalPlan visitTakeOrdered(TakeOrderedOperator node, Object context) { + return doProtect( + new TakeOrderedOperator( + visitInput(node.getInput(), context), + node.getLimit(), + node.getOffset(), + node.getSortList())); + } + /** * Values are a sequence of rows of literal value in memory which doesn't need memory protection. */ diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index b2dc042110..5cd11c6cd4 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -68,6 +68,7 @@ import org.opensearch.sql.planner.physical.NestedOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; +import org.opensearch.sql.planner.physical.TakeOrderedOperator; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -306,6 +307,16 @@ void do_nothing_with_CursorCloseOperator_and_children() { verify(child, never()).accept(executionProtector, null); } + @Test + public void test_visitTakeOrdered() { + Pair sort = + ImmutablePair.of(Sort.SortOption.DEFAULT_ASC, ref("a", INTEGER)); + TakeOrderedOperator takeOrdered = + PhysicalPlanDSL.takeOrdered(PhysicalPlanDSL.values(emptyList()), 10, 5, sort); + assertEquals( + resourceMonitor(takeOrdered), executionProtector.visitTakeOrdered(takeOrdered, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } From 7e73f124e2dc63a69750c8034558d6870c46ccf9 Mon Sep 17 00:00:00 2001 From: Manasvini B Suryanarayana Date: Mon, 5 Aug 2024 11:42:42 -0700 Subject: [PATCH 10/12] Test utils update to fix IT tests for serverless (#2869) Signed-off-by: Manasvini B S --- .../org/opensearch/sql/sql/AggregationIT.java | 9 ++- .../org/opensearch/sql/sql/ScoreQueryIT.java | 8 +-- .../org/opensearch/sql/util/MatcherUtils.java | 31 ++++++++++ .../org/opensearch/sql/util/TestUtils.java | 62 +++++++++++++++++++ 4 files changed, 104 insertions(+), 6 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 29358bd1c3..901c2a41e4 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -9,14 +9,17 @@ import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verify; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; import static org.opensearch.sql.util.MatcherUtils.verifySome; import static org.opensearch.sql.util.TestUtils.getResponseBody; +import static org.opensearch.sql.util.TestUtils.roundOfResponse; import java.io.IOException; import java.util.List; import java.util.Locale; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.opensearch.client.Request; @@ -396,8 +399,9 @@ public void testMaxDoublePushedDown() throws IOException { @Test public void testAvgDoublePushedDown() throws IOException { var response = executeQuery(String.format("SELECT avg(num3)" + " from %s", TEST_INDEX_CALCS)); + JSONArray responseJSON = roundOfResponse(response.getJSONArray("datarows")); verifySchema(response, schema("avg(num3)", null, "double")); - verifyDataRows(response, rows(-6.12D)); + verify(responseJSON, rows(-6.12D)); } @Test @@ -456,8 +460,9 @@ public void testAvgDoubleInMemory() throws IOException { executeQuery( String.format( "SELECT avg(num3)" + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + JSONArray roundOfResponse = roundOfResponse(response.getJSONArray("datarows")); verifySchema(response, schema("avg(num3) OVER(PARTITION BY datetime1)", null, "double")); - verifySome(response.getJSONArray("datarows"), rows(-6.12D)); + verifySome(roundOfResponse, rows(-6.12D)); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java index 6616746d99..a1f71dcf6c 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java @@ -8,7 +8,7 @@ import static org.hamcrest.Matchers.containsString; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; -import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataAddressRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; @@ -123,8 +123,7 @@ public void scoreQueryTest() throws IOException { TestsConstants.TEST_INDEX_ACCOUNT), "jdbc")); verifySchema(result, schema("address", null, "text"), schema("_score", null, "float")); - verifyDataRows( - result, rows("154 Douglass Street", 650.1515), rows("565 Hall Street", 3.2507575)); + verifyDataAddressRows(result, rows("154 Douglass Street"), rows("565 Hall Street")); } @Test @@ -154,7 +153,8 @@ public void scoreQueryDefaultBoostQueryTest() throws IOException { + "where score(matchQuery(address, 'Powell')) order by _score desc limit 2", TestsConstants.TEST_INDEX_ACCOUNT), "jdbc")); + verifySchema(result, schema("address", null, "text"), schema("_score", null, "float")); - verifyDataRows(result, rows("305 Powell Street", 6.501515)); + verifyDataAddressRows(result, rows("305 Powell Street")); } } diff --git a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java index 26a60cb4e5..d4db502407 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java @@ -159,6 +159,11 @@ public static void verifyDataRows(JSONObject response, Matcher... mat verify(response.getJSONArray("datarows"), matchers); } + @SafeVarargs + public static void verifyDataAddressRows(JSONObject response, Matcher... matchers) { + verifyAddressRow(response.getJSONArray("datarows"), matchers); + } + @SafeVarargs public static void verifyColumn(JSONObject response, Matcher... matchers) { verify(response.getJSONArray("schema"), matchers); @@ -183,6 +188,32 @@ public static void verify(JSONArray array, Matcher... matchers) { assertThat(objects, containsInAnyOrder(matchers)); } + // TODO: this is temporary fix for fixing serverless tests to pass as it creates multiple shards + // leading to score differences. + public static void verifyAddressRow(JSONArray array, Matcher... matchers) { + // List to store the processed elements from the JSONArray + List objects = new ArrayList<>(); + + // Iterate through each element in the JSONArray + array + .iterator() + .forEachRemaining( + o -> { + // Check if o is a JSONArray with exactly 2 elements + if (o instanceof JSONArray && ((JSONArray) o).length() == 2) { + // Check if the second element is a BigDecimal/_score value + if (((JSONArray) o).get(1) instanceof BigDecimal) { + // Remove the _score element from response data rows to skip the assertion as it + // will be different when compared against multiple shards + ((JSONArray) o).remove(1); + } + } + objects.add((T) o); + }); + assertEquals(matchers.length, objects.size()); + assertThat(objects, containsInAnyOrder(matchers)); + } + @SafeVarargs @SuppressWarnings("unchecked") public static void verifyInOrder(JSONArray array, Matcher... matchers) { diff --git a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java index 589fb1f9ae..bce83e7ccb 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java @@ -17,6 +17,8 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -27,6 +29,7 @@ import java.util.List; import java.util.Locale; import java.util.stream.Collectors; +import org.json.JSONArray; import org.json.JSONObject; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -34,6 +37,7 @@ import org.opensearch.client.Client; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.client.RestClient; import org.opensearch.common.xcontent.XContentType; import org.opensearch.sql.legacy.cursor.CursorType; @@ -123,10 +127,45 @@ public static Response performRequest(RestClient client, Request request) { } return response; } catch (IOException e) { + if (isRefreshPolicyError(e)) { + try { + return retryWithoutRefreshPolicy(request, client); + } catch (IOException ex) { + throw new IllegalStateException("Failed to perform request without refresh policy.", ex); + } + } throw new IllegalStateException("Failed to perform request", e); } } + /** + * Checks if the IOException is due to an unsupported refresh policy. + * + * @param e The IOException to check. + * @return true if the exception is due to a refresh policy error, false otherwise. + */ + private static boolean isRefreshPolicyError(IOException e) { + return e instanceof ResponseException + && ((ResponseException) e).getResponse().getStatusLine().getStatusCode() == 400 + && e.getMessage().contains("true refresh policy is not supported."); + } + + /** + * Attempts to perform the request without the refresh policy. + * + * @param request The original request. + * @param client client connection + * @return The response after retrying the request. + * @throws IOException If the request fails. + */ + private static Response retryWithoutRefreshPolicy(Request request, RestClient client) + throws IOException { + Request req = + new Request(request.getMethod(), request.getEndpoint().replaceAll("refresh=true", "")); + req.setEntity(request.getEntity()); + return client.performRequest(req); + } + public static String getAccountIndexMapping() { return "{ \"mappings\": {" + " \"properties\": {\n" @@ -772,6 +811,29 @@ public static String getResponseBody(Response response, boolean retainNewLines) return sb.toString(); } + // TODO: this is temporary fix for fixing serverless tests to pass with 2 digit precision value + public static JSONArray roundOfResponse(JSONArray array) { + JSONArray responseJSON = new JSONArray(); + array + .iterator() + .forEachRemaining( + o -> { + JSONArray jsonArray = new JSONArray(); + ((JSONArray) o) + .iterator() + .forEachRemaining( + i -> { + if (i instanceof BigDecimal) { + jsonArray.put(((BigDecimal) i).setScale(2, RoundingMode.HALF_UP)); + } else { + jsonArray.put(i); + } + }); + responseJSON.put(jsonArray); + }); + return responseJSON; + } + public static String fileToString( final String filePathFromProjectRoot, final boolean removeNewLines) throws IOException { From 7022a0946015c41cbd02888e3c712c6b67123982 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 7 Aug 2024 10:50:00 +0800 Subject: [PATCH 11/12] Correct regular expression range (#2836) Signed-off-by: Lantao Jin --- .../java/org/opensearch/sql/common/grok/GrokCompiler.java | 2 +- .../main/java/org/opensearch/sql/common/grok/GrokUtils.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java b/common/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java index aba96ad4cb..05fdbd57ed 100644 --- a/common/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java +++ b/common/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java @@ -29,7 +29,7 @@ public class GrokCompiler implements Serializable { // We don't want \n and commented line - private static final Pattern patternLinePattern = Pattern.compile("^([A-z0-9_]+)\\s+(.*)$"); + private static final Pattern patternLinePattern = Pattern.compile("^([a-zA-Z0-9_]+)\\s+(.*)$"); /** {@code Grok} patterns definitions. */ private final Map grokPatternDefinitions = new HashMap<>(); diff --git a/common/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java b/common/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java index 4b145bbbe8..2a309bba8f 100644 --- a/common/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java +++ b/common/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java @@ -24,8 +24,8 @@ public class GrokUtils { Pattern.compile( "%\\{" + "(?" - + "(?[A-z0-9]+)" - + "(?::(?[A-z0-9_:;,\\-\\/\\s\\.']+))?" + + "(?[a-zA-Z0-9_]+)" + + "(?::(?[a-zA-Z0-9_:;,\\-\\/\\s\\.']+))?" + ")" + "(?:=(?" + "(?:" From 4a735ea9bca6313a06616a99a944b8e512baeb66 Mon Sep 17 00:00:00 2001 From: qianheng Date: Fri, 9 Aug 2024 12:08:48 +0800 Subject: [PATCH 12/12] Push down limit through eval (#2876) --- .../optimizer/LogicalPlanOptimizer.java | 2 + .../planner/optimizer/pattern/Patterns.java | 5 ++ .../planner/optimizer/rule/EvalPushDown.java | 82 +++++++++++++++++++ .../optimizer/LogicalPlanOptimizerTest.java | 23 ++++++ .../org/opensearch/sql/ppl/ExplainIT.java | 13 +++ .../ppl/explain_limit_push.json | 27 ++++++ 6 files changed, 152 insertions(+) create mode 100644 core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java create mode 100644 integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index 5c115f0db8..e805b0dea5 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -12,6 +12,7 @@ import java.util.List; import java.util.stream.Collectors; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.rule.EvalPushDown; import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter; import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; @@ -46,6 +47,7 @@ public static LogicalPlanOptimizer create() { */ new MergeFilterAndFilter(), new PushFilterUnderSort(), + EvalPushDown.PUSH_DOWN_LIMIT, /* * Phase 2: Transformations that rely on data source push down capability */ diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java index ee4e9a20cc..ef2607e018 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java @@ -12,6 +12,7 @@ import java.util.Optional; import lombok.experimental.UtilityClass; import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; @@ -63,6 +64,10 @@ public static Pattern project(Pattern return Pattern.typeOf(LogicalProject.class).with(source(pattern)); } + public static Pattern evalCapture() { + return Pattern.typeOf(LogicalEval.class).capturedAs(Capture.newCapture()); + } + /** Pattern for {@link TableScanBuilder} and capture it meanwhile. */ public static Pattern scanBuilder() { return Pattern.typeOf(TableScanBuilder.class).capturedAs(Capture.newCapture()); diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java new file mode 100644 index 0000000000..17eaed0e8c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule; + +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.evalCapture; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.limit; +import static org.opensearch.sql.planner.optimizer.rule.EvalPushDown.EvalPushDownBuilder.match; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.matching.pattern.CapturePattern; +import com.facebook.presto.matching.pattern.WithPattern; +import java.util.List; +import java.util.function.BiFunction; +import lombok.Getter; +import lombok.experimental.Accessors; +import org.opensearch.sql.planner.logical.LogicalEval; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.Rule; + +/** + * Rule template for all rules related to push down logical plans under eval, so these plans can + * avoid blocking by eval and may have chances to be pushed down into table scan by rules in {@link + * org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown}. + */ +public class EvalPushDown implements Rule { + + // TODO: Add more rules to push down sort and project + /** Push down optimize rule for limit operator. Transform `limit -> eval` to `eval -> limit` */ + public static final Rule PUSH_DOWN_LIMIT = + match(limit(evalCapture())) + .apply( + (limit, logicalEval) -> { + List child = logicalEval.getChild(); + limit.replaceChildPlans(child); + logicalEval.replaceChildPlans(List.of(limit)); + return logicalEval; + }); + + private final Capture capture; + + @Accessors(fluent = true) + @Getter + private final Pattern pattern; + + private final BiFunction pushDownFunction; + + @SuppressWarnings("unchecked") + public EvalPushDown( + WithPattern pattern, BiFunction pushDownFunction) { + this.pattern = pattern; + this.capture = ((CapturePattern) pattern.getPattern()).capture(); + this.pushDownFunction = pushDownFunction; + } + + @Override + public LogicalPlan apply(T plan, Captures captures) { + LogicalEval logicalEval = captures.get(capture); + return pushDownFunction.apply(plan, logicalEval); + } + + static class EvalPushDownBuilder { + + private WithPattern pattern; + + public static EvalPushDown.EvalPushDownBuilder match( + Pattern pattern) { + EvalPushDown.EvalPushDownBuilder builder = new EvalPushDown.EvalPushDownBuilder<>(); + builder.pattern = (WithPattern) pattern; + return builder; + } + + public EvalPushDown apply(BiFunction pushDownFunction) { + return new EvalPushDown<>(pattern, pushDownFunction); + } + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index c25e415cfa..20996503b4 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.data.model.ExprValueUtils.longValue; import static org.opensearch.sql.data.type.ExprCoreType.*; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.eval; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; @@ -43,6 +44,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.planner.logical.LogicalPaginate; @@ -345,6 +347,27 @@ void table_scan_builder_support_offset_push_down_can_apply_its_rule() { assertEquals(project(tableScanBuilder), optimized); } + /** Limit - Eval --> Eval - Limit. */ + @Test + void push_limit_under_eval() { + Pair evalExpr = + Pair.of(DSL.ref("name1", STRING), DSL.ref("name", STRING)); + assertEquals( + eval(limit(tableScanBuilder, 10, 5), evalExpr), + optimize(limit(eval(relation("schema", table), evalExpr), 10, 5))); + } + + /** Limit - Eval - Scan --> Eval - Scan. */ + @Test + void push_limit_through_eval_into_scan() { + when(tableScanBuilder.pushDownLimit(any())).thenReturn(true); + Pair evalExpr = + Pair.of(DSL.ref("name1", STRING), DSL.ref("name", STRING)); + assertEquals( + eval(tableScanBuilder, evalExpr), + optimize(limit(eval(relation("schema", table), evalExpr), 10, 5))); + } + private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); return optimizer.optimize(plan); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index fce975ef92..c6b21e1605 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -76,6 +76,19 @@ public void testSortPushDownExplain() throws Exception { + "| fields age")); } + @Test + public void testLimitPushDownExplain() throws Exception { + String expected = loadFromFile("expectedOutput/ppl/explain_limit_push.json"); + + assertJsonEquals( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account" + + "| eval ageMinus = age - 30 " + + "| head 5 " + + "| fields ageMinus")); + } + String loadFromFile(String filename) throws Exception { URI uri = Resources.getResource(filename).toURI(); return new String(Files.readAllBytes(Paths.get(uri))); diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json new file mode 100644 index 0000000000..51a627ea4d --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json @@ -0,0 +1,27 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[ageMinus]" + }, + "children": [ + { + "name": "EvalOperator", + "description": { + "expressions": { + "ageMinus": "-(age, 30)" + } + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\"}, searchDone=false)" + }, + "children": [] + } + ] + } + ] + } +}