-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// Copyright (c) Canaan Inc. All rights reserved. | ||
// Licensed under the Apache license. See LICENSE file in the project root for full license information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Collections.Immutable; | ||
using System.Diagnostics; | ||
using System.IO; | ||
using System.Linq; | ||
using Google.Protobuf; | ||
using Google.Protobuf.Collections; | ||
using LanguageExt; | ||
using NetFabric.Hyperlinq; | ||
using Newtonsoft.Json.Linq; | ||
using Nncase.IR; | ||
using Nncase.IR.Tensors; | ||
using Tuple = Nncase.IR.Tuple; | ||
using TorchSharp; | ||
Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs GitHub Actions / build-x86_64-linux
Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs GitHub Actions / build-aarch64-macos
Check failure on line 18 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs GitHub Actions / build-aarch64-macos
|
||
using static TorchSharp.torch.nn; | ||
Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs GitHub Actions / build-x86_64-linux
Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs GitHub Actions / build-aarch64-macos
Check failure on line 19 in src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs GitHub Actions / build-aarch64-macos
|
||
|
||
namespace Nncase.Importer; | ||
|
||
public sealed partial class HuggingFaceImporter : BaseImporter | ||
{ | ||
private string _modelDir; | ||
private Dictionary<string, object> _config; | ||
private List<string> _modelArchitectures; | ||
|
||
private Dictionary<string, Nncase.Tensor>? _constTensors; | ||
|
||
private Dictionary<string, Var> _dynVarMap = new(); | ||
private Dictionary<string, int> _fixVarMap = new(); | ||
|
||
public HuggingFaceImporter(string huggingFaceDir, CompileSession compileSession) | ||
: base(compileSession) | ||
{ | ||
_modelDir = huggingFaceDir; | ||
|
||
// 读取 config.json 文件 | ||
getConfigInfo(Path.Combine(huggingFaceDir, "config.json")); | ||
getAllWeights(Path.Combine(huggingFaceDir, "model.safetensors")); | ||
|
||
|
||
if (String.Equals(_config["architectures"], "Qwen2ForCausalLM")) | ||
{ | ||
_modelArchitectures = new List<string>() {"Qwen2Model", "Linear"}; | ||
//{ "Embedding", "Qwen2DecoderLayer", "Qwen2MLP", "Qwen2RMSNorm", "Qwen2RotaryEmbedding" }; | ||
} | ||
} | ||
|
||
protected override (IEnumerable<Var> Inputs, Dictionary<Var, Expr[]> VarMap) CreateInputs() | ||
{ | ||
throw new NotImplementedException(); | ||
} | ||
|
||
protected override void ConvertOp() | ||
{ | ||
foreach (var architecture in _modelArchitectures) | ||
{ | ||
Visit(architecture); | ||
} | ||
} | ||
|
||
protected override Expr CreateOutputs() | ||
{ | ||
throw new NotImplementedException(); | ||
} | ||
|
||
private void Visit(string op) | ||
{ | ||
switch (op) | ||
{ | ||
case "Qwen2Model": | ||
VisitQwen2Model(_config, _constTensors); | ||
break; | ||
} | ||
} | ||
|
||
private void getConfigInfo(string path) | ||
{ | ||
if (File.Exists(path)) | ||
{ | ||
var configJson = File.ReadAllText(path); | ||
_config = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, object>>(configJson); | ||
foreach (var key in _config.Keys.ToList()) | ||
{ | ||
if (_config[key] is JArray jArray) | ||
{ | ||
_config[key] = string.Join(", ", jArray.Select(token => token.ToString())); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
throw new FileNotFoundException($"{_config?["architectures"]}'s config.json not found in the specified directory.", path); | ||
} | ||
} | ||
|
||
private void getAllWeights(string path) | ||
{ | ||
var constTensor = HuggingFaceUtils.LoadStateDict(path); | ||
foreach (var item in constTensor) | ||
{ | ||
Console.WriteLine($"{item.Key}"); | ||
if (item.Value is Tensor tensor) | ||
{ | ||
_constTensors ??= new(); | ||
_constTensors.Add(item.Key, tensor.CastTo(DataTypes.Float32)); | ||
} | ||
} | ||
} | ||
|
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Text; | ||
using System.Text.Json; | ||
using System.Text.Json.Serialization; | ||
using System.Threading; | ||
using NetFabric.Hyperlinq; | ||
using Nncase; | ||
using Nncase.IR; | ||
using TorchSharp; | ||
Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs GitHub Actions / build-x86_64-linux
Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs GitHub Actions / build-aarch64-macos
Check failure on line 11 in src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs GitHub Actions / build-aarch64-macos
|
||
|
||
internal class SafetensorsEntry | ||
{ | ||
[JsonPropertyName("dtype")] | ||
public string DataType { get; init; } | ||
|
||
[JsonPropertyName("shape")] | ||
public long[] Shape { get; init; } | ||
|
||
[JsonPropertyName("data_offsets")] | ||
public long[] Offsets { get; init; } | ||
} | ||
|
||
static class HuggingFaceUtils | ||
{ | ||
|
||
|
||
public static byte[] ReadBytes(this Stream stream, int count) | ||
{ | ||
byte[] buffer = new byte[count]; | ||
stream.Read(buffer, 0, count); | ||
return buffer; | ||
} | ||
|
||
internal static Dictionary<string, SafetensorsEntry> LoadIndex(Stream stream) | ||
{ | ||
ulong uint64 = BitConverter.ToUInt64((ReadOnlySpan<byte>) stream.ReadBytes(8)); | ||
if (uint64 > (ulong) int.MaxValue) | ||
throw new ArgumentOutOfRangeException("length", "Length of JSON exceeded int.MaxValue, not supported yet"); | ||
return JsonSerializer.Deserialize<Dictionary<string, SafetensorsEntry>>(Encoding.UTF8.GetString(stream.ReadBytes((int) uint64))) ?? throw new NotImplementedException("Loaded header string failed to deserialize into the correct format."); | ||
} | ||
|
||
public static Dictionary<string, Tensor> LoadStateDict( | ||
string path, | ||
List<string>? keysToKeep = null) | ||
{ | ||
using (FileStream fileStream = File.OpenRead(path)) | ||
return LoadStateDict((Stream) fileStream, keysToKeep: keysToKeep); | ||
} | ||
|
||
public static Dictionary<string, Tensor> LoadStateDict( | ||
Stream stream, | ||
bool leaveOpen = false, | ||
List<string>? keysToKeep = null) | ||
{ | ||
Dictionary<string, SafetensorsEntry> dictionary1 = HuggingFaceUtils.LoadIndex(stream); | ||
long position = stream.Position; | ||
Dictionary<string, Tensor> dictionary2 = new Dictionary<string, Tensor>(); | ||
foreach (KeyValuePair<string, SafetensorsEntry> keyValuePair in dictionary1) | ||
{ | ||
if (!(keyValuePair.Key == "__metadata__") && (keysToKeep == null || keysToKeep.Contains(keyValuePair.Key))) | ||
{ | ||
var datatype = ConvertToDataDType(keyValuePair.Value.DataType); | ||
// var tensor = new Tensor(datatype, new Shape(keyValuePair.Value.Shape)); | ||
var shape = new Shape(keyValuePair.Value.Shape); | ||
if (keyValuePair.Value.Offsets[1] - keyValuePair.Value.Offsets[0] != datatype.SizeInBytes * shape.Size) | ||
throw new NotImplementedException("Error when loading tensor " + keyValuePair.Key + " - mismatched # of elements"); | ||
stream.Position = position + keyValuePair.Value.Offsets[0]; | ||
var tensor = Tensor.FromStream(datatype, stream, shape); | ||
dictionary2.Add(keyValuePair.Key, tensor); | ||
} | ||
} | ||
if (!leaveOpen) | ||
stream.Close(); | ||
return dictionary2; | ||
} | ||
|
||
private static DataType ConvertToDataDType(string dataType) | ||
{ | ||
if (dataType != null) | ||
{ | ||
switch (dataType.Length) | ||
{ | ||
case 2: | ||
switch (dataType[0]) | ||
{ | ||
case 'I': | ||
if (dataType == "I8") | ||
return DataTypes.Int8; | ||
break; | ||
case 'U': | ||
if (dataType == "U8") | ||
return DataTypes.UInt8; | ||
break; | ||
} | ||
break; | ||
case 3: | ||
switch (dataType[1]) | ||
{ | ||
case '1': | ||
switch (dataType) | ||
{ | ||
case "F16": | ||
return DataTypes.Float16; | ||
case "I16": | ||
return DataTypes.Int16; | ||
} | ||
break; | ||
case '3': | ||
switch (dataType) | ||
{ | ||
case "F32": | ||
return DataTypes.Float32; | ||
case "I32": | ||
return DataTypes.Int32; | ||
} | ||
break; | ||
case '6': | ||
switch (dataType) | ||
{ | ||
case "F64": | ||
return DataTypes.Float64; | ||
case "I64": | ||
return DataTypes.Int64; | ||
} | ||
break; | ||
} | ||
break; | ||
case 4: | ||
switch (dataType[1]) | ||
{ | ||
case 'F': | ||
if (dataType == "BF16") | ||
return DataTypes.BFloat16; | ||
break; | ||
case 'O': | ||
if (dataType == "BOOL") | ||
return DataTypes.Boolean; | ||
break; | ||
} | ||
break; | ||
} | ||
} | ||
throw new NotImplementedException("Unrecognized data type listed: " + dataType); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (c) Canaan Inc. All rights reserved. | ||
// Licensed under the Apache license. See LICENSE file in the project root for full license information. | ||
|
||
using System.Collections.Generic; | ||
using Nncase.IR; | ||
using Nncase.IR.Math; | ||
using TorchSharp; | ||
Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs GitHub Actions / build-x86_64-linux
Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs GitHub Actions / build-aarch64-macos
Check failure on line 7 in src/Nncase.Importer/HuggingFace/Qwen2Model.cs GitHub Actions / build-aarch64-macos
|
||
using F = Nncase.IR.F; | ||
|
||
namespace Nncase.Importer | ||
{ | ||
public partial class HuggingFaceImporter | ||
{ | ||
private Expr VisitQwen2Model(Dictionary<string, object> modelConfig, Dictionary<string, Tensor> constTensors) | ||
{ | ||
// var (lhs, rhs) = GetInputExprs(op, 0, 1); | ||
// if (binaryOp == BinaryOp.Pow && lhs.CheckedDataType != rhs.CheckedDataType) | ||
// { | ||
// return F.Math.Binary(binaryOp, lhs, IR.F.Tensors.Cast(rhs, lhs.CheckedDataType)).With(metadata: new IRMetadata() { OutputNames = op.Output }); | ||
// } | ||
|
||
// return F.Math.Binary(binaryOp, lhs, rhs).With(metadata: new IRMetadata() { OutputNames = op.Output }); | ||
|
||
var input_ids = new Var(); | ||
var position_ids = new Var(); | ||
|
||
} | ||
} | ||
} |