Skip to content

Commit

Permalink
Add custom host support to mysql multiuser lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-georgiev committed Apr 25, 2023
1 parent e3b3d6b commit 15e1c33
Showing 1 changed file with 90 additions and 19 deletions.
109 changes: 90 additions & 19 deletions SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,15 @@ def set_secret(service_client, arn, token):
if not conn:
logger.error("setSecret: Unable to log into database using current credentials for secret %s" % arn)
raise ValueError("Unable to log into database using current credentials for secret %s" % arn)
conn.close()
# get hostname of existing user
try:
with conn.cursor() as cur:
cur.execute("SELECT CURRENT_USER()")
current_user_fullname = cur.fetchone()[0]
user_hostname = current_user_fullname.split("@")[1]
logger.info("User hostname detected: [%s]", user_hostname)
finally:
conn.close()

# Use the master arn from the current secret to fetch master secret contents
master_arn = current_dict['masterarn']
Expand All @@ -200,41 +208,104 @@ def set_secret(service_client, arn, token):
# Now set the password to the pending password
try:
with conn.cursor() as cur:
cur.execute("SELECT User FROM mysql.user WHERE User = %s", pending_dict['username'])
cur.execute(
query="SELECT User FROM mysql.user WHERE User = %s AND Host = %s",
args=(pending_dict["username"], user_hostname),
)
# Create the user if it does not exist
if cur.rowcount == 0:
cur.execute("CREATE USER %s IDENTIFIED BY %s", (pending_dict['username'], pending_dict['password']))
cur.execute(
query="CREATE USER %s@%s IDENTIFIED BY %s",
args=(
pending_dict["username"],
user_hostname,
pending_dict["password"],
),
)

# Copy grants to the new user
cur.execute("SHOW GRANTS FOR %s", current_dict['username'])
cur.execute(
query="SHOW GRANTS FOR %s@%s",
args=(current_dict["username"], user_hostname),
)
for row in cur.fetchall():
grant = row[0].split(' TO ')
new_grant_escaped = grant[0].replace('%', '%%') # % is a special character in Python format strings.
cur.execute(new_grant_escaped + " TO %s", (pending_dict['username'],))
grant = row[0].split(" TO ")
new_grant_escaped = grant[0].replace(
"%", "%%"
) # % is a special character in Python format strings.
cur.execute(
query=new_grant_escaped + " TO %s@%s",
args=(pending_dict["username"], user_hostname),
)

# Get the version of MySQL
cur.execute("SELECT VERSION()")
ver = cur.fetchone()[0]

# Copy TLS options to the new user
escaped_encryption_statement = get_escaped_encryption_statement(ver)
cur.execute("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s", current_dict['username'])
cur.execute(
query="SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s AND Host = %s",
args=(
current_dict["username"],
user_hostname,
),
)
tls_options = cur.fetchone()
ssl_type = tls_options[0]
if not ssl_type:
cur.execute(escaped_encryption_statement + " NONE", pending_dict['username'])
elif ssl_type == "ANY":
cur.execute(escaped_encryption_statement + " SSL", pending_dict['username'])
elif ssl_type == "X509":
cur.execute(escaped_encryption_statement + " X509", pending_dict['username'])
cur.execute(
query=escaped_encryption_statement + " NONE",
args=(
pending_dict["username"],
user_hostname,
),
)
elif "ANY" == ssl_type:
cur.execute(
query=escaped_encryption_statement + " SSL",
args=(
pending_dict["username"],
user_hostname,
),
)
elif "X509" == ssl_type:
cur.execute(
query=escaped_encryption_statement + " X509",
args=(
pending_dict["username"],
user_hostname,
),
)
else:
cur.execute(escaped_encryption_statement + " CIPHER %s AND ISSUER %s AND SUBJECT %s", (pending_dict['username'], tls_options[1], tls_options[2], tls_options[3]))
cur.execute(
query=escaped_encryption_statement
+ " CIPHER %s AND ISSUER %s AND SUBJECT %s",
args=(
pending_dict["username"],
user_hostname,
tls_options[1],
tls_options[2],
tls_options[3],
),
)

# Set the password for the user and commit
password_option = get_password_option(ver)
cur.execute("SET PASSWORD FOR %s = " + password_option, (pending_dict['username'], pending_dict['password']))
password_option = get_password_option(version=ver)
cur.execute(
query="SET PASSWORD FOR %s@%s = " + password_option,
args=(
pending_dict["username"],
user_hostname,
pending_dict["password"],
),
)
conn.commit()
logger.info("setSecret: Successfully set password for %s in MySQL DB for secret arn %s." % (pending_dict['username'], arn))
logger.info(
"setSecret: Successfully set password for %s in MySQL DB for secret arn %s.",
pending_dict["username"],
arn,
)
finally:
conn.close()

Expand Down Expand Up @@ -535,9 +606,9 @@ def get_escaped_encryption_statement(version):
"""
if version.startswith("5.6"):
return "GRANT USAGE ON *.* TO %s@'%%' REQUIRE"
return "GRANT USAGE ON *.* TO %s@%s REQUIRE"
else:
return "ALTER USER %s@'%%' REQUIRE"
return "ALTER USER %s@%s REQUIRE"


def is_rds_replica_database(replica_dict, master_dict):
Expand Down

0 comments on commit 15e1c33

Please sign in to comment.