diff --git a/OnnxSharp.sln b/OnnxSharp.sln
index 86157b0..4062af0 100644
--- a/OnnxSharp.sln
+++ b/OnnxSharp.sln
@@ -22,6 +22,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{604A7FA2-1
rename.ps1 = rename.ps1
restore.ps1 = restore.ps1
test.ps1 = test.ps1
+ update-tool-from-build.ps1 = update-tool-from-build.ps1
update.ps1 = update.ps1
EndProjectSection
EndProject
@@ -38,6 +39,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxSharp.Test", "src\OnnxS
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "dotnet-onnx", "src\dotnet-onnx\dotnet-onnx.csproj", "{DA6F8267-24F1-4104-AB36-10C97B899A8E}"
EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "github-workflows", "github-workflows", "{1EA6AA06-C436-482C-A14C-1AC414D5552F}"
+ ProjectSection(SolutionItems) = preProject
+ .github\workflows\dotnet.yml = .github\workflows\dotnet.yml
+ EndProjectSection
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
diff --git a/README.md b/README.md
index 774d886..1ee7e11 100644
--- a/README.md
+++ b/README.md
@@ -1,32 +1,64 @@
![Build and test](https://github.com/nietras/OnnxSharp/workflows/.NET/badge.svg)
-[![NuGet](https://img.shields.io/nuget/v/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/)
-[![Downloads](https://img.shields.io/nuget/dt/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/)
[![Stars](https://img.shields.io/github/stars/nietras/OnnxSharp)](https://github.com/nietras/OnnxSharp/stargazers)
[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md)
-# OnnxSharp
-ONNX format parsing and manipulation in C#.
+|What |Links and Status|
+|---------------|------|
+|`OnnxSharp` |[![NuGet](https://img.shields.io/nuget/v/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/) [![Downloads](https://img.shields.io/nuget/dt/OnnxSharp)](https://www.nuget.org/packages/OnnxSharp/) |
+|`dotnet-onnx`|[![NuGet](https://img.shields.io/nuget/v/dotnet-onnx)](https://www.nuget.org/packages/dotnet-onnx/) [![Downloads](https://img.shields.io/nuget/dt/dotnet-onnx)](https://www.nuget.org/packages/dotnet-onnx/) |
-# Status
-Pretty much just:
+# `OnnxSharp` library and `dotnet-onnx` tool
+ONNX format parsing and manipulation in C# and with command line .NET tool.
+
+# Quick Guide
+Install latest version of .NET:
+* PowerShell (Windows): [https://dot.net/v1/dotnet-install.ps1](https://dot.net/v1/dotnet-install.ps1)
+* Bash (Linux/macOS): [https://dot.net/v1/dotnet-install.sh](https://dot.net/v1/dotnet-install.sh)
+
+#### Code
+|What |How |
+|--------------|---------------------------------------------------|
+|Install |`dotnet add PROJECT.csproj package OnnxSharp`|
+|Parse |`var model = ModelProto.Parser.ParseFromFile("mnist-8.onnx");`|
+|Info |`var info = model.Graph.Info();`|
+|Clean |`model.Graph.Clean();`|
+|SetDim |`model.Graph.SetDim();`|
+|Write |`model.WriteToFile("mnist-8-clean-dynamic.onnx");`|
+
+#### Tool
+|What |How |
+|--------------|----------------------------|
+|Install |`dotnet tool install dotnet-onnx -g`|
+|Info |`dotnet onnx info mnist-8.onnx`|
+|Info |`dotnet onnx info mnist-8.onnx`|
+|Clean |`dotnet onnx clean mnist-8.onnx mnist-8-clean.onnx`|
+|SetDim |`dotnet onnx setdim mnist-8.onnx mnist-8-setdim.onnx`|
+
+# Source Code
+Base functionality is based on:
```
.\protoc.exe .\onnx.proto3 --csharp_out=OnnxSharp
```
+Everything else written in beautiful C# 9.0 as extensions to this.
# Example Code
```csharp
-using System.IO;
-using Google.Protobuf;
+using Onnx;
// Examples see https://github.com/onnx/models
-var onnxFilePath = @"mnist-8.onnx";
+var onnxInputFilePath = @"mnist-8.onnx";
-using var fileStream = File.OpenRead(onnxFilePath);
-
-var model = Onnx.ModelProto.Parser.ParseFrom(fileStream);
+var model = ModelProto.Parser.ParseFromFile(onnxInputFilePath);
var graph = model.Graph;
-var inputs = graph.Input;
-var valueInfos = graph.ValueInfo;
-var outputs = graph.Output;
+// Clean graph e.g. remove initializers from inputs that may prevent constant folding
+graph.Clean();
+// Set dimension in graph to enable dynamic batch size during inference
+graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
+// Get summarized info about the graph
+var info = graph.Info();
+
+System.Console.WriteLine(info);
+
+model.WriteToFile(@"mnist-8-clean-dynamic-batch-size.onnx");
```
\ No newline at end of file
diff --git a/src/Directory.Build.props b/src/Directory.Build.props
index 6e5e320..c01411a 100644
--- a/src/Directory.Build.props
+++ b/src/Directory.Build.props
@@ -4,8 +4,8 @@
nietras
Copyright © nietras 2021
en
- 0.1.0.0
- 0.1.1
+ 0.2.0.0
+ 0.2.0
$(FileVersion)
$(InformationalVersion)
diff --git a/src/OnnxSharp.Test/AssemblyResourceLoader.cs b/src/OnnxSharp.Test/AssemblyResourceLoader.cs
index 6bb431a..b52ffed 100644
--- a/src/OnnxSharp.Test/AssemblyResourceLoader.cs
+++ b/src/OnnxSharp.Test/AssemblyResourceLoader.cs
@@ -7,14 +7,23 @@ namespace OnnxSharp.Test
{
public static class AssemblyResourceLoader
{
- public static readonly string ResourceNamespace = typeof(AssemblyResourceLoader).Assembly.GetName().Name;
+ public static readonly string ResourceNamespace =
+ typeof(AssemblyResourceLoader).Assembly.GetName().Name;
public const string ResourceNamePrefix = "";
- public static string[] GetStringArray(string resourceName)
+ public static byte[] GetBytes(string resourceName)
{
- return GetString(resourceName).Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
+ using (var stream = GetStream(resourceName))
+ using (var memoryStream = new MemoryStream())
+ {
+ stream.CopyTo(memoryStream);
+ return memoryStream.ToArray();
+ }
}
+ public static string[] GetLines(string resourceName) => GetString(resourceName)
+ .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
+
public static string GetString(string resourceName)
{
using (var stream = GetStream(resourceName))
@@ -24,10 +33,8 @@ public static string GetString(string resourceName)
}
}
- public static string GetFullResourceName(string resourceName)
- {
- return ResourceNamePrefix + resourceName;
- }
+ public static string GetFullResourceName(string resourceName) =>
+ ResourceNamePrefix + resourceName;
public static string FindResourceName(Func filter)
{
@@ -35,14 +42,18 @@ public static string FindResourceName(Func filter)
if (names.Length == 0)
{
- throw new ArgumentException("Could not find any resource. The desired file might not have been defined as Embedded Resource.");
+ throw new ArgumentException("Could not find any resource. " +
+ "The desired file might not have been defined as Embedded Resource.");
}
else if (names.Length != 1)
{
- throw new ArgumentException($"Ambiguous name, cannot identify resource - found {names.Length} possible candidates.");
+ throw new ArgumentException($"Ambiguous name, cannot identify resource - " +
+ $"found {names.Length} possible candidates.");
+ }
+ else
+ {
+ return names[0];
}
-
- return names.Single();
}
public static string[] FindResourceNames(Func filter)
@@ -66,7 +77,8 @@ public static Stream GetStream(string resourceName)
var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(fullResourceName);
if (stream == null)
{
- throw new ArgumentException($"Could not find resource '{resourceName}'. The desired file might not have been defined as Embedded Resource.");
+ throw new ArgumentException($"Could not find resource '{resourceName}'. " +
+ $"The desired file might not have been defined as Embedded Resource.");
}
return stream;
}
diff --git a/src/OnnxSharp.Test/GraphExtensionsTest.cs b/src/OnnxSharp.Test/GraphExtensionsTest.cs
new file mode 100644
index 0000000..17dd639
--- /dev/null
+++ b/src/OnnxSharp.Test/GraphExtensionsTest.cs
@@ -0,0 +1,171 @@
+using System;
+using System.IO;
+using Google.Protobuf;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Onnx;
+
+namespace OnnxSharp.Test
+{
+ [TestClass]
+ public class GraphExtensionsTest
+ {
+ readonly Func m_createStream = () => AssemblyResourceLoader.GetStream("mnist-8.onnx");
+
+ [TestMethod]
+ public void ParseFrom()
+ {
+ // Act
+ var model = ModelProto.Parser.ParseFrom(m_createStream);
+
+ // Assert
+ var graph = model.Graph;
+ // 9 inputs since includes initializers
+ Assert.AreEqual(9, graph.Input.Count);
+ Assert.AreEqual(1, graph.Output.Count);
+ }
+
+ [TestMethod]
+ public void Info()
+ {
+ // Arrange
+ var model = ModelProto.Parser.ParseFrom(m_createStream);
+
+ // Act
+ var actual = model.Graph.Info();
+
+ // Assert
+ var expected = ExpectedInfo;
+ Assert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Clean()
+ {
+ // Arrange
+ var model = ModelProto.Parser.ParseFrom(m_createStream);
+
+ // Act
+ model.Graph.Clean();
+
+ // Assert
+ var graph = model.Graph;
+ Assert.AreEqual(1, graph.Input.Count);
+ Assert.AreEqual(1, graph.Output.Count);
+ var expectedName = $"mnist-8-expected-{nameof(Clean)}.onnx";
+ AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
+ }
+
+ [TestMethod]
+ public void RemoveInitializersFromInputs()
+ {
+ // Arrange
+ var model = ModelProto.Parser.ParseFrom(m_createStream);
+
+ // Act
+ model.Graph.RemoveInitializersFromInputs();
+
+ // Assert
+ var graph = model.Graph;
+ Assert.AreEqual(1, graph.Input.Count);
+ Assert.AreEqual(1, graph.Output.Count);
+ var expectedName = $"mnist-8-expected-{nameof(RemoveInitializersFromInputs)}.onnx";
+ AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
+ }
+
+ [TestMethod]
+ public void RemoveUnnecessaryInitializerReshapes()
+ {
+ // Arrange
+ var model = ModelProto.Parser.ParseFrom(m_createStream);
+
+ // Act
+ model.Graph.RemoveUnnecessaryInitializerReshapes();
+
+ // Assert
+ var graph = model.Graph;
+ Assert.AreEqual(8, graph.Input.Count);
+ Assert.AreEqual(1, graph.Output.Count);
+ var expectedName = $"mnist-8-expected-{nameof(RemoveUnnecessaryInitializerReshapes)}.onnx";
+ AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
+ }
+
+ [TestMethod]
+ public void SetDim()
+ {
+ // Arrange
+ var model = ModelProto.Parser.ParseFrom(m_createStream);
+
+ // Act
+ model.Graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
+
+ // Assert
+ var graph = model.Graph;
+ Assert.AreEqual(9, graph.Input.Count);
+ Assert.AreEqual(1, graph.Output.Count);
+ var expectedName = $"mnist-8-expected-{nameof(SetDim)}.onnx";
+ AssertModelBytesEqualToEmbeddedExpected(model, expectedName);
+ }
+
+ static void AssertModelBytesEqualToEmbeddedExpected(ModelProto model, string expectedName)
+ {
+ var actualBytes = model.ToByteArray();
+ //model.WriteToFile(expectedName);
+ var expectedBytes = AssemblyResourceLoader.GetBytes(expectedName);
+ CollectionAssert.AreEqual(expectedBytes, actualBytes);
+ }
+
+ const string ExpectedInfo = @"## Inputs without Initializer
+### Tensors
+|Name |Type |ElemType|Shape |SizeInFile|
+|:-----|:---------|:-------|--------:|---------:|
+|Input3|TensorType|Float |1x1x28x28| 32|
+
+## Outputs
+### Tensors
+|Name |Type |ElemType|Shape|SizeInFile|
+|:---------------|:---------|:-------|----:|---------:|
+|Plus214_Output_0|TensorType|Float | 1x10| 34|
+
+## Inputs with Initializer
+### Tensors
+|Name |Type |ElemType|Shape |SizeInFile|
+|:---------------------------------|:---------|:-------|--------:|---------:|
+|Parameter5 |TensorType|Float | 8x1x5x5| 36|
+|Parameter6 |TensorType|Float | 8x1x1| 32|
+|Parameter87 |TensorType|Float | 16x8x5x5| 37|
+|Parameter88 |TensorType|Float | 16x1x1| 33|
+|Pooling160_Output_0_reshape0_shape|TensorType|Int64 | 2| 48|
+|Parameter193 |TensorType|Float |16x4x4x10| 38|
+|Parameter193_reshape1_shape |TensorType|Int64 | 2| 41|
+|Parameter194 |TensorType|Float | 1x10| 30|
+
+## Initializers (Parameters etc.)
+|Name |DataType|Dims |Π(Dims)|[v0,v1..vN] | (Min,Mean,Max) |SizeInFile|
+|:---------------------------------|:-------|--------:|------:|-----------------------------------:|---------:|
+|Parameter193 |Float |16x4x4x10| 2560|(-7.595E-001,-1.779E-003,1.186E+000)| 10265|
+|Parameter87 |Float | 16x8x5x5| 3200|(-5.089E-001,-3.028E-002,5.647E-001)| 12824|
+|Parameter5 |Float | 8x1x5x5| 200|(-9.727E-001,-7.360E-003,1.019E+000)| 823|
+|Parameter6 |Float | 8x1x1| 8|(-4.338E-001,-1.023E-001,9.164E-002)| 53|
+|Parameter88 |Float | 16x1x1| 16|(-4.147E-001,-1.554E-001,1.328E-002)| 86|
+|Pooling160_Output_0_reshape0_shape|Int64 | 2| 2| [1,256]| 46|
+|Parameter193_reshape1_shape |Int64 | 2| 2| [256,10]| 39|
+|Parameter194 |Float | 1x10| 10|(-1.264E-001,-4.777E-006,1.402E-001)| 62|
+
+## Value Infos
+### Tensors
+|Name |Type |ElemType|Shape |SizeInFile|
+|:---------------------------|:---------|:-------|---------:|---------:|
+|Parameter193_reshape1 |TensorType|Float | 256x10| 40|
+|Convolution28_Output_0 |TensorType|Float | 1x8x28x28| 48|
+|Plus30_Output_0 |TensorType|Float | 1x8x28x28| 41|
+|ReLU32_Output_0 |TensorType|Float | 1x8x28x28| 41|
+|Pooling66_Output_0 |TensorType|Float | 1x8x14x14| 44|
+|Convolution110_Output_0 |TensorType|Float |1x16x14x14| 49|
+|Plus112_Output_0 |TensorType|Float |1x16x14x14| 42|
+|ReLU114_Output_0 |TensorType|Float |1x16x14x14| 42|
+|Pooling160_Output_0 |TensorType|Float | 1x16x4x4| 45|
+|Pooling160_Output_0_reshape0|TensorType|Float | 1x256| 47|
+|Times212_Output_0 |TensorType|Float | 1x10| 35|
+";
+ }
+}
diff --git a/src/OnnxSharp.Test/MnistTest.cs b/src/OnnxSharp.Test/MnistTest.cs
deleted file mode 100644
index 818b551..0000000
--- a/src/OnnxSharp.Test/MnistTest.cs
+++ /dev/null
@@ -1,41 +0,0 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Onnx;
-
-namespace OnnxSharp.Test
-{
- [TestClass]
- public class MnistTest
- {
- [TestMethod]
- public void ParseFrom()
- {
- // Arrange
- using var stream = AssemblyResourceLoader.GetStream("mnist-8.onnx");
-
- // Act
- var model = ModelProto.Parser.ParseFrom(stream);
-
- // Assert
- var graph = model.Graph;
- // 9 inputs since includes initializers
- Assert.AreEqual(9, graph.Input.Count);
- Assert.AreEqual(1, graph.Output.Count);
- }
-
- [TestMethod]
- public void RemoveInitializersFromInputs()
- {
- // Arrange
- using var stream = AssemblyResourceLoader.GetStream("mnist-8.onnx");
- var model = ModelProto.Parser.ParseFrom(stream);
-
- // Act
- model.RemoveInitializersFromInputs();
-
- // Assert
- var graph = model.Graph;
- Assert.AreEqual(1, graph.Input.Count);
- Assert.AreEqual(1, graph.Output.Count);
- }
- }
-}
diff --git a/src/OnnxSharp.Test/ModelProtoTestExtensions.cs b/src/OnnxSharp.Test/ModelProtoTestExtensions.cs
new file mode 100644
index 0000000..8a54169
--- /dev/null
+++ b/src/OnnxSharp.Test/ModelProtoTestExtensions.cs
@@ -0,0 +1,25 @@
+using System.IO;
+using System.Text.Json;
+using Google.Protobuf;
+using Onnx;
+
+namespace OnnxSharp.Test
+{
+ public static class ModelProtoTestExtensions
+ {
+ public static void WriteIndentedJsonToFile(this ModelProto model, string filePath)
+ {
+ var jsonText = JsonFormatter.Default.Format(model);
+ var jsonElement = JsonSerializer.Deserialize(jsonText);
+
+ var options = new JsonSerializerOptions() { WriteIndented = true };
+
+ var jsonTextPretty = JsonSerializer.Serialize(jsonElement, options);
+ File.WriteAllText(filePath, jsonTextPretty);
+ // Below does not indent
+ //using var stream = File.Open(filePath, FileMode.Create);
+ //using var writer = new Utf8JsonWriter(stream);
+ //JsonSerializer.Serialize(writer, jsonElement, options);
+ }
+ }
+}
diff --git a/src/OnnxSharp.Test/OnnxSharp.Test.csproj b/src/OnnxSharp.Test/OnnxSharp.Test.csproj
index 2f5cac4..94a953e 100644
--- a/src/OnnxSharp.Test/OnnxSharp.Test.csproj
+++ b/src/OnnxSharp.Test/OnnxSharp.Test.csproj
@@ -7,12 +7,12 @@
- net461;net50
+ net461;net5.0
-
+
@@ -25,6 +25,10 @@
+
+
+
+
diff --git a/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx b/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx
new file mode 100644
index 0000000..d708e07
Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-Clean.onnx differ
diff --git a/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx b/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx
new file mode 100644
index 0000000..0312478
Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-RemoveInitializersFromInputs.onnx differ
diff --git a/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx b/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx
new file mode 100644
index 0000000..cd01ee7
Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-RemoveUnnecessaryInitializerReshapes.onnx differ
diff --git a/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx b/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx
new file mode 100644
index 0000000..5d5e8fb
Binary files /dev/null and b/src/OnnxSharp.Test/mnist-8-expected-SetDim.onnx differ
diff --git a/src/OnnxSharp/Collections/ListExtensions.cs b/src/OnnxSharp/Collections/ListExtensions.cs
new file mode 100644
index 0000000..5fe0d6f
--- /dev/null
+++ b/src/OnnxSharp/Collections/ListExtensions.cs
@@ -0,0 +1,51 @@
+using System;
+using System.Collections.Generic;
+
+namespace Onnx.Collections
+{
+ /// Convenience extension methods for .
+ internal static class ListExtensions
+ {
+ internal static bool TryRemove(this IList fields, Func select, Predicate predicate)
+ {
+ for (int i = 0; i < fields.Count; i++)
+ {
+ var field = fields[i];
+ var value = select(field);
+ if (predicate(value))
+ {
+ fields.RemoveAt(i);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ internal static bool TryRemove(this IList fields, Func select, TSelect valueToRemove)
+ where TSelect : IEquatable
+ {
+ var index = fields.IndexOf(select, valueToRemove);
+ if (index >= 0)
+ {
+ fields.RemoveAt(index);
+ return true;
+ }
+ return false;
+ }
+
+ internal static int IndexOf(this IList fields, Func select, TSelect valueToFind)
+ where TSelect : IEquatable
+ {
+ for (int i = 0; i < fields.Count; i++)
+ {
+ var field = fields[i];
+ var value = select(field);
+ if (value.Equals(valueToFind))
+ {
+ return i;
+ }
+ }
+ return -1;
+ }
+ }
+}
diff --git a/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs b/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs
new file mode 100644
index 0000000..b4f8612
--- /dev/null
+++ b/src/OnnxSharp/Collections/ReadOnlyListExtensions.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+
+namespace Onnx.Collections
+{
+ /// Convenience extension methods for .
+ public static class ReadOnlyListExtensions
+ {
+ /// Compute the product of all values.
+ public static long Product(this IReadOnlyList values)
+ {
+ var product = 1L;
+ for (int i = 0; i < values.Count; i++)
+ {
+ product *= values[i];
+ }
+ return product;
+ }
+
+ internal static T Single(this IReadOnlyList fields, Func select, TSelect valueToFind)
+ where TSelect : IEquatable
+ {
+ for (int i = 0; i < fields.Count; i++)
+ {
+ var field = fields[i];
+ var value = select(field);
+ if (value.Equals(valueToFind))
+ {
+ return field;
+ }
+ }
+ throw new ArgumentException($"Could not find field with value '{valueToFind}'");
+ }
+ }
+}
diff --git a/src/OnnxSharp/DimParamOrValue.cs b/src/OnnxSharp/DimParamOrValue.cs
new file mode 100644
index 0000000..64b4569
--- /dev/null
+++ b/src/OnnxSharp/DimParamOrValue.cs
@@ -0,0 +1,66 @@
+using System;
+
+namespace Onnx
+{
+ ///
+ /// Dimension represented either a string 'Param' or an integer 'Value'.
+ ///
+ public readonly struct DimParamOrValue
+ {
+ readonly string _param;
+ readonly int _value;
+
+ private DimParamOrValue(string param, int value)
+ {
+ _param = param;
+ _value = value;
+ }
+
+ /// Create a new named dimension parameter.
+ public static DimParamOrValue New(string param)
+ {
+ if (!IsParamValid(param))
+ {
+ throw new ArgumentException($"{nameof(param)} '{param}' must be a non-whitespace string like 'N'.");
+ }
+ return new DimParamOrValue(param, default);
+ }
+
+ /// Create a new fixed size dimension.
+ public static DimParamOrValue New(int value) =>
+ new DimParamOrValue(default, value);
+
+ /// Get dimension as a named parameter string.
+ public string Param { get { CheckIsParam(); return _param; } }
+ /// Get dimension as an integer value.
+ public int Value { get { CheckIsValue(); return _value; } }
+
+ /// Is the dimension a named parameter.
+ public bool IsParam => IsParamValid(_param);
+ /// Is the dimension a fixed sized integer.
+ public bool IsValue => !IsParam;
+
+ /// Converts the dimension to its equivalent string representation.
+ public override string ToString() => IsParam ? Param : Value.ToString();
+ /// Returns the hash code for this instance.
+ public override int GetHashCode() => IsParam ? Param.GetHashCode() : Value.GetHashCode();
+
+ void CheckIsParam()
+ {
+ if (IsValue)
+ {
+ throw new ArgumentException($"{nameof(DimParamOrValue)} is a value '{_value}' not a param.");
+ }
+ }
+
+ void CheckIsValue()
+ {
+ if (IsParam)
+ {
+ throw new ArgumentException($"{nameof(DimParamOrValue)} is a param '{_param}' not a value.");
+ }
+ }
+
+ static bool IsParamValid(string param) => !string.IsNullOrWhiteSpace(param);
+ }
+}
diff --git a/src/OnnxSharp/Formatting/Align.cs b/src/OnnxSharp/Formatting/Align.cs
new file mode 100644
index 0000000..945be16
--- /dev/null
+++ b/src/OnnxSharp/Formatting/Align.cs
@@ -0,0 +1,8 @@
+namespace Onnx.Formatting
+{
+ internal enum Align
+ {
+ Left,
+ Right,
+ }
+}
\ No newline at end of file
diff --git a/src/OnnxSharp/Formatting/ColumnSpec.cs b/src/OnnxSharp/Formatting/ColumnSpec.cs
new file mode 100644
index 0000000..2d91ee4
--- /dev/null
+++ b/src/OnnxSharp/Formatting/ColumnSpec.cs
@@ -0,0 +1,14 @@
+using System;
+
+namespace Onnx.Formatting
+{
+ internal record ColumnSpec(string Name, Align Align);
+ internal record ColumnSpec(string Name, Align Align, Func Get) : ColumnSpec(Name, Align);
+}
+
+// https://stackoverflow.com/questions/64749385/predefined-type-system-runtime-compilerservices-isexternalinit-is-not-defined
+namespace System.Runtime.CompilerServices
+{
+ internal static class IsExternalInit { }
+}
+
diff --git a/src/OnnxSharp/Formatting/ColumnSpecs.cs b/src/OnnxSharp/Formatting/ColumnSpecs.cs
new file mode 100644
index 0000000..4f9dc5a
--- /dev/null
+++ b/src/OnnxSharp/Formatting/ColumnSpecs.cs
@@ -0,0 +1,177 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using Google.Protobuf;
+using Onnx.Collections;
+
+namespace Onnx.Formatting
+{
+ internal static partial class ColumnSpecs
+ {
+ internal static partial class ValueInfo
+ {
+ internal static readonly IReadOnlyList> Tensor =
+ new ColumnSpec[]
+ {
+ new ("Name", Align.Left, i => i.Name),
+ new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
+ new ("ElemType", Align.Left, i => i.Type.TensorType.ElemType().ToString()),
+ new ("Shape", Align.Right, i => FormatShape(i.Type.TensorType.Shape)),
+ new ("SizeInFile", Align.Right, i => i.CalculateSize().ToString()),
+ };
+
+ internal static readonly IReadOnlyList> Sequence =
+ new ColumnSpec[]
+ {
+ new ("Name", Align.Left, i => i.Name),
+ new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
+ new ("ElemType", Align.Left, i => i.Type.SequenceType.ElemType.ValueCase.ToString()),
+ new ("SizeInFile", Align.Left, i => i.CalculateSize().ToString()),
+ };
+
+ internal static readonly IReadOnlyList> Map =
+ new ColumnSpec[]
+ {
+ new ("Name", Align.Left, i => i.Name),
+ new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
+ new ("KeyType", Align.Left, i => i.Type.MapType.KeyType().ToString()),
+ new ("ValueType", Align.Left, i => i.Type.MapType.ValueType.ValueCase.ToString()),
+ new ("SizeInFile", Align.Left, i => i.CalculateSize().ToString()),
+ };
+
+ internal static readonly IReadOnlyList> None =
+ new ColumnSpec[]
+ {
+ new ("Name", Align.Left, i => i.Name),
+ new ("Type", Align.Left, i => i.Type.ValueCase.ToString()),
+ new ("SizeInFile", Align.Left, i => i.CalculateSize().ToString()),
+ };
+ }
+
+ internal static readonly IReadOnlyList> Tensor =
+ new ColumnSpec[]
+ {
+ new ("Name", Align.Left, t => t.Name),
+ new ("DataType", Align.Left, t => t.DataType().ToString()),
+ new ("Dims", Align.Right, t => string.Join("x", t.Dims)),
+ new ("Π(Dims)", Align.Right, t => t.Dims.Product().ToString()),
+ new ("[v0,v1..vN] | (Min,Mean,Max)", Align.Right, t => FormatValuesOrStats(t)),
+ new ("SizeInFile", Align.Right, t => t.CalculateSize().ToString()),
+ };
+
+ static string FormatShape(TensorShapeProto shape)
+ {
+ return string.Join("x", shape.Dim.Select(d => Format(d)));
+ }
+
+ static string Format(TensorShapeProto.Types.Dimension d) => d.ValueCase switch
+ {
+ TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam => d.DimParam,
+ TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue => d.DimValue.ToString(),
+ TensorShapeProto.Types.Dimension.ValueOneofCase.None => "?",
+ _ => throw new NotSupportedException(d.ValueCase.ToString()),
+ };
+
+ static unsafe string FormatValuesOrStats(TensorProto tensor) => tensor.DataType() switch
+ {
+ // NOTE: Long lines accepted below for structure
+ TensorProto.Types.DataType.Float => FormatValuesOrStats(tensor.FloatData, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
+ TensorProto.Types.DataType.Double => FormatValuesOrStats(tensor.DoubleData, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
+ TensorProto.Types.DataType.Int32 => FormatValuesOrStats(tensor.Int32Data, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
+ TensorProto.Types.DataType.Int64 => FormatValuesOrStats(tensor.Int64Data, tensor.RawData, &Math.Min, (m, v) => m + v, (m, c) => m / c, &Math.Max),
+ // TODO: StringData
+ _ => "N/A",
+ };
+
+ // NOTE: Perf below is not great since function pointer and func calls cannot be inlined.
+ // If necessary refactor to use "value type functor"s.
+ static unsafe string FormatValuesOrStats(
+ IReadOnlyList values,
+ ByteString rawData,
+ delegate* min,
+ Func add,
+ Func divide,
+ delegate* max)
+ where T : struct
+ {
+ // Data may not be in typed part but in raw data
+ // Unfortunately there is no common and efficient "ground" for
+ // "IReadOnlyList values" and "ByteString rawData",
+ // so we have to go through hoops.
+ // RawData and talk about Constant nodes
+ // https://github.com/onnx/onnx/issues/2825#issuecomment-644334359
+
+ var useRawData = values.Count == 0 && rawData.Length > 0;
+ var rawValues = MemoryMarshal.Cast(rawData.Span);
+ var count = useRawData ? rawValues.Length : values.Count;
+
+ const int MaxValueCountToShow = 4;
+ if (count <= MaxValueCountToShow)
+ {
+ return useRawData
+ ? FormatValues(rawValues.ToArray())
+ : FormatValues(values);
+ }
+ else if (count > 0)
+ {
+ if (useRawData) { Thrower.EnsureLittleEndian(); }
+ var stats = useRawData
+ ? GetStats(rawValues, min, add, divide, max)
+ : GetStats(values, min, add, divide, max);
+
+ return $"({stats.min:E3},{stats.mean:E3},{stats.max:E3})";
+ }
+ else
+ {
+ return "[]";
+ }
+ }
+
+ static string FormatValues(IReadOnlyList values) => $"[{string.Join(",", values)}]";
+
+ static unsafe (T min, TMean mean, T max) GetStats(
+ ReadOnlySpan values,
+ delegate* min,
+ Func add,
+ Func divide,
+ delegate* max)
+ where T : struct
+ {
+ T minValue = values[0];
+ T maxValue = values[0];
+ TMean sum = add(default, values[0]);
+ for (int i = 1; i < values.Length; i++)
+ {
+ var value = values[i];
+ minValue = min(minValue, value);
+ maxValue = max(maxValue, value);
+ sum = add(sum, value);
+ }
+ var mean = divide(sum, values.Length);
+ return (minValue, mean, maxValue);
+ }
+
+ static unsafe (T min, TMean mean, T max) GetStats(
+ IReadOnlyList values,
+ delegate* min,
+ Func add,
+ Func divide,
+ delegate* max)
+ where T : struct
+ {
+ T minValue = values[0];
+ T maxValue = values[0];
+ TMean sum = add(default, values[0]);
+ for (int i = 1; i < values.Count; i++)
+ {
+ var value = values[i];
+ minValue = min(minValue, value);
+ maxValue = max(maxValue, value);
+ sum = add(sum, value);
+ }
+ var mean = divide(sum, values.Count);
+ return (minValue, mean, maxValue);
+ }
+ }
+}
diff --git a/src/OnnxSharp/Formatting/MarkdownFormatter.cs b/src/OnnxSharp/Formatting/MarkdownFormatter.cs
new file mode 100644
index 0000000..d78f9ba
--- /dev/null
+++ b/src/OnnxSharp/Formatting/MarkdownFormatter.cs
@@ -0,0 +1,117 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+namespace Onnx.Formatting
+{
+ internal static class MarkdownFormatter
+ {
+ internal static void FormatAsTensors(this IReadOnlyList valueInfos, TextWriter writer)
+ {
+ Format(valueInfos, ColumnSpecs.ValueInfo.Tensor, writer);
+ }
+
+ internal static void FormatAsSequences(this IReadOnlyList valueInfos, TextWriter writer)
+ {
+ Format(valueInfos, ColumnSpecs.ValueInfo.Sequence, writer);
+ }
+
+ internal static void FormatAsMaps(this IReadOnlyList valueInfos, TextWriter writer)
+ {
+ Format(valueInfos, ColumnSpecs.ValueInfo.Map, writer);
+ }
+
+ internal static void FormatAsNones(this IReadOnlyList valueInfos, TextWriter writer)
+ {
+ Format(valueInfos, ColumnSpecs.ValueInfo.None, writer);
+ }
+
+ internal static void Format(this IReadOnlyList summaries, TextWriter writer)
+ {
+ Format(summaries, ColumnSpecs.Tensor, writer);
+ }
+
+ internal static void Format(
+ IReadOnlyList values,
+ IReadOnlyList> columnSpecs,
+ TextWriter writer)
+ {
+ var maxColumnWidth = columnSpecs.Select(n => n.Name.Length).ToArray();
+
+ int rows = values.Count;
+ int cols = columnSpecs.Count;
+
+ var table = new string[rows, cols];
+ for (int row = 0; row < rows; row++)
+ {
+ var summary = values[row];
+
+ for (int col = 0; col < cols; col++)
+ {
+ var spec = columnSpecs[col];
+ var text = spec.Get(summary);
+ table[row, col] = text;
+ maxColumnWidth[col] = Math.Max(maxColumnWidth[col], text.Length);
+ }
+ }
+
+ Format(table, columnSpecs, maxColumnWidth, writer);
+ }
+
+ internal static void Format(
+ string[,] table,
+ IReadOnlyList columnSpecs,
+ IReadOnlyList columnWidths,
+ TextWriter writer)
+ {
+ // TODO: Define constants below
+
+ var rows = table.GetLength(0);
+ var cols = table.GetLength(1);
+
+ // Column Names
+ for (int col = 0; col < cols; col++)
+ {
+ var columnName = columnSpecs[col].Name;
+ writer.Write('|');
+ writer.Write(columnName);
+ writer.Write(' ', columnWidths[col] - columnName.Length);
+ }
+ writer.Write('|');
+ writer.WriteLine();
+
+ // Separator and alignment
+ for (int col = 0; col < cols; col++)
+ {
+ writer.Write('|');
+ var align = columnSpecs[col].Align;
+ if (align == Align.Left)
+ {
+ writer.Write(':');
+ }
+ writer.Write('-', columnWidths[col] - 1);
+ if (align == Align.Right)
+ {
+ writer.Write(':');
+ }
+ }
+ writer.Write('|');
+ writer.WriteLine();
+
+ // Rows
+ for (int row = 0; row < rows; row++)
+ {
+ for (int col = 0; col < cols; col++)
+ {
+ var align = columnSpecs[col].Align;
+ var value = table[row, col];
+ writer.Write('|');
+ writer.WriteAligned(value, align, ' ', columnWidths[col]);
+ }
+ writer.Write('|');
+ writer.WriteLine();
+ }
+ }
+ }
+}
diff --git a/src/OnnxSharp/Formatting/TextWriterExtensions.cs b/src/OnnxSharp/Formatting/TextWriterExtensions.cs
new file mode 100644
index 0000000..8991cbe
--- /dev/null
+++ b/src/OnnxSharp/Formatting/TextWriterExtensions.cs
@@ -0,0 +1,30 @@
+using System.IO;
+
+namespace Onnx.Formatting
+{
+ internal static class TextWriterExtensions
+ {
+ internal static void WriteAligned(this TextWriter writer,
+ string columnName, Align alignment, char pad, int width)
+ {
+ var padCount = width - columnName.Length;
+ if (alignment == Align.Right)
+ {
+ writer.Write(pad, padCount);
+ }
+ writer.Write(columnName);
+ if (alignment == Align.Left)
+ {
+ writer.Write(pad, padCount);
+ }
+ }
+
+ internal static void Write(this TextWriter writer, char value, int repeatCount)
+ {
+ for (int i = 0; i < repeatCount; i++)
+ {
+ writer.Write(value);
+ }
+ }
+ }
+}
diff --git a/src/OnnxSharp/GraphExtensions.Clean.cs b/src/OnnxSharp/GraphExtensions.Clean.cs
new file mode 100644
index 0000000..9720f12
--- /dev/null
+++ b/src/OnnxSharp/GraphExtensions.Clean.cs
@@ -0,0 +1,130 @@
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using Google.Protobuf.Collections;
+using Onnx.Collections;
+
+namespace Onnx
+{
+ /// Extension methods to ONNX graph.
+ public static partial class GraphExtensions
+ {
+ /// Clean graph for inference.
+ public static void Clean(this GraphProto graph)
+ {
+ graph.RemoveInitializersFromInputs();
+ graph.RemoveUnnecessaryInitializerReshapes();
+ }
+
+ /// Remove initializers from inputs of graph.
+ // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
+ public static void RemoveInitializersFromInputs(this GraphProto graph)
+ {
+ var inputs = graph.Input;
+ var nameToInput = inputs.ToDictionary(i => i.Name, i => i);
+
+ foreach (var initializer in graph.Initializer)
+ {
+ if (nameToInput.TryGetValue(initializer.Name, out var input))
+ {
+ // https://github.com/protocolbuffers/protobuf/blob/master/csharp/src/Google.Protobuf/Collections/RepeatedField.cs
+ var removed = inputs.Remove(input);
+ Trace.WriteLine($"{removed} {inputs.Count}");
+ }
+ }
+ }
+
+ /// Remove unnecessary initializer reshapes from graph.
+ // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
+ public static void RemoveUnnecessaryInitializerReshapes(this GraphProto graph)
+ {
+ var nameToInitializer = graph.Initializer.ToDictionary(i => i.Name, i => i);
+
+ var nodes = graph.Node;
+ var valueInfos = graph.ValueInfo;
+
+ var nodesToRemove = new List();
+ for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++)
+ {
+ var node = nodes[nodeIndex];
+
+ var opSpec = Ops.Reshape.Spec;
+ if (node.OpType == opSpec.OpType)
+ {
+ var inputs = node.Input;
+ var outputs = node.Output;
+
+ // Expected Reshape takes 2 inputs and has 1 output
+ if (inputs.Count == opSpec.Inputs && outputs.Count == opSpec.Outputs)
+ {
+ var dataName = inputs[0];
+ var shapeName = inputs[1];
+ var reshapeOutputName = outputs[0];
+
+ // Both inputs must be initializers ("static")
+ if (nameToInitializer.TryGetValue(dataName, out var dataInitializer) &&
+ nameToInitializer.TryGetValue(shapeName, out var shapeInitializer))
+ {
+ // TODO: Check initializer not used in other nodes
+
+ var outputShapeValue = valueInfos.Single(v => v.Name, reshapeOutputName);
+
+ var outputShapeDims = outputShapeValue.Type.TensorType.Shape.Dim;
+ var allValue = outputShapeDims.All(d => d.ValueCase ==
+ TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue);
+ if (allValue)
+ {
+ var outputShape = outputShapeDims.Select(d => d.DimValue).ToArray();
+
+ var allPositive = outputShape.All(d => d > 0);
+ if (allPositive)
+ {
+ // Check shape compared to initializer shape
+ var dataShape = dataInitializer.Dims.ToArray();
+
+ var outputShapeProductSum = outputShape.Product();
+ var dataShapeProductSum = dataShape.Product();
+
+ if (outputShapeProductSum == dataShapeProductSum)
+ {
+ // Change data shape to the reshape output shape directly
+ dataInitializer.Dims.Clear();
+ dataInitializer.Dims.AddRange(outputShape);
+
+ // Remove reshape data shape both as initializer and input
+ graph.Initializer.TryRemove(i => i.Name, shapeName);
+ graph.Input.TryRemove(i => i.Name, shapeName);
+
+ nodesToRemove.Add(node);
+
+ // Replace reshape output name with data name directly in all nodes
+ ReplaceInput(nodes, reshapeOutputName, dataName);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ foreach (var node in nodesToRemove)
+ {
+ nodes.Remove(node);
+ }
+ }
+
+ internal static void ReplaceInput(RepeatedField nodes, string oldValue, string newValue)
+ {
+ for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++)
+ {
+ var updateNodeInputs = nodes[nodeIndex].Input;
+ for (int inputIndex = 0; inputIndex < updateNodeInputs.Count; inputIndex++)
+ {
+ if (updateNodeInputs[inputIndex] == oldValue)
+ {
+ updateNodeInputs[inputIndex] = newValue;
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/OnnxSharp/GraphExtensions.Info.cs b/src/OnnxSharp/GraphExtensions.Info.cs
new file mode 100644
index 0000000..2334250
--- /dev/null
+++ b/src/OnnxSharp/GraphExtensions.Info.cs
@@ -0,0 +1,71 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using Onnx.Formatting;
+
+namespace Onnx
+{
+ public static partial class GraphExtensions
+ {
+ /// Summarize information about the .
+ public static string Info(this GraphProto graph)
+ {
+ var writer = new StringWriter();
+ graph.Info(writer);
+ return writer.ToString();
+ }
+
+ /// Summarize information about the .
+ public static void Info(this GraphProto graph, TextWriter writer)
+ {
+ var initializerNameSet = new HashSet(graph.Initializer.Select(i => i.Name));
+ var inferenceInputs = graph.Input.Where(i => !initializerNameSet.Contains(i.Name)).ToList();
+ var initializerInputs = graph.Input.Where(i => initializerNameSet.Contains(i.Name)).ToList();
+
+ writer.WriteLine("## Inputs without Initializer");
+ Info(inferenceInputs, writer);
+
+ writer.WriteLine();
+ writer.WriteLine("## Outputs");
+ Info(graph.Output, writer);
+
+ writer.WriteLine();
+ writer.WriteLine("## Inputs with Initializer");
+ Info(initializerInputs, writer);
+
+ writer.WriteLine();
+ writer.WriteLine("## Initializers (Parameters etc.)");
+ MarkdownFormatter.Format(graph.Initializer, writer);
+
+ writer.WriteLine();
+ writer.WriteLine("## Value Infos");
+ Info(graph.ValueInfo, writer);
+ }
+
+ static void Info(IReadOnlyList valueInfos, TextWriter writer)
+ {
+ var tensorTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.TensorType).ToList();
+ WriteInfoIfAny(tensorTypes, "Tensors", MarkdownFormatter.FormatAsTensors, writer);
+
+ var sequenceTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.SequenceType).ToList();
+ WriteInfoIfAny(sequenceTypes, "Sequences", MarkdownFormatter.FormatAsSequences, writer);
+
+ var mapTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.MapType).ToList();
+ WriteInfoIfAny(mapTypes, "Maps", MarkdownFormatter.FormatAsMaps, writer);
+
+ var noneTypes = valueInfos.Where(i => i.Type.ValueCase == TypeProto.ValueOneofCase.None).ToList();
+ WriteInfoIfAny(noneTypes, "Nones", MarkdownFormatter.FormatAsNones, writer);
+ }
+
+ static void WriteInfoIfAny(IReadOnlyList values, string name,
+ Action, TextWriter> info, TextWriter writer)
+ {
+ if (values.Count > 0)
+ {
+ writer.WriteLine($"### {name}");
+ info(values, writer);
+ }
+ }
+ }
+}
diff --git a/src/OnnxSharp/GraphExtensions.SetDim.cs b/src/OnnxSharp/GraphExtensions.SetDim.cs
new file mode 100644
index 0000000..4ba2530
--- /dev/null
+++ b/src/OnnxSharp/GraphExtensions.SetDim.cs
@@ -0,0 +1,152 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.InteropServices;
+using Google.Protobuf;
+using Google.Protobuf.Collections;
+using Onnx.Collections;
+
+namespace Onnx
+{
+ public static partial class GraphExtensions
+ {
+ ///
+ /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
+ /// Default sets leading dimension to dynamic batch size 'N'.
+ ///
+ public static void SetDim(this GraphProto graph) =>
+ graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
+
+ ///
+ /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
+ /// Can be used to make models have dynamic batch size or different static batch sizes.
+ ///
+ public static void SetDim(this GraphProto graph, int dimIndex, DimParamOrValue dimParamOrValue)
+ {
+ // Reshape ops have their "new shape" defined as input to the reshape op.
+ // This input needs to be changed to reflect new dim e.g. be set -1 if dynamic.
+ var reshapeDimValue = dimParamOrValue.IsParam
+ ? Ops.Reshape.DynamicReshapeValue
+ : dimParamOrValue.Value;
+ SetDimInReshapes(graph, dimIndex, reshapeDimValue);
+
+ // Should we set this based on nodes instead? Handling input, outputs based on that?
+
+ // Shapes are defined in inputs, valueInfos and outputs
+ //
+ // Only real inputs should be changed, not "initializer" inputs
+ var initializserNames = new HashSet(graph.Initializer.Select(i => i.Name));
+ var inferenceInputs = graph.Input.Where(i => !initializserNames.Contains(i.Name));
+ foreach (var input in inferenceInputs)
+ {
+ SetDim(input, dimIndex, dimParamOrValue);
+ }
+ //SetDim(graph.Input, dimIndex, dimParam);
+
+ SetDim(graph.ValueInfo, dimIndex, dimParamOrValue);
+ SetDim(graph.Output, dimIndex, dimParamOrValue);
+ }
+
+ static void SetDimInReshapes(GraphProto graph, int dimIndex, int dimValue)
+ {
+ var nodes = graph.Node;
+ var initializers = graph.Initializer;
+
+ // TODO: Only fix reshapes that have data input and with dynamic shape after
+
+ var opSpec = Ops.Reshape.Spec;
+ foreach (var node in nodes)
+ {
+ if (node.OpType == opSpec.OpType)
+ {
+ var dataInputName = node.Input[Ops.Reshape.InputDataIndex];
+
+ // Check if data input is an initializer if so we should not change the reshape
+ // and hence skip this reshape node
+ var dataInitializerIndex = initializers.IndexOf(t => t.Name, dataInputName);
+ if (dataInitializerIndex >= 0)
+ { continue; }
+
+ var shapeInputName = node.Input[Ops.Reshape.InputShapeIndex];
+
+ var shape = initializers.Single(tensor => tensor.Name, shapeInputName);
+
+ SetDimInReshapeTensorShape(shape, dimIndex, dimValue);
+ }
+ }
+ }
+
+ static void SetDimInReshapeTensorShape(TensorProto shape, int dimIndex, int dimValue)
+ {
+ Debug.Assert(shape.DataType == (int)TensorProto.Types.DataType.Int64);
+ var dims = shape.Dims;
+ if (dims.Count > 0 && dims[dimIndex] > 0)
+ {
+ // Data may be stored as Int64 or Raw (fixed-width, little-endian)
+ if (shape.Int64Data.Count > 0)
+ {
+ var int64Data = shape.Int64Data;
+ if (int64Data[dimIndex] == 1) // Dimension we replace
+ {
+ int64Data[dimIndex] = dimValue;
+ }
+ }
+ if (!shape.RawData.IsEmpty)
+ {
+ var rawData = shape.RawData;
+ var rawAsInt64Data = MemoryMarshal.Cast(rawData.Span);
+ Debug.Assert(rawAsInt64Data.Length == dims[dimIndex]);
+ if (rawAsInt64Data[dimIndex] == 1) // Dimension we replace
+ {
+ var newShape = rawAsInt64Data.ToArray();
+ newShape[dimIndex] = dimValue;
+ var newShapeBytes = MemoryMarshal.Cast(newShape.AsSpan());
+ shape.RawData = ByteString.CopyFrom(newShapeBytes);
+ }
+ }
+ }
+ }
+
+ internal static void SetDim(RepeatedField valueInfos,
+ int dimIndex, DimParamOrValue dimParamOrValue)
+ {
+ for (int i = 0; i < valueInfos.Count; i++)
+ {
+ var valueInfo = valueInfos[i];
+ SetDim(valueInfo, dimIndex, dimParamOrValue);
+ }
+ }
+
+ internal static void SetDim(ValueInfoProto valueInfo,
+ int dimIndex, DimParamOrValue dimParamOrValue)
+ {
+ var shape = valueInfo.Type.TensorType.Shape;
+ var dims = shape.Dim;
+ var dim = dims[dimIndex];
+ if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue)
+ {
+ // TODO: Should perhaps be parameter that says
+ // bool shouldSetDimFor(dim)
+ if (dim.DimValue == 1)
+ {
+ SetDim(dim, dimParamOrValue);
+ }
+ }
+ }
+
+ internal static void SetDim(TensorShapeProto.Types.Dimension dim,
+ DimParamOrValue dimParamOrValue)
+ {
+ dim.ClearValue();
+ if (dimParamOrValue.IsParam)
+ {
+ dim.DimParam = dimParamOrValue.Param;
+ }
+ else
+ {
+ dim.DimValue = dimParamOrValue.Value;
+ }
+ }
+ }
+}
diff --git a/src/OnnxSharp/MessageExtensions.cs b/src/OnnxSharp/MessageExtensions.cs
new file mode 100644
index 0000000..20928b7
--- /dev/null
+++ b/src/OnnxSharp/MessageExtensions.cs
@@ -0,0 +1,29 @@
+using Google.Protobuf;
+using System.IO;
+
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static partial class MessageExtensions
+ {
+ ///
+ /// Writes the given data to the
+ /// given in protobuf encoding.
+ ///
+ public static void WriteToFile(this IMessage message, string filePath)
+ {
+ using var stream = File.Open(filePath, FileMode.Create);
+ message.WriteTo(stream);
+ }
+
+ ///
+ /// Writes the given data to the
+ /// given in JSON encoding.
+ ///
+ public static void WriteJsonToFile(this IMessage message, string filePath)
+ {
+ using var writer = new StreamWriter(filePath);
+ JsonFormatter.Default.Format(message, writer);
+ }
+ }
+}
diff --git a/src/OnnxSharp/MessageParserExtensions.cs b/src/OnnxSharp/MessageParserExtensions.cs
new file mode 100644
index 0000000..5673f24
--- /dev/null
+++ b/src/OnnxSharp/MessageParserExtensions.cs
@@ -0,0 +1,31 @@
+using System;
+using System.IO;
+using Google.Protobuf;
+
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static partial class MessageParserExtensions
+ {
+ ///
+ /// Parse from file via .
+ ///
+ public static T ParseFromFile(this MessageParser parser, string filePath)
+ where T : IMessage
+ {
+ using var stream = File.Open(filePath, FileMode.Open);
+ return parser.ParseFrom(stream);
+ }
+
+ ///
+ /// Parse from file via
+ /// and disposes the created stream after parsing is done.
+ ///
+ public static T ParseFrom(this MessageParser parser, Func createStream)
+ where T : IMessage
+ {
+ using var stream = createStream();
+ return parser.ParseFrom(stream);
+ }
+ }
+}
diff --git a/src/OnnxSharp/OnnxExtensions.cs b/src/OnnxSharp/OnnxExtensions.cs
deleted file mode 100644
index f7a5af7..0000000
--- a/src/OnnxSharp/OnnxExtensions.cs
+++ /dev/null
@@ -1,42 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-
-namespace Onnx
-{
- ///
- /// Extensions to ONNX protobuf functionality.
- ///
- public static partial class OnnxExtensions
- {
- ///
- /// Remove initializers from inputs of graph in model.
- ///
- public static ModelProto RemoveInitializersFromInputs(this ModelProto model)
- {
- RemoveInitializersFromInputs(model.Graph);
- return model;
- }
-
- ///
- /// Remove initializers from inputs of graph.
- ///
- // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
- public static GraphProto RemoveInitializersFromInputs(GraphProto graph)
- {
- var inputs = graph.Input;
- var nameToInput = inputs.ToDictionary(i => i.Name, i => i);
-
- foreach (var initializer in graph.Initializer)
- {
- if (nameToInput.TryGetValue(initializer.Name, out var input))
- {
- inputs.Remove(input);
- }
- }
-
- return graph;
- }
- }
-}
diff --git a/src/OnnxSharp/OnnxSharp.csproj b/src/OnnxSharp/OnnxSharp.csproj
index 7dc0532..31b4ac1 100644
--- a/src/OnnxSharp/OnnxSharp.csproj
+++ b/src/OnnxSharp/OnnxSharp.csproj
@@ -20,6 +20,8 @@
https://github.com/nietras/OnnxSharp
true
true
+
+ true
diff --git a/src/OnnxSharp/Ops.cs b/src/OnnxSharp/Ops.cs
new file mode 100644
index 0000000..c405342
--- /dev/null
+++ b/src/OnnxSharp/Ops.cs
@@ -0,0 +1,33 @@
+using System;
+
+namespace Onnx
+{
+ internal static class Ops
+ {
+ internal static class Reshape
+ {
+ internal const int InputDataIndex = 0;
+ internal const int InputShapeIndex = 1;
+
+ // Reshape op supports only one dimension in shape to be dynamic,
+ // which is defined as -1.
+ internal const int DynamicReshapeValue = -1;
+
+ internal static readonly OpSpec Spec = new OpSpec(nameof(Reshape), 2, 1);
+ }
+
+ internal readonly struct OpSpec
+ {
+ public OpSpec(string opType, int inputs, int outputs)
+ {
+ OpType = opType ?? throw new ArgumentNullException(nameof(opType));
+ Inputs = inputs;
+ Outputs = outputs;
+ }
+
+ public string OpType { get; }
+ public int Inputs { get; }
+ public int Outputs { get; }
+ }
+ }
+}
diff --git a/src/OnnxSharp/TensorProtoExtensions.cs b/src/OnnxSharp/TensorProtoExtensions.cs
new file mode 100644
index 0000000..d025d59
--- /dev/null
+++ b/src/OnnxSharp/TensorProtoExtensions.cs
@@ -0,0 +1,10 @@
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static class TensorProtoExtensions
+ {
+ /// Get data type of as enum.
+ public static TensorProto.Types.DataType DataType(this TensorProto tensor) =>
+ (TensorProto.Types.DataType)tensor.DataType;
+ }
+}
diff --git a/src/OnnxSharp/Thrower.cs b/src/OnnxSharp/Thrower.cs
new file mode 100644
index 0000000..7b9121b
--- /dev/null
+++ b/src/OnnxSharp/Thrower.cs
@@ -0,0 +1,18 @@
+using System;
+
+namespace Onnx
+{
+ internal static class Thrower
+ {
+ internal static void EnsureLittleEndian()
+ {
+ if (!BitConverter.IsLittleEndian)
+ {
+ var message = "Only little-endian systems are supported. " +
+ "This is due to raw data in onnx files being stored in little-endian order and " +
+ "conversion to big-endian has not implemented.";
+ throw new NotSupportedException(message);
+ }
+ }
+ }
+}
diff --git a/src/OnnxSharp/TypeProtoTypesMapExtensions.cs b/src/OnnxSharp/TypeProtoTypesMapExtensions.cs
new file mode 100644
index 0000000..58e5add
--- /dev/null
+++ b/src/OnnxSharp/TypeProtoTypesMapExtensions.cs
@@ -0,0 +1,10 @@
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static class TypeProtoTypesMapExtensions
+ {
+ /// Get key data type of as enum.
+ public static TensorProto.Types.DataType KeyType(this TypeProto.Types.Map map) =>
+ (TensorProto.Types.DataType)map.KeyType;
+ }
+}
diff --git a/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs b/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs
new file mode 100644
index 0000000..b0f7b2e
--- /dev/null
+++ b/src/OnnxSharp/TypeProtoTypesSequenceExtensions.cs
@@ -0,0 +1,7 @@
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static class TypeProtoTypesSequenceExtensions
+ {
+ }
+}
diff --git a/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs b/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs
new file mode 100644
index 0000000..6835069
--- /dev/null
+++ b/src/OnnxSharp/TypeProtoTypesTensorExtensions.cs
@@ -0,0 +1,10 @@
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static class TypeProtoTypesTensorExtensions
+ {
+ /// Get element data type of as enum.
+ public static TensorProto.Types.DataType ElemType(this TypeProto.Types.Tensor tensor) =>
+ (TensorProto.Types.DataType)tensor.ElemType;
+ }
+}
diff --git a/src/OnnxSharp/ValueInfoProtoExtensions.cs b/src/OnnxSharp/ValueInfoProtoExtensions.cs
new file mode 100644
index 0000000..0f66a60
--- /dev/null
+++ b/src/OnnxSharp/ValueInfoProtoExtensions.cs
@@ -0,0 +1,7 @@
+namespace Onnx
+{
+ /// Convenience extension methods.
+ public static class ValueInfoProtoExtensions
+ {
+ }
+}
diff --git a/src/OnnxSharpConsole/OnnxSharpConsole.csproj b/src/OnnxSharpConsole/OnnxSharpConsole.csproj
index 4615a04..7969b60 100644
--- a/src/OnnxSharpConsole/OnnxSharpConsole.csproj
+++ b/src/OnnxSharpConsole/OnnxSharpConsole.csproj
@@ -8,7 +8,7 @@
Exe
- net50
+ net5.0
false
diff --git a/src/OnnxSharpConsole/Program.cs b/src/OnnxSharpConsole/Program.cs
index ed574a5..54a4fba 100644
--- a/src/OnnxSharpConsole/Program.cs
+++ b/src/OnnxSharpConsole/Program.cs
@@ -1,46 +1,18 @@
-using System.IO;
-using System.Linq;
-using Google.Protobuf;
+using Onnx;
// Examples see https://github.com/onnx/models
var onnxInputFilePath = @"mnist-8.onnx";
-var onnxInputFileName = Path.GetFileNameWithoutExtension(onnxInputFilePath);
-var outputDirectory = Path.GetDirectoryName(onnxInputFilePath);
+var model = ModelProto.Parser.ParseFromFile(onnxInputFilePath);
-using (var file = File.OpenRead(onnxInputFilePath))
-{
- var model = Onnx.ModelProto.Parser.ParseFrom(file);
- var graph = model.Graph;
- // Shapes are defined in inputs, values and outputs
- var inputs = graph.Input;
- var values = graph.ValueInfo;
- var outputs = graph.Output;
+var graph = model.Graph;
+// Clean graph e.g. remove initializers from inputs that may prevent constant folding
+graph.Clean();
+// Set dimension in graph to enable dynamic batch size during inference
+graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
+// Get summarized info about the graph
+var info = graph.Info();
- foreach (var value in inputs.Concat(values).Concat(outputs))
- {
- var shape = value.Type.TensorType.Shape;
- var dims = shape.Dim;
- var dim = dims[0];
- //dim.DimValue = -1;
- dim.ClearValue();
- dim.DimValue = -1; // Or don't set it
- //dim.DimParam = "None"; // Or don't set it, unset dimension means dynamic
- }
+System.Console.WriteLine(info);
- var fileNameSuffix = "-dynamic-leading-dimension";
- var outputFilePathPrefix = Path.Combine(outputDirectory, onnxInputFileName + fileNameSuffix);
-
- var onnxOutputFilePath = outputFilePathPrefix + ".onnx";
- using (var outputFile = File.Create(onnxOutputFilePath))
- {
- model.WriteTo(outputFile);
- }
-
- var jsonOnnxOutputFilePath = outputFilePathPrefix + ".json";
- using (var output = new StreamWriter(jsonOnnxOutputFilePath))
- {
- var fmt = new JsonFormatter(JsonFormatter.Settings.Default);
- fmt.Format(model, output);
- }
-}
+model.WriteToFile(@"mnist-8-clean-dynamic-batch-size.onnx");
\ No newline at end of file
diff --git a/src/dotnet-onnx/Commands/CleanCommand.cs b/src/dotnet-onnx/Commands/CleanCommand.cs
new file mode 100644
index 0000000..5294c9c
--- /dev/null
+++ b/src/dotnet-onnx/Commands/CleanCommand.cs
@@ -0,0 +1,14 @@
+using McMaster.Extensions.CommandLineUtils;
+using Onnx;
+
+[Command("clean", Description = "Clean model for inference e.g. remove initializers from inputs")]
+public class CleanCommand : InputOutputCommand
+{
+ public CleanCommand(IConsole console) : base(console)
+ { }
+
+ protected override void Run(ModelProto model)
+ {
+ model.Graph.Clean();
+ }
+}
\ No newline at end of file
diff --git a/src/dotnet-onnx/Commands/Command.cs b/src/dotnet-onnx/Commands/Command.cs
new file mode 100644
index 0000000..7e47230
--- /dev/null
+++ b/src/dotnet-onnx/Commands/Command.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Threading.Tasks;
+
+public abstract class Command
+{
+ public async Task OnExecuteAsync()
+ {
+ try
+ {
+ await Run();
+ }
+ //catch (CliException e)
+ catch (Exception e)
+ {
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.Error.WriteLine(e.Message);
+ Console.ResetColor();
+ //Environment.Exit(e.ExitCode);
+ }
+ }
+
+ public abstract Task Run();
+}
diff --git a/src/dotnet-onnx/Commands/InfoCommand.cs b/src/dotnet-onnx/Commands/InfoCommand.cs
new file mode 100644
index 0000000..f3c497e
--- /dev/null
+++ b/src/dotnet-onnx/Commands/InfoCommand.cs
@@ -0,0 +1,21 @@
+using McMaster.Extensions.CommandLineUtils;
+using Onnx;
+
+[Command("info", Description = "Print information about a model e.g. inputs and outputs")]
+public class InfoCommand : InputCommand
+{
+ public InfoCommand(IConsole console) : base(console)
+ {
+ LogInput = null;
+ }
+
+ protected override void Run(ModelProto model)
+ {
+ var writer = _console.Out;
+
+ writer.WriteLine($"# {Input}");
+ writer.WriteLine();
+
+ model.Graph.Info(writer);
+ }
+}
\ No newline at end of file
diff --git a/src/dotnet-onnx/Commands/InputCommand.cs b/src/dotnet-onnx/Commands/InputCommand.cs
new file mode 100644
index 0000000..ccae195
--- /dev/null
+++ b/src/dotnet-onnx/Commands/InputCommand.cs
@@ -0,0 +1,34 @@
+using System;
+using System.ComponentModel.DataAnnotations;
+using System.Threading.Tasks;
+using McMaster.Extensions.CommandLineUtils;
+using Onnx;
+
+public abstract class InputCommand : Command
+{
+ protected readonly IConsole _console;
+ protected Action LogInput;
+
+ public InputCommand(IConsole console)
+ {
+ _console = console;
+ LogInput = t => _console.WriteLine(t);
+ }
+
+ [Argument(0, "input", Description = "Input file path")]
+ [Required]
+ public string Input { get; }
+
+ public override Task Run()
+ {
+ var model = ModelProto.Parser.ParseFromFile(Input);
+
+ LogInput?.Invoke($"Parsed input file '{Input}' of size {model.CalculateSize()}");
+
+ Run(model);
+
+ return Task.CompletedTask;
+ }
+
+ protected abstract void Run(ModelProto model);
+}
diff --git a/src/dotnet-onnx/Commands/InputOutputCommand.cs b/src/dotnet-onnx/Commands/InputOutputCommand.cs
new file mode 100644
index 0000000..b226de8
--- /dev/null
+++ b/src/dotnet-onnx/Commands/InputOutputCommand.cs
@@ -0,0 +1,39 @@
+using System.ComponentModel.DataAnnotations;
+using System.Threading.Tasks;
+using McMaster.Extensions.CommandLineUtils;
+using Onnx;
+
+public abstract class InputOutputCommand : Command
+{
+ protected readonly IConsole _console;
+
+ public InputOutputCommand(IConsole console)
+ {
+ _console = console;
+ }
+
+ [Argument(0, "input", Description = "Input file path")]
+ [Required]
+ public string Input { get; }
+
+ [Argument(1, "output", Description = "Output file path")]
+ [Required]
+ public string Output { get; }
+
+ public override Task Run()
+ {
+ var model = ModelProto.Parser.ParseFromFile(Input);
+
+ _console.WriteLine($"Parsed input file '{Input}' of size {model.CalculateSize()}");
+
+ Run(model);
+
+ model.WriteToFile(Output);
+
+ _console.WriteLine($"Wrote output file '{Output}' of size {model.CalculateSize()}");
+
+ return Task.CompletedTask;
+ }
+
+ protected abstract void Run(ModelProto model);
+}
diff --git a/src/dotnet-onnx/Commands/SetDimCommand.cs b/src/dotnet-onnx/Commands/SetDimCommand.cs
new file mode 100644
index 0000000..de4f044
--- /dev/null
+++ b/src/dotnet-onnx/Commands/SetDimCommand.cs
@@ -0,0 +1,28 @@
+using McMaster.Extensions.CommandLineUtils;
+using Onnx;
+
+[Command("setdim", Description = "Set dimension of reshapes, inputs and outputs of model e.g. set new dynamic or fixed batch size.")]
+public class SetDimCommand : InputOutputCommand
+{
+ public SetDimCommand(IConsole console) : base(console)
+ { }
+
+ [Option("-i|--index", Description = "Dimension index to set. Default = 0.")]
+ public int Index { get; } = 0; // Parametize defaults
+
+ [Option("-d|--dim", Description = "Dimension to set. Default = N. Use string e.g. 'N' for dynamic batch size or integer e.g. '3' for fixed size")]
+ public string Dim { get; } = "N";
+
+ protected override void Run(ModelProto model)
+ {
+ // Should this not be before loading input? Is the abstract base really that good?
+
+ var dimParamOrValue = int.TryParse(Dim, out var dimValue)
+ ? DimParamOrValue.New(dimValue)
+ : DimParamOrValue.New(Dim);
+
+ _console.WriteLine($"Setting dimension at {Index} to '{dimParamOrValue}'");
+
+ model.Graph.SetDim(Index, dimParamOrValue);
+ }
+}
\ No newline at end of file
diff --git a/src/dotnet-onnx/Program.cs b/src/dotnet-onnx/Program.cs
index 9e31089..039c696 100644
--- a/src/dotnet-onnx/Program.cs
+++ b/src/dotnet-onnx/Program.cs
@@ -1,10 +1,6 @@
-using System;
-using System.ComponentModel.DataAnnotations;
-using System.IO;
+using System.IO;
using System.Threading.Tasks;
-using Google.Protobuf;
using McMaster.Extensions.CommandLineUtils;
-using Onnx;
// https://github.com/natemcmaster/CommandLineUtils
// https://natemcmaster.github.io/CommandLineUtils/docs/advanced/dependency-injection.html
@@ -13,11 +9,12 @@
// TODO: Handle multiple command names etc.
// https://github.com/jonstodle/DotNetSdkHelpers/blob/master/src/DotNetSdkHelpers/Program.cs
-[Command("dotnet-onnx", Description = "Inspect and manipulate ONNX files"),
- Subcommand(typeof(Clean)),
- //Subcommand(typeof(List)),
- //Subcommand(typeof(Download))
- ]
+// TODO: Switch from attributes to code instead
+[Command("dotnet onnx", Description = "Inspect and manipulate ONNX files. Copyright nietras 2021."),
+ Subcommand(typeof(CleanCommand)),
+ Subcommand(typeof(SetDimCommand)),
+ Subcommand(typeof(InfoCommand))
+]
class Program
{
static Task Main(string[] args)
@@ -31,54 +28,11 @@ static Task Main(string[] args)
return app.ExecuteAsync(args);
}
-}
-public abstract class Command
-{
- public async Task OnExecuteAsync()
+ public Task OnExecuteAsync(CommandLineApplication app)
{
- try
- {
- await Run();
- }
- //catch (CliException e)
- catch (Exception e)
- {
- Console.ForegroundColor = ConsoleColor.Red;
- Console.Error.WriteLine(e.Message);
- Console.ResetColor();
- //Environment.Exit(e.ExitCode);
- }
- }
-
- public abstract Task Run();
-}
-
-[Command("clean", Description = "Clean graph for inference e.g. remove initializers from inputs")]
-public class Clean : Command
-{
- [Argument(0, "input", Description = "Input file path")]
- [Required]
- public string Input { get; }
+ app.ShowHelp();
- [Argument(1, "output", Description = "Output file path")]
- [Required]
- public string Output { get; }
-
- public override Task Run()
- {
- using (var inputFile = File.OpenRead(Input))
- {
- var model = ModelProto.Parser.ParseFrom(inputFile);
-
- model.RemoveInitializersFromInputs();
-
- using (var outputFile = File.Create(Output))
- {
- model.WriteTo(outputFile);
- }
- }
-
- return Task.CompletedTask;
+ return Task.FromResult(0);
}
-}
\ No newline at end of file
+}
diff --git a/update-tool-from-build.ps1 b/update-tool-from-build.ps1
new file mode 100644
index 0000000..debde86
--- /dev/null
+++ b/update-tool-from-build.ps1
@@ -0,0 +1,2 @@
+#!/usr/local/bin/powershell
+dotnet tool update dotnet-onnx --add-source ./build/dotnet-onnx_AnyCPU_Release -g
\ No newline at end of file