Skip to content

Commit

Permalink
Merge branch 'main' into feature/pit
Browse files Browse the repository at this point in the history
  • Loading branch information
manasvinibs committed Aug 14, 2024
2 parents 368b6a8 + 4a735ea commit 2773b1b
Show file tree
Hide file tree
Showing 110 changed files with 3,328 additions and 370 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ gen
.worktrees
http-client.env.json
/doctest/sql-cli/
/doctest/opensearch-job-scheduler/
.factorypath
17 changes: 11 additions & 6 deletions async-query-core/src/main/antlr/SqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ compoundStatement
;

setStatementWithOptionalVarKeyword
: SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword
| SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
: SET variable? assignmentList #setVariableWithOptionalKeyword
| SET variable? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword
;

Expand Down Expand Up @@ -215,9 +215,9 @@ statement
routineCharacteristics
RETURN (query | expression) #createUserDefinedFunction
| DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction
| DECLARE (OR REPLACE)? VARIABLE?
| DECLARE (OR REPLACE)? variable?
identifierReference dataType? variableDefaultExpression? #createVariable
| DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable
| DROP TEMPORARY variable (IF EXISTS)? identifierReference #dropVariable
| EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)?
(statement|setResetStatement) #explain
| SHOW TABLES ((FROM | IN) identifierReference)?
Expand Down Expand Up @@ -272,8 +272,8 @@ setResetStatement
| SET TIME ZONE interval #setTimeZone
| SET TIME ZONE timezone #setTimeZone
| SET TIME ZONE .*? #setTimeZone
| SET (VARIABLE | VAR) assignmentList #setVariable
| SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
| SET variable assignmentList #setVariable
| SET variable LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
LEFT_PAREN query RIGHT_PAREN #setVariable
| SET configKey EQ configValue #setQuotedConfiguration
| SET configKey (EQ .*?)? #setConfiguration
Expand Down Expand Up @@ -438,6 +438,11 @@ namespaces
| SCHEMAS
;

variable
: VARIABLE
| VAR
;

describeFuncName
: identifierReference
| stringLit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ CreateAsyncQueryResponse createAsyncQuery(
* @param queryId queryId.
* @return {@link String} cancelledQueryId.
*/
String cancelQuery(String queryId);
String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) {
}

@Override
public String cancelQuery(String queryId) {
public String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext) {
Optional<AsyncQueryJobMetadata> asyncQueryJobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (asyncQueryJobMetadata.isPresent()) {
return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get());
return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext);
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.spark.asyncquery.model;

import org.opensearch.sql.datasource.RequestContext;

