diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 822bb67924..7f76d998d3 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -56,6 +56,11 @@ jobs: env: VERSION: ${{ matrix.version }} + # Run neural query integration tests separately as they use a significant amount of memory on their own + - run: "./build.sh integrate ${{ matrix.version }} neuralquery random:test_only_one --report" + name: Neural Query Integration Tests + working-directory: client + - name: Upload test report if: failure() uses: actions/upload-artifact@v3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 003f1725ba..570db9772b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Added - Added support for `MinScore` on `ScriptScoreQuery` ([#624](https://github.com/opensearch-project/opensearch-net/pull/624)) +- Added support for the `neural` query type and `text_embedding` ingest processor type ([#636](https://github.com/opensearch-project/opensearch-net/pull/636)) - Added support for the `Cat.PitSegments` and `Cat.SegmentReplication` APIs ([#527](https://github.com/opensearch-project/opensearch-net/pull/527)) - Added support for serializing the `DateOnly` and `TimeOnly` types ([#734](https://github.com/opensearch-project/opensearch-net/pull/734)) - Added support for the `Ext` parameter on `SearchRequest` ([#738](https://github.com/opensearch-project/opensearch-net/pull/738)) @@ -206,4 +207,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) [1.6.0]: https://github.com/opensearch-project/opensearch-net/compare/v1.5.0...v1.6.0 [1.5.0]: https://github.com/opensearch-project/opensearch-net/compare/v1.4.0...v1.5.0 [1.4.0]: https://github.com/opensearch-project/opensearch-net/compare/v1.3.0...v1.4.0 -[1.3.0]: https://github.com/opensearch-project/opensearch-net/compare/v1.2.0...v1.3.0 \ No newline at end of file +[1.3.0]: https://github.com/opensearch-project/opensearch-net/compare/v1.2.0...v1.3.0 diff --git a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs index 644d1844c8..3e7cec92f8 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs @@ -75,7 +75,7 @@ public EphemeralClusterConfiguration(OpenSearchVersion version, ClusterFeatures /// This can be useful to fail early when subsequent operations are relying on installation /// succeeding. /// </summary> - public bool ValidatePluginsToInstall { get; } = true; + public bool ValidatePluginsToInstall { get; set; } = true; public bool EnableSsl => Features.HasFlag(ClusterFeatures.SSL); diff --git a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs index 3225476353..57a6982541 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs @@ -61,30 +61,29 @@ public override void Run(IEphemeralCluster<EphemeralClusterConfiguration> cluste .Where(p => !p.IsValid(v)) .Select(p => p.SubProductName).ToList(); if (invalidPlugins.Any()) - throw new OpenSearchCleanExitException( - $"Can not install the following plugins for version {v}: {string.Join(", ", invalidPlugins)} "); - } + { + throw new OpenSearchCleanExitException( + $"Can not install the following plugins for version {v}: {string.Join(", ", invalidPlugins)} "); + } + } foreach (var plugin in requiredPlugins) { - var includedByDefault = plugin.IsIncludedOutOfTheBox(v); - if (includedByDefault) + if (plugin.IsIncludedOutOfTheBox(v)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] shipped OOTB as of: {{{plugin.ShippedByDefaultAsOf}}}"); continue; } - var validForCurrentVersion = plugin.IsValid(v); - if (!validForCurrentVersion) + if (!plugin.IsValid(v)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] not valid for version: {{{v}}}"); continue; } - var alreadyInstalled = AlreadyInstalled(fs, plugin.SubProductName); - if (alreadyInstalled) + if (AlreadyInstalled(fs, plugin.SubProductName)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] already installed"); @@ -92,7 +91,7 @@ public override void Run(IEphemeralCluster<EphemeralClusterConfiguration> cluste } cluster.Writer?.WriteDiagnostic( - $"{{{nameof(InstallPlugins)}}} attempting install [{plugin.SubProductName}] as it's not OOTB: {{{plugin.ShippedByDefaultAsOf}}} and valid for {v}: {{{plugin.IsValid(v)}}}"); + $"{{{nameof(InstallPlugins)}}} attempting install [{plugin.SubProductName}] as it's not OOTB: {{{plugin.ShippedByDefaultAsOf}}} and valid for {v}"); var homeConfigPath = Path.Combine(fs.OpenSearchHome, "config"); diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs index 9e84b0ee01..e7ea9bdf87 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs @@ -35,102 +35,109 @@ using Xunit; using Xunit.Abstractions; using Xunit.Sdk; -using Enumerable = System.Linq.Enumerable; -namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// <summary> +/// A Xunit test that should be skipped, and a reason why. +/// </summary> +public abstract class SkipTestAttributeBase : Attribute { - /// <summary> - /// An Xunit test that should be skipped, and a reason why. - /// </summary> - public abstract class SkipTestAttributeBase : Attribute - { - /// <summary> - /// Whether the test should be skipped - /// </summary> - public abstract bool Skip { get; } + /// <summary> + /// Whether the test should be skipped + /// </summary> + public abstract bool Skip { get; } + + /// <summary> + /// The reason why the test should be skipped + /// </summary> + public abstract string Reason { get; } +} - /// <summary> - /// The reason why the test should be skipped - /// </summary> - public abstract string Reason { get; } - } +/// <summary> +/// An Xunit integration test +/// </summary> +[XunitTestCaseDiscoverer("OpenSearch.OpenSearch.Xunit.XunitPlumbing.IntegrationTestDiscoverer", + "OpenSearch.OpenSearch.Xunit")] +public class I : FactAttribute +{ +} - /// <summary> - /// An Xunit integration test - /// </summary> - [XunitTestCaseDiscoverer("OpenSearch.OpenSearch.Xunit.XunitPlumbing.IntegrationTestDiscoverer", - "OpenSearch.OpenSearch.Xunit")] - public class I : FactAttribute - { - } +/// <summary> +/// A test discoverer used to discover integration tests cases attached +/// to test methods that are attributed with <see cref="I" /> attribute +/// </summary> +public class IntegrationTestDiscoverer : OpenSearchTestCaseDiscoverer +{ + public IntegrationTestDiscoverer(IMessageSink diagnosticMessageSink) : base(diagnosticMessageSink) + { + } - /// <summary> - /// A test discoverer used to discover integration tests cases attached - /// to test methods that are attributed with <see cref="I" /> attribute - /// </summary> - public class IntegrationTestDiscoverer : OpenSearchTestCaseDiscoverer - { - public IntegrationTestDiscoverer(IMessageSink diagnosticMessageSink) : base(diagnosticMessageSink) - { - } + /// <inheritdoc /> + protected override bool SkipMethod(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, + out string skipReason) + { + skipReason = null; + var runIntegrationTests = + discoveryOptions.GetValue<bool>(nameof(OpenSearchXunitRunOptions.RunIntegrationTests)); + if (!runIntegrationTests) return true; - /// <inheritdoc /> - protected override bool SkipMethod(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, - out string skipReason) - { - skipReason = null; - var runIntegrationTests = - discoveryOptions.GetValue<bool>(nameof(OpenSearchXunitRunOptions.RunIntegrationTests)); - if (!runIntegrationTests) return true; + var cluster = TestAssemblyRunner.GetClusterForClass(testMethod.TestClass.Class); + if (cluster == null) + { + skipReason += + $"{testMethod.TestClass.Class.Name} does not define a cluster through IClusterFixture or {nameof(IntegrationTestClusterAttribute)}"; + return true; + } - var cluster = TestAssemblyRunner.GetClusterForClass(testMethod.TestClass.Class); - if (cluster == null) - { - skipReason += - $"{testMethod.TestClass.Class.Name} does not define a cluster through IClusterFixture or {nameof(IntegrationTestClusterAttribute)}"; - return true; - } + var openSearchVersion = + discoveryOptions.GetValue<OpenSearchVersion>(nameof(OpenSearchXunitRunOptions.Version)); - var openSearchVersion = - discoveryOptions.GetValue<OpenSearchVersion>(nameof(OpenSearchXunitRunOptions.Version)); + // Skip if the version we are testing against is attributed to be skipped do not run the test nameof(SkipVersionAttribute.Ranges) + var skipVersionAttribute = GetAttributes<SkipVersionAttribute>(testMethod).FirstOrDefault(); + if (skipVersionAttribute != null) + { + var skipVersionRanges = + skipVersionAttribute.GetNamedArgument<IList<Range>>(nameof(SkipVersionAttribute.Ranges)) ?? + new List<Range>(); + if (openSearchVersion == null && skipVersionRanges.Count > 0) + { + skipReason = $"{nameof(SkipVersionAttribute)} has ranges defined for this test but " + + $"no {nameof(OpenSearchXunitRunOptions.Version)} has been provided to {nameof(OpenSearchXunitRunOptions)}"; + return true; + } - // Skip if the version we are testing against is attributed to be skipped do not run the test nameof(SkipVersionAttribute.Ranges) - var skipVersionAttribute = Enumerable.FirstOrDefault(GetAttributes<SkipVersionAttribute>(testMethod)); - if (skipVersionAttribute != null) - { - var skipVersionRanges = - skipVersionAttribute.GetNamedArgument<IList<Range>>(nameof(SkipVersionAttribute.Ranges)) ?? - new List<Range>(); - if (openSearchVersion == null && skipVersionRanges.Count > 0) - { - skipReason = $"{nameof(SkipVersionAttribute)} has ranges defined for this test but " + - $"no {nameof(OpenSearchXunitRunOptions.Version)} has been provided to {nameof(OpenSearchXunitRunOptions)}"; - return true; - } + if (openSearchVersion != null) + { + var reason = skipVersionAttribute.GetNamedArgument<string>(nameof(SkipVersionAttribute.Reason)); + foreach (var range in skipVersionRanges) + { + // inrange takes prereleases into account + if (!openSearchVersion.InRange(range)) continue; + skipReason = + $"{nameof(SkipVersionAttribute)} has range {range} that {openSearchVersion} satisfies"; + if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; + return true; + } + } + } - if (openSearchVersion != null) - { - var reason = skipVersionAttribute.GetNamedArgument<string>(nameof(SkipVersionAttribute.Reason)); - for (var index = 0; index < skipVersionRanges.Count; index++) - { - var range = skipVersionRanges[index]; - // inrange takes prereleases into account - if (!openSearchVersion.InRange(range)) continue; - skipReason = - $"{nameof(SkipVersionAttribute)} has range {range} that {openSearchVersion} satisfies"; - if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; - return true; - } - } - } + // Skip if a prerelease version and has SkipPrereleaseVersionsAttribute + var skipPrerelease = GetAttributes<SkipPrereleaseVersionsAttribute>(testMethod).FirstOrDefault(); + if (openSearchVersion != null && openSearchVersion.IsPreRelease && skipPrerelease != null) + { + skipReason = $"{nameof(SkipPrereleaseVersionsAttribute)} has been applied to this test"; + var reason = skipPrerelease.GetNamedArgument<string>(nameof(SkipVersionAttribute.Reason)); + if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; + return true; + } - var skipTests = GetAttributes<SkipTestAttributeBase>(testMethod) - .FirstOrDefault(a => a.GetNamedArgument<bool>(nameof(SkipTestAttributeBase.Skip))); + var skipTests = GetAttributes<SkipTestAttributeBase>(testMethod) + .FirstOrDefault(a => a.GetNamedArgument<bool>(nameof(SkipTestAttributeBase.Skip))); - if (skipTests == null) return false; + if (skipTests == null) return false; - skipReason = skipTests.GetNamedArgument<string>(nameof(SkipTestAttributeBase.Reason)); - return true; - } - } + skipReason = skipTests.GetNamedArgument<string>(nameof(SkipTestAttributeBase.Reason)); + return true; + } } diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs new file mode 100644 index 0000000000..dee8646e6f --- /dev/null +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; + +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// <summary> +/// A Xunit test that should be skipped for prerelease OpenSearch versions, and a reason why. +/// </summary> +public class SkipPrereleaseVersionsAttribute : Attribute +{ + public SkipPrereleaseVersionsAttribute(string reason) => Reason = reason; + + /// <summary> + /// The reason why the test should be skipped + /// </summary> + public string Reason { get; } +} diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs index cfeec7b8da..e885718fe2 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs @@ -31,35 +31,34 @@ using System.Linq; using SemanticVersioning; -namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// <summary> +/// A Xunit test that should be skipped for given OpenSearch versions, and a reason why. +/// </summary> +public class SkipVersionAttribute : Attribute { - /// <summary> - /// An Xunit test that should be skipped for given OpenSearch versions, and a reason why. - /// </summary> - public class SkipVersionAttribute : Attribute - { - // ReSharper disable once UnusedParameter.Local - // reason is used to allow the test its used on to self document why its been put in place - public SkipVersionAttribute(string skipVersionRangesSeparatedByComma, string reason) - { - Reason = reason; - Ranges = string.IsNullOrEmpty(skipVersionRangesSeparatedByComma) - ? new List<Range>() - : skipVersionRangesSeparatedByComma.Split(',') - .Select(r => r.Trim()) - .Where(r => !string.IsNullOrWhiteSpace(r)) - .Select(r => new Range(r)) - .ToList(); - } + // ReSharper disable once UnusedParameter.Local + // reason is used to allow the test its used on to self document why its been put in place + public SkipVersionAttribute(string skipVersionRangesSeparatedByComma, string reason) + { + Reason = reason; + Ranges = string.IsNullOrEmpty(skipVersionRangesSeparatedByComma) + ? new List<Range>() + : skipVersionRangesSeparatedByComma.Split(',') + .Select(r => r.Trim()) + .Where(r => !string.IsNullOrWhiteSpace(r)) + .Select(r => new Range(r)) + .ToList(); + } - /// <summary> - /// The reason why the test should be skipped - /// </summary> - public string Reason { get; } + /// <summary> + /// The reason why the test should be skipped + /// </summary> + public string Reason { get; } - /// <summary> - /// The version ranges for which the test should be skipped - /// </summary> - public IList<Range> Ranges { get; } - } + /// <summary> + /// The version ranges for which the test should be skipped + /// </summary> + public IList<Range> Ranges { get; } } diff --git a/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs b/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs index 536cebb3ee..f2ac62ab0e 100644 --- a/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs +++ b/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs @@ -27,6 +27,7 @@ */ using System; +using Version = SemanticVersioning.Version; namespace OpenSearch.Stack.ArtifactsApi.Products { @@ -81,5 +82,9 @@ public OpenSearchPlugin(string plugin, Func<OpenSearchVersion, bool> isValid = n public static OpenSearchPlugin DeleteByQuery { get; } = new("delete-by-query", version => version < "1.0.0"); public static OpenSearchPlugin Knn { get; } = new("opensearch-knn"); - } + + public static OpenSearchPlugin MachineLearning { get; } = new("opensearch-ml", v => v.BaseVersion() >= new Version("1.3.0") && !v.IsPreRelease); + + public static OpenSearchPlugin NeuralSearch { get; } = new("opensearch-neural-search", v => v.BaseVersion() >= new Version("2.4.0") && !v.IsPreRelease); + } } diff --git a/samples/Samples/NeuralSearch/NeuralSearchSample.cs b/samples/Samples/NeuralSearch/NeuralSearchSample.cs index dc459d5418..aeb2c28a80 100644 --- a/samples/Samples/NeuralSearch/NeuralSearchSample.cs +++ b/samples/Samples/NeuralSearch/NeuralSearchSample.cs @@ -5,7 +5,6 @@ * compatible open source license. */ -using System.Diagnostics; using OpenSearch.Client; using OpenSearch.Net; @@ -46,7 +45,7 @@ protected override async Task Run(IOpenSearchClient client) .Add("plugins.ml_commons.only_run_on_ml_node", false) .Add("plugins.ml_commons.model_access_control_enabled", true) .Add("plugins.ml_commons.native_memory_threshold", 99))); - Debug.Assert(putSettingsResp.IsValid, putSettingsResp.DebugInformation); + Assert(putSettingsResp, r => r.IsValid); Console.WriteLine("Configured cluster to allow local execution of the ML model"); // Register an ML model group @@ -58,7 +57,7 @@ protected override async Task Run(IOpenSearchClient client) description = $"A model group for the opensearch-net {SampleName} sample", access_mode = "public" })); - Debug.Assert(registerModelGroupResp.Success && (string) registerModelGroupResp.Body.status == "CREATED", registerModelGroupResp.DebugInformation); + AssertCreatedStatus(registerModelGroupResp); Console.WriteLine($"Model group named {MlModelGroupName} {registerModelGroupResp.Body.status}: {registerModelGroupResp.Body.model_group_id}"); _modelGroupId = (string) registerModelGroupResp.Body.model_group_id; @@ -72,7 +71,7 @@ protected override async Task Run(IOpenSearchClient client) model_group_id = _modelGroupId, model_format = "TORCH_SCRIPT" })); - Debug.Assert(registerModelResp.Success && (string) registerModelResp.Body.status == "CREATED", registerModelResp.DebugInformation); + AssertCreatedStatus(registerModelResp); Console.WriteLine($"Model registration task {registerModelResp.Body.status}: {registerModelResp.Body.task_id}"); _modelRegistrationTaskId = (string) registerModelResp.Body.task_id; @@ -81,7 +80,7 @@ protected override async Task Run(IOpenSearchClient client) { var getTaskResp = await client.Http.GetAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}"); Console.WriteLine($"Model registration: {getTaskResp.Body.state}"); - Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation); + AssertNotFailedState(getTaskResp); if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) { _modelId = getTaskResp.Body.model_id; @@ -93,7 +92,7 @@ protected override async Task Run(IOpenSearchClient client) // Deploy the ML model var deployModelResp = await client.Http.PostAsync<DynamicResponse>($"/_plugins/_ml/models/{_modelId}/_deploy"); - Debug.Assert(deployModelResp.Success && (string) deployModelResp.Body.status == "CREATED", deployModelResp.DebugInformation); + AssertCreatedStatus(deployModelResp); Console.WriteLine($"Model deployment task {deployModelResp.Body.status}: {deployModelResp.Body.task_id}"); _modelDeployTaskId = (string) deployModelResp.Body.task_id; @@ -102,35 +101,21 @@ protected override async Task Run(IOpenSearchClient client) { var getTaskResp = await client.Http.GetAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelDeployTaskId}"); Console.WriteLine($"Model deployment: {getTaskResp.Body.state}"); - Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation); + AssertNotFailedState(getTaskResp); if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break; await Task.Delay(10000); } Console.WriteLine($"Model deployed: {_modelId}"); // Create the text_embedding ingest pipeline - // TODO: Client does not yet contain typings for the text_embedding processor - var putIngestPipelineResp = await client.Http.PutAsync<PutPipelineResponse>( - $"/_ingest/pipeline/{IngestPipelineName}", - r => r.SerializableBody(new - { - description = $"A text_embedding ingest pipeline for the opensearch-net {SampleName} sample", - processors = new[] - { - new - { - text_embedding = new - { - model_id = _modelId, - field_map = new - { - text = "passage_embedding" - } - } - } - } - })); - Debug.Assert(putIngestPipelineResp.IsValid, putIngestPipelineResp.DebugInformation); + var putIngestPipelineResp = await client.Ingest.PutPipelineAsync(IngestPipelineName, p => p + .Description($"A text_embedding ingest pipeline for the opensearch-net {SampleName} sample") + .Processors(pp => pp + .TextEmbedding<NeuralSearchDoc>(te => te + .ModelId(_modelId) + .FieldMap(fm => fm + .Map(d => d.Text, d => d.PassageEmbedding))))); + AssertValid(putIngestPipelineResp); Console.WriteLine($"Put ingest pipeline {IngestPipelineName}: {putIngestPipelineResp.Acknowledged}"); _putIngestPipeline = true; @@ -152,7 +137,7 @@ protected override async Task Run(IOpenSearchClient client) .Engine("lucene") .SpaceType("l2") .Name("hnsw")))))); - Debug.Assert(createIndexResp.IsValid, createIndexResp.DebugInformation); + AssertValid(createIndexResp); Console.WriteLine($"Created index {IndexName}: {createIndexResp.Acknowledged}"); _createdIndex = true; @@ -169,31 +154,23 @@ protected override async Task Run(IOpenSearchClient client) .Index(IndexName) .IndexMany(documents) .Refresh(Refresh.WaitFor)); - Debug.Assert(bulkResp.IsValid, bulkResp.DebugInformation); + AssertValid(bulkResp); Console.WriteLine($"Indexed {documents.Length} documents"); // Perform the neural search - // TODO: Client does not yet contain typings for neural query type Console.WriteLine("Performing neural search for text 'wild west'"); - var searchResp = await client.Http.PostAsync<SearchResponse<NeuralSearchDoc>>( - $"/{IndexName}/_search", - r => r.SerializableBody(new - { - _source = new { excludes = new[] { "passage_embedding" } }, - query = new - { - neural = new - { - passage_embedding = new - { - query_text = "wild west", - model_id = _modelId, - k = 5 - } - } - } - })); - Debug.Assert(searchResp.IsValid, searchResp.DebugInformation); + var searchResp = await client.SearchAsync<NeuralSearchDoc>(s => s + .Index(IndexName) + .Source(sf => sf + .Excludes(f => f + .Field(d => d.PassageEmbedding))) + .Query(q => q + .Neural(n => n + .Field(f => f.PassageEmbedding) + .QueryText("wild west") + .ModelId(_modelId) + .K(5)))); + AssertValid(searchResp); Console.WriteLine($"Found {searchResp.Hits.Count} documents"); foreach (var hit in searchResp.Hits) Console.WriteLine($"- Document id: {hit.Source.Id}, score: {hit.Score}, text: {hit.Source.Text}"); } @@ -205,7 +182,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the index var deleteIndexResp = await client.Indices.DeleteAsync(IndexName); - Debug.Assert(deleteIndexResp.IsValid, deleteIndexResp.DebugInformation); + AssertValid(deleteIndexResp); Console.WriteLine($"Deleted index: {deleteIndexResp.Acknowledged}"); } @@ -213,7 +190,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the ingest pipeline var deleteIngestPipelineResp = await client.Ingest.DeletePipelineAsync(IngestPipelineName); - Debug.Assert(deleteIngestPipelineResp.IsValid, deleteIngestPipelineResp.DebugInformation); + AssertValid(deleteIngestPipelineResp); Console.WriteLine($"Deleted ingest pipeline: {deleteIngestPipelineResp.Acknowledged}"); } @@ -221,7 +198,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the model deployment task var deleteModelDeployTaskResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelDeployTaskId}"); - Debug.Assert(deleteModelDeployTaskResp.Success && (string) deleteModelDeployTaskResp.Body.result == "deleted", deleteModelDeployTaskResp.DebugInformation); + AssertDeletedResult(deleteModelDeployTaskResp); Console.WriteLine($"Deleted model deployment task: {deleteModelDeployTaskResp.Body.result}"); } @@ -237,11 +214,11 @@ protected override async Task Cleanup(IOpenSearchClient client) break; } - Debug.Assert(((string?)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false, deleteModelResp.DebugInformation); + Assert(deleteModelResp, r => ((string?) r.Body.error?.reason)?.Contains("Try undeploy") ?? false); // Undeploy the ML model var undeployModelResp = await client.Http.PostAsync<DynamicResponse>($"/_plugins/_ml/models/{_modelId}/_undeploy"); - Debug.Assert(undeployModelResp.Success, undeployModelResp.DebugInformation); + Assert(undeployModelResp, r => r.Success); Console.WriteLine("Undeployed model"); await Task.Delay(10000); } @@ -251,7 +228,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the model registration task var deleteModelRegistrationTaskResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}"); - Debug.Assert(deleteModelRegistrationTaskResp.Success && (string) deleteModelRegistrationTaskResp.Body.result == "deleted", deleteModelRegistrationTaskResp.DebugInformation); + AssertDeletedResult(deleteModelRegistrationTaskResp); Console.WriteLine($"Deleted model registration task: {deleteModelRegistrationTaskResp.Body.result}"); } @@ -259,8 +236,17 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the model group var deleteModelGroupResp = await client.Http.DeleteAsync<DynamicResponse>($"/_plugins/_ml/model_groups/{_modelGroupId}"); - Debug.Assert(deleteModelGroupResp.Success && (string) deleteModelGroupResp.Body.result == "deleted", deleteModelGroupResp.DebugInformation); + AssertDeletedResult(deleteModelGroupResp); Console.WriteLine($"Deleted model group: {deleteModelGroupResp.Body.result}"); } } + + private static void AssertCreatedStatus(DynamicResponse response) => + Assert(response, r => r.Success && (string)r.Body.status == "CREATED"); + + private static void AssertNotFailedState(DynamicResponse response) => + Assert(response, r => r.Success && (string) r.Body.state != "FAILED"); + + private static void AssertDeletedResult(DynamicResponse response) => + Assert(response, r => r.Success && (string) r.Body.result == "deleted"); } diff --git a/samples/Samples/Sample.cs b/samples/Samples/Sample.cs index e683b98935..4057486f11 100644 --- a/samples/Samples/Sample.cs +++ b/samples/Samples/Sample.cs @@ -8,6 +8,7 @@ using System.CommandLine; using System.CommandLine.Binding; using OpenSearch.Client; +using OpenSearch.Net; namespace Samples; @@ -58,4 +59,13 @@ public Command AsCommand(IValueDescriptor<IOpenSearchClient> clientDescriptor) protected abstract Task Run(IOpenSearchClient client); protected virtual Task Cleanup(IOpenSearchClient client) => Task.CompletedTask; + + protected static void Assert<T>(T response, Func<T, bool> condition) where T : IOpenSearchResponse + { + if (condition(response)) return; + + throw new Exception($"Assertion failed:\n{response.ApiCall?.DebugInformation}"); + } + + protected static void AssertValid(IResponse response) => Assert(response, r => r.IsValid); } diff --git a/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs b/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs index 714f43b474..a6b83805b1 100644 --- a/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs +++ b/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs @@ -46,12 +46,12 @@ internal Indices(ManyIndices indices) : base(indices) { } internal Indices(IEnumerable<IndexName> indices) : base(new ManyIndices(indices)) { } /// <summary>All indices. Represents _all</summary> - public static Indices All { get; } = new Indices(new AllIndicesMarker()); + public static Indices All { get; } = new(new AllIndicesMarker()); /// <inheritdoc cref="All" /> - public static Indices AllIndices { get; } = All; + public static Indices AllIndices => All; - private string DebugDisplay => Match( + private string DebugDisplay => Match( all => "_all", types => $"Count: {types.Indices.Count} [" + string.Join(",", types.Indices.Select((t, i) => $"({i + 1}: {t.DebugDisplay})")) + "]" ); @@ -62,11 +62,13 @@ string IUrlParameter.GetString(IConnectionConfigurationValues settings) => Match all => "_all", many => { - if (!(settings is IConnectionSettingsValues oscSettings)) - throw new Exception( - "Tried to pass index names on querysting but it could not be resolved because no OpenSearch.Client settings are available"); + if (settings is not IConnectionSettingsValues oscSettings) + { + throw new Exception( + "Tried to pass index names on querysting but it could not be resolved because no OpenSearch.Client settings are available"); + } - var infer = oscSettings.Inferrer; + var infer = oscSettings.Inferrer; var indices = many.Indices.Select(i => infer.IndexName(i)).Distinct(); return string.Join(",", indices); } diff --git a/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs b/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs index 7e260ff2c3..447d5042c8 100644 --- a/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs +++ b/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs @@ -70,7 +70,8 @@ internal class ProcessorFormatter : IJsonFormatter<IProcessor> { "uri_parts", 30 }, { "fingerprint", 31 }, { "community_id", 32 }, - { "network_direction", 33 } + { "network_direction", 33 }, + { "text_embedding", 34 } }; public IProcessor Deserialize(ref JsonReader reader, IJsonFormatterResolver formatterResolver) @@ -193,6 +194,9 @@ public IProcessor Deserialize(ref JsonReader reader, IJsonFormatterResolver form case 33: processor = Deserialize<NetworkDirectionProcessor>(ref reader, formatterResolver); break; + case 34: + processor = Deserialize<TextEmbeddingProcessor>(ref reader, formatterResolver); + break; } } else @@ -316,6 +320,9 @@ public void Serialize(ref JsonWriter writer, IProcessor value, IJsonFormatterRes case "network_direction": Serialize<INetworkDirectionProcessor>(ref writer, value, formatterResolver); break; + case "text_embedding": + Serialize<ITextEmbeddingProcessor>(ref writer, value, formatterResolver); + break; default: var formatter = DynamicObjectResolver.ExcludeNullCamelCase.GetFormatter<IProcessor>(); formatter.Serialize(ref writer, value, formatterResolver); diff --git a/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/InferenceProcessorBase.cs b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/InferenceProcessorBase.cs new file mode 100644 index 0000000000..8384b34cb7 --- /dev/null +++ b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/InferenceProcessorBase.cs @@ -0,0 +1,119 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Runtime.Serialization; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +[JsonFormatter(typeof(VerbatimDictionaryKeysFormatter<InferenceFieldMap, IInferenceFieldMap, Field, Field>))] +public interface IInferenceFieldMap : IIsADictionary<Field, Field> { } + +[InterfaceDataContract] +public interface IInferenceProcessor : IProcessor +{ + /// <summary> + /// The ID of the model that will be used to generate the embeddings. + /// The model must be deployed in OpenSearch before it can be used in neural search. + /// </summary> + /// <remarks> + /// For more information, + /// see <a href="https://opensearch.org/docs/latest/ml-commons-plugin/using-ml-models/">Using custom models within OpenSearch</a> + /// and <a href="https://opensearch.org/docs/latest/search-plugins/semantic-search/">Semantic search</a>. + /// </remarks> + [DataMember(Name = "model_id")] + string ModelId { get; set; } + + /// <summary> + /// Contains key-value pairs that specify the mapping of a text field to a vector field. + /// <ul> + /// <li><c>Key</c> being the name of the field from which to generate embeddings.</li> + /// <li><c>Value</c> being the name of the vector field in which to store the generated embeddings.</li> + /// </ul> + /// </summary> + [DataMember(Name = "field_map")] + IInferenceFieldMap FieldMap { get; set; } +} + +public class InferenceFieldMap : IsADictionaryBase<Field, Field>, IInferenceFieldMap +{ + public InferenceFieldMap() { } + public InferenceFieldMap(IDictionary<Field, Field> container) : base(container) { } + + public void Add(Field source, Field target) => BackingDictionary.Add(source, target); +} + +/// <inheritdoc cref="IInferenceProcessor"/> +public abstract class InferenceProcessorBase : ProcessorBase, IInferenceProcessor +{ + /// <inheritdoc /> + public string ModelId { get; set; } + /// <inheritdoc /> + public IInferenceFieldMap FieldMap { get; set; } +} + +public class InferenceFieldMapDescriptor<TDocument> + : IsADictionaryDescriptorBase<InferenceFieldMapDescriptor<TDocument>, InferenceFieldMap, Field, Field> + where TDocument : class +{ + public InferenceFieldMapDescriptor() : base(new InferenceFieldMap()) { } + + public InferenceFieldMapDescriptor<TDocument> Map(Field source, Field target) => + Assign(source, target); + + public InferenceFieldMapDescriptor<TDocument> Map<TSourceValue>( + Expression<Func<TDocument, TSourceValue>> source, + Field target + ) => + Assign(source, target); + + public InferenceFieldMapDescriptor<TDocument> Map<TTargetValue>( + Field source, + Expression<Func<TDocument, TTargetValue>> target + ) => + Assign(source, target); + + public InferenceFieldMapDescriptor<TDocument> Map<TSourceValue, TTargetValue>( + Expression<Func<TDocument, TSourceValue>> source, + Expression<Func<TDocument, TTargetValue>> target + ) => + Assign(source, target); +} + +/// <inheritdoc cref="IInferenceProcessor"/> +public abstract class InferenceProcessorDescriptorBase<T, TInferenceProcessorDescriptor, TInferenceProcessorInterface> + : ProcessorDescriptorBase<TInferenceProcessorDescriptor, TInferenceProcessorInterface>, IInferenceProcessor + where T : class + where TInferenceProcessorDescriptor : InferenceProcessorDescriptorBase<T, TInferenceProcessorDescriptor, TInferenceProcessorInterface>, TInferenceProcessorInterface + where TInferenceProcessorInterface : class, IInferenceProcessor +{ + string IInferenceProcessor.ModelId { get; set; } + IInferenceFieldMap IInferenceProcessor.FieldMap { get; set; } + + /// <inheritdoc cref="IInferenceProcessor.ModelId"/> + public TInferenceProcessorDescriptor ModelId(string modelId) => Assign(modelId, (a, v) => a.ModelId = v); + + /// <inheritdoc cref="IInferenceProcessor.FieldMap"/> + public TInferenceProcessorDescriptor FieldMap(IDictionary<Field, Field> fieldMap) => + Assign(fieldMap, (a, v) => a.FieldMap = v != null ? new InferenceFieldMap(v) : null); + + /// <inheritdoc cref="IInferenceProcessor.FieldMap"/> + public TInferenceProcessorDescriptor FieldMap(IInferenceFieldMap fieldMap) => + Assign(fieldMap, (a, v) => a.FieldMap = v); + + /// <inheritdoc cref="IInferenceProcessor.FieldMap"/> + public TInferenceProcessorDescriptor FieldMap(Func<InferenceFieldMapDescriptor<T>, IPromise<IInferenceFieldMap>> selector) => + Assign(selector, (a, v) => a.FieldMap = v?.Invoke(new InferenceFieldMapDescriptor<T>())?.Value); + + /// <inheritdoc cref="IInferenceProcessor.FieldMap"/> + public TInferenceProcessorDescriptor FieldMap<TDocument>(Func<InferenceFieldMapDescriptor<TDocument>, IPromise<IInferenceFieldMap>> selector) + where TDocument : class => + Assign(selector, (a, v) => a.FieldMap = v?.Invoke(new InferenceFieldMapDescriptor<TDocument>())?.Value); +} diff --git a/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/TextEmbeddingProcessor.cs b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/TextEmbeddingProcessor.cs new file mode 100644 index 0000000000..5df6fc70ff --- /dev/null +++ b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/TextEmbeddingProcessor.cs @@ -0,0 +1,32 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +/// <summary> +/// The <c>text_embedding</c> processor is used to generate vector embeddings from text fields for <a href="https://opensearch.org/docs/latest/search-plugins/semantic-search/">semantic search</a>. +/// </summary> +[InterfaceDataContract] +public interface ITextEmbeddingProcessor : IInferenceProcessor +{ +} + +/// <inheritdoc cref="ITextEmbeddingProcessor"/> +public class TextEmbeddingProcessor : InferenceProcessorBase, ITextEmbeddingProcessor +{ + protected override string Name => "text_embedding"; +} + +/// <inheritdoc cref="ITextEmbeddingProcessor"/> +public class TextEmbeddingProcessorDescriptor<TDocument> + : InferenceProcessorDescriptorBase<TDocument, TextEmbeddingProcessorDescriptor<TDocument>, ITextEmbeddingProcessor>, ITextEmbeddingProcessor + where TDocument : class +{ + protected override string Name => "text_embedding"; +} diff --git a/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs b/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs index 7fab35d570..1fd0a29b23 100644 --- a/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs +++ b/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs @@ -205,5 +205,9 @@ public ProcessorsDescriptor NetworkCommunityId<T>(Func<NetworkCommunityIdProcess /// <inheritdoc cref="INetworkDirectionProcessor"/> public ProcessorsDescriptor NetworkDirection<T>(Func<NetworkDirectionProcessorDescriptor<T>, INetworkDirectionProcessor> selector) where T : class => Assign(selector, (a, v) => a.AddIfNotNull(v?.Invoke(new NetworkDirectionProcessorDescriptor<T>()))); - } + + /// <inheritdoc cref="ITextEmbeddingProcessor"/> + public ProcessorsDescriptor TextEmbedding<T>(Func<TextEmbeddingProcessorDescriptor<T>, ITextEmbeddingProcessor> selector) where T : class => + Assign(selector, (a, v) => a.AddIfNotNull(v?.Invoke(new TextEmbeddingProcessorDescriptor<T>()))); + } } diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs index 3468f49ee4..e71268f7a7 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs @@ -203,6 +203,9 @@ public interface IQueryContainer [DataMember(Name = "knn")] IKnnQuery Knn { get; set; } + [DataMember(Name = "neural")] + INeuralQuery Neural { get; set; } + void Accept(IQueryVisitor visitor); } } diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs index a7b9c79fdb..45cfb19d3a 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs @@ -61,6 +61,7 @@ public partial class QueryContainer : IQueryContainer, IDescriptor private IMoreLikeThisQuery _moreLikeThis; private IMultiMatchQuery _multiMatch; private INestedQuery _nested; + private INeuralQuery _neural; private IParentIdQuery _parentId; private IPercolateQuery _percolate; private IPrefixQuery _prefix; @@ -254,6 +255,12 @@ INestedQuery IQueryContainer.Nested set => _nested = Set(value); } + INeuralQuery IQueryContainer.Neural + { + get => _neural; + set => _neural = Set(value); + } + IParentIdQuery IQueryContainer.ParentId { get => _parentId; diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs index 419e41d869..3ccb8529bb 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs @@ -32,480 +32,485 @@ namespace OpenSearch.Client { - [DataContract] - public class QueryContainerDescriptor<T> : QueryContainer where T : class - { - private QueryContainer WrapInContainer<TQuery, TQueryInterface>( - Func<TQuery, TQueryInterface> create, - Action<TQueryInterface, IQueryContainer> assign - ) - where TQuery : class, TQueryInterface, IQuery, new() - where TQueryInterface : class, IQuery - { - // Invoke the create delegate before assigning container; the create delegate - // may mutate the current QueryContainerDescriptor<T> instance such that it - // contains a query. See https://github.com/elastic/elasticsearch-net/issues/2875 - var query = create.InvokeOrDefault(new TQuery()); - - var container = ContainedQuery == null - ? this - : new QueryContainerDescriptor<T>(); - - IQueryContainer c = container; - c.IsVerbatim = query.IsVerbatim; - c.IsStrict = query.IsStrict; - assign(query, container); - container.ContainedQuery = query; - - //if query is writable (not conditionless or verbatim): return a container that holds the query - if (query.IsWritable) - return container; - - //query is conditionless but marked as strict, throw exception - if (query.IsStrict) - throw new ArgumentException("Query is conditionless but strict is turned on"); - - //query is conditionless return an empty container that can later be rewritten - return null; - } - - /// <summary> - /// A query defined using a raw json string. - /// <para>The query must be enclosed within '{' and '}'</para> - /// </summary> - /// <param name="rawJson">The query dsl json</param> - public QueryContainer Raw(string rawJson) => - WrapInContainer((RawQueryDescriptor descriptor) => descriptor.Raw(rawJson), (query, container) => container.RawQuery = query); - - /// <summary> - /// A query that uses a query parser in order to parse its content. - /// </summary> - public QueryContainer QueryString(Func<QueryStringQueryDescriptor<T>, IQueryStringQuery> selector) => - WrapInContainer(selector, (query, container) => container.QueryString = query); - - /// <summary> - /// A query that uses the SimpleQueryParser to parse its context. - /// Unlike the regular query_string query, the simple_query_string query will - /// never throw an exception, and discards invalid parts of the query. - /// </summary> - public QueryContainer SimpleQueryString(Func<SimpleQueryStringQueryDescriptor<T>, ISimpleQueryStringQuery> selector) => - WrapInContainer(selector, (query, container) => container.SimpleQueryString = query); - - /// <summary> - /// A query that match on any (configurable) of the provided terms. - /// This is a simpler syntax query for using a bool query with several term queries in the should clauses. - /// </summary> - public QueryContainer Terms(Func<TermsQueryDescriptor<T>, ITermsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Terms = query); - - /// <summary> - /// A fuzzy based query that uses similarity based on Levenshtein (edit distance) algorithm. - /// Warning: this query is not very scalable with its default prefix length of 0. In this case, - /// every term will be enumerated and cause an edit score calculation or max_expansions is not set. - /// </summary> - public QueryContainer Fuzzy(Func<FuzzyQueryDescriptor<T>, IFuzzyQuery> selector) => - WrapInContainer(selector, (query, container) => container.Fuzzy = query); - - public QueryContainer FuzzyNumeric(Func<FuzzyNumericQueryDescriptor<T>, IFuzzyQuery> selector) => - WrapInContainer(selector, (query, container) => container.Fuzzy = query); - - public QueryContainer FuzzyDate(Func<FuzzyDateQueryDescriptor<T>, IFuzzyQuery> selector) => - WrapInContainer(selector, (query, container) => container.Fuzzy = query); - - /// <summary> - /// The default match query is of type boolean. It means that the text provided is analyzed and the analysis - /// process constructs a boolean query from the provided text. - /// </summary> - public QueryContainer Match(Func<MatchQueryDescriptor<T>, IMatchQuery> selector) => - WrapInContainer(selector, (query, container) => container.Match = query); - - /// <summary> - /// The match_phrase query analyzes the match and creates a phrase query out of the analyzed text. - /// </summary> - public QueryContainer MatchPhrase(Func<MatchPhraseQueryDescriptor<T>, IMatchPhraseQuery> selector) => - WrapInContainer(selector, (query, container) => container.MatchPhrase = query); - - /// <inheritdoc cref="IMatchBoolPrefixQuery"/> - public QueryContainer MatchBoolPrefix(Func<MatchBoolPrefixQueryDescriptor<T>, IMatchBoolPrefixQuery> selector) => - WrapInContainer(selector, (query, container) => container.MatchBoolPrefix = query); - - /// <summary> - /// The match_phrase_prefix is the same as match_phrase, expect it allows for prefix matches on the last term - /// in the text - /// </summary> - public QueryContainer MatchPhrasePrefix(Func<MatchPhrasePrefixQueryDescriptor<T>, IMatchPhrasePrefixQuery> selector) => - WrapInContainer(selector, (query, container) => container.MatchPhrasePrefix = query); - - /// <summary> - /// The multi_match query builds further on top of the match query by allowing multiple fields to be specified. - /// The idea here is to allow to more easily build a concise match type query over multiple fields instead of using a - /// relatively more expressive query by using multiple match queries within a bool query. - /// </summary> - public QueryContainer MultiMatch(Func<MultiMatchQueryDescriptor<T>, IMultiMatchQuery> selector) => - WrapInContainer(selector, (query, container) => container.MultiMatch = query); - - /// <summary> - /// Nested query allows to query nested objects / docs (see nested mapping). The query is executed against the - /// nested objects / docs as if they were indexed as separate docs (they are, internally) and resulting in the - /// root parent doc (or parent nested mapping). - /// </summary> - public QueryContainer Nested(Func<NestedQueryDescriptor<T>, INestedQuery> selector) => - WrapInContainer(selector, (query, container) => container.Nested = query); - - /// <summary> - /// A thin wrapper allowing fined grained control what should happen if a query is conditionless - /// if you need to fallback to something other than a match_all query - /// </summary> - public QueryContainer Conditionless(Func<ConditionlessQueryDescriptor<T>, IConditionlessQuery> selector) - { - var query = selector(new ConditionlessQueryDescriptor<T>()); - return query?.Query ?? query?.Fallback; - } - - /// <summary> - /// Matches documents with fields that have terms within a certain numeric range. - /// </summary> - public QueryContainer Range(Func<NumericRangeQueryDescriptor<T>, INumericRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - public QueryContainer LongRange(Func<LongRangeQueryDescriptor<T>, ILongRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - /// <summary> - /// Matches documents with fields that have terms within a certain date range. - /// </summary> - public QueryContainer DateRange(Func<DateRangeQueryDescriptor<T>, IDateRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - /// <summary> - /// Matches documents with fields that have terms within a certain term range. - /// </summary> - public QueryContainer TermRange(Func<TermRangeQueryDescriptor<T>, ITermRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - /// <summary> - /// More like this query find documents that are like the provided text by running it against one or more fields. - /// </summary> - public QueryContainer MoreLikeThis(Func<MoreLikeThisQueryDescriptor<T>, IMoreLikeThisQuery> selector) => - WrapInContainer(selector, (query, container) => container.MoreLikeThis = query); - - /// <summary> - /// A geo_shape query that finds documents - /// that have a geometry that matches for the given spatial relation and input shape - /// </summary> - public QueryContainer GeoShape(Func<GeoShapeQueryDescriptor<T>, IGeoShapeQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoShape = query); - - /// <summary> - /// Finds documents with shapes that either intersect, are within, or do not intersect a specified shape. - /// </summary> - public QueryContainer Shape(Func<ShapeQueryDescriptor<T>, IShapeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Shape = query); - - /// <summary> - /// Matches documents with a geo_point type field that falls within a polygon of points - /// </summary> - public QueryContainer GeoPolygon(Func<GeoPolygonQueryDescriptor<T>, IGeoPolygonQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoPolygon = query); - - /// <summary> - /// Matches documents with a geo_point type field to include only those - /// that exist within a specific distance from a given geo_point - /// </summary> - public QueryContainer GeoDistance(Func<GeoDistanceQueryDescriptor<T>, IGeoDistanceQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoDistance = query); - - /// <summary> - /// Matches documents with a geo_point type field to include only those that exist within a bounding box - /// </summary> - public QueryContainer GeoBoundingBox(Func<GeoBoundingBoxQueryDescriptor<T>, IGeoBoundingBoxQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoBoundingBox = query); - - /// <summary> - /// The has_child query works the same as the has_child filter, by automatically wrapping the filter with a - /// constant_score. - /// </summary> - /// <typeparam name="TChild">Type of the child</typeparam> - public QueryContainer HasChild<TChild>(Func<HasChildQueryDescriptor<TChild>, IHasChildQuery> selector) where TChild : class => - WrapInContainer(selector, (query, container) => container.HasChild = query); - - /// <summary> - /// The has_parent query works the same as the has_parent filter, by automatically wrapping the filter with a - /// constant_score. - /// </summary> - /// <typeparam name="TParent">Type of the parent</typeparam> - public QueryContainer HasParent<TParent>(Func<HasParentQueryDescriptor<TParent>, IHasParentQuery> selector) where TParent : class => - WrapInContainer(selector, (query, container) => container.HasParent = query); - - public QueryContainer Knn(Func<KnnQueryDescriptor<T>, IKnnQuery> selector) => - WrapInContainer(selector, (query, container) => container.Knn = query); - - /// <summary> - /// A query that generates the union of documents produced by its subqueries, and that scores each document - /// with the maximum score for that document as produced by any subquery, plus a tie breaking increment for - /// any additional matching subqueries. - /// </summary> - public QueryContainer DisMax(Func<DisMaxQueryDescriptor<T>, IDisMaxQuery> selector) => - WrapInContainer(selector, (query, container) => container.DisMax = query); - - /// <inheritdoc cref="IDistanceFeatureQuery"/> - public QueryContainer DistanceFeature(Func<DistanceFeatureQueryDescriptor<T>, IDistanceFeatureQuery> selector) => - WrapInContainer(selector, (query, container) => container.DistanceFeature = query); - - /// <summary> - /// A query that wraps a filter or another query and simply returns a constant score equal to the query boost - /// for every document in the filter. Maps to Lucene ConstantScoreQuery. - /// </summary> - public QueryContainer ConstantScore(Func<ConstantScoreQueryDescriptor<T>, IConstantScoreQuery> selector) => - WrapInContainer(selector, (query, container) => container.ConstantScore = query); - - /// <summary> - /// A query that matches documents matching boolean combinations of other queries. The bool query maps to - /// Lucene BooleanQuery. - /// It is built using one or more boolean clauses, each clause with a typed occurrence - /// </summary> - public QueryContainer Bool(Func<BoolQueryDescriptor<T>, IBoolQuery> selector) => - WrapInContainer(selector, (query, container) => container.Bool = query); - - /// <summary> - /// A query that can be used to effectively demote results that match a given query. - /// Unlike the "must_not" clause in bool query, this still selects documents that contain - /// undesirable terms, but reduces their overall score. - /// </summary> - public QueryContainer Boosting(Func<BoostingQueryDescriptor<T>, IBoostingQuery> selector) => - WrapInContainer(selector, (query, container) => container.Boosting = query); - - /// <summary> - /// A query that matches all documents. Maps to Lucene MatchAllDocsQuery. - /// </summary> - public QueryContainer MatchAll(Func<MatchAllQueryDescriptor, IMatchAllQuery> selector = null) => - WrapInContainer(selector, (query, container) => container.MatchAll = query ?? new MatchAllQuery()); - - /// <summary> - /// A query that matches no documents. This is the inverse of the match_all query. - /// </summary> - public QueryContainer MatchNone(Func<MatchNoneQueryDescriptor, IMatchNoneQuery> selector = null) => - WrapInContainer(selector, (query, container) => container.MatchNone = query ?? new MatchNoneQuery()); - - /// <summary> - /// Matches documents that have fields that contain a term (not analyzed). - /// The term query maps to Lucene TermQuery. - /// </summary> - public QueryContainer Term<TValue>(Expression<Func<T, TValue>> field, object value, double? boost = null, string name = null) => - Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); - - /// <summary> - /// Helper method to easily filter on join relations - /// </summary> - public QueryContainer HasRelationName(Expression<Func<T, JoinField>> field, RelationName value) => - Term(t => t.Field(field).Value(value)); - - /// <summary>Helper method to easily filter on join relations</summary> - public QueryContainer HasRelationName<TRelation>(Expression<Func<T, JoinField>> field) => - Term(t => t.Field(field).Value(Infer.Relation<TRelation>())); - - /// <summary> - /// Matches documents that have fields that contain a term (not analyzed). - /// The term query maps to Lucene TermQuery. - /// </summary> - public QueryContainer Term(Field field, object value, double? boost = null, string name = null) => - Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); - - /// <summary> - /// Matches documents that have fields that contain a term (not analyzed). - /// The term query maps to Lucene TermQuery. - /// </summary> - public QueryContainer Term(Func<TermQueryDescriptor<T>, ITermQuery> selector) => - WrapInContainer(selector, (query, container) => container.Term = query); - - /// <summary> - /// Matches documents that have fields matching a wildcard expression (not analyzed). - /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, - /// which matches any single character. Note this query can be slow, as it needs to iterate - /// over many terms. In order to prevent extremely slow wildcard queries, a wildcard term should - /// not start with one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. - /// </summary> - public QueryContainer Wildcard<TValue>(Expression<Func<T, TValue>> field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, - string name = null - ) => - Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); - - /// <summary> - /// Matches documents that have fields matching a wildcard expression (not analyzed). - /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, - /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. - /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with - /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. - /// </summary> - public QueryContainer Wildcard(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => - Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); - - /// <summary> - /// Matches documents that have fields matching a wildcard expression (not analyzed). - /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, - /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. - /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with - /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. - /// </summary> - public QueryContainer Wildcard(Func<WildcardQueryDescriptor<T>, IWildcardQuery> selector) => - WrapInContainer(selector, (query, container) => container.Wildcard = query); - - /// <summary> - /// Matches documents that have fields containing terms with a specified prefix (not analyzed). - /// The prefix query maps to Lucene PrefixQuery. - /// </summary> - public QueryContainer Prefix<TValue>(Expression<Func<T, TValue>> field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, - string name = null - ) => - Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); - - /// <summary> - /// Matches documents that have fields containing terms with a specified prefix (not analyzed). - /// The prefix query maps to Lucene PrefixQuery. - /// </summary> - public QueryContainer Prefix(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => - Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); - - /// <summary> - /// Matches documents that have fields containing terms with a specified prefix (not analyzed). - /// The prefix query maps to Lucene PrefixQuery. - /// </summary> - public QueryContainer Prefix(Func<PrefixQueryDescriptor<T>, IPrefixQuery> selector) => - WrapInContainer(selector, (query, container) => container.Prefix = query); - - /// <summary> - /// Matches documents that only have the provided ids. - /// Note, this filter does not require the _id field to be indexed since - /// it works using the _uid field. - /// </summary> - public QueryContainer Ids(Func<IdsQueryDescriptor, IIdsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Ids = query); - - /// <summary> - /// Allows fine-grained control over the order and proximity of matching terms. - /// Matching rules are constructed from a small set of definitions, - /// and the rules are then applied to terms from a particular field. - /// The definitions produce sequences of minimal intervals that span terms in a body of text. - /// These intervals can be further combined and filtered by parent sources. - /// </summary> - public QueryContainer Intervals(Func<IntervalsQueryDescriptor<T>, IIntervalsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Intervals = query); - - /// <inheritdoc cref="IRankFeatureQuery"/> - public QueryContainer RankFeature(Func<RankFeatureQueryDescriptor<T>, IRankFeatureQuery> selector) => - WrapInContainer(selector, (query, container) => container.RankFeature = query); - - /// <summary> - /// Matches spans containing a term. The span term query maps to Lucene SpanTermQuery. - /// </summary> - public QueryContainer SpanTerm(Func<SpanTermQueryDescriptor<T>, ISpanTermQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanTerm = query); - - /// <summary> - /// Matches spans near the beginning of a field. The span first query maps to Lucene SpanFirstQuery. - /// </summary> - public QueryContainer SpanFirst(Func<SpanFirstQueryDescriptor<T>, ISpanFirstQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanFirst = query); - - /// <summary> - /// Matches spans which are near one another. One can specify slop, the maximum number of - /// intervening unmatched positions, as well as whether matches are required to be in-order. - /// The span near query maps to Lucene SpanNearQuery. - /// </summary> - public QueryContainer SpanNear(Func<SpanNearQueryDescriptor<T>, ISpanNearQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanNear = query); - - /// <summary> - /// Matches the union of its span clauses. - /// The span or query maps to Lucene SpanOrQuery. - /// </summary> - public QueryContainer SpanOr(Func<SpanOrQueryDescriptor<T>, ISpanOrQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanOr = query); - - /// <summary> - /// Removes matches which overlap with another span query. - /// The span not query maps to Lucene SpanNotQuery. - /// </summary> - public QueryContainer SpanNot(Func<SpanNotQueryDescriptor<T>, ISpanNotQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanNot = query); - - /// <summary> - /// Wrap a multi term query (one of fuzzy, prefix, term range or regexp query) - /// as a span query so it can be nested. - /// </summary> - public QueryContainer SpanMultiTerm(Func<SpanMultiTermQueryDescriptor<T>, ISpanMultiTermQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanMultiTerm = query); - - /// <summary> - /// Returns matches which enclose another span query. - /// The span containing query maps to Lucene SpanContainingQuery - /// </summary> - public QueryContainer SpanContaining(Func<SpanContainingQueryDescriptor<T>, ISpanContainingQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanContaining = query); - - /// <summary> - /// Returns Matches which are enclosed inside another span query. - /// The span within query maps to Lucene SpanWithinQuery - /// </summary> - public QueryContainer SpanWithin(Func<SpanWithinQueryDescriptor<T>, ISpanWithinQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanWithin = query); - - /// <summary> - /// Wraps span queries to allow them to participate in composite single-field Span queries by 'lying' about their search field. - /// That is, the masked span query will function as normal, but the field points back to the set field of the query. - /// This can be used to support queries like SpanNearQuery or SpanOrQuery across different fields, - /// which is not ordinarily permitted. - /// </summary> - public QueryContainer SpanFieldMasking(Func<SpanFieldMaskingQueryDescriptor<T>, ISpanFieldMaskingQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanFieldMasking = query); - - /// <summary> - /// Allows you to use regular expression term queries. - /// "term queries" means that OpenSearch will apply the regexp to the terms produced - /// by the tokenizer for that field, and not to the original text of the field. - /// </summary> - public QueryContainer Regexp(Func<RegexpQueryDescriptor<T>, IRegexpQuery> selector) => - WrapInContainer(selector, (query, container) => container.Regexp = query); - - /// <summary> - /// The function_score query allows you to modify the score of documents that are retrieved by a query. - /// This can be useful if, for example, a score function is computationally expensive and it is - /// sufficient to compute the score on a filtered set of documents. - /// </summary> - /// <returns></returns> - public QueryContainer FunctionScore(Func<FunctionScoreQueryDescriptor<T>, IFunctionScoreQuery> selector) => - WrapInContainer(selector, (query, container) => container.FunctionScore = query); - - public QueryContainer Script(Func<ScriptQueryDescriptor<T>, IScriptQuery> selector) => - WrapInContainer(selector, (query, container) => container.Script = query); - - public QueryContainer ScriptScore(Func<ScriptScoreQueryDescriptor<T>, IScriptScoreQuery> selector) => - WrapInContainer(selector, (query, container) => container.ScriptScore = query); - - public QueryContainer Exists(Func<ExistsQueryDescriptor<T>, IExistsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Exists = query); - - /// <summary> - /// Used to match queries stored in an index. - /// The percolate query itself contains the document that will be used as query - /// to match with the stored queries. - /// </summary> - public QueryContainer Percolate(Func<PercolateQueryDescriptor<T>, IPercolateQuery> selector) => - WrapInContainer(selector, (query, container) => container.Percolate = query); - - /// <summary> - /// Used to find child documents which belong to a particular parent. - /// </summary> - public QueryContainer ParentId(Func<ParentIdQueryDescriptor<T>, IParentIdQuery> selector) => - WrapInContainer(selector, (query, container) => container.ParentId = query); - - /// <summary> - /// Returns any documents that match with at least one or more of the provided terms. - /// The terms are not analyzed and thus must match exactly. The number of terms that must match varies - /// per document and is either controlled by a minimum should match field or - /// computed per document in a minimum should match script. - /// </summary> - public QueryContainer TermsSet(Func<TermsSetQueryDescriptor<T>, ITermsSetQuery> selector) => - WrapInContainer(selector, (query, container) => container.TermsSet = query); - } + [DataContract] + public class QueryContainerDescriptor<T> : QueryContainer where T : class + { + private QueryContainer WrapInContainer<TQuery, TQueryInterface>( + Func<TQuery, TQueryInterface> create, + Action<TQueryInterface, IQueryContainer> assign + ) + where TQuery : class, TQueryInterface, IQuery, new() + where TQueryInterface : class, IQuery + { + // Invoke the create delegate before assigning container; the create delegate + // may mutate the current QueryContainerDescriptor<T> instance such that it + // contains a query. See https://github.com/elastic/elasticsearch-net/issues/2875 + var query = create.InvokeOrDefault(new TQuery()); + + var container = ContainedQuery == null + ? this + : new QueryContainerDescriptor<T>(); + + IQueryContainer c = container; + c.IsVerbatim = query.IsVerbatim; + c.IsStrict = query.IsStrict; + assign(query, container); + container.ContainedQuery = query; + + //if query is writable (not conditionless or verbatim): return a container that holds the query + if (query.IsWritable) + return container; + + //query is conditionless but marked as strict, throw exception + if (query.IsStrict) + throw new ArgumentException("Query is conditionless but strict is turned on"); + + //query is conditionless return an empty container that can later be rewritten + return null; + } + + /// <summary> + /// A query defined using a raw json string. + /// <para>The query must be enclosed within '{' and '}'</para> + /// </summary> + /// <param name="rawJson">The query dsl json</param> + public QueryContainer Raw(string rawJson) => + WrapInContainer((RawQueryDescriptor descriptor) => descriptor.Raw(rawJson), (query, container) => container.RawQuery = query); + + /// <summary> + /// A query that uses a query parser in order to parse its content. + /// </summary> + public QueryContainer QueryString(Func<QueryStringQueryDescriptor<T>, IQueryStringQuery> selector) => + WrapInContainer(selector, (query, container) => container.QueryString = query); + + /// <summary> + /// A query that uses the SimpleQueryParser to parse its context. + /// Unlike the regular query_string query, the simple_query_string query will + /// never throw an exception, and discards invalid parts of the query. + /// </summary> + public QueryContainer SimpleQueryString(Func<SimpleQueryStringQueryDescriptor<T>, ISimpleQueryStringQuery> selector) => + WrapInContainer(selector, (query, container) => container.SimpleQueryString = query); + + /// <summary> + /// A query that match on any (configurable) of the provided terms. + /// This is a simpler syntax query for using a bool query with several term queries in the should clauses. + /// </summary> + public QueryContainer Terms(Func<TermsQueryDescriptor<T>, ITermsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Terms = query); + + /// <summary> + /// A fuzzy based query that uses similarity based on Levenshtein (edit distance) algorithm. + /// Warning: this query is not very scalable with its default prefix length of 0. In this case, + /// every term will be enumerated and cause an edit score calculation or max_expansions is not set. + /// </summary> + public QueryContainer Fuzzy(Func<FuzzyQueryDescriptor<T>, IFuzzyQuery> selector) => + WrapInContainer(selector, (query, container) => container.Fuzzy = query); + + public QueryContainer FuzzyNumeric(Func<FuzzyNumericQueryDescriptor<T>, IFuzzyQuery> selector) => + WrapInContainer(selector, (query, container) => container.Fuzzy = query); + + public QueryContainer FuzzyDate(Func<FuzzyDateQueryDescriptor<T>, IFuzzyQuery> selector) => + WrapInContainer(selector, (query, container) => container.Fuzzy = query); + + /// <summary> + /// The default match query is of type boolean. It means that the text provided is analyzed and the analysis + /// process constructs a boolean query from the provided text. + /// </summary> + public QueryContainer Match(Func<MatchQueryDescriptor<T>, IMatchQuery> selector) => + WrapInContainer(selector, (query, container) => container.Match = query); + + /// <summary> + /// The match_phrase query analyzes the match and creates a phrase query out of the analyzed text. + /// </summary> + public QueryContainer MatchPhrase(Func<MatchPhraseQueryDescriptor<T>, IMatchPhraseQuery> selector) => + WrapInContainer(selector, (query, container) => container.MatchPhrase = query); + + /// <inheritdoc cref="IMatchBoolPrefixQuery"/> + public QueryContainer MatchBoolPrefix(Func<MatchBoolPrefixQueryDescriptor<T>, IMatchBoolPrefixQuery> selector) => + WrapInContainer(selector, (query, container) => container.MatchBoolPrefix = query); + + /// <summary> + /// The match_phrase_prefix is the same as match_phrase, expect it allows for prefix matches on the last term + /// in the text + /// </summary> + public QueryContainer MatchPhrasePrefix(Func<MatchPhrasePrefixQueryDescriptor<T>, IMatchPhrasePrefixQuery> selector) => + WrapInContainer(selector, (query, container) => container.MatchPhrasePrefix = query); + + /// <summary> + /// The multi_match query builds further on top of the match query by allowing multiple fields to be specified. + /// The idea here is to allow to more easily build a concise match type query over multiple fields instead of using a + /// relatively more expressive query by using multiple match queries within a bool query. + /// </summary> + public QueryContainer MultiMatch(Func<MultiMatchQueryDescriptor<T>, IMultiMatchQuery> selector) => + WrapInContainer(selector, (query, container) => container.MultiMatch = query); + + /// <summary> + /// Nested query allows to query nested objects / docs (see nested mapping). The query is executed against the + /// nested objects / docs as if they were indexed as separate docs (they are, internally) and resulting in the + /// root parent doc (or parent nested mapping). + /// </summary> + public QueryContainer Nested(Func<NestedQueryDescriptor<T>, INestedQuery> selector) => + WrapInContainer(selector, (query, container) => container.Nested = query); + + /// <summary> + /// A thin wrapper allowing fined grained control what should happen if a query is conditionless + /// if you need to fallback to something other than a match_all query + /// </summary> + public QueryContainer Conditionless(Func<ConditionlessQueryDescriptor<T>, IConditionlessQuery> selector) + { + var query = selector(new ConditionlessQueryDescriptor<T>()); + return query?.Query ?? query?.Fallback; + } + + /// <summary> + /// Matches documents with fields that have terms within a certain numeric range. + /// </summary> + public QueryContainer Range(Func<NumericRangeQueryDescriptor<T>, INumericRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + public QueryContainer LongRange(Func<LongRangeQueryDescriptor<T>, ILongRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + /// <summary> + /// Matches documents with fields that have terms within a certain date range. + /// </summary> + public QueryContainer DateRange(Func<DateRangeQueryDescriptor<T>, IDateRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + /// <summary> + /// Matches documents with fields that have terms within a certain term range. + /// </summary> + public QueryContainer TermRange(Func<TermRangeQueryDescriptor<T>, ITermRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + /// <summary> + /// More like this query find documents that are like the provided text by running it against one or more fields. + /// </summary> + public QueryContainer MoreLikeThis(Func<MoreLikeThisQueryDescriptor<T>, IMoreLikeThisQuery> selector) => + WrapInContainer(selector, (query, container) => container.MoreLikeThis = query); + + /// <summary> + /// A geo_shape query that finds documents + /// that have a geometry that matches for the given spatial relation and input shape + /// </summary> + public QueryContainer GeoShape(Func<GeoShapeQueryDescriptor<T>, IGeoShapeQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoShape = query); + + /// <summary> + /// Finds documents with shapes that either intersect, are within, or do not intersect a specified shape. + /// </summary> + public QueryContainer Shape(Func<ShapeQueryDescriptor<T>, IShapeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Shape = query); + + /// <summary> + /// Matches documents with a geo_point type field that falls within a polygon of points + /// </summary> + public QueryContainer GeoPolygon(Func<GeoPolygonQueryDescriptor<T>, IGeoPolygonQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoPolygon = query); + + /// <summary> + /// Matches documents with a geo_point type field to include only those + /// that exist within a specific distance from a given geo_point + /// </summary> + public QueryContainer GeoDistance(Func<GeoDistanceQueryDescriptor<T>, IGeoDistanceQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoDistance = query); + + /// <summary> + /// Matches documents with a geo_point type field to include only those that exist within a bounding box + /// </summary> + public QueryContainer GeoBoundingBox(Func<GeoBoundingBoxQueryDescriptor<T>, IGeoBoundingBoxQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoBoundingBox = query); + + /// <summary> + /// The has_child query works the same as the has_child filter, by automatically wrapping the filter with a + /// constant_score. + /// </summary> + /// <typeparam name="TChild">Type of the child</typeparam> + public QueryContainer HasChild<TChild>(Func<HasChildQueryDescriptor<TChild>, IHasChildQuery> selector) where TChild : class => + WrapInContainer(selector, (query, container) => container.HasChild = query); + + /// <summary> + /// The has_parent query works the same as the has_parent filter, by automatically wrapping the filter with a + /// constant_score. + /// </summary> + /// <typeparam name="TParent">Type of the parent</typeparam> + public QueryContainer HasParent<TParent>(Func<HasParentQueryDescriptor<TParent>, IHasParentQuery> selector) where TParent : class => + WrapInContainer(selector, (query, container) => container.HasParent = query); + + public QueryContainer Knn(Func<KnnQueryDescriptor<T>, IKnnQuery> selector) => + WrapInContainer(selector, (query, container) => container.Knn = query); + + /// <summary> + /// A query that generates the union of documents produced by its subqueries, and that scores each document + /// with the maximum score for that document as produced by any subquery, plus a tie breaking increment for + /// any additional matching subqueries. + /// </summary> + public QueryContainer DisMax(Func<DisMaxQueryDescriptor<T>, IDisMaxQuery> selector) => + WrapInContainer(selector, (query, container) => container.DisMax = query); + + /// <inheritdoc cref="IDistanceFeatureQuery"/> + public QueryContainer DistanceFeature(Func<DistanceFeatureQueryDescriptor<T>, IDistanceFeatureQuery> selector) => + WrapInContainer(selector, (query, container) => container.DistanceFeature = query); + + /// <summary> + /// A query that wraps a filter or another query and simply returns a constant score equal to the query boost + /// for every document in the filter. Maps to Lucene ConstantScoreQuery. + /// </summary> + public QueryContainer ConstantScore(Func<ConstantScoreQueryDescriptor<T>, IConstantScoreQuery> selector) => + WrapInContainer(selector, (query, container) => container.ConstantScore = query); + + /// <summary> + /// A query that matches documents matching boolean combinations of other queries. The bool query maps to + /// Lucene BooleanQuery. + /// It is built using one or more boolean clauses, each clause with a typed occurrence + /// </summary> + public QueryContainer Bool(Func<BoolQueryDescriptor<T>, IBoolQuery> selector) => + WrapInContainer(selector, (query, container) => container.Bool = query); + + /// <summary> + /// A query that can be used to effectively demote results that match a given query. + /// Unlike the "must_not" clause in bool query, this still selects documents that contain + /// undesirable terms, but reduces their overall score. + /// </summary> + public QueryContainer Boosting(Func<BoostingQueryDescriptor<T>, IBoostingQuery> selector) => + WrapInContainer(selector, (query, container) => container.Boosting = query); + + /// <summary> + /// A query that matches all documents. Maps to Lucene MatchAllDocsQuery. + /// </summary> + public QueryContainer MatchAll(Func<MatchAllQueryDescriptor, IMatchAllQuery> selector = null) => + WrapInContainer(selector, (query, container) => container.MatchAll = query ?? new MatchAllQuery()); + + /// <summary> + /// A query that matches no documents. This is the inverse of the match_all query. + /// </summary> + public QueryContainer MatchNone(Func<MatchNoneQueryDescriptor, IMatchNoneQuery> selector = null) => + WrapInContainer(selector, (query, container) => container.MatchNone = query ?? new MatchNoneQuery()); + + /// <summary> + /// Matches documents that have fields that contain a term (not analyzed). + /// The term query maps to Lucene TermQuery. + /// </summary> + public QueryContainer Term<TValue>(Expression<Func<T, TValue>> field, object value, double? boost = null, string name = null) => + Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); + + /// <summary> + /// Helper method to easily filter on join relations + /// </summary> + public QueryContainer HasRelationName(Expression<Func<T, JoinField>> field, RelationName value) => + Term(t => t.Field(field).Value(value)); + + /// <summary>Helper method to easily filter on join relations</summary> + public QueryContainer HasRelationName<TRelation>(Expression<Func<T, JoinField>> field) => + Term(t => t.Field(field).Value(Infer.Relation<TRelation>())); + + /// <summary> + /// Matches documents that have fields that contain a term (not analyzed). + /// The term query maps to Lucene TermQuery. + /// </summary> + public QueryContainer Term(Field field, object value, double? boost = null, string name = null) => + Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); + + /// <summary> + /// Matches documents that have fields that contain a term (not analyzed). + /// The term query maps to Lucene TermQuery. + /// </summary> + public QueryContainer Term(Func<TermQueryDescriptor<T>, ITermQuery> selector) => + WrapInContainer(selector, (query, container) => container.Term = query); + + /// <summary> + /// Matches documents that have fields matching a wildcard expression (not analyzed). + /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, + /// which matches any single character. Note this query can be slow, as it needs to iterate + /// over many terms. In order to prevent extremely slow wildcard queries, a wildcard term should + /// not start with one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. + /// </summary> + public QueryContainer Wildcard<TValue>(Expression<Func<T, TValue>> field, string value, double? boost = null, + MultiTermQueryRewrite rewrite = null, + string name = null + ) => + Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); + + /// <summary> + /// Matches documents that have fields matching a wildcard expression (not analyzed). + /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, + /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. + /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with + /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. + /// </summary> + public QueryContainer Wildcard(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => + Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); + + /// <summary> + /// Matches documents that have fields matching a wildcard expression (not analyzed). + /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, + /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. + /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with + /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. + /// </summary> + public QueryContainer Wildcard(Func<WildcardQueryDescriptor<T>, IWildcardQuery> selector) => + WrapInContainer(selector, (query, container) => container.Wildcard = query); + + /// <summary> + /// Matches documents that have fields containing terms with a specified prefix (not analyzed). + /// The prefix query maps to Lucene PrefixQuery. + /// </summary> + public QueryContainer Prefix<TValue>(Expression<Func<T, TValue>> field, string value, double? boost = null, + MultiTermQueryRewrite rewrite = null, + string name = null + ) => + Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); + + /// <summary> + /// Matches documents that have fields containing terms with a specified prefix (not analyzed). + /// The prefix query maps to Lucene PrefixQuery. + /// </summary> + public QueryContainer Prefix(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => + Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); + + /// <summary> + /// Matches documents that have fields containing terms with a specified prefix (not analyzed). + /// The prefix query maps to Lucene PrefixQuery. + /// </summary> + public QueryContainer Prefix(Func<PrefixQueryDescriptor<T>, IPrefixQuery> selector) => + WrapInContainer(selector, (query, container) => container.Prefix = query); + + /// <summary> + /// Matches documents that only have the provided ids. + /// Note, this filter does not require the _id field to be indexed since + /// it works using the _uid field. + /// </summary> + public QueryContainer Ids(Func<IdsQueryDescriptor, IIdsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Ids = query); + + /// <summary> + /// Allows fine-grained control over the order and proximity of matching terms. + /// Matching rules are constructed from a small set of definitions, + /// and the rules are then applied to terms from a particular field. + /// The definitions produce sequences of minimal intervals that span terms in a body of text. + /// These intervals can be further combined and filtered by parent sources. + /// </summary> + public QueryContainer Intervals(Func<IntervalsQueryDescriptor<T>, IIntervalsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Intervals = query); + + /// <inheritdoc cref="IRankFeatureQuery"/> + public QueryContainer RankFeature(Func<RankFeatureQueryDescriptor<T>, IRankFeatureQuery> selector) => + WrapInContainer(selector, (query, container) => container.RankFeature = query); + + /// <summary> + /// Matches spans containing a term. The span term query maps to Lucene SpanTermQuery. + /// </summary> + public QueryContainer SpanTerm(Func<SpanTermQueryDescriptor<T>, ISpanTermQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanTerm = query); + + /// <summary> + /// Matches spans near the beginning of a field. The span first query maps to Lucene SpanFirstQuery. + /// </summary> + public QueryContainer SpanFirst(Func<SpanFirstQueryDescriptor<T>, ISpanFirstQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanFirst = query); + + /// <summary> + /// Matches spans which are near one another. One can specify slop, the maximum number of + /// intervening unmatched positions, as well as whether matches are required to be in-order. + /// The span near query maps to Lucene SpanNearQuery. + /// </summary> + public QueryContainer SpanNear(Func<SpanNearQueryDescriptor<T>, ISpanNearQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanNear = query); + + /// <summary> + /// Matches the union of its span clauses. + /// The span or query maps to Lucene SpanOrQuery. + /// </summary> + public QueryContainer SpanOr(Func<SpanOrQueryDescriptor<T>, ISpanOrQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanOr = query); + + /// <summary> + /// Removes matches which overlap with another span query. + /// The span not query maps to Lucene SpanNotQuery. + /// </summary> + public QueryContainer SpanNot(Func<SpanNotQueryDescriptor<T>, ISpanNotQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanNot = query); + + /// <summary> + /// Wrap a multi term query (one of fuzzy, prefix, term range or regexp query) + /// as a span query so it can be nested. + /// </summary> + public QueryContainer SpanMultiTerm(Func<SpanMultiTermQueryDescriptor<T>, ISpanMultiTermQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanMultiTerm = query); + + /// <summary> + /// Returns matches which enclose another span query. + /// The span containing query maps to Lucene SpanContainingQuery + /// </summary> + public QueryContainer SpanContaining(Func<SpanContainingQueryDescriptor<T>, ISpanContainingQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanContaining = query); + + /// <summary> + /// Returns Matches which are enclosed inside another span query. + /// The span within query maps to Lucene SpanWithinQuery + /// </summary> + public QueryContainer SpanWithin(Func<SpanWithinQueryDescriptor<T>, ISpanWithinQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanWithin = query); + + /// <summary> + /// Wraps span queries to allow them to participate in composite single-field Span queries by 'lying' about their search field. + /// That is, the masked span query will function as normal, but the field points back to the set field of the query. + /// This can be used to support queries like SpanNearQuery or SpanOrQuery across different fields, + /// which is not ordinarily permitted. + /// </summary> + public QueryContainer SpanFieldMasking(Func<SpanFieldMaskingQueryDescriptor<T>, ISpanFieldMaskingQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanFieldMasking = query); + + /// <summary> + /// Allows you to use regular expression term queries. + /// "term queries" means that OpenSearch will apply the regexp to the terms produced + /// by the tokenizer for that field, and not to the original text of the field. + /// </summary> + public QueryContainer Regexp(Func<RegexpQueryDescriptor<T>, IRegexpQuery> selector) => + WrapInContainer(selector, (query, container) => container.Regexp = query); + + /// <summary> + /// The function_score query allows you to modify the score of documents that are retrieved by a query. + /// This can be useful if, for example, a score function is computationally expensive and it is + /// sufficient to compute the score on a filtered set of documents. + /// </summary> + /// <returns></returns> + public QueryContainer FunctionScore(Func<FunctionScoreQueryDescriptor<T>, IFunctionScoreQuery> selector) => + WrapInContainer(selector, (query, container) => container.FunctionScore = query); + + public QueryContainer Script(Func<ScriptQueryDescriptor<T>, IScriptQuery> selector) => + WrapInContainer(selector, (query, container) => container.Script = query); + + public QueryContainer ScriptScore(Func<ScriptScoreQueryDescriptor<T>, IScriptScoreQuery> selector) => + WrapInContainer(selector, (query, container) => container.ScriptScore = query); + + public QueryContainer Exists(Func<ExistsQueryDescriptor<T>, IExistsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Exists = query); + + /// <summary> + /// Used to match queries stored in an index. + /// The percolate query itself contains the document that will be used as query + /// to match with the stored queries. + /// </summary> + public QueryContainer Percolate(Func<PercolateQueryDescriptor<T>, IPercolateQuery> selector) => + WrapInContainer(selector, (query, container) => container.Percolate = query); + + /// <summary> + /// Used to find child documents which belong to a particular parent. + /// </summary> + public QueryContainer ParentId(Func<ParentIdQueryDescriptor<T>, IParentIdQuery> selector) => + WrapInContainer(selector, (query, container) => container.ParentId = query); + + /// <summary> + /// Returns any documents that match with at least one or more of the provided terms. + /// The terms are not analyzed and thus must match exactly. The number of terms that must match varies + /// per document and is either controlled by a minimum should match field or + /// computed per document in a minimum should match script. + /// </summary> + public QueryContainer TermsSet(Func<TermsSetQueryDescriptor<T>, ITermsSetQuery> selector) => + WrapInContainer(selector, (query, container) => container.TermsSet = query); + + public QueryContainer Neural(Func<NeuralQueryDescriptor<T>, INeuralQuery> selector) => + WrapInContainer(selector, (query, container) => container.Neural = query); + } } diff --git a/src/OpenSearch.Client/QueryDsl/Query.cs b/src/OpenSearch.Client/QueryDsl/Query.cs index 84796d0636..67c023b5ae 100644 --- a/src/OpenSearch.Client/QueryDsl/Query.cs +++ b/src/OpenSearch.Client/QueryDsl/Query.cs @@ -123,6 +123,9 @@ public static QueryContainer MultiMatch(Func<MultiMatchQueryDescriptor<T>, IMult public static QueryContainer Nested(Func<NestedQueryDescriptor<T>, INestedQuery> selector) => new QueryContainerDescriptor<T>().Nested(selector); + public static QueryContainer Neural(Func<NeuralQueryDescriptor<T>, INeuralQuery> selector) => + new QueryContainerDescriptor<T>().Neural(selector); + public static QueryContainer ParentId(Func<ParentIdQueryDescriptor<T>, IParentIdQuery> selector) => new QueryContainerDescriptor<T>().ParentId(selector); diff --git a/src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs b/src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs new file mode 100644 index 0000000000..f97080d694 --- /dev/null +++ b/src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System.Runtime.Serialization; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +/// <summary> +/// A neural query. +/// </summary> +[InterfaceDataContract] +[JsonFormatter(typeof(FieldNameQueryFormatter<NeuralQuery, INeuralQuery>))] +public interface INeuralQuery : IFieldNameQuery +{ + /// <summary> + /// The query text from which to produce queries. + /// </summary> + [DataMember(Name = "query_text")] + string QueryText { get; set; } + + /// <summary> + /// The number of results the k-NN search returns. + /// </summary> + [DataMember(Name = "k")] + int? K { get; set; } + + /// <summary> + /// The ID of the model that will be used in the embedding interface. + /// The model must be indexed in OpenSearch before it can be used in Neural Search. + /// </summary> + [DataMember(Name = "model_id")] + string ModelId { get; set; } +} + +[DataContract] +public class NeuralQuery : FieldNameQueryBase, INeuralQuery +{ + /// <inheritdoc /> + public string QueryText { get; set; } + /// <inheritdoc /> + public int? K { get; set; } + /// <inheritdoc /> + public string ModelId { get; set; } + + protected override bool Conditionless => IsConditionless(this); + + internal override void InternalWrapInContainer(IQueryContainer container) => container.Neural = this; + + internal static bool IsConditionless(INeuralQuery q) => string.IsNullOrEmpty(q.QueryText) || q.K == null || q.K == 0 || string.IsNullOrEmpty(q.ModelId) || q.Field.IsConditionless(); +} + +public class NeuralQueryDescriptor<T> + : FieldNameQueryDescriptorBase<NeuralQueryDescriptor<T>, INeuralQuery, T>, + INeuralQuery + where T : class +{ + protected override bool Conditionless => NeuralQuery.IsConditionless(this); + string INeuralQuery.QueryText { get; set; } + int? INeuralQuery.K { get; set; } + string INeuralQuery.ModelId { get; set; } + + /// <inheritdoc cref="INeuralQuery.QueryText" /> + public NeuralQueryDescriptor<T> QueryText(string queryText) => Assign(queryText, (a, v) => a.QueryText = v); + + /// <inheritdoc cref="INeuralQuery.K" /> + public NeuralQueryDescriptor<T> K(int? k) => Assign(k, (a, v) => a.K = v); + + /// <inheritdoc cref="INeuralQuery.ModelId" /> + public NeuralQueryDescriptor<T> ModelId(string modelId) => Assign(modelId, (a, v) => a.ModelId = v); +} diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs b/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs index 2608c09ac9..82a9700a61 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs @@ -179,6 +179,8 @@ private void WriteShape(IGeoShape shape, IFieldLookup indexedField, Field field, public virtual void Visit(INestedQuery query) => Write("nested"); + public virtual void Visit(INeuralQuery query) => Write("neural", query.Field); + public virtual void Visit(IPrefixQuery query) => Write("prefix"); public virtual void Visit(IQueryStringQuery query) => Write("query_string"); diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs b/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs index 4440578ab7..58bbb302e7 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs @@ -100,6 +100,8 @@ public interface IQueryVisitor void Visit(INestedQuery query); + void Visit(INeuralQuery query); + void Visit(IPrefixQuery query); void Visit(IQueryStringQuery query); @@ -247,6 +249,8 @@ public virtual void Visit(IMultiMatchQuery query) { } public virtual void Visit(INestedQuery query) { } + public virtual void Visit(INeuralQuery query) { } + public virtual void Visit(IPrefixQuery query) { } public virtual void Visit(IQueryStringQuery query) { } diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs b/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs index 2ff147331b..5a8697dd89 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs @@ -83,6 +83,7 @@ public void Walk(IQueryContainer qd, IQueryVisitor visitor) VisitQuery(qd.Percolate, visitor, (v, d) => v.Visit(d)); VisitQuery(qd.ParentId, visitor, (v, d) => v.Visit(d)); VisitQuery(qd.TermsSet, visitor, (v, d) => v.Visit(d)); + VisitQuery(qd.Neural, visitor, (v, d) => v.Visit(d)); VisitQuery(qd.Bool, visitor, (v, d) => { diff --git a/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs b/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs index 1f8457c82f..320aab24ea 100644 --- a/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs +++ b/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs @@ -47,10 +47,12 @@ private static ClientTestClusterConfiguration CreateConfiguration() => AnalysisIcu, AnalysisKuromoji, AnalysisNori, AnalysisPhonetic, IngestAttachment, IngestGeoIp, Knn, + MachineLearning, MapperMurmur3, Security) { - MaxConcurrency = 4 + MaxConcurrency = 4, + ValidatePluginsToInstall = false }; protected override void SeedNode() diff --git a/tests/Tests/Ingest/ProcessorAssertions.cs b/tests/Tests/Ingest/ProcessorAssertions.cs index 5e2b873a2a..9b70cb0fce 100644 --- a/tests/Tests/Ingest/ProcessorAssertions.cs +++ b/tests/Tests/Ingest/ProcessorAssertions.cs @@ -30,6 +30,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using JetBrains.Annotations; using OpenSearch.OpenSearch.Xunit.XunitPlumbing; using OpenSearch.Client; using Tests.Core.Client; @@ -62,11 +63,21 @@ public abstract class ProcessorAssertion : IProcessorAssertion public static class ProcessorAssertions { public static IEnumerable<IProcessorAssertion> All => - from t in typeof(ProcessorAssertions).GetNestedTypes() - where typeof(IProcessorAssertion).IsAssignableFrom(t) && t.IsClass - let a = t.GetCustomAttributes(typeof(SkipVersionAttribute)).FirstOrDefault() as SkipVersionAttribute - where a == null || !a.Ranges.Any(r => r.IsSatisfied(TestClient.Configuration.OpenSearchVersion)) - select (IProcessorAssertion)Activator.CreateInstance(t); + typeof(ProcessorAssertions).GetNestedTypes() + .Where(t => + { + if (!t.IsClass || !typeof(IProcessorAssertion).IsAssignableFrom(t)) return false; + + var skipVersion = t.GetCustomAttributes<SkipVersionAttribute>().FirstOrDefault(); + if (skipVersion != null && skipVersion.Ranges.Any(r => r.IsSatisfied(TestClient.Configuration.OpenSearchVersion))) + return false; + + var skipPrereleases = t.GetCustomAttributes<SkipPrereleaseVersionsAttribute>().FirstOrDefault(); + if (skipPrereleases != null && TestClient.Configuration.OpenSearchVersion.IsPreRelease) return false; + + return true; + }) + .Select(t => (IProcessorAssertion)Activator.CreateInstance(t)); public static IProcessor[] Initializers => All.Select(a => a.Initializer).ToArray(); @@ -592,5 +603,45 @@ public class Pipeline : ProcessorAssertion public override string Key => "pipeline"; } + + [SkipVersion("<2.4.0", "neural search plugin was released with v2.4.0")] + [SkipPrereleaseVersions("Prerelease versions of OpenSearch do not include the ML & Neural Search plugins")] + public class TextEmbedding : ProcessorAssertion + { + private class NeuralSearchDoc + { + [PropertyName("text")] + public string Text { get; set; } + + [PropertyName("passage_embedding")] + public float[] PassageEmbedding { get; set; } + } + + public override ProcFunc Fluent => d => d + .TextEmbedding<NeuralSearchDoc>(te => te + .ModelId("someModel-abcdef") + .FieldMap(f => f + .Map(doc => doc.Text, doc => doc.PassageEmbedding))); + + public override IProcessor Initializer => new TextEmbeddingProcessor + { + ModelId = "someModel-abcdef", + FieldMap = new InferenceFieldMap + { + {new Field((NeuralSearchDoc d) => d.Text), new Field((NeuralSearchDoc d) => d.PassageEmbedding)} + } + }; + + public override object Json => new + { + model_id = "someModel-abcdef", + field_map = new + { + text = "passage_embedding" + } + }; + + public override string Key => "text_embedding"; + } } } diff --git a/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs b/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs index 3d064a9d4e..53c06a61ed 100644 --- a/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs +++ b/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs @@ -55,7 +55,7 @@ public PutPipelineApiTests(WritableCluster cluster, EndpointUsage usage) : base( processors = ProcessorAssertions.AllAsJson }; -protected override int ExpectStatusCode => 200; + protected override int ExpectStatusCode => 200; protected override Func<PutPipelineDescriptor, IPutPipelineRequest> Fluent => d => d .Description("My test pipeline") diff --git a/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs b/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs index eaeaff5800..b2cd56a7d6 100644 --- a/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs +++ b/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs @@ -34,6 +34,7 @@ using FluentAssertions; using OpenSearch.Client; using Newtonsoft.Json; +using OpenSearch.OpenSearch.Ephemeral; using Tests.Core.Client; using Tests.Core.Extensions; using Tests.Core.ManagedOpenSearch.Clusters; @@ -41,105 +42,111 @@ using Tests.Framework.EndpointTests; using Tests.Framework.EndpointTests.TestState; -namespace Tests.QueryDsl +namespace Tests.QueryDsl; + +public abstract class QueryDslUsageTestsBase<TCluster, TDocument> + : ApiTestBase<TCluster, ISearchResponse<TDocument>, ISearchRequest, SearchDescriptor<TDocument>, SearchRequest<TDocument>> + where TCluster : IEphemeralCluster<EphemeralClusterConfiguration>, IOpenSearchClientTestCluster, new() + where TDocument : class { - public abstract class QueryDslUsageTestsBase - : ApiTestBase<ReadOnlyCluster, ISearchResponse<Project>, ISearchRequest, SearchDescriptor<Project>, SearchRequest<Project>> - { - protected readonly QueryContainer ConditionlessQuery = new QueryContainer(new TermQuery()); + protected QueryDslUsageTestsBase(TCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected abstract IndexName IndexName { get; } + protected abstract string ExpectedIndexString { get; } - protected readonly QueryContainer VerbatimQuery = new QueryContainer(new TermQuery { IsVerbatim = true }); + protected virtual ConditionlessWhen ConditionlessWhen => null; - protected byte[] ShortFormQuery => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new { description = "project description" })); + protected override object ExpectJson => new { query = QueryJson }; - protected QueryDslUsageTestsBase(ReadOnlyCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + protected override Func<SearchDescriptor<TDocument>, ISearchRequest> Fluent => s => s + .Index(IndexName) + .Query(QueryFluent); - protected virtual ConditionlessWhen ConditionlessWhen => null; + protected override HttpMethod HttpMethod => HttpMethod.POST; - protected override object ExpectJson => new { query = QueryJson }; + protected override SearchRequest<TDocument> Initializer => + new(IndexName) + { + Query = QueryInitializer + }; - protected override Func<SearchDescriptor<Project>, ISearchRequest> Fluent => s => s - .Query(q => QueryFluent(q)); + protected virtual NotConditionlessWhen NotConditionlessWhen => null; - protected override HttpMethod HttpMethod => HttpMethod.POST; + protected abstract QueryContainer QueryInitializer { get; } - protected override SearchRequest<Project> Initializer => - new SearchRequest<Project> - { - Query = QueryInitializer - }; + protected abstract object QueryJson { get; } + protected override string UrlPath => $"/{ExpectedIndexString}/_search"; - protected virtual bool KnownParseException => false; + protected override LazyResponses ClientUsage() => Calls( + (client, f) => client.Search(f), + (client, f) => client.SearchAsync(f), + (client, r) => client.Search<TDocument>(r), + (client, r) => client.SearchAsync<TDocument>(r) + ); - protected virtual NotConditionlessWhen NotConditionlessWhen => null; + protected abstract QueryContainer QueryFluent(QueryContainerDescriptor<TDocument> q); - protected abstract QueryContainer QueryInitializer { get; } + [U] public void FluentIsNotConditionless() => + AssertIsNotConditionless(QueryFluent(new QueryContainerDescriptor<TDocument>())); - protected abstract object QueryJson { get; } - protected override string UrlPath => "/project/_search"; + [U] public void InitializerIsNotConditionless() => AssertIsNotConditionless(QueryInitializer); - protected override LazyResponses ClientUsage() => Calls( - (client, f) => client.Search(f), - (client, f) => client.SearchAsync(f), - (client, r) => client.Search<Project>(r), - (client, r) => client.SearchAsync<Project>(r) - ); + private void AssertIsNotConditionless(IQueryContainer c) + { + if (!c.IsVerbatim) + c.IsConditionless.Should().BeFalse(); + } - protected abstract QueryContainer QueryFluent(QueryContainerDescriptor<Project> q); + [U] public void SeenByVisitor() + { + var visitor = new DslPrettyPrintVisitor(TestClient.DefaultInMemoryClient.ConnectionSettings); + var query = QueryFluent(new QueryContainerDescriptor<TDocument>()); + query.Should().NotBeNull("query evaluated to null which implies it may be conditionless"); + query.Accept(visitor); + var pretty = visitor.PrettyPrint; + pretty.Should().NotBeNullOrWhiteSpace(); + } - [U] public void FluentIsNotConditionless() => - AssertIsNotConditionless(QueryFluent(new QueryContainerDescriptor<Project>())); + [U] public void ConditionlessWhenExpectedToBe() + { + if (ConditionlessWhen == null) return; - [U] public void InitializerIsNotConditionless() => AssertIsNotConditionless(QueryInitializer); + foreach (var when in ConditionlessWhen) + { + when(QueryFluent(new QueryContainerDescriptor<TDocument>())); + when(QueryInitializer); + } - private void AssertIsNotConditionless(IQueryContainer c) - { - if (!c.IsVerbatim) - c.IsConditionless.Should().BeFalse(); - } + ((IQueryContainer)QueryInitializer).IsConditionless.Should().BeFalse(); + } - [U] public void SeenByVisitor() - { - var visitor = new DslPrettyPrintVisitor(TestClient.DefaultInMemoryClient.ConnectionSettings); - var query = QueryFluent(new QueryContainerDescriptor<Project>()); - query.Should().NotBeNull("query evaluated to null which implies it may be conditionless"); - query.Accept(visitor); - var pretty = visitor.PrettyPrint; - pretty.Should().NotBeNullOrWhiteSpace(); - } + [U] public void NotConditionlessWhenExpectedToBe() + { + if (NotConditionlessWhen == null) return; - [U] public void ConditionlessWhenExpectedToBe() - { - if (ConditionlessWhen == null) return; + foreach (var when in NotConditionlessWhen) + { + when(QueryFluent(new QueryContainerDescriptor<TDocument>())); + when(QueryInitializer); + } + } - foreach (var when in ConditionlessWhen) - { - when(QueryFluent(new QueryContainerDescriptor<Project>())); - //this.JsonEquals(query, new { }); - when(QueryInitializer); - //this.JsonEquals(query, new { }); - } + [I] protected async Task AssertQueryResponse() => await AssertOnAllResponses(AssertQueryResponseValid); - ((IQueryContainer)QueryInitializer).IsConditionless.Should().BeFalse(); - } + protected virtual void AssertQueryResponseValid(ISearchResponse<TDocument> response) => response.ShouldBeValid(); +} + +public abstract class QueryDslUsageTestsBase + : QueryDslUsageTestsBase<ReadOnlyCluster, Project> +{ + protected static byte[] ShortFormQuery => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new { description = "project description" })); - [U] public void NotConditionlessWhenExpectedToBe() - { - if (NotConditionlessWhen == null) return; + protected static readonly QueryContainer ConditionlessQuery = new(new TermQuery()); - foreach (var when in NotConditionlessWhen) - { - var query = QueryFluent(new QueryContainerDescriptor<Project>()); - when(query); + protected static readonly QueryContainer VerbatimQuery = new(new TermQuery { IsVerbatim = true }); - query = QueryInitializer; - when(query); - } - } + protected QueryDslUsageTestsBase(ReadOnlyCluster cluster, EndpointUsage usage) : base(cluster, usage) { } - [I] protected async Task AssertQueryResponse() => await AssertOnAllResponses(r => - { - r.ShouldBeValid(); - }); - } + protected override IndexName IndexName => typeof(Project); + protected override string ExpectedIndexString => "project"; } diff --git a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs index 8150004a20..98d20ac6e4 100644 --- a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs +++ b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs @@ -6,7 +6,6 @@ */ using System; -using System.Linq; using System.Threading.Tasks; using FluentAssertions; using OpenSearch.Client; diff --git a/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs new file mode 100644 index 0000000000..28745bc74c --- /dev/null +++ b/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs @@ -0,0 +1,295 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Linq; +using System.Threading; +using FluentAssertions; +using OpenSearch.Client; +using OpenSearch.Net; +using OpenSearch.OpenSearch.Xunit.XunitPlumbing; +using OpenSearch.Stack.ArtifactsApi.Products; +using Tests.Core.Extensions; +using Tests.Core.ManagedOpenSearch.Clusters; +using Tests.Framework.EndpointTests.TestState; +using Version = SemanticVersioning.Version; + +namespace Tests.QueryDsl.Specialized.Neural; + +public class NeuralQueryCluster : ClientTestClusterBase +{ + public NeuralQueryCluster() : base(CreateConfiguration()) { } + + private static ClientTestClusterConfiguration CreateConfiguration() + { + var config = new ClientTestClusterConfiguration( + OpenSearchPlugin.Knn, + OpenSearchPlugin.MachineLearning, + OpenSearchPlugin.NeuralSearch, + OpenSearchPlugin.Security) + { + MaxConcurrency = 4, + ValidatePluginsToInstall = false, + }; + + config.DefaultNodeSettings.Add("plugins.ml_commons.only_run_on_ml_node", "false"); + config.DefaultNodeSettings.Add("plugins.ml_commons.native_memory_threshold", "99"); + config.DefaultNodeSettings.Add("plugins.ml_commons.model_access_control_enabled", "true", ">=2.8.0"); + + return config; + } +} + +public class NeuralSearchDoc +{ + [PropertyName("id")] public string Id { get; set; } + [PropertyName("text")] public string Text { get; set; } + [PropertyName("passage_embedding")] public float[] PassageEmbedding { get; set; } +} + +[SkipVersion("<2.6.0", "Avoid the various early permutations of the ML APIs")] +public class NeuralQueryUsageTests + : QueryDslUsageTestsBase<NeuralQueryCluster, NeuralSearchDoc> +{ + private static readonly string TestName = nameof(NeuralQueryUsageTests).ToLowerInvariant(); + + private string _modelGroupId; + private string _modelId = "default-for-unit-tests"; + + public NeuralQueryUsageTests(NeuralQueryCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected override IndexName IndexName => TestName; + protected override string ExpectedIndexString => TestName; + + protected override QueryContainer QueryInitializer => new NeuralQuery + { + Field = Infer.Field<NeuralSearchDoc>(d => d.PassageEmbedding), + QueryText = "wild west", + K = 5, + ModelId = _modelId + }; + + protected override object QueryJson => new + { + neural = new + { + passage_embedding = new + { + query_text = "wild west", + k = 5, + model_id = _modelId + } + } + }; + + protected override QueryContainer QueryFluent(QueryContainerDescriptor<NeuralSearchDoc> q) => q + .Neural(n => n + .Field(f => f.PassageEmbedding) + .QueryText("wild west") + .K(5) + .ModelId(_modelId)); + + protected override ConditionlessWhen ConditionlessWhen => new ConditionlessWhen<INeuralQuery>(a => a.Neural) + { + q => + { + q.Field = null; + q.QueryText = "wild west"; + q.K = 5; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = null; + q.K = 5; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = ""; + q.K = 5; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = null; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = 0; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = 5; + q.ModelId = null; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = 5; + q.ModelId = ""; + } + }; + + protected override void IntegrationSetup(IOpenSearchClient client, CallUniqueValues values) + { + var baseVersion = Cluster.ClusterConfiguration.Version.BaseVersion(); + var renamedToRegisterDeploy = baseVersion >= new Version("2.7.0"); + var hasModelAccessControl = baseVersion >= new Version("2.8.0"); + + if (hasModelAccessControl) + { + var registerModelGroupResp = client.Http.Post<DynamicResponse>( + "/_plugins/_ml/model_groups/_register", + r => r.SerializableBody(new + { + name = TestName, + access_mode = "public", + model_access_mode = "public" + })); + registerModelGroupResp.ShouldBeCreated(); + _modelGroupId = (string)registerModelGroupResp.Body.model_group_id; + } + + var registerModelResp = client.Http.Post<DynamicResponse>( + $"/_plugins/_ml/models/{(renamedToRegisterDeploy ? "_register" : "_upload")}", + r => r.SerializableBody(new + { + name = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b", + version = "1.0.1", + model_group_id = _modelGroupId, + model_format = "TORCH_SCRIPT" + })); + registerModelResp.ShouldBeCreated(); + var modelRegistrationTaskId = (string) registerModelResp.Body.task_id; + + while (true) + { + var getTaskResp = client.Http.Get<DynamicResponse>($"/_plugins/_ml/tasks/{modelRegistrationTaskId}"); + getTaskResp.ShouldNotBeFailed(); + if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) + { + _modelId = getTaskResp.Body.model_id; + break; + } + Thread.Sleep(5000); + } + + var deployModelResp = client.Http.Post<DynamicResponse>($"/_plugins/_ml/models/{_modelId}/{(renamedToRegisterDeploy ? "_deploy" : "_load")}"); + deployModelResp.ShouldBeCreated(); + var modelDeployTaskId = (string) deployModelResp.Body.task_id; + + while (true) + { + var getTaskResp = client.Http.Get<DynamicResponse>($"/_plugins/_ml/tasks/{modelDeployTaskId}"); + getTaskResp.ShouldNotBeFailed(); + if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break; + Thread.Sleep(5000); + } + + var putIngestPipelineResp = client.Ingest.PutPipeline(TestName, p => p + .Processors(pp => pp + .TextEmbedding<NeuralSearchDoc>(te => te + .ModelId(_modelId) + .FieldMap(fm => fm + .Map(d => d.Text, d => d.PassageEmbedding))))); + putIngestPipelineResp.ShouldBeValid(); + + var createIndexResp = client.Indices.Create( + IndexName, + i => i + .Settings(s => s + .Setting("index.knn", true) + .DefaultPipeline(TestName)) + .Map<NeuralSearchDoc>(m => m + .Properties(p => p + .Text(t => t.Name(d => d.Id)) + .Text(t => t.Name(d => d.Text)) + .KnnVector(k => k + .Name(d => d.PassageEmbedding) + .Dimension(768) + .Method(km => km + .Engine("lucene") + .SpaceType("l2") + .Name("hnsw")))))); + createIndexResp.ShouldBeValid(); + + var documents = new NeuralSearchDoc[] + { + new() { Id = "4319130149.jpg", Text = "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena ." }, + new() { Id = "1775029934.jpg", Text = "A wild animal races across an uncut field with a minimal amount of trees ." }, + new() { Id = "2664027527.jpg", Text = "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco ." }, + new() { Id = "4427058951.jpg", Text = "A man who is riding a wild horse in the rodeo is very near to falling off ." }, + new() { Id = "2691147709.jpg", Text = "A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse ." } + }; + var bulkResp = client.Bulk(b => b + .Index(IndexName) + .IndexMany(documents) + .Refresh(Refresh.WaitFor)); + bulkResp.ShouldBeValid(); + } + + protected override void AssertQueryResponseValid(ISearchResponse<NeuralSearchDoc> response) + { + base.AssertQueryResponseValid(response); + + response.Hits.Should().HaveCount(5); + var hit = response.Hits.First(); + + hit.Id.Should().Be("4427058951.jpg"); + hit.Score.Should().BeApproximately(0.01585195, 0.00000001); + hit.Source.Text.Should().Be("A man who is riding a wild horse in the rodeo is very near to falling off ."); + hit.Source.PassageEmbedding.Should().HaveCount(768); + } + + protected override void IntegrationTeardown(IOpenSearchClient client, CallUniqueValues values) + { + client.Indices.Delete(IndexName); + client.Ingest.DeletePipeline(TestName); + + if (_modelId != "default-for-unit-tests") + { + while (true) + { + var deleteModelResp = client.Http.Delete<DynamicResponse>($"/_plugins/_ml/models/{_modelId}"); + if (deleteModelResp.Success || !(((string)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false)) break; + + client.Http.Post<DynamicResponse>($"/_plugins/_ml/models/{_modelId}/_undeploy"); + Thread.Sleep(5000); + } + } + + if (_modelGroupId != null) + { + client.Http.Delete<DynamicResponse>($"/_plugins/_ml/model_groups/{_modelGroupId}"); + } + } +} + +internal static class Helpers +{ + public static void ShouldBeCreated(this DynamicResponse r) + { + if (!r.Success || r.Body.status != "CREATED") throw new Exception("Expected to be created, was: " + r.DebugInformation); + } + + public static void ShouldNotBeFailed(this DynamicResponse r) + { + if (!r.Success || r.Body.state == "FAILED") throw new Exception("Expected to not be failed, was: " + r.DebugInformation); + } +}