diff --git a/ggml-backend.c b/ggml-backend.c index f5bdcf07838aa4..749c618948b9f2 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1041,6 +1041,7 @@ struct ggml_backend_sched { bool is_alloc; int n_backends; + bool skip_cpu; ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS]; ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS]; @@ -1638,6 +1639,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + if (sched->skip_cpu && ggml_backend_is_cpu(split_backend)) { + continue; + } + // copy the input tensors to the split backend for (int j = 0; j < split->n_inputs; j++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); @@ -1782,6 +1787,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { free(sched); } +void ggml_backend_sched_set_skip_cpu(ggml_backend_sched_t sched, bool value) { + sched->skip_cpu = value; +} + void ggml_backend_sched_reset(ggml_backend_sched_t sched) { // reset state for the next run if (!sched->is_reset) { diff --git a/ggml-backend.h b/ggml-backend.h index d1becc840f0aac..49632ac844eff9 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -182,6 +182,8 @@ extern "C" { GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + GGML_API void ggml_backend_sched_set_skip_cpu(ggml_backend_sched_t sched, bool value); + // Initialize backend buffers from a measure graph GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); diff --git a/llama.cpp b/llama.cpp index 32653d665d60d9..a9f3a0e2cd065d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -18026,6 +18026,10 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { 1.0e6 * ctx->n_sample / ctx->t_sample_us); } +void llama_set_skip_cpu(struct llama_context * ctx, bool value) { + ggml_backend_sched_set_skip_cpu(ctx->sched, value); +} + // For internal test use const std::vector> & llama_internal_get_tensor_map( struct llama_context * ctx diff --git a/llama.h b/llama.h index 5d7f33e3aa8775..98cbf2639d6c38 100644 --- a/llama.h +++ b/llama.h @@ -1104,6 +1104,8 @@ extern "C" { LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + LLAMA_API void llama_set_skip_cpu(struct llama_context * ctx, bool value); + #ifdef __cplusplus } #endif