Skip to content

Commit

Permalink
Backed up socket fixes (#254)
Browse files Browse the repository at this point in the history
Adjust reconnect to abort if in the middle of a disconnect. Track the last outbound socket write and adjust ping time accordingly. Both changes only apply to MQTT311.
  • Loading branch information
TwistedTwigleg authored Apr 28, 2023
1 parent 51e35a9 commit 9b9eb47
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 16 deletions.
9 changes: 9 additions & 0 deletions include/aws/mqtt/private/client_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct aws_mqtt_client_connection {
struct aws_byte_buf client_id;
bool clean_session;
uint16_t keep_alive_time_secs;
uint64_t keep_alive_time_ns;
uint64_t ping_timeout_ns;
uint64_t operation_timeout_ns;
struct aws_string *username;
Expand Down Expand Up @@ -309,6 +310,14 @@ struct aws_mqtt_client_connection {
struct aws_http_message *handshake_request;
} websocket;

/**
* The time that the next ping task should execute at. Note that this does not mean that
* this IS when the ping task will execute, but rather that this is when the next ping
* SHOULD execute. There may be an already scheduled PING task that will elapse sooner
* than this time that has to be rescheduled.
*/
uint64_t next_ping_time;

/**
* Statistics tracking operational state
*/
Expand Down
60 changes: 53 additions & 7 deletions source/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ static void s_mqtt_client_init(
mqtt_connection_unlock_synced_data(connection);
} /* END CRITICAL SECTION */

/* intall the slot and handler */
/* install the slot and handler */
if (failed_create_slot) {

AWS_LOGF_ERROR(
Expand Down Expand Up @@ -688,11 +688,55 @@ static void s_attempt_reconnect(struct aws_task *task, void *userdata, enum aws_
struct aws_mqtt_reconnect_task *reconnect = userdata;
struct aws_mqtt_client_connection *connection = aws_atomic_load_ptr(&reconnect->connection_ptr);

/* If the task is not cancelled and a connection has not succeeded, attempt reconnect */
if (status == AWS_TASK_STATUS_RUN_READY && connection) {
/* If the task is not cancelled and a connection has not succeeded, attempt reconnect */

mqtt_connection_lock_synced_data(connection);

/**
* Check the state and if we are disconnecting (AWS_MQTT_CLIENT_STATE_DISCONNECTING) then we want to skip it
* and abort the reconnect task (or rather, just do not try to reconnect)
*/
if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT, "id=%p: Skipping reconnect: Client is trying to disconnect", (void *)connection);

/**
* There is the nasty world where the disconnect task/function is called right when we are "reconnecting" as
* our state but we have not reconnected. When this happens, the disconnect function doesn't do anything
* beyond setting the state to AWS_MQTT_CLIENT_STATE_DISCONNECTING (aws_mqtt_client_connection_disconnect),
* meaning the disconnect callback will NOT be called nor will we release memory.
* For this reason, we have to do the callback and release of the connection here otherwise the code
* will DEADLOCK forever and that is bad.
*/
bool perform_full_destroy = false;
if (!connection->slot) {
AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT,
"id=%p: Reconnect task called but client is disconnecting and has no slot. Finishing disconnect",
(void *)connection);
mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_DISCONNECTED);
perform_full_destroy = true;
}

aws_mem_release(reconnect->allocator, reconnect);
connection->reconnect_task = NULL;

/* Unlock the synced data, then potentially call the disconnect callback and release the connection */
mqtt_connection_unlock_synced_data(connection);
if (perform_full_destroy) {
MQTT_CLIENT_CALL_CALLBACK(connection, on_disconnect);
MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_closed, NULL);
aws_mqtt_client_connection_release(connection);
}
return;
}

AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT,
"id=%p: Attempting reconnect, if it fails next attempt will be in %" PRIu64 " seconds",
(void *)connection,
connection->reconnect_timeouts.current_sec);

