Skip to content

Commit

Permalink
Switch --smart-jwks to --smart-key and accept PEM keys
Browse files Browse the repository at this point in the history
Some servers (notably Epic) support using a private key as an
alternative to a JWK Set. So let's add support for that too.

It's a little fussier, but easy enough.
  • Loading branch information
mikix committed Jan 10, 2025
1 parent 9b6c8ec commit de656d6
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 66 deletions.
8 changes: 7 additions & 1 deletion cumulus_etl/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
def add_auth(parser: argparse.ArgumentParser, *, use_fhir_url: bool = True):
group = parser.add_argument_group("authentication")
group.add_argument("--smart-client-id", metavar="ID", help="client ID for SMART authentication")
group.add_argument("--smart-jwks", metavar="PATH", help="JWKS file for SMART authentication")
group.add_argument(
"--smart-key", metavar="PATH", help="JWKS or PEM file for SMART authentication"
)
group.add_argument("--basic-user", metavar="USER", help="username for Basic authentication")
group.add_argument(
"--basic-passwd", metavar="PATH", help="password file for Basic authentication"
Expand All @@ -31,6 +33,10 @@ def add_auth(parser: argparse.ArgumentParser, *, use_fhir_url: bool = True):
help="FHIR server base URL, only needed if you exported separately",
)

# --smart-jwks is a deprecated alias for --smart-key (as of Jan 2025)
# Keep it around for a bit, since it was in common use for a couple years.
group.add_argument("--smart-jwks", metavar="PATH", help=argparse.SUPPRESS)


