Skip to content

Commit

Permalink
Lower the workgroup count for some shaders by providing a loop that p…
Browse files Browse the repository at this point in the history
…rocesses

four floats at a time.
  • Loading branch information
manyoso committed Oct 26, 2023
1 parent 9c43141 commit 130c909
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 21 deletions.
16 changes: 8 additions & 8 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else {
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst));
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
}
} break;
case GGML_OP_MUL:
Expand All @@ -1367,7 +1367,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
// src1 is a row
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
} else {
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst));
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
}
} break;
case GGML_OP_SCALE:
Expand All @@ -1379,15 +1379,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break;
case GGML_UNARY_OP_RELU:
{
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break;
case GGML_UNARY_OP_GELU:
{
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
} break;
default:
{
Expand Down Expand Up @@ -1427,9 +1427,9 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
ggml_is_transposed(src1)) {
fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
goto not_implemented;
}
}

switch (src0t) {
switch (src0t) {
case GGML_TYPE_F32:
ggml_vk_mul_mat_mat_f32(seq,
id_src0, id_src1, id_dst,
Expand Down Expand Up @@ -1459,7 +1459,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
goto not_implemented;
}
}

} break;
case GGML_OP_GET_ROWS:
{
Expand Down
9 changes: 6 additions & 3 deletions kompute/op_add.comp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs;

void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;

out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i) + pcs.inBOff];
}
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff];
}
}
9 changes: 6 additions & 3 deletions kompute/op_gelu.comp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ layout(push_constant) uniform PushConstants {
} pcs;

void main() {
const uint i = gl_WorkGroupID.x;
const float x = in_[i + pcs.inOff];
const uint baseIndex = gl_WorkGroupID.x * 4;

out_[i + pcs.outOff] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));
}
}
7 changes: 5 additions & 2 deletions kompute/op_mul.comp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
} pcs;

void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;

out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
}
}
7 changes: 5 additions & 2 deletions kompute/op_relu.comp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ layout(push_constant) uniform PushConstants {
} pcs;

void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;

out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
}
}
10 changes: 7 additions & 3 deletions kompute/op_silu.comp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ layout(push_constant) uniform PushConstants {
uint outOff;
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const float x = in_[i + pcs.inOff];

out_[i + pcs.outOff] = x / (1.0 + exp(-x));
const uint baseIndex = gl_WorkGroupID.x * 4;

for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
}
}

0 comments on commit 130c909

Please sign in to comment.