Skip to content

Commit

Permalink
Add pit to default cursor
Browse files Browse the repository at this point in the history
Signed-off-by: Rupal Mahajan <[email protected]>
  • Loading branch information
rupal-bq authored and manasvinibs committed Aug 14, 2024
1 parent 0b095f9 commit 682d1b7
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@

package org.opensearch.sql.legacy.cursor;

import static org.opensearch.core.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS;
import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Strings;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
Expand All @@ -16,8 +26,19 @@
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.Setter;
import lombok.SneakyThrows;
import org.json.JSONArray;
import org.json.JSONObject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.search.SearchModule;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.format.Schema;

/**
Expand All @@ -40,6 +61,10 @@ public class DefaultCursor implements Cursor {
private static final String SCROLL_ID = "s";
private static final String SCHEMA_COLUMNS = "c";
private static final String FIELD_ALIAS_MAP = "a";
private static final String PIT_ID = "p";
private static final String SEARCH_REQUEST = "r";
private static final String SORT_FIELDS = "h";
private static final ObjectMapper objectMapper = new ObjectMapper();

/**
* To get mappings for index to check if type is date needed for
Expand Down Expand Up @@ -70,42 +95,105 @@ public class DefaultCursor implements Cursor {
/** To get next batch of result */
private String scrollId;

/** To get Point In Time */
private String pitId;

/** To get next batch of result with search after api */
public SearchSourceBuilder searchSourceBuilder;

/** To get last sort values * */
private Object[] sortFields;

/** To reduce the number of rows left by fetchSize */
@NonNull private Integer fetchSize;

private Integer limit;

/**
* {@link NamedXContentRegistry} from {@link SearchModule} used for construct {@link QueryBuilder}
* from DSL query string.
*/
private static final NamedXContentRegistry xContentRegistry =
new NamedXContentRegistry(
new SearchModule(Settings.builder().build(), new ArrayList<>()).getNamedXContents());

@Override
public CursorType getType() {
return type;
}

@Override
public String generateCursorId() {
if (rowsLeft <= 0 || Strings.isNullOrEmpty(scrollId)) {
boolean isCursorValid =
LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)
? Strings.isNullOrEmpty(pitId)
: Strings.isNullOrEmpty(scrollId);
if (rowsLeft <= 0 || isCursorValid) {
return null;
}
JSONObject json = new JSONObject();
json.put(FETCH_SIZE, fetchSize);
json.put(ROWS_LEFT, rowsLeft);
json.put(INDEX_PATTERN, indexPattern);
json.put(SCROLL_ID, scrollId);
json.put(SCHEMA_COLUMNS, getSchemaAsJson());
json.put(FIELD_ALIAS_MAP, fieldAliasMap);
return String.format("%s:%s", type.getId(), encodeCursor(json));
if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) {
json.put(PIT_ID, pitId);
String sortFieldValue =
AccessController.doPrivileged(
(PrivilegedAction<String>)
() -> {
try {
return objectMapper.writeValueAsString(sortFields);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
});
json.put(SORT_FIELDS, sortFieldValue);
} else {
json.put(SCROLL_ID, scrollId);
}
return String.format("%s:%s", type.getId(), encodeCursor(json, searchSourceBuilder));
}

