Skip to content

Commit

Permalink
read config and safetensor
Browse files Browse the repository at this point in the history
  • Loading branch information
curioyang committed Jan 17, 2025
1 parent 4b6966b commit 31e7fde
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 10 deletions.
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<PackageVersion Include="Avalonia.Themes.Fluent" Version="11.0.2" />
<PackageVersion Include="Avalonia.Fonts.Inter" Version="11.0.2" />
<PackageVersion Include="Avalonia.ReactiveUI" Version="11.0.2" />
<PackageVersion Include="Clawfoot.Extensions.Newtonsoft" Version="0.1.0" />
<PackageVersion Include="MessageBox.Avalonia" Version="3.1.2" />
<PackageVersion Include="CommunityToolkit.Mvvm" Version="8.2.1" />
<PackageVersion Include="Google.OrTools" Version="9.4.1874" />
Expand Down
10 changes: 10 additions & 0 deletions src/Nncase.Cli/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@
"nncase.importer": {
"type": "Project",
"dependencies": {
"Clawfoot.Extensions.Newtonsoft": "[0.1.0, )",
"LanguageExt.Core": "[4.4.9, )",
"Nncase.Core": "[1.0.0, )",
"Onnx.Protobuf": "[1.0.0, )",
Expand Down Expand Up @@ -814,6 +815,15 @@
"Nncase.FlatBuffers": "[2.0.0, )"
}
},
"Clawfoot.Extensions.Newtonsoft": {
"type": "CentralTransitive",
"requested": "[0.1.0, )",
"resolved": "0.1.0",
"contentHash": "A8p8THcOiOoexdYUUHgEVeW2BgsFLRqm65+4WuE3Te0XyZdiq+3Alu0D8ktMFUU+0eFeXr6sYcNsDTD3OsVJ4w==",
"dependencies": {
"Newtonsoft.Json": "12.0.2"
}
},
"CommunityToolkit.HighPerformance": {
"type": "CentralTransitive",
"requested": "[8.2.2, )",
Expand Down
10 changes: 10 additions & 0 deletions src/Nncase.Compiler/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@
"nncase.importer": {
"type": "Project",
"dependencies": {
"Clawfoot.Extensions.Newtonsoft": "[0.1.0, )",
"LanguageExt.Core": "[4.4.9, )",
"Nncase.Core": "[1.0.0, )",
"Onnx.Protobuf": "[1.0.0, )",
Expand Down Expand Up @@ -787,6 +788,15 @@
"Nncase.FlatBuffers": "[2.0.0, )"
}
},
"Clawfoot.Extensions.Newtonsoft": {
"type": "CentralTransitive",
"requested": "[0.1.0, )",
"resolved": "0.1.0",
"contentHash": "A8p8THcOiOoexdYUUHgEVeW2BgsFLRqm65+4WuE3Te0XyZdiq+3Alu0D8ktMFUU+0eFeXr6sYcNsDTD3OsVJ4w==",
"dependencies": {
"Newtonsoft.Json": "12.0.2"
}
},
"CommunityToolkit.HighPerformance": {
"type": "CentralTransitive",
"requested": "[8.2.2, )",
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/IR/Shape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ private Shape(ShapeKind kind, IEnumerable<Dimension> dimensions)
public int Rank => _dimensions.Length;

/// <summary>
/// Gets get Total Elements.
/// Gets Total Elements.
/// </summary>
public int Size => Enumerable.Range(0, Rank).Aggregate(1, (size, i) => size * _dimensions[i].FixedValue);

Expand Down
114 changes: 114 additions & 0 deletions src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Google.Protobuf;
using Google.Protobuf.Collections;
using LanguageExt;
using NetFabric.Hyperlinq;
using Newtonsoft.Json.Linq;
using Nncase.IR;
using Nncase.IR.Tensors;
using Tuple = Nncase.IR.Tuple;
using TorchSharp;

Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)
using static TorchSharp.torch.nn;

Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

namespace Nncase.Importer;

public sealed partial class HuggingFaceImporter : BaseImporter
{
private string _modelDir;
private Dictionary<string, object> _config;
private List<string> _modelArchitectures;

private Dictionary<string, Nncase.Tensor>? _constTensors;

private Dictionary<string, Var> _dynVarMap = new();
private Dictionary<string, int> _fixVarMap = new();

public HuggingFaceImporter(string huggingFaceDir, CompileSession compileSession)
: base(compileSession)
{
_modelDir = huggingFaceDir;

// 读取 config.json 文件
getConfigInfo(Path.Combine(huggingFaceDir, "config.json"));
getAllWeights(Path.Combine(huggingFaceDir, "model.safetensors"));


if (String.Equals(_config["architectures"], "Qwen2ForCausalLM"))
{
_modelArchitectures = new List<string>() {"Qwen2Model", "Linear"};
//{ "Embedding", "Qwen2DecoderLayer", "Qwen2MLP", "Qwen2RMSNorm", "Qwen2RotaryEmbedding" };
}
}

protected override (IEnumerable<Var> Inputs, Dictionary<Var, Expr[]> VarMap) CreateInputs()
{
throw new NotImplementedException();
}

protected override void ConvertOp()
{
foreach (var architecture in _modelArchitectures)
{
Visit(architecture);
}
}

protected override Expr CreateOutputs()
{
throw new NotImplementedException();
}

private void Visit(string op)
{
switch (op)
{
case "Qwen2Model":
VisitQwen2Model(_config, _constTensors);
break;
}
}

private void getConfigInfo(string path)
{
if (File.Exists(path))
{
var configJson = File.ReadAllText(path);
_config = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, object>>(configJson);
foreach (var key in _config.Keys.ToList())
{
if (_config[key] is JArray jArray)
{
_config[key] = string.Join(", ", jArray.Select(token => token.ToString()));
}
}
}
else
{
throw new FileNotFoundException($"{_config?["architectures"]}'s config.json not found in the specified directory.", path);
}
}

private void getAllWeights(string path)
{
var constTensor = HuggingFaceUtils.LoadStateDict(path);
foreach (var item in constTensor)
{
Console.WriteLine($"{item.Key}");
if (item.Value is Tensor tensor)
{
_constTensors ??= new();
_constTensors.Add(item.Key, tensor.CastTo(DataTypes.Float32));
}
}
}


}
147 changes: 147 additions & 0 deletions src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using NetFabric.Hyperlinq;
using Nncase;
using Nncase.IR;
using TorchSharp;

Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

internal class SafetensorsEntry
{
[JsonPropertyName("dtype")]
public string DataType { get; init; }

[JsonPropertyName("shape")]
public long[] Shape { get; init; }

[JsonPropertyName("data_offsets")]
public long[] Offsets { get; init; }
}

static class HuggingFaceUtils
{


public static byte[] ReadBytes(this Stream stream, int count)
{
byte[] buffer = new byte[count];
stream.Read(buffer, 0, count);
return buffer;
}

internal static Dictionary<string, SafetensorsEntry> LoadIndex(Stream stream)
{
ulong uint64 = BitConverter.ToUInt64((ReadOnlySpan<byte>) stream.ReadBytes(8));
if (uint64 > (ulong) int.MaxValue)
throw new ArgumentOutOfRangeException("length", "Length of JSON exceeded int.MaxValue, not supported yet");
return JsonSerializer.Deserialize<Dictionary<string, SafetensorsEntry>>(Encoding.UTF8.GetString(stream.ReadBytes((int) uint64))) ?? throw new NotImplementedException("Loaded header string failed to deserialize into the correct format.");
}

public static Dictionary<string, Tensor> LoadStateDict(
string path,
List<string>? keysToKeep = null)
{
using (FileStream fileStream = File.OpenRead(path))
return LoadStateDict((Stream) fileStream, keysToKeep: keysToKeep);
}

public static Dictionary<string, Tensor> LoadStateDict(
Stream stream,
bool leaveOpen = false,
List<string>? keysToKeep = null)
{
Dictionary<string, SafetensorsEntry> dictionary1 = HuggingFaceUtils.LoadIndex(stream);
long position = stream.Position;
Dictionary<string, Tensor> dictionary2 = new Dictionary<string, Tensor>();
foreach (KeyValuePair<string, SafetensorsEntry> keyValuePair in dictionary1)
{
if (!(keyValuePair.Key == "__metadata__") && (keysToKeep == null || keysToKeep.Contains(keyValuePair.Key)))
{
var datatype = ConvertToDataDType(keyValuePair.Value.DataType);
// var tensor = new Tensor(datatype, new Shape(keyValuePair.Value.Shape));
var shape = new Shape(keyValuePair.Value.Shape);
if (keyValuePair.Value.Offsets[1] - keyValuePair.Value.Offsets[0] != datatype.SizeInBytes * shape.Size)
throw new NotImplementedException("Error when loading tensor " + keyValuePair.Key + " - mismatched # of elements");
stream.Position = position + keyValuePair.Value.Offsets[0];
var tensor = Tensor.FromStream(datatype, stream, shape);
dictionary2.Add(keyValuePair.Key, tensor);
}
}
if (!leaveOpen)
stream.Close();
return dictionary2;
}

private static DataType ConvertToDataDType(string dataType)
{
if (dataType != null)
{
switch (dataType.Length)
{
case 2:
switch (dataType[0])
{
case 'I':
if (dataType == "I8")
return DataTypes.Int8;
break;
case 'U':
if (dataType == "U8")
return DataTypes.UInt8;
break;
}
break;
case 3:
switch (dataType[1])
{
case '1':
switch (dataType)
{
case "F16":
return DataTypes.Float16;
case "I16":
return DataTypes.Int16;
}
break;
case '3':
switch (dataType)
{
case "F32":
return DataTypes.Float32;
case "I32":
return DataTypes.Int32;
}
break;
case '6':
switch (dataType)
{
case "F64":
return DataTypes.Float64;
case "I64":
return DataTypes.Int64;
}
break;
}
break;
case 4:
switch (dataType[1])
{
case 'F':
if (dataType == "BF16")
return DataTypes.BFloat16;
break;
case 'O':
if (dataType == "BOOL")
return DataTypes.Boolean;
break;
}
break;
}
}
throw new NotImplementedException("Unrecognized data type listed: " + dataType);
}
}
29 changes: 29 additions & 0 deletions src/Nncase.Importer/HuggingFace/Qwen2Model.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using Nncase.IR;
using Nncase.IR.Math;
using TorchSharp;

Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The type or namespace name 'TorchSharp' could not be found (are you missing a using directive or an assembly reference?)
using F = Nncase.IR.F;

namespace Nncase.Importer
{
public partial class HuggingFaceImporter
{
private Expr VisitQwen2Model(Dictionary<string, object> modelConfig, Dictionary<string, Tensor> constTensors)
{
// var (lhs, rhs) = GetInputExprs(op, 0, 1);
// if (binaryOp == BinaryOp.Pow && lhs.CheckedDataType != rhs.CheckedDataType)
// {
// return F.Math.Binary(binaryOp, lhs, IR.F.Tensors.Cast(rhs, lhs.CheckedDataType)).With(metadata: new IRMetadata() { OutputNames = op.Output });
// }

// return F.Math.Binary(binaryOp, lhs, rhs).With(metadata: new IRMetadata() { OutputNames = op.Output });

var input_ids = new Var();
var position_ids = new Var();

}
}
}
13 changes: 13 additions & 0 deletions src/Nncase.Importer/Importers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,17 @@ public static IRModule ImportNcnn(Stream ncnnParam, Stream ncnnBin, CompileSessi
var importer = new NcnnImporter(ncnnParam, ncnnBin, compileSession);
return importer.Import();
}

/// <summary>
/// Import huggingface model.
/// </summary>
/// <param name="hfModelDir">Huggingface model directory.</param>
/// <param name="compileSession">compile session.</param>
/// <returns>Imported IR module.</returns>
public static IRModule ImportHuggingFace(string hfModelDir, CompileSession compileSession)
{
compileSession.CompileOptions.ModelLayout = "NCHW";
var importer = new HuggingFaceImporter(hfModelDir, compileSession);
return importer.Import();
}
}
1 change: 1 addition & 0 deletions src/Nncase.Importer/Nncase.Importer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Clawfoot.Extensions.Newtonsoft" />
<PackageReference Include="LanguageExt.Core" PrivateAssets="compile" />
</ItemGroup>

Expand Down
Loading

0 comments on commit 31e7fde

Please sign in to comment.