diff --git a/OnnxSharp.sln b/OnnxSharp.sln index 86157b0..4062af0 100644 --- a/OnnxSharp.sln +++ b/OnnxSharp.sln @@ -22,6 +22,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{604A7FA2-1 rename.ps1 = rename.ps1 restore.ps1 = restore.ps1 test.ps1 = test.ps1 + update-tool-from-build.ps1 = update-tool-from-build.ps1 update.ps1 = update.ps1 EndProjectSection EndProject @@ -38,6 +39,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxSharp.Test", "src\OnnxS EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "dotnet-onnx", "src\dotnet-onnx\dotnet-onnx.csproj", "{DA6F8267-24F1-4104-AB36-10C97B899A8E}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "github-workflows", "github-workflows", "{1EA6AA06-C436-482C-A14C-1AC414D5552F}" + ProjectSection(SolutionItems) = preProject + .github\workflows\dotnet.yml = .github\workflows\dotnet.yml + EndProjectSection +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU diff --git a/README.md b/README.md index 774d886..1ee7e11 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,64 @@ ![Build and test](https://github.com/nietras/OnnxSharp/workflows/.NET/badge.svg) -[![NuGet](https://img.shields.io/nuget/v/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/) -[![Downloads](https://img.shields.io/nuget/dt/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/) [![Stars](https://img.shields.io/github/stars/nietras/OnnxSharp)](https://github.com/nietras/OnnxSharp/stargazers) [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md) -# OnnxSharp -ONNX format parsing and manipulation in C#. +|What |Links and Status| +|---------------|------| +|`OnnxSharp` |[![NuGet](https://img.shields.io/nuget/v/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/) [![Downloads](https://img.shields.io/nuget/dt/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/) | +|`dotnet-onnx`|[![NuGet](https://img.shields.io/nuget/v/dotnet-onnx)](https://www.nuget.org/packages/dotnet-onnx/) [![Downloads](https://img.shields.io/nuget/dt/dotnet-onnx)](https://www.nuget.org/packages/dotnet-onnx/) | -# Status -Pretty much just: +# `OnnxSharp` library and `dotnet-onnx` tool +ONNX format parsing and manipulation in C# and with command line .NET tool. + +# Quick Guide +Install latest version of .NET: +* PowerShell (Windows): [https://dot.net/v1/dotnet-install.ps1](https://dot.net/v1/dotnet-install.ps1) +* Bash (Linux/macOS): [https://dot.net/v1/dotnet-install.sh](https://dot.net/v1/dotnet-install.sh) + +#### Code +|What |How | +|--------------|---------------------------------------------------| +|Install |`dotnet add PROJECT.csproj package OnnxSharp`| +|Parse |`var model = ModelProto.Parser.ParseFromFile("mnist-8.onnx");`| +|Info |`var info = model.Graph.Info();`| +|Clean |`model.Graph.Clean();`| +|SetDim |`model.Graph.SetDim();`| +|Write |`model.WriteToFile("mnist-8-clean-dynamic.onnx");`| + +#### Tool +|What |How | +|--------------|----------------------------| +|Install |`dotnet tool install dotnet-onnx -g`| +|Info |`dotnet onnx info mnist-8.onnx`| +|Info |`dotnet onnx info mnist-8.onnx`| +|Clean |`dotnet onnx clean mnist-8.onnx mnist-8-clean.onnx`| +|SetDim |`dotnet onnx setdim mnist-8.onnx mnist-8-setdim.onnx`| + +# Source Code +Base functionality is based on: ``` .\protoc.exe .\onnx.proto3 --csharp_out=OnnxSharp ``` +Everything else written in beautiful C# 9.0 as extensions to this. # Example Code ```csharp -using System.IO; -using Google.Protobuf; +using Onnx; // Examples see https://github.com/onnx/models -var onnxFilePath = @"mnist-8.onnx"; +var onnxInputFilePath = @"mnist-8.onnx"; -using var fileStream = File.OpenRead(onnxFilePath); - -var model = Onnx.ModelProto.Parser.ParseFrom(fileStream); +var model = ModelProto.Parser.ParseFromFile(onnxInputFilePath); var graph = model.Graph; -var inputs = graph.Input; -var valueInfos = graph.ValueInfo; -var outputs = graph.Output; +// Clean graph e.g. remove initializers from inputs that may prevent constant folding +graph.Clean(); +// Set dimension in graph to enable dynamic batch size during inference +graph.SetDim(dimIndex: 0, DimParamOrValue.New("N")); +// Get summarized info about the graph +var info = graph.Info(); + +System.Console.WriteLine(info); + +model.WriteToFile(@"mnist-8-clean-dynamic-batch-size.onnx"); ``` \ No newline at end of file diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 6e5e320..c01411a 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -4,8 +4,8 @@ nietras Copyright © nietras 2021 en - 0.1.0.0 - 0.1.1 + 0.2.0.0 + 0.2.0 $(FileVersion) $(InformationalVersion) diff --git a/src/OnnxSharp.Test/AssemblyResourceLoader.cs b/src/OnnxSharp.Test/AssemblyResourceLoader.cs index 6bb431a..b52ffed 100644 --- a/src/OnnxSharp.Test/AssemblyResourceLoader.cs +++ b/src/OnnxSharp.Test/AssemblyResourceLoader.cs @@ -7,14 +7,23 @@ namespace OnnxSharp.Test { public static class AssemblyResourceLoader { - public static readonly string ResourceNamespace = typeof(AssemblyResourceLoader).Assembly.GetName().Name; + public static readonly string ResourceNamespace = + typeof(AssemblyResourceLoader).Assembly.GetName().Name; public const string ResourceNamePrefix = ""; - public static string[] GetStringArray(string resourceName) + public static byte[] GetBytes(string resourceName) { - return GetString(resourceName).Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries); + using (var stream = GetStream(resourceName)) + using (var memoryStream = new MemoryStream()) + { + stream.CopyTo(memoryStream); + return memoryStream.ToArray(); + } } + public static string[] GetLines(string resourceName) => GetString(resourceName) + .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries); + public static string GetString(string resourceName) { using (var stream = GetStream(resourceName)) @@ -24,10 +33,8 @@ public static string GetString(string resourceName) } } - public static string GetFullResourceName(string resourceName) - { - return ResourceNamePrefix + resourceName; - } + public static string GetFullResourceName(string resourceName) => + ResourceNamePrefix + resourceName; public static string FindResourceName(Func filter) { @@ -35,14 +42,18 @@ public static string FindResourceName(Func filter) if (names.Length == 0) { - throw new ArgumentException("Could not find any resource. The desired file might not have been defined as Embedded Resource."); + throw new ArgumentException("Could not find any resource. " + + "The desired file might not have been defined as Embedded Resource."); } else if (names.Length != 1) { - throw new ArgumentException($"Ambiguous name, cannot identify resource - found {names.Length} possible candidates."); + throw new ArgumentException($"Ambiguous name, cannot identify resource - " + + $"found {names.Length} possible candidates."); + } + else + { + return names[0]; } - - return names.Single(); } public static string[] FindResourceNames(Func filter) @@ -66,7 +77,8 @@ public static Stream GetStream(string resourceName) var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(fullResourceName); if (stream == null) { - throw new ArgumentException($"Could not find resource '{resourceName}'. The desired file might not have been defined as Embedded Resource."); + throw new ArgumentException($"Could not find resource '{resourceName}'. " + + $"The desired file might not have been defined as Embedded Resource."); } return stream; } diff --git a/src/OnnxSharp.Test/GraphExtensionsTest.cs b/src/OnnxSharp.Test/GraphExtensionsTest.cs new file mode 100644 index 0000000..17dd639 --- /dev/null +++ b/src/OnnxSharp.Test/GraphExtensionsTest.cs @@ -0,0 +1,171 @@ +using System; +using System.IO; +using Google.Protobuf; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Onnx; + +namespace OnnxSharp.Test +{ + [TestClass] + public class GraphExtensionsTest + { + readonly Func m_createStream = () => AssemblyResourceLoader.GetStream("mnist-8.onnx"); + + [TestMethod] + public void ParseFrom() + { + // Act + var model = ModelProto.Parser.ParseFrom(m_createStream); + + // Assert + var graph = model.Graph; + // 9 inputs since includes initializers + Assert.AreEqual(9, graph.Input.Count); + Assert.AreEqual(1, graph.Output.Count); + } + + [TestMethod] + public void Info() + { + // Arrange + var model = ModelProto.Parser.ParseFrom(m_createStream); + + // Act + var actual = model.Graph.Info(); + + // Assert + var expected = ExpectedInfo; + Assert.AreEqual(expected, actual); + } + + [TestMethod] + public void Clean() + { + // Arrange + var model = ModelProto.Parser.ParseFrom(m_createStream); + + // Act + model.Graph.Clean(); + + // Assert + var graph = model.Graph; + Assert.AreEqual(1, graph.Input.Count); + Assert.AreEqual(1, graph.Output.Count); + var expectedName = $"mnist-8-expected-{nameof(Clean)}.onnx"; + AssertModelBytesEqualToEmbeddedExpected(model, expectedName); + } + + [TestMethod] + public void RemoveInitializersFromInputs() + { + // Arrange + var model = ModelProto.Parser.ParseFrom(m_createStream); + + // Act + model.Graph.RemoveInitializersFromInputs(); + + // Assert + var graph = model.Graph; + Assert.AreEqual(1, graph.Input.Count); + Assert.AreEqual(1, graph.Output.Count); + var expectedName = $"mnist-8-expected-{nameof(RemoveInitializersFromInputs)}.onnx"; + AssertModelBytesEqualToEmbeddedExpected(model, expectedName); + } + + [TestMethod] + public void RemoveUnnecessaryInitializerReshapes() + { + // Arrange + var model = ModelProto.Parser.ParseFrom(m_createStream); + + // Act + model.Graph.RemoveUnnecessaryInitializerReshapes(); + + // Assert + var graph = model.Graph; + Assert.AreEqual(8, graph.Input.Count); + Assert.AreEqual(1, graph.Output.Count); + var expectedName = $"mnist-8-expected-{nameof(RemoveUnnecessaryInitializerReshapes)}.onnx"; + AssertModelBytesEqualToEmbeddedExpected(model, expectedName); + } + + [TestMethod] + public void SetDim() + { + // Arrange + var model = ModelProto.Parser.ParseFrom(m_createStream); + + // Act + model.Graph.SetDim(dimIndex: 0, DimParamOrValue.New("N")); + + // Assert + var graph = model.Graph; + Assert.AreEqual(9, graph.Input.Count); + Assert.AreEqual(1, graph.Output.Count); + var expectedName = $"mnist-8-expected-{nameof(SetDim)}.onnx"; + AssertModelBytesEqualToEmbeddedExpected(model, expectedName); + } + + static void AssertModelBytesEqualToEmbeddedExpected(ModelProto model, string expectedName) + { + var actualBytes = model.ToByteArray(); + //model.WriteToFile(expectedName); + var expectedBytes = AssemblyResourceLoader.GetBytes(expectedName); + CollectionAssert.AreEqual(expectedBytes, actualBytes); + } + + const string ExpectedInfo = @"## Inputs without Initializer +### Tensors +|Name |Type |ElemType|Shape |SizeInFile| +|:-----|:---------|:-------|--------:|---------:| +|Input3|TensorType|Float |1x1x28x28| 32| + +## Outputs +### Tensors +|Name |Type |ElemType|Shape|SizeInFile| +|:---------------|:---------|:-------|----:|---------:| +|Plus214_Output_0|TensorType|Float | 1x10| 34| + +## Inputs with Initializer +### Tensors +|Name |Type |ElemType|Shape |SizeInFile| +|:---------------------------------|:---------|:-------|--------:|---------:| +|Parameter5 |TensorType|Float | 8x1x5x5| 36| +|Parameter6 |TensorType|Float | 8x1x1| 32| +|Parameter87 |TensorType|Float | 16x8x5x5| 37| +|Parameter88 |TensorType|Float | 16x1x1| 33| +|Pooling160_Output_0_reshape0_shape|TensorType|Int64 | 2| 48| +|Parameter193 |TensorType|Float |16x4x4x10| 38| +|Parameter193_reshape1_shape |TensorType|Int64 | 2| 41| +|Parameter194 |TensorType|Float | 1x10| 30| + +## Initializers (Parameters etc.) +|Name |DataType|Dims |Π(Dims)|[v0,v1..vN] | (Min,Mean,Max) |SizeInFile| +|:---------------------------------|:-------|--------:|------:|-----------------------------------:|---------:| +|Parameter193 |Float |16x4x4x10| 2560|(-7.595E-001,-1.779E-003,1.186E+000)| 10265| +|Parameter87 |Float | 16x8x5x5| 3200|(-5.089E-001,-3.028E-002,5.647E-001)| 12824| +|Parameter5 |Float | 8x1x5x5| 200|(-9.727E-001,-7.360E-003,1.019E+000)| 823| +|Parameter6 |Float | 8x1x1| 8|(-4.338E-001,-1.023E-001,9.164E-002)| 53| +|Parameter88 |Float | 16x1x1| 16|(-4.147E-001,-1.554E-001,1.328E-002)| 86| +|Pooling160_Output_0_reshape0_shape|Int64 | 2| 2| [1,256]| 46| +|Parameter193_reshape1_shape |Int64 | 2| 2| [256,10]| 39| +|Parameter194 |Float | 1x10| 10|(-1.264E-001,-4.777E-006,1.402E-001)| 62| + +## Value Infos +### Tensors +|Name |Type |ElemType|Shape |SizeInFile| +|:---------------------------|:---------|:-------|---------:|---------:| +|Parameter193_reshape1 |TensorType|Float | 256x10| 40| +|Convolution28_Output_0 |TensorType|Float | 1x8x28x28| 48| +|Plus30_Output_0 |TensorType|Float | 1x8x28x28| 41| +|ReLU32_Output_0 |TensorType|Float | 1x8x28x28| 41| +|Pooling66_Output_0 |TensorType|Float | 1x8x14x14| 44| +|Convolution110_Output_0 |TensorType|Float |1x16x14x14| 49| +|Plus112_Output_0 |TensorType|Float |1x16x14x14| 42| +|ReLU114_Output_0 |TensorType|Float |1x16x14x14| 42| +|Pooling160_Output_0 |TensorType|Float | 1x16x4x4| 45| +|Pooling160_Output_0_reshape0|TensorType|Float | 1x256| 47| +|Times212_Output_0 |TensorType|Float | 1x10| 35| +"; + } +} diff --git a/src/OnnxSharp.Test/MnistTest.cs b/src/OnnxSharp.Test/MnistTest.cs deleted file mode 100644 index 818b551..0000000 --- a/src/OnnxSharp.Test/MnistTest.cs +++ /dev/null @@ -1,41 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Onnx; - -namespace OnnxSharp.Test -{ - [TestClass] - public class MnistTest - { - [TestMethod] - public void ParseFrom() - { - // Arrange - using var stream = AssemblyResourceLoader.GetStream("mnist-8.onnx"); - - // Act - var model = ModelProto.Parser.ParseFrom(stream); - - // Assert - var graph = model.Graph; - // 9 inputs since includes initializers - Assert.AreEqual(9, graph.Input.Count); - Assert.AreEqual(1, graph.Output.Count); - } - - [TestMethod] - public void RemoveInitializersFromInputs() - { - // Arrange - using var stream = AssemblyResourceLoader.GetStream("mnist-8.onnx"); - var model = ModelProto.Parser.ParseFrom(stream); - - // Act - model.RemoveInitializersFromInputs(); - - // Assert - var graph = model.Graph; - Assert.AreEqual(1, graph.Input.Count); - Assert.AreEqual(1, graph.Output.Count); - } - } -} diff --git a/src/OnnxSharp.Test/ModelProtoTestExtensions.cs b/src/OnnxSharp.Test/ModelProtoTestExtensions.cs new file mode 100644 index 0000000..8a54169 --- /dev/null +++ b/src/OnnxSharp.Test/ModelProtoTestExtensions.cs @@ -0,0 +1,25 @@ +using System.IO; +using System.Text.Json; +using Google.Protobuf; +using Onnx; + +namespace OnnxSharp.Test +{ + public static class ModelProtoTestExtensions + { + public static void WriteIndentedJsonToFile(this ModelProto model, string filePath) + { + var jsonText = JsonFormatter.Default.Format(model); + var jsonElement = JsonSerializer.Deserialize(jsonText); + + var options = new JsonSerializerOptions() { WriteIndented = true }; + + var jsonTextPretty = JsonSerializer.Serialize(jsonElement, options); + File.WriteAllText(filePath, jsonTextPretty); + // Below does not indent + //using var stream = File.Open(filePath, FileMode.Create); + //using var writer = new Utf8JsonWriter(stream); + //JsonSerializer.Serialize(writer, jsonElement, options); + } + } +} diff --git a/src/OnnxSharp.Test/OnnxSharp.Test.csproj b/src/OnnxSharp.Test/OnnxSharp.Test.csproj index 2f5cac4..94a953e 100644 --- a/src/OnnxSharp.Test/OnnxSharp.Test.csproj +++ b/src/OnnxSharp.Test/OnnxSharp.Test.csproj @@ -7,12 +7,12 @@ - net461;net50 + net461;net5.0 - + @@ -25,6 +25,10 @@ + + + + diff --git a/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx b/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx new file mode 100644 index 0000000..d708e07 Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx differ diff --git a/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx b/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx new file mode 100644 index 0000000..0312478 Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx differ diff --git a/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx b/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx new file mode 100644 index 0000000..cd01ee7 Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx differ diff --git a/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx b/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx new file mode 100644 index 0000000..5d5e8fb Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx differ diff --git a/src/OnnxSharp/Collections/ListExtensions.cs b/src/OnnxSharp/Collections/ListExtensions.cs new file mode 100644 index 0000000..5fe0d6f --- /dev/null +++ b/src/OnnxSharp/Collections/ListExtensions.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; + +namespace Onnx.Collections +{ + /// Convenience extension methods for . + internal static class ListExtensions + { + internal static bool TryRemove(this IList fields, Func select, Predicate predicate) + { + for (int i = 0; i < fields.Count; i++) + { + var field = fields[i]; + var value = select(field); + if (predicate(value)) + { + fields.RemoveAt(i); + return true; + } + } + return false; + } + + internal static bool TryRemove(this IList fields, Func select, TSelect valueToRemove) + where TSelect : IEquatable + { + var index = fields.IndexOf(select, valueToRemove); + if (index >= 0) + { + fields.RemoveAt(index); + return true; + } + return false; + } + + internal static int IndexOf(this IList fields, Func select, TSelect valueToFind) + where TSelect : IEquatable + { + for (int i = 0; i < fields.Count; i++) + { + var field = fields[i]; + var value = select(field); + if (value.Equals(valueToFind)) + { + return i; + } + } + return -1; + } + } +} diff --git a/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs b/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs new file mode 100644 index 0000000..b4f8612 --- /dev/null +++ b/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; + +namespace Onnx.Collections +{ + /// Convenience extension methods for . + public static class ReadOnlyListExtensions + { + /// Compute the product of all values. + public static long Product(this IReadOnlyList values) + { + var product = 1L; + for (int i = 0; i < values.Count; i++) + { + product *= values[i]; + } + return product; + } + + internal static T Single(this IReadOnlyList fields, Func select, TSelect valueToFind) + where TSelect : IEquatable + { + for (int i = 0; i < fields.Count; i++) + { + var field = fields[i]; + var value = select(field); + if (value.Equals(valueToFind)) + { + return field; + } + } + throw new ArgumentException($"Could not find field with value '{valueToFind}'"); + } + } +} diff --git a/src/OnnxSharp/DimParamOrValue.cs b/src/OnnxSharp/DimParamOrValue.cs new file mode 100644 index 0000000..64b4569 --- /dev/null +++ b/src/OnnxSharp/DimParamOrValue.cs @@ -0,0 +1,66 @@ +using System; + +namespace Onnx +{ + /// + /// Dimension represented either a string 'Param' or an integer 'Value'. + /// + public readonly struct DimParamOrValue + { + readonly string _param; + readonly int _value; + + private DimParamOrValue(string param, int value) + { + _param = param; + _value = value; + } + + /// Create a new named dimension parameter. + public static DimParamOrValue New(string param) + { + if (!IsParamValid(param)) + { + throw new ArgumentException($"{nameof(param)} '{param}' must be a non-whitespace string like 'N'."); + } + return new DimParamOrValue(param, default); + } + + /// Create a new fixed size dimension. + public static DimParamOrValue New(int value) => + new DimParamOrValue(default, value); + + /// Get dimension as a named parameter string. + public string Param { get { CheckIsParam(); return _param; } } + /// Get dimension as an integer value. + public int Value { get { CheckIsValue(); return _value; } } + + /// Is the dimension a named parameter. + public bool IsParam => IsParamValid(_param); + /// Is the dimension a fixed sized integer. + public bool IsValue => !IsParam; + + /// Converts the dimension to its equivalent string representation. + public override string ToString() => IsParam ? Param : Value.ToString(); + /// Returns the hash code for this instance. + public override int GetHashCode() => IsParam ? Param.GetHashCode() : Value.GetHashCode(); + + void CheckIsParam() + { + if (IsValue) + { + throw new ArgumentException($"{nameof(DimParamOrValue)} is a value '{_value}' not a param."); + } + } + + void CheckIsValue() + { + if (IsParam) + { + throw new ArgumentException($"{nameof(DimParamOrValue)} is a param '{_param}' not a value."); + } + } + + static bool IsParamValid(string param) => !string.IsNullOrWhiteSpace(param); + } +} diff --git a/src/OnnxSharp/Formatting/Align.cs b/src/OnnxSharp/Formatting/Align.cs new file mode 100644 index 0000000..945be16 --- /dev/null +++ b/src/OnnxSharp/Formatting/Align.cs @@ -0,0 +1,8 @@ +namespace Onnx.Formatting +{ + internal enum Align + { + Left, + Right, + } +} \ No newline at end of file diff --git a/src/OnnxSharp/Formatting/ColumnSpec.cs b/src/OnnxSharp/Formatting/ColumnSpec.cs new file mode 100644 index 0000000..2d91ee4 --- /dev/null +++ b/src/OnnxSharp/Formatting/ColumnSpec.cs @@ -0,0 +1,14 @@ +using System; + +namespace Onnx.Formatting +{ + internal record ColumnSpec(string Name, Align Align); + internal record ColumnSpec(string Name, Align Align, Func Get) : ColumnSpec(Name, Align); +} + +// https://stackoverflow.com/questions/64749385/predefined-type-system-runtime-compilerservices-isexternalinit-is-not-defined +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +} + diff --git a/src/OnnxSharp/Formatting/ColumnSpecs.cs b/src/OnnxSharp/Formatting/ColumnSpecs.cs new file mode 100644 index 0000000..4f9dc5a --- /dev/null +++ b/src/OnnxSharp/Formatting/ColumnSpecs.cs @@ -0,0 +1,177 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Google.Protobuf; +using Onnx.Collections; + +namespace Onnx.Formatting +{ + internal static partial class ColumnSpecs + { + internal static partial class ValueInfo + { + internal static readonly IReadOnlyList> Tensor = + new ColumnSpec[] + { + new ("Name", Align.Left, i => i.Name), + new ("Type", Align.Left, i => i.Type.ValueCase.ToString()), + new ("ElemType", Align.Left, i => i.Type.TensorType.ElemType().ToString()), + new ("Shape", Align.Right, i => FormatShape(i.Type.TensorType.Shape)), + new ("SizeInFile", Align.Right, i => i.CalculateSize().ToString()), + }; + + internal static readonly IReadOnlyList> Sequence = + new ColumnSpec[] + { + new ("Name", Align.Left, i => i.Name), + new ("Type", Align.Left, i => i.Type.ValueCase.ToString()), + new ("ElemType", Align.Left, i => i.Type.SequenceType.ElemType.ValueCase.ToString()), + new ("SizeInFile", Align.Left, i => i.CalculateSize().ToString()), + }; + + internal static readonly IReadOnlyList> Map = + new ColumnSpec[] + { + new ("Name", Align.Left, i => i.Name), + new ("Type", Align.Left, i => i.Type.ValueCase.ToString()), + new ("KeyType", Align.Left, i => i.Type.MapType.KeyType().ToString()), + new ("ValueType", Align.Left, i => i.Type.MapType.ValueType.ValueCase.ToString()), + new ("SizeInFile", Align.Left, i => i.CalculateSize().ToString()), + }; + + internal static readonly IReadOnlyList> None = + new ColumnSpec[] + { + new ("Name", Align.Left, i => i.Name), + new ("Type", Align.Left, i => i.Type.ValueCase.ToString()), + new ("SizeInFile", Align.Left, i => i.CalculateSize().ToString()), + }; + } + + internal static readonly IReadOnlyList> Tensor = + new ColumnSpec[] + { + new ("Name", Align.Left, t => t.Name), + new ("DataType", Align.Left, t => t.DataType().ToString()), + new ("Dims", Align.Right, t => string.Join("x", t.Dims)), + new ("Π(Dims)", Align.Right, t => t.Dims.Product().ToString()), + new ("[v0,v1..vN] | (Min,Mean,Max)", Align.Right, t => FormatValuesOrStats(t)), + new ("SizeInFile", Align.Right, t => t.CalculateSize().ToString()), + }; + + static string FormatShape(TensorShapeProto shape) + { + return string.Join("x", shape.Dim.Select(d => Format(d))); + } + + static string Format(TensorShapeProto.Types.Dimension d) => d.ValueCase switch + { + TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam => d.DimParam, + TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue => d.DimValue.ToString(), + TensorShapeProto.Types.Dimension.ValueOneofCase.None => "?", + _ => throw new NotSupportedException(d.ValueCase.ToString()), + }; + + static unsafe string FormatValuesOrStats(TensorProto tensor) => tensor.DataType() switch + { + // NOTE: Long lines accepted below for structure + TensorProto.Types.DataType.Float => FormatValuesOrStats(tensor.FloatData, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max), + TensorProto.Types.DataType.Double => FormatValuesOrStats(tensor.DoubleData, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max), + TensorProto.Types.DataType.Int32 => FormatValuesOrStats(tensor.Int32Data, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max), + TensorProto.Types.DataType.Int64 => FormatValuesOrStats(tensor.Int64Data, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max), + // TODO: StringData + _ => "N/A", + }; + + // NOTE: Perf below is not great since function pointer and func calls cannot be inlined. + // If necessary refactor to use "value type functor"s. + static unsafe string FormatValuesOrStats( + IReadOnlyList values, + ByteString rawData, + delegate* min, + Func add, + Func divide, + delegate* max) + where T : struct + { + // Data may not be in typed part but in raw data + // Unfortunately there is no common and efficient "ground" for + // "IReadOnlyList values" and "ByteString rawData", + // so we have to go through hoops. + // RawData and talk about Constant nodes + // https://github.com/onnx/onnx/issues/2825#issuecomment-644334359 + + var useRawData = values.Count == 0 && rawData.Length > 0; + var rawValues = MemoryMarshal.Cast(rawData.Span); + var count = useRawData ? rawValues.Length : values.Count; + + const int MaxValueCountToShow = 4; + if (count <= MaxValueCountToShow) + { + return useRawData + ? FormatValues(rawValues.ToArray()) + : FormatValues(values); + } + else if (count > 0) + { + if (useRawData) { Thrower.EnsureLittleEndian(); } + var stats = useRawData + ? GetStats(rawValues, min, add, divide, max) + : GetStats(values, min, add, divide, max); + + return $"({stats.min:E3},{stats.mean:E3},{stats.max:E3})"; + } + else + { + return "[]"; + } + } + + static string FormatValues(IReadOnlyList values) => $"[{string.Join(",", values)}]"; + + static unsafe (T min, TMean mean, T max) GetStats( + ReadOnlySpan values, + delegate* min, + Func add, + Func divide, + delegate* max) + where T : struct + { + T minValue = values[0]; + T maxValue = values[0]; + TMean sum = add(default, values[0]); + for (int i = 1; i < values.Length; i++) + { + var value = values[i]; + minValue = min(minValue, value); + maxValue = max(maxValue, value); + sum = add(sum, value); + } + var mean = divide(sum, values.Length); + return (minValue, mean, maxValue); + } + + static unsafe (T min, TMean mean, T max) GetStats( + IReadOnlyList values, + delegate* min, + Func add, + Func divide, + delegate* max) + where T : struct + { + T minValue = values[0]; + T maxValue = values[0]; + TMean sum = add(default, values[0]); + for (int i = 1; i < values.Count; i++) + { + var value = values[i]; + minValue = min(minValue, value); + maxValue = max(maxValue, value); + sum = add(sum, value); + } + var mean = divide(sum, values.Count); + return (minValue, mean, maxValue); + } + } +} diff --git a/src/OnnxSharp/Formatting/MarkdownFormatter.cs b/src/OnnxSharp/Formatting/MarkdownFormatter.cs new file mode 100644 index 0000000..d78f9ba --- /dev/null +++ b/src/OnnxSharp/Formatting/MarkdownFormatter.cs @@ -0,0 +1,117 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Onnx.Formatting +{ + internal static class MarkdownFormatter + { + internal static void FormatAsTensors(this IReadOnlyList valueInfos, TextWriter writer) + { + Format(valueInfos, ColumnSpecs.ValueInfo.Tensor, writer); + } + + internal static void FormatAsSequences(this IReadOnlyList valueInfos, TextWriter writer) + { + Format(valueInfos, ColumnSpecs.ValueInfo.Sequence, writer); + } + + internal static void FormatAsMaps(this IReadOnlyList valueInfos, TextWriter writer) + { + Format(valueInfos, ColumnSpecs.ValueInfo.Map, writer); + } + + internal static void FormatAsNones(this IReadOnlyList valueInfos, TextWriter writer) + { + Format(valueInfos, ColumnSpecs.ValueInfo.None, writer); + } + + internal static void Format(this IReadOnlyList summaries, TextWriter writer) + { + Format(summaries, ColumnSpecs.Tensor, writer); + } + + internal static void Format( + IReadOnlyList values, + IReadOnlyList> columnSpecs, + TextWriter writer) + { + var maxColumnWidth = columnSpecs.Select(n => n.Name.Length).ToArray(); + + int rows = values.Count; + int cols = columnSpecs.Count; + + var table = new string[rows, cols]; + for (int row = 0; row < rows; row++) + { + var summary = values[row]; + + for (int col = 0; col < cols; col++) + { + var spec = columnSpecs[col]; + var text = spec.Get(summary); + table[row, col] = text; + maxColumnWidth[col] = Math.Max(maxColumnWidth[col], text.Length); + } + } + + Format(table, columnSpecs, maxColumnWidth, writer); + } + + internal static void Format( + string[,] table, + IReadOnlyList columnSpecs, + IReadOnlyList columnWidths, + TextWriter writer) + { + // TODO: Define constants below + + var rows = table.GetLength(0); + var cols = table.GetLength(1); + + // Column Names + for (int col = 0; col < cols; col++) + { + var columnName = columnSpecs[col].Name; + writer.Write('|'); + writer.Write(columnName); + writer.Write(' ', columnWidths[col] - columnName.Length); + } + writer.Write('|'); + writer.WriteLine(); + + // Separator and alignment + for (int col = 0; col < cols; col++) + { + writer.Write('|'); + var align = columnSpecs[col].Align; + if (align == Align.Left) + { + writer.Write(':'); + } + writer.Write('-', columnWidths[col] - 1); + if (align == Align.Right) + { + writer.Write(':'); + } + } + writer.Write('|'); + writer.WriteLine(); + + // Rows + for (int row = 0; row < rows; row++) + { + for (int col = 0; col < cols; col++) + { + var align = columnSpecs[col].Align; + var value = table[row, col]; + writer.Write('|'); + writer.WriteAligned(value, align, ' ', columnWidths[col]); + } + writer.Write('|'); + writer.WriteLine(); + } + } + } +} diff --git a/src/OnnxSharp/Formatting/TextWriterExtensions.cs b/src/OnnxSharp/Formatting/TextWriterExtensions.cs new file mode 100644 index 0000000..8991cbe --- /dev/null +++ b/src/OnnxSharp/Formatting/TextWriterExtensions.cs @@ -0,0 +1,30 @@ +using System.IO; + +namespace Onnx.Formatting +{ + internal static class TextWriterExtensions + { + internal static void WriteAligned(this TextWriter writer, + string columnName, Align alignment, char pad, int width) + { + var padCount = width - columnName.Length; + if (alignment == Align.Right) + { + writer.Write(pad, padCount); + } + writer.Write(columnName); + if (alignment == Align.Left) + { + writer.Write(pad, padCount); + } + } + + internal static void Write(this TextWriter writer, char value, int repeatCount) + { + for (int i = 0; i < repeatCount; i++) + { + writer.Write(value); + } + } + } +} diff --git a/src/OnnxSharp/GraphExtensions.Clean.cs b/src/OnnxSharp/GraphExtensions.Clean.cs new file mode 100644 index 0000000..9720f12 --- /dev/null +++ b/src/OnnxSharp/GraphExtensions.Clean.cs @@ -0,0 +1,130 @@ +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Google.Protobuf.Collections; +using Onnx.Collections; + +namespace Onnx +{ + /// Extension methods to ONNX graph. + public static partial class GraphExtensions + { + /// Clean graph for inference. + public static void Clean(this GraphProto graph) + { + graph.RemoveInitializersFromInputs(); + graph.RemoveUnnecessaryInitializerReshapes(); + } + + /// Remove initializers from inputs of graph. + // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py + public static void RemoveInitializersFromInputs(this GraphProto graph) + { + var inputs = graph.Input; + var nameToInput = inputs.ToDictionary(i => i.Name, i => i); + + foreach (var initializer in graph.Initializer) + { + if (nameToInput.TryGetValue(initializer.Name, out var input)) + { + // https://github.com/protocolbuffers/protobuf/blob/master/csharp/src/Google.Protobuf/Collections/RepeatedField.cs + var removed = inputs.Remove(input); + Trace.WriteLine($"{removed} {inputs.Count}"); + } + } + } + + /// Remove unnecessary initializer reshapes from graph. + // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py + public static void RemoveUnnecessaryInitializerReshapes(this GraphProto graph) + { + var nameToInitializer = graph.Initializer.ToDictionary(i => i.Name, i => i); + + var nodes = graph.Node; + var valueInfos = graph.ValueInfo; + + var nodesToRemove = new List(); + for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++) + { + var node = nodes[nodeIndex]; + + var opSpec = Ops.Reshape.Spec; + if (node.OpType == opSpec.OpType) + { + var inputs = node.Input; + var outputs = node.Output; + + // Expected Reshape takes 2 inputs and has 1 output + if (inputs.Count == opSpec.Inputs && outputs.Count == opSpec.Outputs) + { + var dataName = inputs[0]; + var shapeName = inputs[1]; + var reshapeOutputName = outputs[0]; + + // Both inputs must be initializers ("static") + if (nameToInitializer.TryGetValue(dataName, out var dataInitializer) && + nameToInitializer.TryGetValue(shapeName, out var shapeInitializer)) + { + // TODO: Check initializer not used in other nodes + + var outputShapeValue = valueInfos.Single(v => v.Name, reshapeOutputName); + + var outputShapeDims = outputShapeValue.Type.TensorType.Shape.Dim; + var allValue = outputShapeDims.All(d => d.ValueCase == + TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue); + if (allValue) + { + var outputShape = outputShapeDims.Select(d => d.DimValue).ToArray(); + + var allPositive = outputShape.All(d => d > 0); + if (allPositive) + { + // Check shape compared to initializer shape + var dataShape = dataInitializer.Dims.ToArray(); + + var outputShapeProductSum = outputShape.Product(); + var dataShapeProductSum = dataShape.Product(); + + if (outputShapeProductSum == dataShapeProductSum) + { + // Change data shape to the reshape output shape directly + dataInitializer.Dims.Clear(); + dataInitializer.Dims.AddRange(outputShape); + + // Remove reshape data shape both as initializer and input + graph.Initializer.TryRemove(i => i.Name, shapeName); + graph.Input.TryRemove(i => i.Name, shapeName); + + nodesToRemove.Add(node); + + // Replace reshape output name with data name directly in all nodes + ReplaceInput(nodes, reshapeOutputName, dataName); + } + } + } + } + } + } + } + foreach (var node in nodesToRemove) + { + nodes.Remove(node); + } + } + + internal static void ReplaceInput(RepeatedField nodes, string oldValue, string newValue) + { + for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++) + { + var updateNodeInputs = nodes[nodeIndex].Input; + for (int inputIndex = 0; inputIndex < updateNodeInputs.Count; inputIndex++) + { + if (updateNodeInputs[inputIndex] == oldValue) + { + updateNodeInputs[inputIndex] = newValue; + } + } + } + } + } +} diff --git a/src/OnnxSharp/GraphExtensions.Info.cs b/src/OnnxSharp/GraphExtensions.Info.cs new file mode 100644 index 0000000..2334250 --- /dev/null +++ b/src/OnnxSharp/GraphExtensions.Info.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Onnx.Formatting; + +namespace Onnx +{ + public static partial class GraphExtensions + { + /// Summarize information about the . + public static string Info(this GraphProto graph) + { + var writer = new StringWriter(); + graph.Info(writer); + return writer.ToString(); + } + + /// Summarize information about the . + public static void Info(this GraphProto graph, TextWriter writer) + { + var initializerNameSet = new HashSet(graph.Initializer.Select(i => i.Name)); + var inferenceInputs = graph.Input.Where(i => !initializerNameSet.Contains(i.Name)).ToList(); + var initializerInputs = graph.Input.Where(i => initializerNameSet.Contains(i.Name)).ToList(); + + writer.WriteLine("## Inputs without Initializer"); + Info(inferenceInputs, writer); + + writer.WriteLine(); + writer.WriteLine("## Outputs"); + Info(graph.Output, writer); + + writer.WriteLine(); + writer.WriteLine("## Inputs with Initializer"); + Info(initializerInputs, writer); + + writer.WriteLine(); + writer.WriteLine("## Initializers (Parameters etc.)"); + MarkdownFormatter.Format(graph.Initializer, writer); + + writer.WriteLine(); + writer.WriteLine("## Value Infos"); + Info(graph.ValueInfo, writer); + } + + static void Info(IReadOnlyList valueInfos, TextWriter writer) + { + var tensorTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.TensorType).ToList(); + WriteInfoIfAny(tensorTypes, "Tensors", MarkdownFormatter.FormatAsTensors, writer); + + var sequenceTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.SequenceType).ToList(); + WriteInfoIfAny(sequenceTypes, "Sequences", MarkdownFormatter.FormatAsSequences, writer); + + var mapTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.MapType).ToList(); + WriteInfoIfAny(mapTypes, "Maps", MarkdownFormatter.FormatAsMaps, writer); + + var noneTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.None).ToList(); + WriteInfoIfAny(noneTypes, "Nones", MarkdownFormatter.FormatAsNones, writer); + } + + static void WriteInfoIfAny(IReadOnlyList values, string name, + Action, TextWriter> info, TextWriter writer) + { + if (values.Count > 0) + { + writer.WriteLine($"### {name}"); + info(values, writer); + } + } + } +} diff --git a/src/OnnxSharp/GraphExtensions.SetDim.cs b/src/OnnxSharp/GraphExtensions.SetDim.cs new file mode 100644 index 0000000..4ba2530 --- /dev/null +++ b/src/OnnxSharp/GraphExtensions.SetDim.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.InteropServices; +using Google.Protobuf; +using Google.Protobuf.Collections; +using Onnx.Collections; + +namespace Onnx +{ + public static partial class GraphExtensions + { + /// + /// Set dimension of inputs, value infos, outputs and potential Reshape ops. + /// Default sets leading dimension to dynamic batch size 'N'. + /// + public static void SetDim(this GraphProto graph) => + graph.SetDim(dimIndex: 0, DimParamOrValue.New("N")); + + /// + /// Set dimension of inputs, value infos, outputs and potential Reshape ops. + /// Can be used to make models have dynamic batch size or different static batch sizes. + /// + public static void SetDim(this GraphProto graph, int dimIndex, DimParamOrValue dimParamOrValue) + { + // Reshape ops have their "new shape" defined as input to the reshape op. + // This input needs to be changed to reflect new dim e.g. be set -1 if dynamic. + var reshapeDimValue = dimParamOrValue.IsParam + ? Ops.Reshape.DynamicReshapeValue + : dimParamOrValue.Value; + SetDimInReshapes(graph, dimIndex, reshapeDimValue); + + // Should we set this based on nodes instead? Handling input, outputs based on that? + + // Shapes are defined in inputs, valueInfos and outputs + // + // Only real inputs should be changed, not "initializer" inputs + var initializserNames = new HashSet(graph.Initializer.Select(i => i.Name)); + var inferenceInputs = graph.Input.Where(i => !initializserNames.Contains(i.Name)); + foreach (var input in inferenceInputs) + { + SetDim(input, dimIndex, dimParamOrValue); + } + //SetDim(graph.Input, dimIndex, dimParam); + + SetDim(graph.ValueInfo, dimIndex, dimParamOrValue); + SetDim(graph.Output, dimIndex, dimParamOrValue); + } + + static void SetDimInReshapes(GraphProto graph, int dimIndex, int dimValue) + { + var nodes = graph.Node; + var initializers = graph.Initializer; + + // TODO: Only fix reshapes that have data input and with dynamic shape after + + var opSpec = Ops.Reshape.Spec; + foreach (var node in nodes) + { + if (node.OpType == opSpec.OpType) + { + var dataInputName = node.Input[Ops.Reshape.InputDataIndex]; + + // Check if data input is an initializer if so we should not change the reshape + // and hence skip this reshape node + var dataInitializerIndex = initializers.IndexOf(t => t.Name, dataInputName); + if (dataInitializerIndex >= 0) + { continue; } + + var shapeInputName = node.Input[Ops.Reshape.InputShapeIndex]; + + var shape = initializers.Single(tensor => tensor.Name, shapeInputName); + + SetDimInReshapeTensorShape(shape, dimIndex, dimValue); + } + } + } + + static void SetDimInReshapeTensorShape(TensorProto shape, int dimIndex, int dimValue) + { + Debug.Assert(shape.DataType == (int)TensorProto.Types.DataType.Int64); + var dims = shape.Dims; + if (dims.Count > 0 && dims[dimIndex] > 0) + { + // Data may be stored as Int64 or Raw (fixed-width, little-endian) + if (shape.Int64Data.Count > 0) + { + var int64Data = shape.Int64Data; + if (int64Data[dimIndex] == 1) // Dimension we replace + { + int64Data[dimIndex] = dimValue; + } + } + if (!shape.RawData.IsEmpty) + { + var rawData = shape.RawData; + var rawAsInt64Data = MemoryMarshal.Cast(rawData.Span); + Debug.Assert(rawAsInt64Data.Length == dims[dimIndex]); + if (rawAsInt64Data[dimIndex] == 1) // Dimension we replace + { + var newShape = rawAsInt64Data.ToArray(); + newShape[dimIndex] = dimValue; + var newShapeBytes = MemoryMarshal.Cast(newShape.AsSpan()); + shape.RawData = ByteString.CopyFrom(newShapeBytes); + } + } + } + } + + internal static void SetDim(RepeatedField valueInfos, + int dimIndex, DimParamOrValue dimParamOrValue) + { + for (int i = 0; i < valueInfos.Count; i++) + { + var valueInfo = valueInfos[i]; + SetDim(valueInfo, dimIndex, dimParamOrValue); + } + } + + internal static void SetDim(ValueInfoProto valueInfo, + int dimIndex, DimParamOrValue dimParamOrValue) + { + var shape = valueInfo.Type.TensorType.Shape; + var dims = shape.Dim; + var dim = dims[dimIndex]; + if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue) + { + // TODO: Should perhaps be parameter that says + // bool shouldSetDimFor(dim) + if (dim.DimValue == 1) + { + SetDim(dim, dimParamOrValue); + } + } + } + + internal static void SetDim(TensorShapeProto.Types.Dimension dim, + DimParamOrValue dimParamOrValue) + { + dim.ClearValue(); + if (dimParamOrValue.IsParam) + { + dim.DimParam = dimParamOrValue.Param; + } + else + { + dim.DimValue = dimParamOrValue.Value; + } + } + } +} diff --git a/src/OnnxSharp/MessageExtensions.cs b/src/OnnxSharp/MessageExtensions.cs new file mode 100644 index 0000000..20928b7 --- /dev/null +++ b/src/OnnxSharp/MessageExtensions.cs @@ -0,0 +1,29 @@ +using Google.Protobuf; +using System.IO; + +namespace Onnx +{ + /// Convenience extension methods. + public static partial class MessageExtensions + { + /// + /// Writes the given data to the + /// given in protobuf encoding. + /// + public static void WriteToFile(this IMessage message, string filePath) + { + using var stream = File.Open(filePath, FileMode.Create); + message.WriteTo(stream); + } + + /// + /// Writes the given data to the + /// given in JSON encoding. + /// + public static void WriteJsonToFile(this IMessage message, string filePath) + { + using var writer = new StreamWriter(filePath); + JsonFormatter.Default.Format(message, writer); + } + } +} diff --git a/src/OnnxSharp/MessageParserExtensions.cs b/src/OnnxSharp/MessageParserExtensions.cs new file mode 100644 index 0000000..5673f24 --- /dev/null +++ b/src/OnnxSharp/MessageParserExtensions.cs @@ -0,0 +1,31 @@ +using System; +using System.IO; +using Google.Protobuf; + +namespace Onnx +{ + /// Convenience extension methods. + public static partial class MessageParserExtensions + { + /// + /// Parse from file via . + /// + public static T ParseFromFile(this MessageParser parser, string filePath) + where T : IMessage + { + using var stream = File.Open(filePath, FileMode.Open); + return parser.ParseFrom(stream); + } + + /// + /// Parse from file via + /// and disposes the created stream after parsing is done. + /// + public static T ParseFrom(this MessageParser parser, Func createStream) + where T : IMessage + { + using var stream = createStream(); + return parser.ParseFrom(stream); + } + } +} diff --git a/src/OnnxSharp/OnnxExtensions.cs b/src/OnnxSharp/OnnxExtensions.cs deleted file mode 100644 index f7a5af7..0000000 --- a/src/OnnxSharp/OnnxExtensions.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; - -namespace Onnx -{ - /// - /// Extensions to ONNX protobuf functionality. - /// - public static partial class OnnxExtensions - { - /// - /// Remove initializers from inputs of graph in model. - /// - public static ModelProto RemoveInitializersFromInputs(this ModelProto model) - { - RemoveInitializersFromInputs(model.Graph); - return model; - } - - /// - /// Remove initializers from inputs of graph. - /// - // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py - public static GraphProto RemoveInitializersFromInputs(GraphProto graph) - { - var inputs = graph.Input; - var nameToInput = inputs.ToDictionary(i => i.Name, i => i); - - foreach (var initializer in graph.Initializer) - { - if (nameToInput.TryGetValue(initializer.Name, out var input)) - { - inputs.Remove(input); - } - } - - return graph; - } - } -} diff --git a/src/OnnxSharp/OnnxSharp.csproj b/src/OnnxSharp/OnnxSharp.csproj index 7dc0532..31b4ac1 100644 --- a/src/OnnxSharp/OnnxSharp.csproj +++ b/src/OnnxSharp/OnnxSharp.csproj @@ -20,6 +20,8 @@ https://github.com/nietras/OnnxSharp true true + + true diff --git a/src/OnnxSharp/Ops.cs b/src/OnnxSharp/Ops.cs new file mode 100644 index 0000000..c405342 --- /dev/null +++ b/src/OnnxSharp/Ops.cs @@ -0,0 +1,33 @@ +using System; + +namespace Onnx +{ + internal static class Ops + { + internal static class Reshape + { + internal const int InputDataIndex = 0; + internal const int InputShapeIndex = 1; + + // Reshape op supports only one dimension in shape to be dynamic, + // which is defined as -1. + internal const int DynamicReshapeValue = -1; + + internal static readonly OpSpec Spec = new OpSpec(nameof(Reshape), 2, 1); + } + + internal readonly struct OpSpec + { + public OpSpec(string opType, int inputs, int outputs) + { + OpType = opType ?? throw new ArgumentNullException(nameof(opType)); + Inputs = inputs; + Outputs = outputs; + } + + public string OpType { get; } + public int Inputs { get; } + public int Outputs { get; } + } + } +} diff --git a/src/OnnxSharp/TensorProtoExtensions.cs b/src/OnnxSharp/TensorProtoExtensions.cs new file mode 100644 index 0000000..d025d59 --- /dev/null +++ b/src/OnnxSharp/TensorProtoExtensions.cs @@ -0,0 +1,10 @@ +namespace Onnx +{ + /// Convenience extension methods. + public static class TensorProtoExtensions + { + /// Get data type of as enum. + public static TensorProto.Types.DataType DataType(this TensorProto tensor) => + (TensorProto.Types.DataType)tensor.DataType; + } +} diff --git a/src/OnnxSharp/Thrower.cs b/src/OnnxSharp/Thrower.cs new file mode 100644 index 0000000..7b9121b --- /dev/null +++ b/src/OnnxSharp/Thrower.cs @@ -0,0 +1,18 @@ +using System; + +namespace Onnx +{ + internal static class Thrower + { + internal static void EnsureLittleEndian() + { + if (!BitConverter.IsLittleEndian) + { + var message = "Only little-endian systems are supported. " + + "This is due to raw data in onnx files being stored in little-endian order and " + + "conversion to big-endian has not implemented."; + throw new NotSupportedException(message); + } + } + } +} diff --git a/src/OnnxSharp/TypeProtoTypesMapExtensions.cs b/src/OnnxSharp/TypeProtoTypesMapExtensions.cs new file mode 100644 index 0000000..58e5add --- /dev/null +++ b/src/OnnxSharp/TypeProtoTypesMapExtensions.cs @@ -0,0 +1,10 @@ +namespace Onnx +{ + /// Convenience extension methods. + public static class TypeProtoTypesMapExtensions + { + /// Get key data type of as enum. + public static TensorProto.Types.DataType KeyType(this TypeProto.Types.Map map) => + (TensorProto.Types.DataType)map.KeyType; + } +} diff --git a/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs b/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs new file mode 100644 index 0000000..b0f7b2e --- /dev/null +++ b/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs @@ -0,0 +1,7 @@ +namespace Onnx +{ + /// Convenience extension methods. + public static class TypeProtoTypesSequenceExtensions + { + } +} diff --git a/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs b/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs new file mode 100644 index 0000000..6835069 --- /dev/null +++ b/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs @@ -0,0 +1,10 @@ +namespace Onnx +{ + /// Convenience extension methods. + public static class TypeProtoTypesTensorExtensions + { + /// Get element data type of as enum. + public static TensorProto.Types.DataType ElemType(this TypeProto.Types.Tensor tensor) => + (TensorProto.Types.DataType)tensor.ElemType; + } +} diff --git a/src/OnnxSharp/ValueInfoProtoExtensions.cs b/src/OnnxSharp/ValueInfoProtoExtensions.cs new file mode 100644 index 0000000..0f66a60 --- /dev/null +++ b/src/OnnxSharp/ValueInfoProtoExtensions.cs @@ -0,0 +1,7 @@ +namespace Onnx +{ + /// Convenience extension methods. + public static class ValueInfoProtoExtensions + { + } +} diff --git a/src/OnnxSharpConsole/OnnxSharpConsole.csproj b/src/OnnxSharpConsole/OnnxSharpConsole.csproj index 4615a04..7969b60 100644 --- a/src/OnnxSharpConsole/OnnxSharpConsole.csproj +++ b/src/OnnxSharpConsole/OnnxSharpConsole.csproj @@ -8,7 +8,7 @@ Exe - net50 + net5.0 false diff --git a/src/OnnxSharpConsole/Program.cs b/src/OnnxSharpConsole/Program.cs index ed574a5..54a4fba 100644 --- a/src/OnnxSharpConsole/Program.cs +++ b/src/OnnxSharpConsole/Program.cs @@ -1,46 +1,18 @@ -using System.IO; -using System.Linq; -using Google.Protobuf; +using Onnx; // Examples see https://github.com/onnx/models var onnxInputFilePath = @"mnist-8.onnx"; -var onnxInputFileName = Path.GetFileNameWithoutExtension(onnxInputFilePath); -var outputDirectory = Path.GetDirectoryName(onnxInputFilePath); +var model = ModelProto.Parser.ParseFromFile(onnxInputFilePath); -using (var file = File.OpenRead(onnxInputFilePath)) -{ - var model = Onnx.ModelProto.Parser.ParseFrom(file); - var graph = model.Graph; - // Shapes are defined in inputs, values and outputs - var inputs = graph.Input; - var values = graph.ValueInfo; - var outputs = graph.Output; +var graph = model.Graph; +// Clean graph e.g. remove initializers from inputs that may prevent constant folding +graph.Clean(); +// Set dimension in graph to enable dynamic batch size during inference +graph.SetDim(dimIndex: 0, DimParamOrValue.New("N")); +// Get summarized info about the graph +var info = graph.Info(); - foreach (var value in inputs.Concat(values).Concat(outputs)) - { - var shape = value.Type.TensorType.Shape; - var dims = shape.Dim; - var dim = dims[0]; - //dim.DimValue = -1; - dim.ClearValue(); - dim.DimValue = -1; // Or don't set it - //dim.DimParam = "None"; // Or don't set it, unset dimension means dynamic - } +System.Console.WriteLine(info); - var fileNameSuffix = "-dynamic-leading-dimension"; - var outputFilePathPrefix = Path.Combine(outputDirectory, onnxInputFileName + fileNameSuffix); - - var onnxOutputFilePath = outputFilePathPrefix + ".onnx"; - using (var outputFile = File.Create(onnxOutputFilePath)) - { - model.WriteTo(outputFile); - } - - var jsonOnnxOutputFilePath = outputFilePathPrefix + ".json"; - using (var output = new StreamWriter(jsonOnnxOutputFilePath)) - { - var fmt = new JsonFormatter(JsonFormatter.Settings.Default); - fmt.Format(model, output); - } -} +model.WriteToFile(@"mnist-8-clean-dynamic-batch-size.onnx"); \ No newline at end of file diff --git a/src/dotnet-onnx/Commands/CleanCommand.cs b/src/dotnet-onnx/Commands/CleanCommand.cs new file mode 100644 index 0000000..5294c9c --- /dev/null +++ b/src/dotnet-onnx/Commands/CleanCommand.cs @@ -0,0 +1,14 @@ +using McMaster.Extensions.CommandLineUtils; +using Onnx; + +[Command("clean", Description = "Clean model for inference e.g. remove initializers from inputs")] +public class CleanCommand : InputOutputCommand +{ + public CleanCommand(IConsole console) : base(console) + { } + + protected override void Run(ModelProto model) + { + model.Graph.Clean(); + } +} \ No newline at end of file diff --git a/src/dotnet-onnx/Commands/Command.cs b/src/dotnet-onnx/Commands/Command.cs new file mode 100644 index 0000000..7e47230 --- /dev/null +++ b/src/dotnet-onnx/Commands/Command.cs @@ -0,0 +1,23 @@ +using System; +using System.Threading.Tasks; + +public abstract class Command +{ + public async Task OnExecuteAsync() + { + try + { + await Run(); + } + //catch (CliException e) + catch (Exception e) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.Error.WriteLine(e.Message); + Console.ResetColor(); + //Environment.Exit(e.ExitCode); + } + } + + public abstract Task Run(); +} diff --git a/src/dotnet-onnx/Commands/InfoCommand.cs b/src/dotnet-onnx/Commands/InfoCommand.cs new file mode 100644 index 0000000..f3c497e --- /dev/null +++ b/src/dotnet-onnx/Commands/InfoCommand.cs @@ -0,0 +1,21 @@ +using McMaster.Extensions.CommandLineUtils; +using Onnx; + +[Command("info", Description = "Print information about a model e.g. inputs and outputs")] +public class InfoCommand : InputCommand +{ + public InfoCommand(IConsole console) : base(console) + { + LogInput = null; + } + + protected override void Run(ModelProto model) + { + var writer = _console.Out; + + writer.WriteLine($"# {Input}"); + writer.WriteLine(); + + model.Graph.Info(writer); + } +} \ No newline at end of file diff --git a/src/dotnet-onnx/Commands/InputCommand.cs b/src/dotnet-onnx/Commands/InputCommand.cs new file mode 100644 index 0000000..ccae195 --- /dev/null +++ b/src/dotnet-onnx/Commands/InputCommand.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel.DataAnnotations; +using System.Threading.Tasks; +using McMaster.Extensions.CommandLineUtils; +using Onnx; + +public abstract class InputCommand : Command +{ + protected readonly IConsole _console; + protected Action LogInput; + + public InputCommand(IConsole console) + { + _console = console; + LogInput = t => _console.WriteLine(t); + } + + [Argument(0, "input", Description = "Input file path")] + [Required] + public string Input { get; } + + public override Task Run() + { + var model = ModelProto.Parser.ParseFromFile(Input); + + LogInput?.Invoke($"Parsed input file '{Input}' of size {model.CalculateSize()}"); + + Run(model); + + return Task.CompletedTask; + } + + protected abstract void Run(ModelProto model); +} diff --git a/src/dotnet-onnx/Commands/InputOutputCommand.cs b/src/dotnet-onnx/Commands/InputOutputCommand.cs new file mode 100644 index 0000000..b226de8 --- /dev/null +++ b/src/dotnet-onnx/Commands/InputOutputCommand.cs @@ -0,0 +1,39 @@ +using System.ComponentModel.DataAnnotations; +using System.Threading.Tasks; +using McMaster.Extensions.CommandLineUtils; +using Onnx; + +public abstract class InputOutputCommand : Command +{ + protected readonly IConsole _console; + + public InputOutputCommand(IConsole console) + { + _console = console; + } + + [Argument(0, "input", Description = "Input file path")] + [Required] + public string Input { get; } + + [Argument(1, "output", Description = "Output file path")] + [Required] + public string Output { get; } + + public override Task Run() + { + var model = ModelProto.Parser.ParseFromFile(Input); + + _console.WriteLine($"Parsed input file '{Input}' of size {model.CalculateSize()}"); + + Run(model); + + model.WriteToFile(Output); + + _console.WriteLine($"Wrote output file '{Output}' of size {model.CalculateSize()}"); + + return Task.CompletedTask; + } + + protected abstract void Run(ModelProto model); +} diff --git a/src/dotnet-onnx/Commands/SetDimCommand.cs b/src/dotnet-onnx/Commands/SetDimCommand.cs new file mode 100644 index 0000000..de4f044 --- /dev/null +++ b/src/dotnet-onnx/Commands/SetDimCommand.cs @@ -0,0 +1,28 @@ +using McMaster.Extensions.CommandLineUtils; +using Onnx; + +[Command("setdim", Description = "Set dimension of reshapes, inputs and outputs of model e.g. set new dynamic or fixed batch size.")] +public class SetDimCommand : InputOutputCommand +{ + public SetDimCommand(IConsole console) : base(console) + { } + + [Option("-i|--index", Description = "Dimension index to set. Default = 0.")] + public int Index { get; } = 0; // Parametize defaults + + [Option("-d|--dim", Description = "Dimension to set. Default = N. Use string e.g. 'N' for dynamic batch size or integer e.g. '3' for fixed size")] + public string Dim { get; } = "N"; + + protected override void Run(ModelProto model) + { + // Should this not be before loading input? Is the abstract base really that good? + + var dimParamOrValue = int.TryParse(Dim, out var dimValue) + ? DimParamOrValue.New(dimValue) + : DimParamOrValue.New(Dim); + + _console.WriteLine($"Setting dimension at {Index} to '{dimParamOrValue}'"); + + model.Graph.SetDim(Index, dimParamOrValue); + } +} \ No newline at end of file diff --git a/src/dotnet-onnx/Program.cs b/src/dotnet-onnx/Program.cs index 9e31089..039c696 100644 --- a/src/dotnet-onnx/Program.cs +++ b/src/dotnet-onnx/Program.cs @@ -1,10 +1,6 @@ -using System; -using System.ComponentModel.DataAnnotations; -using System.IO; +using System.IO; using System.Threading.Tasks; -using Google.Protobuf; using McMaster.Extensions.CommandLineUtils; -using Onnx; // https://github.com/natemcmaster/CommandLineUtils // https://natemcmaster.github.io/CommandLineUtils/docs/advanced/dependency-injection.html @@ -13,11 +9,12 @@ // TODO: Handle multiple command names etc. // https://github.com/jonstodle/DotNetSdkHelpers/blob/master/src/DotNetSdkHelpers/Program.cs -[Command("dotnet-onnx", Description = "Inspect and manipulate ONNX files"), - Subcommand(typeof(Clean)), - //Subcommand(typeof(List)), - //Subcommand(typeof(Download)) - ] +// TODO: Switch from attributes to code instead +[Command("dotnet onnx", Description = "Inspect and manipulate ONNX files. Copyright nietras 2021."), + Subcommand(typeof(CleanCommand)), + Subcommand(typeof(SetDimCommand)), + Subcommand(typeof(InfoCommand)) +] class Program { static Task Main(string[] args) @@ -31,54 +28,11 @@ static Task Main(string[] args) return app.ExecuteAsync(args); } -} -public abstract class Command -{ - public async Task OnExecuteAsync() + public Task OnExecuteAsync(CommandLineApplication app) { - try - { - await Run(); - } - //catch (CliException e) - catch (Exception e) - { - Console.ForegroundColor = ConsoleColor.Red; - Console.Error.WriteLine(e.Message); - Console.ResetColor(); - //Environment.Exit(e.ExitCode); - } - } - - public abstract Task Run(); -} - -[Command("clean", Description = "Clean graph for inference e.g. remove initializers from inputs")] -public class Clean : Command -{ - [Argument(0, "input", Description = "Input file path")] - [Required] - public string Input { get; } + app.ShowHelp(); - [Argument(1, "output", Description = "Output file path")] - [Required] - public string Output { get; } - - public override Task Run() - { - using (var inputFile = File.OpenRead(Input)) - { - var model = ModelProto.Parser.ParseFrom(inputFile); - - model.RemoveInitializersFromInputs(); - - using (var outputFile = File.Create(Output)) - { - model.WriteTo(outputFile); - } - } - - return Task.CompletedTask; + return Task.FromResult(0); } -} \ No newline at end of file +} diff --git a/update-tool-from-build.ps1 b/update-tool-from-build.ps1 new file mode 100644 index 0000000..debde86 --- /dev/null +++ b/update-tool-from-build.ps1 @@ -0,0 +1,2 @@ +#!/usr/local/bin/powershell +dotnet tool update dotnet-onnx --add-source ./build/dotnet-onnx_AnyCPU_Release -g \ No newline at end of file