/* Check before multiplying to avoid potential overflow */
if (connection->reconnect_timeouts.current_sec > connection->reconnect_timeouts.max_sec / 2) {
connection->reconnect_timeouts.current_sec = connection->reconnect_timeouts.max_sec;
Expand Down Expand Up @@ -1508,6 +1552,9 @@ int aws_mqtt_client_connection_connect(
if (!connection->keep_alive_time_secs) {
connection->keep_alive_time_secs = s_default_keep_alive_sec;
}
connection->keep_alive_time_ns =
aws_timestamp_convert(connection->keep_alive_time_secs, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL);

if (!connection_options->protocol_operation_timeout_ms) {
connection->operation_timeout_ns = UINT64_MAX;
} else {
Expand All @@ -1526,16 +1573,15 @@ int aws_mqtt_client_connection_connect(
}

/* Keep alive time should always be greater than the timeouts. */
if (AWS_UNLIKELY(connection->keep_alive_time_secs * (uint64_t)AWS_TIMESTAMP_NANOS <= connection->ping_timeout_ns)) {
if (AWS_UNLIKELY(connection->keep_alive_time_ns <= connection->ping_timeout_ns)) {
AWS_LOGF_FATAL(
AWS_LS_MQTT_CLIENT,
"id=%p: Illegal configuration, Connection keep alive %" PRIu64
"ns must be greater than the request timeouts %" PRIu64 "ns.",
(void *)connection,
(uint64_t)connection->keep_alive_time_secs * (uint64_t)AWS_TIMESTAMP_NANOS,
connection->keep_alive_time_ns,
connection->ping_timeout_ns);
AWS_FATAL_ASSERT(
connection->keep_alive_time_secs * (uint64_t)AWS_TIMESTAMP_NANOS > connection->ping_timeout_ns);
AWS_FATAL_ASSERT(connection->keep_alive_time_ns > connection->ping_timeout_ns);
}

AWS_LOGF_INFO(
Expand Down
54 changes: 45 additions & 9 deletions source/client_channel_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@
# pragma warning(disable : 4204)
#endif

/*******************************************************************************
* Static Helper functions
******************************************************************************/

/* Caches the socket write time for ping scheduling purposes */
static void s_update_next_ping_time(struct aws_mqtt_client_connection *connection) {
if (connection->slot != NULL && connection->slot->channel != NULL) {
aws_channel_current_clock_time(connection->slot->channel, &connection->next_ping_time);
aws_add_u64_checked(connection->next_ping_time, connection->keep_alive_time_ns, &connection->next_ping_time);
}
}

/*******************************************************************************
* Packet State Machine
******************************************************************************/
Expand All @@ -42,27 +54,42 @@ static void s_schedule_ping(struct aws_mqtt_client_connection *connection) {

uint64_t now = 0;
aws_channel_current_clock_time(connection->slot->channel, &now);
AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT, "id=%p: Scheduling PING. current timestamp is %" PRIu64, (void *)connection, now);

uint64_t schedule_time =
now + aws_timestamp_convert(connection->keep_alive_time_secs, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL);
AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT, "id=%p: Scheduling PING task. current timestamp is %" PRIu64, (void *)connection, now);

AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT,
"id=%p: The next ping will be run at timestamp %" PRIu64,
"id=%p: The next PING task will be run at timestamp %" PRIu64,
(void *)connection,
schedule_time);
aws_channel_schedule_task_future(connection->slot->channel, &connection->ping_task, schedule_time);
connection->next_ping_time);

aws_channel_schedule_task_future(connection->slot->channel, &connection->ping_task, connection->next_ping_time);
}

static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
(void)channel_task;

if (status == AWS_TASK_STATUS_RUN_READY) {
struct aws_mqtt_client_connection *connection = arg;
AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Sending PING", (void *)connection);
aws_mqtt_client_connection_ping(connection);

uint64_t now = 0;
aws_channel_current_clock_time(connection->slot->channel, &now);
if (now >= connection->next_ping_time) {
s_update_next_ping_time(connection);
AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Sending PING", (void *)connection);
aws_mqtt_client_connection_ping(connection);
} else {

AWS_LOGF_TRACE(
AWS_LS_MQTT_CLIENT,
"id=%p: Skipped sending PING because scheduled ping time %" PRIu64
" has not elapsed yet. Current time is %" PRIu64
". Rescheduling ping to run at the scheduled ping time...",
(void *)connection,
connection->next_ping_time,
now);
}
s_schedule_ping(connection);
}
}
Expand Down Expand Up @@ -175,6 +202,7 @@ static int s_packet_handler_connack(

AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: connection callback completed", (void *)connection);

s_update_next_ping_time(connection);
s_schedule_ping(connection);
return AWS_OP_SUCCESS;
}
Expand Down Expand Up @@ -793,6 +821,9 @@ static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, en
aws_mqtt_connection_statistics_change_operation_statistic_state(
request->connection, request, AWS_MQTT_OSS_NONE);

/* Since a request has complete, update the next ping time */
s_update_next_ping_time(connection);

aws_hash_table_remove(
&connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
aws_memory_pool_release(&connection->synced_data.requests_pool, request);
Expand All @@ -814,6 +845,9 @@ static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, en
aws_mqtt_connection_statistics_change_operation_statistic_state(
request->connection, request, AWS_MQTT_OSS_INCOMPLETE | AWS_MQTT_OSS_UNACKED);

/* Since a request has complete, update the next ping time */
s_update_next_ping_time(connection);

mqtt_connection_unlock_synced_data(connection);
} /* END CRITICAL SECTION */

Expand Down Expand Up @@ -1057,5 +1091,7 @@ void mqtt_disconnect_impl(struct aws_mqtt_client_connection *connection, int err
shutdown_task->error_code = error_code;
aws_channel_task_init(&shutdown_task->task, s_mqtt_disconnect_task, connection, "mqtt_disconnect");
aws_channel_schedule_task_now(connection->slot->channel, &shutdown_task->task);
} else {
AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Client currently has no slot to disconnect", (void *)connection);
}
}
4 changes: 4 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ add_test_case(mqtt_clean_session_keep_next_session)
add_test_case(mqtt_connection_publish_QoS1_timeout)
add_test_case(mqtt_connection_unsub_timeout)
add_test_case(mqtt_connection_publish_QoS1_timeout_connection_lost_reset_time)
add_test_case(mqtt_connection_ping_norm)
add_test_case(mqtt_connection_ping_no)
add_test_case(mqtt_connection_ping_basic_scenario)
add_test_case(mqtt_connection_ping_double_scenario)
add_test_case(mqtt_connection_close_callback_simple)
add_test_case(mqtt_connection_close_callback_interrupted)
add_test_case(mqtt_connection_close_callback_multi)
Expand Down
Loading

0 comments on commit 9b9eb47

Please sign in to comment.