def add_aws(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group("AWS")
Expand Down
2 changes: 2 additions & 0 deletions cumulus_etl/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
COMPLETION_ARG_MISSING = 34
TASK_HELP = 35
MISSING_REQUESTED_RESOURCES = 36
TOO_MANY_SMART_CREDENTIALS = 37
BAD_SMART_CREDENTIAL = 38


class FatalError(Exception):
Expand Down
97 changes: 69 additions & 28 deletions cumulus_etl/fhir/fhir_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ def sign_headers(self, headers: dict) -> dict:
return headers


class JwksAuth(Auth):
"""Authentication with a JWK Set (typical backend service profile)"""
class JwtAuth(Auth):
"""Authentication with a JWT (typical OAuth2 backend service profile)"""

def __init__(self, server_root: str, client_id: str, jwks: dict, resources: Iterable[str]):
def __init__(
self, server_root: str, client_id: str, jwks: dict, pem: dict, resources: Iterable[str]
):
super().__init__()
self._server_root = server_root
self._client_id = client_id
self._jwks = jwks
self._pem = pem
self._resources = list(resources)
self._token_endpoint = None
self._access_token = None
Expand Down Expand Up @@ -128,6 +131,42 @@ async def _get_token_endpoint(self, session: httpx.AsyncClient) -> str:

return config["token_endpoint"]

def _make_pem_jwk(self) -> tuple[str, jwk.JWK]:
try:
jwk_key = jwk.JWK.from_pem(self._pem.encode("utf8"))
except ValueError:
jwk_key = None # will fail below

if jwk_key and jwk_key.has_private:
# Unfortunately, we can't just ask jcrypto "hey what JWT alg value should I use here?".
# So instead, we check for a few common values.
if jwk_key.get("kty") == "RSA":
# Could pick any RS* value (like RS256, RS384, or RS512), assuming the server
# supports it. Since 384 is practically as secure as 512, but more common, we'll
# use that.
return "RS384", jwk_key
elif jwk_key.get("kty") == "EC" and jwk_key.get("crv") == "P-256":
return "ES256", jwk_key
elif jwk_key.get("kty") == "EC" and jwk_key.get("crv") == "P-384":
return "ES384", jwk_key
elif jwk_key.get("kty") == "EC" and jwk_key.get("crv") == "P-521":
# Yes, P-521 is not a typo, it's what the curve is called.
# The curve uses 521 bits, but it's hashed with SHA512, so it's called ES512.
return "ES512", jwk_key

errors.fatal(
"No supported private key found in the provided PEM file.", errors.BAD_SMART_CREDENTIAL
)

def _make_jwks_jwk(self) -> tuple[str, jwk.JWK]:
# Find a usable signing JWK from JWKS
for key in self._jwks.get("keys", []):
if "sign" in key.get("key_ops", []) and key.get("kid") and key.get("alg"):
return key["alg"], jwk.JWK(**key)
errors.fatal(
"No valid private key found in the provided JWKS file.", errors.BAD_SMART_CREDENTIAL
)

def _make_signed_jwt(self) -> str:
"""
Creates a signed JWT for use in the client-confidential-asymmetric protocol.
Expand All @@ -136,19 +175,19 @@ def _make_signed_jwt(self) -> str:
:returns: a signed JWT string, ready for authentication with the FHIR server
"""
# Find a usable singing JWK from JWKS
for key in self._jwks.get("keys", []):
if "sign" in key.get("key_ops", []) and key.get("kid") and key.get("alg"):
break
else: # no valid private JWK found
raise errors.FatalError("No valid private key found in the provided JWKS file.")
if self._pem:
generator = self._make_pem_jwk
else:
generator = self._make_jwks_jwk
algorithm, jwk_key = generator()

# Now generate a signed JWT based off the given JWK
header = {
"alg": key["alg"],
"kid": key["kid"],
"alg": algorithm,
"typ": "JWT",
}
if "kid" in jwk_key:
header["kid"] = jwk_key["kid"]
claims = {
"iss": self._client_id,
"sub": self._client_id,
Expand All @@ -157,7 +196,7 @@ def _make_signed_jwt(self) -> str:
"jti": str(uuid.uuid4()),
}
token = jwt.JWT(header=header, claims=claims)
token.make_signed_token(key=jwk.JWK(**key))
token.make_signed_token(key=jwk_key)
return token.serialize()


Expand Down Expand Up @@ -202,41 +241,43 @@ def create_auth(
bearer_token: str | None,
smart_client_id: str | None,
smart_jwks: dict | None,
smart_pem: str | None,
) -> Auth:
"""Determine which auth method to use based on user provided arguments"""
valid_smart_jwks = smart_jwks is not None
valid_smart_pem = smart_pem is not None

# Check if the user tried to specify multiple types of auth, and help them out
has_basic_args = bool(basic_user or basic_password)
has_bearer_args = bool(bearer_token)
has_smart_args = bool(valid_smart_jwks)
has_smart_args = bool(valid_smart_jwks or valid_smart_pem)
total_auth_types = has_basic_args + has_bearer_args + has_smart_args
if total_auth_types > 1:
print(
"Multiple authentication methods have been specified. Double check your arguments to Cumulus ETL.",
file=sys.stderr,
errors.fatal(
"Multiple authentication methods have been specified. "
"Double check your arguments to Cumulus ETL.",
errors.ARGS_CONFLICT,
)
raise SystemExit(errors.ARGS_CONFLICT)

if basic_user and basic_password:
return BasicAuth(basic_user, basic_password)
elif basic_user or basic_password:
print(
"You must provide both --basic-user and --basic-password to connect to a Basic auth server.",
file=sys.stderr,
errors.fatal(
"You must provide both --basic-user and --basic-password "
"to connect to a Basic auth server.",
errors.BASIC_CREDENTIALS_MISSING,
)
raise SystemExit(errors.BASIC_CREDENTIALS_MISSING)

if bearer_token:
return BearerAuth(bearer_token)

if smart_client_id and valid_smart_jwks:
return JwksAuth(server_root, smart_client_id, smart_jwks, resources)
elif smart_client_id or valid_smart_jwks:
print(
"You must provide both --smart-client-id and --smart-jwks to connect to a SMART FHIR server.",
file=sys.stderr,
if smart_client_id and has_smart_args:
return JwtAuth(server_root, smart_client_id, smart_jwks, smart_pem, resources)
elif smart_client_id or has_smart_args:
errors.fatal(
"You must provide both --smart-client-id and --smart-key "
"to connect to a SMART FHIR server.",
errors.SMART_CREDENTIALS_MISSING,
)
raise SystemExit(errors.SMART_CREDENTIALS_MISSING)

return Auth()
20 changes: 18 additions & 2 deletions cumulus_etl/fhir/fhir_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
bearer_token: str | None = None,
smart_client_id: str | None = None,
smart_jwks: dict | None = None,
smart_pem: str | None = None,
):
"""
Initialize and authorize a BackendServiceServer context manager.
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
bearer_token,
smart_client_id,
smart_jwks,
smart_pem,
)
self._session: httpx.AsyncClient | None = None
self._capabilities: dict = {}
Expand Down Expand Up @@ -275,12 +277,25 @@ def create_fhir_client_for_cli(
except FileNotFoundError:
smart_client_id = args.smart_client_id

# Check deprecated --smart-jwks argument first
smart_jwks = common.read_json(args.smart_jwks) if args.smart_jwks else None
smart_pem = None
if args.smart_key:
folded = args.smart_key.casefold()
if folded.endswith(".jwks"):
smart_jwks = common.read_json(args.smart_key)
elif folded.endswith(".pem"):
smart_pem = common.read_text(args.smart_key).strip()
else:
raise OSError(
f"Unrecognized private key file '{args.smart_key}'\n"
"(must end in .jwks or .pem)."
)

basic_password = common.read_text(args.basic_passwd).strip() if args.basic_passwd else None
bearer_token = common.read_text(args.bearer_token).strip() if args.bearer_token else None
except OSError as exc:
print(exc, file=sys.stderr)
raise SystemExit(errors.ARGS_INVALID) from exc
errors.fatal(str(exc), errors.ARGS_INVALID)

client_resources = set(resources)
if {"DiagnosticReport", "DocumentReference"} & client_resources:
Expand All @@ -296,4 +311,5 @@ def create_fhir_client_for_cli(
bearer_token=bearer_token,
smart_client_id=smart_client_id,
smart_jwks=smart_jwks,
smart_pem=smart_pem,
)
26 changes: 20 additions & 6 deletions docs/bulk-exports.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ but they are also harmless and will be ignored.
## Registering an Export Client

On your server, you need to register a new "backend service" client.
You'll be asked to provide a JWKS (JWK Set) file.
You'll be asked to provide some sort of private/public key.
See below for generating that.
You'll also be asked for a client ID or the server may generate a client ID for you.

### Generating a JWKS
### Generating a JWK Set

A JWKS is just a file with some cryptographic keys,
A JWK Set (JWKS) is just a file with some cryptographic keys,
usually holding a public and private version of the same key.
FHIR servers use it to grant clients access.

Expand All @@ -159,15 +159,29 @@ jose jwk gen -s -i "{\"alg\":\"RS384\",\"kid\":\"`uuidgen`\"}" -o private.jwks
jose jwk pub -s -i private.jwks -o public.jwks
```

Then give `public.jwks` to your FHIR server and `private.jwks` to Cumulus ETL (details on that below).
After giving `public.jwks` to your FHIR server,
you can pass `private.jwks` to Cumulus ETL with `--smart-key` (example below).

### Generating a PEM key

A PEM key is just a file with a single private cryptographic key.
Some FHIR servers may use it to grant clients access.

If your FHIR server uses a PEM key,
it will provide instructions on the kind of key it expects and how to generate it.
See for example,
[Epic's documentation](https://vendorservices.epic.com/Article?docId=oauth2&section=Creating-Key-Pair).

After giving the public key to your FHIR server,
you can pass your `private.pem` file to Cumulus ETL with `--smart-key` (example below).

### SMART Arguments

You'll need to pass two new arguments to Cumulus ETL:
You'll need to pass two arguments to Cumulus ETL:

```sh
--smart-client-id=YOUR_CLIENT_ID
--smart-jwks=/path/to/private.jwks
--smart-key=/path/to/private.jwks
```

You can also give `--smart-client-id` a path to a file with your client ID,
Expand Down
4 changes: 2 additions & 2 deletions tests/export/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_arg_passthrough(self):
"--since=1920",
"--until=1923",
"--smart-client-id=ID",
"--smart-jwks=jwks.json",
"--smart-key=jwks.json",
"--basic-user=alice",
"--basic-passwd=passwd.txt",
"--bearer-token=token.txt",
Expand All @@ -74,7 +74,7 @@ async def test_arg_passthrough(self):
self.assertEqual("1923", self.loader_init_mock.call_args.kwargs["until"])
self.assertEqual("my-url", self.loader_init_mock.call_args.kwargs["resume"])
self.assertEqual("ID", self.client_mock.call_args.args[0].smart_client_id)
self.assertEqual("jwks.json", self.client_mock.call_args.args[0].smart_jwks)
self.assertEqual("jwks.json", self.client_mock.call_args.args[0].smart_key)
self.assertEqual("alice", self.client_mock.call_args.args[0].basic_user)
self.assertEqual("passwd.txt", self.client_mock.call_args.args[0].basic_passwd)
self.assertEqual("token.txt", self.client_mock.call_args.args[0].bearer_token)
Loading

0 comments on commit de656d6

Please sign in to comment.