From d8c210d51af20482bf58293f8889fea87234651b Mon Sep 17 00:00:00 2001 From: Brian Date: Mon, 13 Jan 2025 00:49:40 -0500 Subject: [PATCH] Added vllm cuda support Signed-off-by: Brian --- ramalama/common.py | 43 ++++++++++++++++++++++++++++++------------- ramalama/model.py | 22 ++++++++++++++++++---- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/ramalama/common.py b/ramalama/common.py index cb41cdf9..1d0c33bd 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -193,6 +193,30 @@ def engine_version(engine): def get_gpu(): + + envs = get_env_vars() + # If env vars already set return + if envs: + return + + # ASAHI CASE + if os.path.exists('/etc/os-release'): + with open('/etc/os-release', 'r') as file: + if "asahi" in file.read().lower(): + # Set Env Var and break + os.environ["ASAHI_VISIBLE_DEVICES"] = "1" + return + + # NVIDIA CASE + try: + command = ['nvidia-smi'] + run_cmd(command).stdout.decode("utf-8") + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + return + except Exception: + pass + + # ROCm/AMD CASE i = 0 gpu_num = 0 gpu_bytes = 0 @@ -205,24 +229,17 @@ def get_gpu(): i += 1 - if gpu_bytes: # this is the ROCm/AMD case - return "HIP_VISIBLE_DEVICES", gpu_num - - if os.path.exists('/etc/os-release'): - with open('/etc/os-release', 'r') as file: - content = file.read() - if "asahi" in content.lower(): - return "ASAHI_VISIBLE_DEVICES", 1 - - return None, None + if gpu_bytes: + os.environ["HIP_VISIBLE_DEVICES"] = gpu_num + return def get_env_vars(): prefixes = ("ASAHI_", "CUDA_", "HIP_", "HSA_") env_vars = {k: v for k, v in os.environ.items() if k.startswith(prefixes)} - gpu_type, gpu_num = get_gpu() - if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}: - env_vars[gpu_type] = str(gpu_num) + # gpu_type, gpu_num = get_gpu() + # if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}: + # env_vars[gpu_type] = str(gpu_num) return env_vars diff --git a/ramalama/model.py b/ramalama/model.py index 5dc45984..5acb0f0e 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -109,7 +109,13 @@ def _image(self, args): if args.image != default_image(): return args.image - gpu_type, _ = get_gpu() + env_vars = get_env_vars() + + if not env_vars: + return args.image + + gpu_type, _ = next(iter(env_vars.items())) + if args.runtime == "vllm": if gpu_type == "HIP_VISIBLE_DEVICES": return "quay.io/modh/vllm:rhoai-2.17-rocm" @@ -171,8 +177,11 @@ def setup_container(self, args): conman_args += ["--device", "/dev/kfd"] for k, v in get_env_vars().items(): - conman_args += ["-e", f"{k}={v}"] - + # Special case for Cuda + if k == "CUDA_VISIBLE_DEVICES": + conman_args += ["--device", "nvidia.com/gpu=all"] + else: + conman_args += ["-e", f"{k}={v}"] return conman_args def gpu_args(self, force=False, server=False): @@ -259,6 +268,7 @@ def build_exec_args_run(self, args, model_path, prompt): if args.debug: exec_args += ["-v"] + get_gpu() gpu_args = self.gpu_args(force=args.gpu) if gpu_args is not None: exec_args.extend(gpu_args) @@ -303,13 +313,17 @@ def build_exec_args_serve(self, args, exec_model_path): ] if args.seed: exec_args += ["--seed", args.seed] + return exec_args def handle_runtime(self, args, exec_args, exec_model_path): if args.runtime == "vllm": + get_gpu() exec_model_path = os.path.dirname(exec_model_path) - exec_args = ["vllm", "serve", "--port", args.port, exec_model_path] + # Left out "vllm", "serve" the image entrypoint already starts it + exec_args = ["--port", args.port, "--model", MNT_FILE, "--max_model_len", "2048"] else: + get_gpu() gpu_args = self.gpu_args(force=args.gpu, server=True) if gpu_args is not None: exec_args.extend(gpu_args)