Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: opt-in to expose successful connection logs #123

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions functions/replace-route/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,34 @@
import time
import urllib
import socket
import structlog
import orjson

import botocore
import boto3


slogger = structlog.get_logger()
# use structlog's production-ready, performant example config
# ref: https://www.structlog.org/en/stable/performance.html#example
structlog.configure(
cache_logger_on_first_use=True,
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.format_exc_info,
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.EventRenamer("message"),
structlog.processors.JSONRenderer(serializer=orjson.dumps)
],
logger_factory=structlog.BytesLoggerFactory()
)

# logger is still needed to set the level for dependencies
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.getLogger('boto3').setLevel(logging.CRITICAL)
logging.getLogger('botocore').setLevel(logging.CRITICAL)


ec2_client = boto3.client("ec2")

LIFECYCLE_KEY = "LifecycleHookName"
Expand All @@ -33,6 +50,8 @@
# Whether or not use IPv6.
DEFAULT_HAS_IPV6 = True

# Whether or not to log successful connections.
DEFAULT_LOG_SUCCESSFUL_CONNECTIONS = False

# Overrides socket.getaddrinfo to perform IPv4 lookups
# See https://github.com/chime/terraform-aws-alternat/issues/87
Expand All @@ -51,18 +70,18 @@ def get_az_and_vpc_zone_identifier(auto_scaling_group):
try:
asg_objects = autoscaling.describe_auto_scaling_groups(AutoScalingGroupNames=[auto_scaling_group])
except botocore.exceptions.ClientError as error:
logger.error("Unable to describe autoscaling groups")
slogger.error("Unable to describe autoscaling groups")
raise error

if asg_objects["AutoScalingGroups"] and len(asg_objects["AutoScalingGroups"]) > 0:
asg = asg_objects["AutoScalingGroups"][0]
logger.debug("Auto Scaling Group: %s", asg)
slogger.debug("Auto Scaling Group: %s", asg)

availability_zone = asg["AvailabilityZones"][0]
logger.debug("Availability Zone: %s", availability_zone)
slogger.debug("Availability Zone: %s", availability_zone)

vpc_zone_identifier = asg["VPCZoneIdentifier"]
logger.debug("VPC zone identifier: %s", vpc_zone_identifier)
slogger.debug("VPC zone identifier: %s", vpc_zone_identifier)

return availability_zone, vpc_zone_identifier

Expand All @@ -73,18 +92,18 @@ def get_vpc_id(route_table):
try:
route_tables = ec2_client.describe_route_tables(RouteTableIds=[route_table])
except botocore.exceptions.ClientError as error:
logger.error("Unable to get vpc id")
slogger.error("Unable to get vpc id")
raise error
if "RouteTables" in route_tables and len(route_tables["RouteTables"]) == 1:
vpc_id = route_tables["RouteTables"][0]["VpcId"]
logger.debug("VPC ID: %s", vpc_id)
slogger.debug("VPC ID: %s", vpc_id)
return vpc_id


def get_nat_gateway_id(vpc_id, subnet_id):
nat_gateway_id = os.getenv("NAT_GATEWAY_ID")
if nat_gateway_id:
logger.info("Using NAT_GATEWAY_ID env. variable (%s)", nat_gateway_id)
slogger.info("Using NAT_GATEWAY_ID env. variable (%s)", nat_gateway_id)
return nat_gateway_id

try:
Expand All @@ -101,15 +120,15 @@ def get_nat_gateway_id(vpc_id, subnet_id):
]
)
except botocore.exceptions.ClientError as error:
logger.error("Unable to describe nat gateway")
slogger.error("Unable to describe nat gateway")
raise error

logger.debug("NAT Gateways: %s", nat_gateways)
slogger.debug("NAT Gateways: %s", nat_gateways)
if len(nat_gateways.get("NatGateways")) < 1:
raise MissingNatGatewayError(nat_gateways)

nat_gateway_id = nat_gateways['NatGateways'][0]["NatGatewayId"]
logger.debug("NAT Gateway ID: %s", nat_gateway_id)
slogger.debug("NAT Gateway ID: %s", nat_gateway_id)
return nat_gateway_id


Expand All @@ -120,10 +139,10 @@ def replace_route(route_table_id, nat_gateway_id):
"RouteTableId": route_table_id
}
try:
logger.info("Replacing existing route %s for route table %s", route_table_id, new_route_table)
slogger.info("Replacing existing route %s for route table %s", route_table_id, new_route_table)
ec2_client.replace_route(**new_route_table)
except botocore.exceptions.ClientError as error:
logger.error("Unable to replace route")
slogger.error("Unable to replace route")
raise error


