Skip to content

Commit

Permalink
Added vllm cuda support
Browse files Browse the repository at this point in the history
Signed-off-by: Brian <[email protected]>
  • Loading branch information
bmahabirbu committed Jan 14, 2025
1 parent 7b5e3ce commit 18745d0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
43 changes: 30 additions & 13 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,30 @@ def engine_version(engine):


def get_gpu():

envs = get_env_vars()
# If env vars already set return
if envs:
return

# ASAHI CASE
if os.path.exists('/etc/os-release'):
with open('/etc/os-release', 'r') as file:
if "asahi" in file.read().lower():
# Set Env Var and break
os.environ["ASAHI_VISIBLE_DEVICES"] = "1"
return

# NVIDIA CASE
try:
command = ['nvidia-smi']
run_cmd(command).stdout.decode("utf-8")
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
return
except Exception:
pass

# ROCm/AMD CASE
i = 0
gpu_num = 0
gpu_bytes = 0
Expand All @@ -205,24 +229,17 @@ def get_gpu():

i += 1

if gpu_bytes: # this is the ROCm/AMD case
return "HIP_VISIBLE_DEVICES", gpu_num

if os.path.exists('/etc/os-release'):
with open('/etc/os-release', 'r') as file:
content = file.read()
if "asahi" in content.lower():
return "ASAHI_VISIBLE_DEVICES", 1

return None, None
if gpu_bytes:
os.environ["HIP_VISIBLE_DEVICES"] = gpu_num
return


def get_env_vars():
prefixes = ("ASAHI_", "CUDA_", "HIP_", "HSA_")
env_vars = {k: v for k, v in os.environ.items() if k.startswith(prefixes)}

gpu_type, gpu_num = get_gpu()
if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}:
env_vars[gpu_type] = str(gpu_num)
# gpu_type, gpu_num = get_gpu()
# if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}:
# env_vars[gpu_type] = str(gpu_num)

return env_vars
22 changes: 18 additions & 4 deletions ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ def _image(self, args):
if args.image != default_image():
return args.image

gpu_type, _ = get_gpu()
env_vars = get_env_vars()

if not env_vars:
return args.image

gpu_type, _ = next(iter(env_vars.items()))

if args.runtime == "vllm":
if gpu_type == "HIP_VISIBLE_DEVICES":
return "quay.io/modh/vllm:rhoai-2.17-rocm"
Expand Down Expand Up @@ -171,8 +177,11 @@ def setup_container(self, args):
conman_args += ["--device", "/dev/kfd"]

for k, v in get_env_vars().items():
conman_args += ["-e", f"{k}={v}"]

# Special case for Cuda
if k == "CUDA_VISIBLE_DEVICES":
conman_args += ["--device", "nvidia.com/gpu=all"]
else:
conman_args += ["-e", f"{k}={v}"]
return conman_args

def gpu_args(self, force=False, server=False):
Expand Down Expand Up @@ -259,6 +268,7 @@ def build_exec_args_run(self, args, model_path, prompt):
if args.debug:
exec_args += ["-v"]

get_gpu()
gpu_args = self.gpu_args(force=args.gpu)
if gpu_args is not None:
exec_args.extend(gpu_args)
Expand Down Expand Up @@ -303,13 +313,17 @@ def build_exec_args_serve(self, args, exec_model_path):
]
if args.seed:
exec_args += ["--seed", args.seed]

return exec_args

def handle_runtime(self, args, exec_args, exec_model_path):
if args.runtime == "vllm":
get_gpu()
exec_model_path = os.path.dirname(exec_model_path)
exec_args = ["vllm", "serve", "--port", args.port, exec_model_path]
# Left out "vllm", "serve" the image entrypoint already starts it
exec_args = ["--port", args.port, "--model", MNT_FILE, "--max_model_len", "2048"]
else:
get_gpu()
gpu_args = self.gpu_args(force=args.gpu, server=True)
if gpu_args is not None:
exec_args.extend(gpu_args)
Expand Down
4 changes: 2 additions & 2 deletions test/system/040-serve.bats
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name


run cat $name.yaml
is "$output" ".*command: \[\"vllm\"\]" "command is correct"
is "$output" ".*args: \['serve', '--port', '1234', '/mnt/models'\]" "args is correct"
is "$output" ".*command: \[\"--port\"\]" "command is correct"
is "$output" ".*args: \['1234', '--model', '/mnt/models/model.file', '--max_model_len', '2048'\]" "args are correct"

is "$output" ".*image: quay.io/ramalama/ramalama" "image is correct"
is "$output" ".*reference: ${ociimage}" "AI image should be created"
Expand Down

0 comments on commit 18745d0

Please sign in to comment.