From 570fe991c87b98574f537ab8063ba8e200aa8c7c Mon Sep 17 00:00:00 2001 From: Bill Menees Date: Sat, 8 Oct 2022 13:54:40 -0500 Subject: [PATCH] Move TryGetType to NodeSettings Add NodeSettings.RequireGetType, which handles mixed runtime type resolution (e.g., Framework to Core or different runtime versions). Embedded debug symbols to get better stack traces in production. --- src/Directory.Build.props | 3 +- src/Menees.Remoting/IServer.cs | 5 - src/Menees.Remoting/Node.cs | 50 +-------- src/Menees.Remoting/NodeSettings.cs | 104 ++++++++++++++++++ .../NodeSettingsTests.cs | 54 +++++++++ tests/Menees.Remoting.Tests/RmiServerTests.cs | 17 ++- 6 files changed, 173 insertions(+), 60 deletions(-) create mode 100644 tests/Menees.Remoting.Tests/NodeSettingsTests.cs diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 008300f..1d377b2 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -26,10 +26,11 @@ enable true false + embedded - 0.7.0-beta + 0.8.0-beta diff --git a/src/Menees.Remoting/IServer.cs b/src/Menees.Remoting/IServer.cs index ba8d733..3e05ad3 100644 --- a/src/Menees.Remoting/IServer.cs +++ b/src/Menees.Remoting/IServer.cs @@ -22,11 +22,6 @@ public interface IServer : IDisposable /// Action? ReportUnhandledException { get; set; } - /// - /// See . - /// - Func TryGetType { get; set; } - #endregion #region Public Methods diff --git a/src/Menees.Remoting/Node.cs b/src/Menees.Remoting/Node.cs index 224cf8b..229838f 100644 --- a/src/Menees.Remoting/Node.cs +++ b/src/Menees.Remoting/Node.cs @@ -19,7 +19,7 @@ public abstract class Node : IDisposable private readonly string serverPath; private bool disposed; - private Func tryGetType = RequireGetType; + private Func tryGetType; private ISerializer? systemSerializer; private ISerializer? userSerializer; private Func? createLogger; @@ -35,52 +35,13 @@ public abstract class Node : IDisposable protected Node(NodeSettings settings) { this.serverPath = settings.ServerPath; + this.tryGetType = settings?.TryGetType ?? NodeSettings.RequireGetType; this.userSerializer = settings?.Serializer; this.createLogger = settings?.CreateLogger; } #endregion - #region Public Properties - - /// - /// Allows customization of how an assembly-qualified type name (serialized from - /// ) should be deserialized into a .NET - /// . - /// - /// - /// A secure system needs to support a known list of legal/safe/valid types that it - /// can load dynamically. It shouldn't just trust and load an arbitrary assembly and - /// then load an arbitrary type out of it. Doing that can execute malicious code - /// in the current process (e.g., via the Type's static constructor or the assembly's - /// module initializer). So a security best practice is to validate every assembly- - /// qualified type name before you load the type. - /// - /// However, this is a case where security is at odds with convenience. The default for - /// this property just calls to try to load the type, - /// and it throws an exception if the type can't be loaded. - /// - /// https://github.com/dotnet/runtime/issues/31567#issuecomment-558335944 - /// https://stackoverflow.com/a/66963611/1882616 - /// https://github.com/dotnet/runtime/issues/43482#issue-722814247 (related Exception comment) - /// - public Func TryGetType - { - get => this.tryGetType; - set - { - if (this.tryGetType != value) - { - this.tryGetType = value ?? RequireGetType; - - // On the next serialization, we need to create a new serializer instance using the new tryGetType lambda. - this.systemSerializer = null; - } - } - } - - #endregion - #region Internal Properties internal ISerializer SystemSerializer @@ -154,13 +115,6 @@ protected virtual void Dispose(bool disposing) #endregion - #region Private Methods - - private static Type? RequireGetType(string qualifiedTypeName) - => Type.GetType(qualifiedTypeName, throwOnError: true); - - #endregion - #region Private Types private sealed class ScopedLogger : ILogger diff --git a/src/Menees.Remoting/NodeSettings.cs b/src/Menees.Remoting/NodeSettings.cs index 3dd91fb..030788e 100644 --- a/src/Menees.Remoting/NodeSettings.cs +++ b/src/Menees.Remoting/NodeSettings.cs @@ -4,6 +4,8 @@ using Menees.Remoting.Security; using Microsoft.Extensions.Logging; +using System.Reflection; +using System.Runtime.InteropServices; #endregion @@ -12,6 +14,16 @@ /// public abstract class NodeSettings { + #region Private Data Members + + // I'm keeping this private for now (even though BaseTests duplicates it) because I may want to + // support other CLR scopes later (e.g., Framework, Core, Mono, Wasm, SQL CLR, Native). + private static readonly bool IsDotNetFramework = RuntimeInformation.FrameworkDescription.Contains("Framework"); + + private Func tryGetType = RequireGetType; + + #endregion + #region Constructors /// @@ -60,6 +72,98 @@ protected NodeSettings(string serverPath) /// public NodeSecurity? Security => this.GetSecurity(); + /// + /// Allows customization of how an assembly-qualified type name (serialized from + /// ) should be deserialized into a .NET + /// . + /// + /// + /// This is useful for type translation and security. It's for translation if you're supporting + /// calls between different runtimes (e.g., Framework and "Core") or versions + /// (e.g., .NET 6.0 and 7.0). When mixing runtimes, many types will be in different + /// assemblies (e.g., int, string, Uri, IPAddress, Stack<T>), so this handler needs + /// to deal with that for all your supported types. Even mixing versions of the same + /// runtime is complicated because strongly-named assemblies embed their version + /// in their AssemblyQualifiedName. + /// + /// A secure system needs to support a known list of legal/safe/valid types that it + /// can load dynamically. It shouldn't just trust and load an arbitrary assembly and + /// then load an arbitrary type out of it. Doing that can execute malicious code + /// in the current process (e.g., via the Type's static constructor or the assembly's + /// module initializer). So, a security best practice is to validate every assembly- + /// qualified type name before you load the type. + /// + /// However, this is a case where security is at odds with convenience. The default for + /// this property just calls to try to load the type, + /// and it throws an exception if the type can't be loaded. + /// + /// https://github.com/dotnet/runtime/issues/31567#issuecomment-558335944 + /// https://stackoverflow.com/a/66963611/1882616 + /// https://github.com/dotnet/runtime/issues/43482#issue-722814247 (related Exception comment) + /// + public Func TryGetType + { + get => this.tryGetType; + set => this.tryGetType = value ?? RequireGetType; + } + + #endregion + + #region Internal Methods + + /// + /// Loads a type given an assembly-qualified type name. + /// + /// An assembly-qualified type name. + /// The .NET Type associated with . + public static Type RequireGetType(string typeName) + { + Type? result = Type.GetType(typeName, throwOnError: false, ignoreCase: true); + if (result == null) + { + // Fallback to the Type.GetType overload where we can pass a custom assembly resolver. + // This is important since a single type name can contain multiple assembly references + // (e.g., Dictionary includes asm refs for Dictionary<>, string, and Uri). + Assembly[]? assemblies = null; + result = Type.GetType( + typeName, + assemblyName => + { + Assembly? assembly = null; + string simpleName = assemblyName.Name ?? string.Empty; + + // Try to translate the simple built-in scalar types correctly across different runtimes. + if ((IsDotNetFramework && simpleName.Equals("System.Private.CoreLib", StringComparison.OrdinalIgnoreCase)) + || (!IsDotNetFramework && simpleName.Equals("MsCorLib", StringComparison.OrdinalIgnoreCase))) + { + assembly = typeof(string).Assembly; + } + else + { + // See if any assembly is already loaded with the same simple name. + // This ignores versions and strong naming, so it's convenient but insecure. + // We'll allow a lower version to match in case a .NET 7.0 client needs to + // call into a .NET 6.0 server. + // https://github.com/dotnet/fsharp/issues/3408#issuecomment-319519926 + assemblies ??= AppDomain.CurrentDomain.GetAssemblies(); + assembly = assemblies.FirstOrDefault(asm => asm.GetName().Name?.Equals(simpleName, StringComparison.OrdinalIgnoreCase) ?? false); + } + + return assembly; + }, + typeResolver: null, + throwOnError: false, + ignoreCase: true); + } + + if (result == null) + { + throw new TypeLoadException($"Unable to load type \"{typeName}\"."); + } + + return result; + } + #endregion #region Private Protected Methods diff --git a/tests/Menees.Remoting.Tests/NodeSettingsTests.cs b/tests/Menees.Remoting.Tests/NodeSettingsTests.cs new file mode 100644 index 0000000..271871a --- /dev/null +++ b/tests/Menees.Remoting.Tests/NodeSettingsTests.cs @@ -0,0 +1,54 @@ +namespace Menees.Remoting; + +[TestClass] +public class NodeSettingsTests +{ + [TestMethod] + public void RequireGetType() + { + const string CoreStringTypeName = "System.String, System.Private.CoreLib, Version=6.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e"; + const string FrameworkStringTypeName = "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089"; + TestVersions(typeof(string), CoreStringTypeName, FrameworkStringTypeName); + + const string CoreDictionaryTypeName = "System.Collections.Generic.IReadOnlyDictionary`2[" + + "[System.String, System.Private.CoreLib, Version=6.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]," + + "[System.Object, System.Private.CoreLib, Version=6.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e]" + + "], System.Private.CoreLib, Version=6.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e"; + const string FrameworkDictionaryTypeName = "System.Collections.Generic.IReadOnlyDictionary`2[" + + "[System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089]," + + "[System.Object, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089]" + + "], mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089"; + TestVersions(typeof(IReadOnlyDictionary), CoreDictionaryTypeName, FrameworkDictionaryTypeName); + + Should.Throw(() => NodeSettings.RequireGetType("This.Type.Does.Not.Exist")); + + static void TestVersions(Type expected, params string[] typeNames) + { + foreach (string typeName in typeNames) + { + // Make sure mixed runtimes work correctly for built-in types. + NodeSettings.RequireGetType(typeName).ShouldBe(expected); + + // Make sure older and newer versions will resolve to the current version. + NodeSettings.RequireGetType(AdjustVersion(typeName, 1)).ShouldBe(expected); + NodeSettings.RequireGetType(AdjustVersion(typeName, 999)).ShouldBe(expected); + } + } + + static string AdjustVersion(string typeName, uint majorVersion) + { + string result = typeName; + + const string Prefix = ", Version="; + int startIndex = 0; + while ((startIndex = result.IndexOf(Prefix, startIndex)) >= 0) + { + int endIndex = result.IndexOf('.', startIndex); + result = result.Substring(0, startIndex + Prefix.Length) + majorVersion + result.Substring(endIndex); + startIndex = endIndex; + } + + return result; + } + } +} \ No newline at end of file diff --git a/tests/Menees.Remoting.Tests/RmiServerTests.cs b/tests/Menees.Remoting.Tests/RmiServerTests.cs index 71fc4d7..90b8ca0 100644 --- a/tests/Menees.Remoting.Tests/RmiServerTests.cs +++ b/tests/Menees.Remoting.Tests/RmiServerTests.cs @@ -44,13 +44,18 @@ private interface IDiamond : ILevel1A, ILevel1B public void CloneString() { string serverPath = this.GenerateServerPath(); - string expected = Guid.NewGuid().ToString(); - using RmiServer server = new(expected, serverPath, loggerFactory: this.LoggerFactory); + ServerSettings settings = new(serverPath) + { + CreateLogger = this.LoggerFactory.CreateLogger, - // This is a super weak, insecure example since it just checks for the word "System". - server.TryGetType = typeName => typeName.Contains(nameof(System)) - ? Type.GetType(typeName, true) - : throw new ArgumentException("TryGetType disallowed " + typeName); + // This is a super weak, insecure example since it just checks for the word "System". + TryGetType = typeName => typeName.Contains(nameof(System)) + ? Type.GetType(typeName, true) + : throw new ArgumentException("TryGetType disallowed " + typeName), + }; + + string expected = Guid.NewGuid().ToString(); + using RmiServer server = new(expected, settings); server.ReportUnhandledException = WriteUnhandledServerException; server.Start();