Skip to content

Commit

Permalink
Overhaul memory management with ad-hoc Tensor(s) (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
nietras authored Jul 7, 2024
1 parent eb389b9 commit 25b1bb2
Show file tree
Hide file tree
Showing 11 changed files with 826 additions and 681 deletions.
132 changes: 94 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,46 +34,102 @@
Face. This means there is no need to run any Python here to get data or
similar. Clone and run ✅
* Output should then be something like:
```powershell
```text
ProcessorCount: 32
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
MaxTokenCount: 1024
VocabularySize: 50257
LayerCount: 12
HeadCount: 12
ChannelCount: 768
ParameterCount: 124439808
[State]
batch_size: 4
seq_len: 64
activationCount: 73323776
Logits TENSOR OK
dwte TENSOR OK
dwpe TENSOR OK
dln1w TENSOR OK
dln1b TENSOR OK
dqkvw TENSOR OK
dqkvb TENSOR OK
dattprojw TENSOR OK
dattprojb TENSOR OK
dln2w TENSOR OK
dln2b TENSOR OK
dfcw TENSOR OK
dfcb TENSOR OK
dfcprojw TENSOR OK
dfcprojb TENSOR OK
dlnfw TENSOR OK
dlnfb TENSOR OK
step 0: loss 5.269890 expected loss 5.270007 OK (took 4219 ms)
step 1: loss 4.059388 expected loss 4.059707 OK (took 4099 ms)
step 2: loss 3.374212 expected loss 3.375123 OK (took 4050 ms)
step 3: loss 2.800128 expected loss 2.800783 OK (took 4073 ms)
step 4: loss 2.315312 expected loss 2.315382 OK (took 4089 ms)
step 5: loss 1.849347 expected loss 1.849029 OK (took 4052 ms)
step 6: loss 1.395217 expected loss 1.394656 OK (took 4071 ms)
step 7: loss 0.998616 expected loss 0.999147 OK (took 4057 ms)
step 8: loss 0.625540 expected loss 0.624080 OK (took 4073 ms)
step 9: loss 0.378012 expected loss 0.376511 OK (took 4059 ms)
overall okay: True
BatchSize: 4
TokenCount: 64
OutputCount: 73323776
Logits TENSOR OK MaxAbsDiff 0.000534
δTokenEmbeddings TENSOR OK MaxAbsDiff 0.001185
δPositionEmbeddings TENSOR OK MaxAbsDiff 0.000037
δLayerNorm1Weights TENSOR OK MaxAbsDiff 0.003039
δLayerNorm1Bias TENSOR OK MaxAbsDiff 0.001283
δQKVWeights TENSOR OK MaxAbsDiff 0.000474
δQKVBias TENSOR OK MaxAbsDiff 0.000257
δAttentionProjectionWeights TENSOR OK MaxAbsDiff 0.000200
δAttentionProjectionBias TENSOR OK MaxAbsDiff 0.000179
δLayerNorm2Weights TENSOR OK MaxAbsDiff 0.009708
δLayerNorm2Bias TENSOR OK MaxAbsDiff 0.000819
δFullConnectWeights TENSOR OK MaxAbsDiff 0.000794
δFullConnectBias TENSOR OK MaxAbsDiff 0.000193
δFullConnectProjectionWeights TENSOR OK MaxAbsDiff 0.000385
δFullConnectProjectionBias TENSOR OK MaxAbsDiff 0.000118
δLayerNormFinalWeights TENSOR OK MaxAbsDiff 0.000362
δLayerNormFinalBias TENSOR OK MaxAbsDiff 0.000066
0: loss 5.269892 exp. 5.270007 OK ( 1386 ms = Forward 490 ms ZeroGrad 90 ms Backward 632 ms Update 174 ms) JIT/WARMUP
1: loss 4.059388 exp. 4.059707 OK ( 875 ms = Forward 279 ms ZeroGrad 28 ms Backward 463 ms Update 106 ms) JIT/WARMUP
2: loss 3.374209 exp. 3.375123 OK ( 1005 ms = Forward 407 ms ZeroGrad 28 ms Backward 459 ms Update 110 ms) JIT/WARMUP
3: loss 2.800130 exp. 2.800783 OK ( 867 ms = Forward 266 ms ZeroGrad 28 ms Backward 474 ms Update 99 ms)
4: loss 2.315308 exp. 2.315382 OK ( 847 ms = Forward 238 ms ZeroGrad 28 ms Backward 477 ms Update 103 ms)
5: loss 1.849346 exp. 1.849029 OK ( 884 ms = Forward 234 ms ZeroGrad 28 ms Backward 516 ms Update 106 ms)
6: loss 1.395217 exp. 1.394656 OK ( 884 ms = Forward 282 ms ZeroGrad 28 ms Backward 468 ms Update 106 ms)
7: loss 0.998617 exp. 0.999147 OK ( 839 ms = Forward 231 ms ZeroGrad 28 ms Backward 474 ms Update 106 ms)
8: loss 0.625541 exp. 0.624080 OK ( 887 ms = Forward 309 ms ZeroGrad 28 ms Backward 449 ms Update 102 ms)
9: loss 0.378010 exp. 0.376511 OK ( 915 ms = Forward 311 ms ZeroGrad 28 ms Backward 485 ms Update 91 ms)
All okay: True
0.Forward 00 EmbedForward 0% count: 7 sum: 1.2 min: 0.2 mean: 0.2 max: 0.2 [ms]
0.Forward 01 LayerNormForward 0% count: 84 sum: 7.2 min: 0.1 mean: 0.1 max: 0.1 [ms]
0.Forward 02 MatMulForward 3% count: 84 sum: 197.5 min: 1.9 mean: 2.4 max: 4.4 [ms]
0.Forward 03 AttentionForward 1% count: 84 sum: 67.0 min: 0.7 mean: 0.8 max: 1.6 [ms]
0.Forward 04 MatMulForward 1% count: 84 sum: 69.5 min: 0.7 mean: 0.8 max: 1.4 [ms]
0.Forward 05 ResidualForward 0% count: 84 sum: 7.5 min: 0.1 mean: 0.1 max: 0.3 [ms]
0.Forward 06 LayerNormForward 0% count: 84 sum: 6.8 min: 0.1 mean: 0.1 max: 0.1 [ms]
0.Forward 07 MatMulForward 4% count: 84 sum: 252.7 min: 2.3 mean: 3.0 max: 5.2 [ms]
0.Forward 08 GeLUForward 1% count: 84 sum: 74.9 min: 0.7 mean: 0.9 max: 1.7 [ms]
0.Forward 09 MatMulForward 4% count: 84 sum: 250.0 min: 2.4 mean: 3.0 max: 5.9 [ms]
0.Forward 10 ResidualForward 0% count: 84 sum: 7.2 min: 0.1 mean: 0.1 max: 0.1 [ms]
0.Forward 11 LayerNormForward 0% count: 7 sum: 0.5 min: 0.1 mean: 0.1 max: 0.1 [ms]
0.Forward 12 MatMulForward 15% count: 7 sum: 887.3 min: 93.7 mean: 126.8 max: 175.8 [ms]
0.Forward 13 SoftmaxForward 1% count: 7 sum: 38.6 min: 3.8 mean: 5.5 max: 9.4 [ms]
0.Forward 14 CrossEntropyForward 0% count: 7 sum: 0.2 min: 0.0 mean: 0.0 max: 0.0 [ms]
1.ZeroGrad 00 Zero 2% count: 7 sum: 123.0 min: 17.5 mean: 17.6 max: 17.6 [ms]
1.ZeroGrad 01 Zero 1% count: 7 sum: 72.2 min: 10.3 mean: 10.3 max: 10.4 [ms]
2.Backward 00 CrossEntropySoftmaxBackward 1% count: 7 sum: 45.4 min: 6.4 mean: 6.5 max: 6.6 [ms]
2.Backward 01 MatMulBackward 18% count: 7 sum: 1116.7 min: 139.0 mean: 159.5 max: 173.6 [ms]
2.Backward 02 LayerNormBackward 0% count: 7 sum: 4.1 min: 0.5 mean: 0.6 max: 0.6 [ms]
2.Backward 03 ResidualBackward 0% count: 84 sum: 14.1 min: 0.1 mean: 0.2 max: 0.3 [ms]
2.Backward 04 MatMulBackward 11% count: 84 sum: 679.4 min: 6.8 mean: 8.1 max: 14.3 [ms]
2.Backward 05 GeLUBackward 2% count: 84 sum: 124.7 min: 1.3 mean: 1.5 max: 2.0 [ms]
2.Backward 06 MatMulBackward 8% count: 84 sum: 519.8 min: 5.7 mean: 6.2 max: 10.2 [ms]
2.Backward 07 LayerNormBackward 1% count: 84 sum: 39.3 min: 0.4 mean: 0.5 max: 1.1 [ms]
2.Backward 08 ResidualBackward 0% count: 84 sum: 12.2 min: 0.1 mean: 0.1 max: 0.3 [ms]
2.Backward 09 MatMulBackward 3% count: 84 sum: 158.9 min: 1.6 mean: 1.9 max: 7.1 [ms]
2.Backward 10 AttentionBackward 3% count: 84 sum: 159.1 min: 1.4 mean: 1.9 max: 4.6 [ms]
2.Backward 11 MatMulBackward 7% count: 84 sum: 423.6 min: 4.3 mean: 5.0 max: 8.1 [ms]
2.Backward 12 LayerNormBackward 1% count: 84 sum: 42.5 min: 0.4 mean: 0.5 max: 0.8 [ms]
2.Backward 13 EmbedBackward 0% count: 7 sum: 1.3 min: 0.2 mean: 0.2 max: 0.2 [ms]
3.Update 00 AdamW 12% count: 7 sum: 713.1 min: 91.1 mean: 101.9 max: 106.3 [ms]
MatMulBackward 47% sum: 2898 [ms] per step: 414 [ms]
MatMulForward 27% sum: 1657 [ms] per step: 237 [ms]
AdamW 12% sum: 713 [ms] per step: 102 [ms]
Zero 3% sum: 195 [ms] per step: 28 [ms]
AttentionBackward 3% sum: 159 [ms] per step: 23 [ms]
GeLUBackward 2% sum: 125 [ms] per step: 18 [ms]
LayerNormBackward 1% sum: 86 [ms] per step: 12 [ms]
GeLUForward 1% sum: 75 [ms] per step: 11 [ms]
AttentionForward 1% sum: 67 [ms] per step: 10 [ms]
CrossEntropySoftmaxBackward 1% sum: 45 [ms] per step: 6 [ms]
SoftmaxForward 1% sum: 39 [ms] per step: 6 [ms]
ResidualBackward 0% sum: 26 [ms] per step: 4 [ms]
ResidualForward 0% sum: 15 [ms] per step: 2 [ms]
LayerNormForward 0% sum: 14 [ms] per step: 2 [ms]
EmbedBackward 0% sum: 1 [ms] per step: 0 [ms]
EmbedForward 0% sum: 1 [ms] per step: 0 [ms]
CrossEntropyForward 0% sum: 0 [ms] per step: 0 [ms]
Total 100% sum: 6117 [ms] per step: 874 [ms]
```

## Example
Expand Down
28 changes: 14 additions & 14 deletions src/Llm.Benchmarks/Gpt2Bench.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ namespace nietras.LargeLanguageModel.Benchmarks;
public class Gpt2Bench
{
const string DataDirectory = "../../../";
GPT2 _model = new();
ExpectedTensors _expectedInputsOutputs;
ParameterTensors _expectedGrads;
Model _model = null!;
ExpectedTokenTensors _expectedTokens = null!;
ExpectedOutputTensors _expectedOutputs = null!;
ParameterTensors _expectedGrads = null!;
TimeLlm? _llm;
int _step;

Expand All @@ -35,9 +36,10 @@ public void GlobalSetup()
DataDirectory, t => Trace.WriteLine(t));

// build the GPT-2 model from a checkpoint
BuildFromCheckpoint(ref _model, DataDirectory + ModelBinaryFileName);
_model = ModelFromCheckpoint(DataDirectory + ModelBinaryFileName);

(_expectedInputsOutputs, _expectedGrads) = ReadExpectedState(_model, DataDirectory);
(_expectedTokens, _expectedOutputs, _expectedGrads) =
ReadExpectedState(_model, DataDirectory);

var llm = LlmFactory.NameToCreate[Name]();
_llm = new TimeLlm(llm);
Expand All @@ -47,9 +49,9 @@ public void GlobalSetup()
[Benchmark]
public unsafe float Train()
{
var (loss, t) = TrainStep(ref _model,
_expectedInputsOutputs.InputTokenIndices, _expectedInputsOutputs.OutputTokenIndices,
_expectedInputsOutputs.BatchSize, _expectedInputsOutputs.TokenCount,
var (loss, t) = TrainStep(_model,
_expectedTokens.InputTokenIndices, _expectedTokens.OutputTokenIndices,
_expectedTokens.BatchSize, _expectedTokens.TokenCount,
_llm!, _step);
++_step;
return loss;
Expand All @@ -58,11 +60,9 @@ public unsafe float Train()
[GlobalCleanup]
public unsafe void GlobalCleanup()
{
free(_expectedInputsOutputs.InputTokenIndices);
free(_expectedInputsOutputs.OutputTokenIndices);
free(_expectedInputsOutputs.ExpectedLogits);
free(_expectedInputsOutputs.ExpectedLoss);
free(_expectedGrads.MemoryPtr);
Free(ref _model);
_expectedTokens.Dispose();
_expectedOutputs.Dispose();
_expectedGrads.Dispose();
_model.Dispose();
}
}
14 changes: 14 additions & 0 deletions src/Llm.Test/DevTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace nietras.LargeLanguageModel.Test;

[TestClass]
public class DevTest
{
[TestMethod]
public void DevTest_()
{
Gpt2.Config c = new() { VocabularySize = 16, MaxTokenCount = 32, LayerCount = 4, ChannelCount = 8, HeadCount = 16 };
using var tensors = Gpt2.ParameterTensors.Create(c);
}
}
46 changes: 46 additions & 0 deletions src/Llm/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,50 @@ public static unsafe void ReadExactlyUnmanaged<T>(this FileStream file, T* value
totalReadCount += countToRead;
}
}

public static nint Product(this nint[] values) => Product(values.AsSpan());

public static nint Product(this ReadOnlySpan<nint> values)
{
if (values.Length == 0) { return 0; }
var product = values[0];
for (var i = 1; i < values.Length; i++)
{
product *= values[i];
}
return product;
}

public static nint[] CalculateStrides(this nint[] lengths) => CalculateStrides(lengths.AsSpan());

public static nint[] CalculateStrides(this ReadOnlySpan<nint> lengths)
{
var strides = new nint[lengths.Length];
if (lengths.Length == 1 && lengths[0] == 0 || lengths.Length == 0)
{
strides[0] = 0;
return strides;
}
nint stride = 1;
for (var i = strides.Length - 1; i >= 0; i--)
{
strides[i] = stride;
stride *= lengths[i];
}
return strides;
}

public static string ToShapeText(this ReadOnlySpan<nint> values)
{
Span<char> buffer = stackalloc char[1024];
var handler = new DefaultInterpolatedStringHandler(values.Length - 1 + 2, values.Length, null, buffer);
handler.AppendLiteral("[");
for (var i = 0; i < values.Length; i++)
{
if (i > 0) { handler.AppendLiteral(", "); }
handler.AppendFormatted(values[i], "D");
}
handler.AppendLiteral("]");
return handler.ToString();
}
}
Loading

0 comments on commit 25b1bb2

Please sign in to comment.