Expand All @@ -133,22 +152,27 @@ def check_connection(check_urls):
If all fail, replaces the route table to point at a standby NAT Gateway and
return failure.
"""
log_successful_connections = get_env_bool("LOG_SUCCESSFUL_CONNECTIONS", DEFAULT_LOG_SUCCESSFUL_CONNECTIONS)

for url in check_urls:
try:
req = urllib.request.Request(url)
req.add_header('User-Agent', 'alternat/1.0')
urllib.request.urlopen(req, timeout=REQUEST_TIMEOUT)
logger.debug("Successfully connected to %s", url)
if log_successful_connections:
slogger.info("Successfully connected to %s", url)
else:
slogger.debug("Successfully connected to %s", url)
return True
except urllib.error.HTTPError as error:
logger.warning("Response error from %s: %s, treating as success", url, error)
slogger.warning("Response error from %s: %s, treating as success", url, error)
return True
except urllib.error.URLError as error:
logger.error("error connecting to %s: %s", url, error)
slogger.error("error connecting to %s: %s", url, error)
except socket.timeout as error:
logger.error("timeout error connecting to %s: %s", url, error)
slogger.error("timeout error connecting to %s: %s", url, error)

logger.warning("Failed connectivity tests! Replacing route")
slogger.warning("Failed connectivity tests! Replacing route")

public_subnet_id = os.getenv("PUBLIC_SUBNET_ID")
if not public_subnet_id:
Expand All @@ -163,20 +187,20 @@ def check_connection(check_urls):

for rtb in route_tables:
replace_route(rtb, nat_gateway_id)
logger.info("Route replacement succeeded")
slogger.info("Route replacement succeeded")
return False


def connectivity_test_handler(event, context):
if not isinstance(event, dict):
logger.error(f"Unknown event: {event}")
slogger.error("Unknown event: %s", {event})
return

if event.get("source") != "aws.events":
logger.error(f"Unable to handle unknown event type: {json.dumps(event)}")
slogger.error("Unable to handle unknown event type: %s", json.dumps(event))
raise UnknownEventTypeError

logger.debug("Starting NAT instance connectivity test")
slogger.debug("Starting NAT instance connectivity test")

check_interval = int(os.getenv("CONNECTIVITY_CHECK_INTERVAL", DEFAULT_CONNECTIVITY_CHECK_INTERVAL))
check_urls = "CHECK_URLS" in os.environ and os.getenv("CHECK_URLS").split(",") or DEFAULT_CHECK_URLS
Expand Down Expand Up @@ -209,10 +233,10 @@ def handler(event, _):
if LIFECYCLE_KEY in message and ASG_KEY in message:
asg = message[ASG_KEY]
else:
logger.error("Failed to find lifecycle message to parse")
slogger.error("Failed to find lifecycle message to parse")
raise LifecycleMessageError
except Exception as error:
logger.error("Error: %s", error)
slogger.error(error)
raise error

availability_zone, vpc_zone_identifier = get_az_and_vpc_zone_identifier(asg)
Expand All @@ -227,7 +251,7 @@ def handler(event, _):

for rtb in route_tables:
replace_route(rtb, nat_gateway_id)
logger.info("Route replacement succeeded")
slogger.info("Route replacement succeeded")


class UnknownEventTypeError(Exception): pass
Expand Down
2 changes: 2 additions & 0 deletions functions/replace-route/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
boto3==1.34.90
structlog==24.4.0
orjson==3.10.14
2 changes: 2 additions & 0 deletions lambda.tf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ locals {
: replace(upper(obj.az), "-", "_") => join(",", obj.route_table_ids)
}
has_ipv6_env_var = { "HAS_IPV6" = var.lambda_has_ipv6 }
log_successful_connections_env_var = { "LOG_SUCCESSFUL_CONNECTIONS" = var.log_successful_connections }
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not incredibly versed in base Terraform, being a CDKTF user myself. I am looking to make sure I am adding this variable right. Any and all feedback on how to achieve this is welcome!

lambda_runtime = "python3.12"
}

Expand Down Expand Up @@ -161,6 +162,7 @@ resource "aws_lambda_function" "alternat_connectivity_tester" {
NAT_GATEWAY_ID = var.nat_gateway_id,
},
local.has_ipv6_env_var,
local.log_successful_connections_env_var
var.lambda_environment_variables,
)
}
Expand Down
6 changes: 6 additions & 0 deletions variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ variable "lifecycle_heartbeat_timeout" {
default = 180
}

variable "log_successful_connections" {
description = "Logs successful connection events during connection checks. This will increase the number of logs produced by the Lambda function in proportion to the number of checks each time the Lambda runs."
type = bool
default = false
}

variable "max_instance_lifetime" {
description = "Max instance life in seconds. Defaults to 14 days. Set to 0 to disable."
type = number
Expand Down