/** Context interface to provide additional request related information */
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
public interface AsyncQueryRequestContext extends RequestContext {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.amazonaws.services.emrserverless.model.JobRunState;
import org.json.JSONObject;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -54,7 +55,9 @@ protected abstract JSONObject getResponseFromResultIndex(
protected abstract JSONObject getResponseFromExecutor(
AsyncQueryJobMetadata asyncQueryJobMetadata);

public abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata);
public abstract String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

public abstract DispatchQueryResponse submit(
DispatchQueryRequest request, DispatchQueryContext context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
Expand Down Expand Up @@ -61,7 +62,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
emrServerlessClient.cancelJobRun(
asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId(), false);
return asyncQueryJobMetadata.getQueryId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@

package org.opensearch.sql.spark.dispatcher;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.utils.IDUtils;

/** Generates QueryId by embedding Datasource name and random UUID */
public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider {

@Override
public String getQueryId(DispatchQueryRequest dispatchQueryRequest) {
public String getQueryId(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
return IDUtils.encode(dispatchQueryRequest.getDatasource());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ public DispatchQueryResponse submit(
long startTime = System.currentTimeMillis();
try {
IndexQueryDetails indexDetails = context.getIndexQueryDetails();
FlintIndexMetadata indexMetadata = getFlintIndexMetadata(indexDetails);
FlintIndexMetadata indexMetadata =
getFlintIndexMetadata(indexDetails, context.getAsyncQueryRequestContext());

getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata);
getIndexOp(dispatchQueryRequest, indexDetails)
.apply(indexMetadata, context.getAsyncQueryRequestContext());

String asyncQueryId =
storeIndexDMLResult(
Expand Down Expand Up @@ -146,9 +148,11 @@ private FlintIndexOp getIndexOp(
}
}

private FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexDetails) {
private FlintIndexMetadata getFlintIndexMetadata(
IndexQueryDetails indexDetails, AsyncQueryRequestContext asyncQueryRequestContext) {
Map<String, FlintIndexMetadata> indexMetadataMap =
flintIndexMetadataService.getFlintIndexMetadata(indexDetails.openSearchIndexName());
flintIndexMetadataService.getFlintIndexMetadata(
indexDetails.openSearchIndexName(), asyncQueryRequestContext);
if (!indexMetadataMap.containsKey(indexDetails.openSearchIndexName())) {
throw new IllegalStateException(
String.format(
Expand All @@ -174,7 +178,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
throw new IllegalArgumentException("can't cancel index DML query");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -71,7 +72,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
String queryId = asyncQueryJobMetadata.getQueryId();
getStatementByQueryId(
asyncQueryJobMetadata.getSessionId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package org.opensearch.sql.spark.dispatcher;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;

/** Interface for extension point to specify queryId. Called when new query is executed. */
public interface QueryIdProvider {
String getQueryId(DispatchQueryRequest dispatchQueryRequest);
String getQueryId(
DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Map;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
Expand Down Expand Up @@ -51,18 +52,21 @@ public RefreshQueryHandler(
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
String datasourceName = asyncQueryJobMetadata.getDatasourceName();
Map<String, FlintIndexMetadata> indexMetadataMap =
flintIndexMetadataService.getFlintIndexMetadata(asyncQueryJobMetadata.getIndexName());
flintIndexMetadataService.getFlintIndexMetadata(
asyncQueryJobMetadata.getIndexName(), asyncQueryRequestContext);
if (!indexMetadataMap.containsKey(asyncQueryJobMetadata.getIndexName())) {
throw new IllegalStateException(
String.format(
"Couldn't fetch flint index: %s details", asyncQueryJobMetadata.getIndexName()));
}
FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName());
FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName);
jobCancelOp.apply(indexMetadata);
jobCancelOp.apply(indexMetadata, asyncQueryRequestContext);
return asyncQueryJobMetadata.getQueryId();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public DispatchQueryResponse dispatch(
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
dispatchQueryRequest.getDatasource(), asyncQueryRequestContext);

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
String query = dispatchQueryRequest.getQuery();
Expand All @@ -69,7 +69,8 @@ private DispatchQueryResponse handleFlintExtensionQuery(
DataSourceMetadata dataSourceMetadata) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
getDefaultDispatchContextBuilder(
dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
Expand All @@ -84,7 +85,8 @@ private DispatchQueryResponse handleDefaultQuery(
DataSourceMetadata dataSourceMetadata) {

DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
getDefaultDispatchContextBuilder(
dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

Expand All @@ -93,11 +95,13 @@ private DispatchQueryResponse handleDefaultQuery(
}

private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
DispatchQueryRequest dispatchQueryRequest,
DataSourceMetadata dataSourceMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
return DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(queryIdProvider.getQueryId(dispatchQueryRequest));
.queryId(queryIdProvider.getQueryId(dispatchQueryRequest, asyncQueryRequestContext));
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
Expand Down Expand Up @@ -158,9 +162,11 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata)
.getQueryResponse(asyncQueryJobMetadata);
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.cancelJob(asyncQueryJobMetadata);
.cancelJob(asyncQueryJobMetadata, asyncQueryRequestContext);
}

private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.Map;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
Expand Down Expand Up @@ -46,7 +47,9 @@ public StreamingQueryHandler(
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public String cancelJob(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
throw new IllegalArgumentException(
"can't cancel index DML query, using ALTER auto_refresh=off statement to stop job, using"
+ " VACUUM statement to stop job and delete data");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.flint;

import java.util.Map;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;

/** Interface for FlintIndexMetadataReader */
Expand All @@ -15,16 +16,22 @@ public interface FlintIndexMetadataService {
* Retrieves a map of {@link FlintIndexMetadata} instances matching the specified index pattern.
*
* @param indexPattern indexPattern.
* @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService
* @return A map of {@link FlintIndexMetadata} instances against indexName, each providing
* metadata access for a matched index. Returns an empty list if no indices match the pattern.
*/
Map<String, FlintIndexMetadata> getFlintIndexMetadata(String indexPattern);
Map<String, FlintIndexMetadata> getFlintIndexMetadata(
String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext);

/**
* Performs validation and updates flint index to manual refresh.
*
* @param indexName indexName.
* @param flintIndexOptions flintIndexOptions.
* @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService
*/
void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions);
void updateIndexToManualRefresh(
String indexName,
FlintIndexOptions flintIndexOptions,
AsyncQueryRequestContext asyncQueryRequestContext);
}
Loading

0 comments on commit 2773b1b

Please sign in to comment.