@SneakyThrows
public static DefaultCursor from(String cursorId) {
/**
* It is assumed that cursorId here is the second part of the original cursor passed by the
* client after removing first part which identifies cursor type
*/
JSONObject json = decodeCursor(cursorId);
String[] parts = cursorId.split(":::");
JSONObject json = decodeCursor(parts[0]);
DefaultCursor cursor = new DefaultCursor();
cursor.setFetchSize(json.getInt(FETCH_SIZE));
cursor.setRowsLeft(json.getLong(ROWS_LEFT));
cursor.setIndexPattern(json.getString(INDEX_PATTERN));
cursor.setScrollId(json.getString(SCROLL_ID));
if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) {
cursor.setPitId(json.getString(PIT_ID));

Object[] sortFieldValue =
AccessController.doPrivileged(
(PrivilegedAction<Object[]>)
() -> {
try {
return objectMapper.readValue(json.getString(SORT_FIELDS), Object[].class);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
});
cursor.setSortFields(sortFieldValue);

byte[] bytes = Base64.getDecoder().decode(parts[1]);
ByteArrayInputStream streamInput = new ByteArrayInputStream(bytes);
XContentParser parser =
XContentType.JSON
.xContent()
.createParser(xContentRegistry, IGNORE_DEPRECATIONS, streamInput);
SearchSourceBuilder sourceBuilder = SearchSourceBuilder.fromXContent(parser);
cursor.searchSourceBuilder = sourceBuilder;
} else {
cursor.setScrollId(json.getString(SCROLL_ID));
}
cursor.setColumns(getColumnsFromSchema(json.getJSONArray(SCHEMA_COLUMNS)));
cursor.setFieldAliasMap(fieldAliasMap(json.getJSONObject(FIELD_ALIAS_MAP)));

Expand All @@ -132,8 +220,18 @@ private JSONObject schemaEntry(String name, String alias, String type) {
return entry;
}

private static String encodeCursor(JSONObject cursorJson) {
return Base64.getEncoder().encodeToString(cursorJson.toString().getBytes());
@SneakyThrows
private static String encodeCursor(JSONObject cursorJson, SearchSourceBuilder sourceBuilder) {
String jsonBase64 = Base64.getEncoder().encodeToString(cursorJson.toString().getBytes());

ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
XContentBuilder builder = XContentFactory.jsonBuilder(outputStream);
sourceBuilder.toXContent(builder, null);
builder.close();

String searchRequestBase64 = Base64.getEncoder().encodeToString(outputStream.toByteArray());

return jsonBase64 + ":::" + searchRequestBase64;
}

private static JSONObject decodeCursor(String cursorId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.legacy.executor.cursor;

import static org.opensearch.core.rest.RestStatus.OK;
import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER;

import java.util.Map;
import org.apache.logging.log4j.LogManager;
Expand All @@ -18,8 +19,11 @@
import org.opensearch.rest.RestChannel;
import org.opensearch.sql.legacy.cursor.CursorType;
import org.opensearch.sql.legacy.cursor.DefaultCursor;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.pit.PointInTimeHandler;
import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl;
import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException;

public class CursorCloseExecutor implements CursorRestExecutor {
Expand Down Expand Up @@ -79,14 +83,25 @@ public String execute(Client client, Map<String, String> params) throws Exceptio
}

private String handleDefaultCursorCloseRequest(Client client, DefaultCursor cursor) {
String scrollId = cursor.getScrollId();
ClearScrollResponse clearScrollResponse =
client.prepareClearScroll().addScrollId(scrollId).get();
if (clearScrollResponse.isSucceeded()) {
return SUCCEEDED_TRUE;
if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) {
String pitId = cursor.getPitId();
PointInTimeHandler pit = new PointInTimeHandlerImpl(client, pitId);
if (pit.delete()) {
return SUCCEEDED_TRUE;
} else {
Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
return SUCCEEDED_FALSE;
}
} else {
Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
return SUCCEEDED_FALSE;
String scrollId = cursor.getScrollId();
ClearScrollResponse clearScrollResponse =
client.prepareClearScroll().addScrollId(scrollId).get();
if (clearScrollResponse.isSucceeded()) {
return SUCCEEDED_TRUE;
} else {
Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
return SUCCEEDED_FALSE;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.sql.legacy.executor.cursor;

import static org.opensearch.core.rest.RestStatus.OK;
import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE;
import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER;

import java.util.Arrays;
import java.util.Map;
Expand All @@ -14,21 +16,25 @@
import org.json.JSONException;
import org.opensearch.OpenSearchException;
import org.opensearch.action.search.ClearScrollResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.search.builder.PointInTimeBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.legacy.cursor.CursorType;
import org.opensearch.sql.legacy.cursor.DefaultCursor;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.Format;
import org.opensearch.sql.legacy.executor.format.Protocol;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.pit.PointInTimeHandler;
import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl;
import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException;

public class CursorResultExecutor implements CursorRestExecutor {
Expand Down Expand Up @@ -91,14 +97,27 @@ public String execute(Client client, Map<String, String> params) throws Exceptio
}

private String handleDefaultCursorRequest(Client client, DefaultCursor cursor) {
String previousScrollId = cursor.getScrollId();
LocalClusterState clusterState = LocalClusterState.state();
TimeValue scrollTimeout = clusterState.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE);
SearchResponse scrollResponse =
client.prepareSearchScroll(previousScrollId).setScroll(scrollTimeout).get();
TimeValue paginationTimeout = clusterState.getSettingValue(SQL_CURSOR_KEEP_ALIVE);

SearchResponse scrollResponse = null;
if (clusterState.getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) {
String pitId = cursor.getPitId();
SearchSourceBuilder source = cursor.getSearchSourceBuilder();
source.searchAfter(cursor.getSortFields());
source.pointInTimeBuilder(new PointInTimeBuilder(pitId));
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(source);
scrollResponse = client.search(searchRequest).actionGet();
} else {
String previousScrollId = cursor.getScrollId();
scrollResponse =
client.prepareSearchScroll(previousScrollId).setScroll(paginationTimeout).get();
}
SearchHits searchHits = scrollResponse.getHits();
SearchHit[] searchHitArray = searchHits.getHits();
String newScrollId = scrollResponse.getScrollId();
String newPitId = scrollResponse.pointInTimeId();

int rowsLeft = (int) cursor.getRowsLeft();
int fetch = cursor.getFetchSize();
Expand All @@ -124,16 +143,35 @@ private String handleDefaultCursorRequest(Client client, DefaultCursor cursor) {

if (rowsLeft <= 0) {
/** Clear the scroll context on last page */
ClearScrollResponse clearScrollResponse =
client.prepareClearScroll().addScrollId(newScrollId).get();
if (!clearScrollResponse.isSucceeded()) {
Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
LOG.info("Error closing the cursor context {} ", newScrollId);
if (newScrollId != null) {
ClearScrollResponse clearScrollResponse =
client.prepareClearScroll().addScrollId(newScrollId).get();
if (!clearScrollResponse.isSucceeded()) {
Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
LOG.info("Error closing the cursor context {} ", newScrollId);
}
}
if (newPitId != null) {
PointInTimeHandler pit = new PointInTimeHandlerImpl(client, newPitId);
if (!pit.delete()) {
Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment();
LOG.info("Error deleting point in time {} ", newPitId);
}
}
}

cursor.setRowsLeft(rowsLeft);
cursor.setScrollId(newScrollId);
if (clusterState.getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) {
cursor.setPitId(newPitId);
cursor.setSearchSourceBuilder(cursor.getSearchSourceBuilder());
cursor.setSortFields(
scrollResponse
.getHits()
.getAt(scrollResponse.getHits().getHits().length - 1)
.getSortValues());
} else {
cursor.setScrollId(newScrollId);
}
Protocol protocol = new Protocol(client, searchHits, format.name().toLowerCase(), cursor);
return protocol.cursorFormat();
}
Expand Down
Loading

0 comments on commit 682d1b7

Please sign in to comment.