Skip to content

Commit

Permalink
Merge pull request Netflix#457 from Netflix/fix-websocket-subscriptions
Browse files Browse the repository at this point in the history
Fix handling of null variables for subscriptions and properly cleanup websocket subscriptions.
  • Loading branch information
srinivasankavitha authored Jul 1, 2021
2 parents 382cb44 + 2d4c770 commit cc3abc5
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DgsWebSocketHandler(private val dgsQueryExecutor: DgsQueryExecutor) : Text
}
GQL_STOP -> {
subscriptions[session.id]?.get(id)?.cancel()
subscriptions.remove(id)
subscriptions[session.id]?.remove(id)
}
GQL_CONNECTION_TERMINATE -> {
logger.info("Terminated session " + session.id)
Expand All @@ -86,6 +86,7 @@ class DgsWebSocketHandler(private val dgsQueryExecutor: DgsQueryExecutor) : Text
private fun cleanupSubscriptionsForSession(session: WebSocketSession) {
logger.info("Cleaning up for session {}", session.id)
subscriptions[session.id]?.values?.forEach { it.cancel() }
subscriptions.remove(session.id)
sessions.remove(session)
}

Expand All @@ -96,7 +97,8 @@ class DgsWebSocketHandler(private val dgsQueryExecutor: DgsQueryExecutor) : Text
subscriptionStream.subscribe(object : Subscriber<ExecutionResult> {
override fun onSubscribe(s: Subscription) {
logger.info("Subscription started for {}", id)
subscriptions[session.id] = mutableMapOf(Pair(id, s))
subscriptions.putIfAbsent(session.id, mutableMapOf())
subscriptions[session.id]?.set(id, s)

s.request(1)
}
Expand Down Expand Up @@ -132,7 +134,7 @@ class DgsWebSocketHandler(private val dgsQueryExecutor: DgsQueryExecutor) : Text
session.sendMessage(jsonMessage)
}

subscriptions.remove(id)
subscriptions[session.id]?.remove(id)
}
})
}
Expand All @@ -149,4 +151,4 @@ const val GQL_CONNECTION_TERMINATE = "connection_terminate"

data class DataPayload(val data: Any?, val errors: List<Any>? = emptyList())
data class OperationMessage(@JsonProperty("type") val type: String, @JsonProperty("payload") val payload: Any? = null, @JsonProperty("id", required = false) val id: String? = "")
data class QueryPayload(@JsonProperty("variables") val variables: Map<String, Any> = emptyMap(), @JsonProperty("extensions") val extensions: Map<String, Any> = emptyMap(), @JsonProperty("operationName") val operationName: String?, @JsonProperty("query") val query: String)
data class QueryPayload(@JsonProperty("variables") val variables: Map<String, Any>?, @JsonProperty("extensions") val extensions: Map<String, Any> = emptyMap(), @JsonProperty("operationName") val operationName: String?, @JsonProperty("query") val query: String)
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class DgsWebsocketHandlerTest {
"query": "{ hello }",
"variables": {},
"extensions": {}
}
},
"id": "123"
}
""".trimIndent()
)
Expand All @@ -77,7 +78,21 @@ class DgsWebsocketHandlerTest {
"query": "query HELLO(${'$'}name: String){ hello(name:${'$'}name) }",
"variables": {"name": "Stranger"},
"extensions": {}
}
},
"id": "222"
}
""".trimIndent()
)

private val queryMessageWithNullVariable = TextMessage(
"""{
"type": "$GQL_START",
"payload": {
"query": "query HELLO(${'$'}name: String){ hello(name:${'$'}name) }",
"variables": null,
"extensions": {}
},
"id": "123"
}
""".trimIndent()
)
Expand All @@ -103,6 +118,18 @@ class DgsWebsocketHandlerTest {
}
}

@Test
fun testWithMultipleSubscriptionsPerSession() {
connect(session1)
start(session1, 1)
startWithVariable(session1, 1)

assertThat(dgsWebsocketHandler.sessions.size).isEqualTo(1)
assertThat(dgsWebsocketHandler.subscriptions.size).isEqualTo(1)
disconnect(session1)
assertThat(dgsWebsocketHandler.sessions.size).isEqualTo(0)
}

@Test
fun testWithQueryVariables() {
connect(session1)
Expand All @@ -116,6 +143,19 @@ class DgsWebsocketHandlerTest {
}
}

@Test
fun testWithNullQueryVariables() {
connect(session1)
startWithNullVariable(session1, 1)

disconnect(session1)

// ACK, DATA, COMPLETE
verify(exactly = 3) {
session1.sendMessage(any())
}
}

@Test
fun testWithError() {
connect(session1)
Expand All @@ -128,6 +168,19 @@ class DgsWebsocketHandlerTest {
}
}

@Test
fun testWithStop() {
connect(session1)
start(session1, 1)
stop(session1)
disconnect(session1)

// ACK, DATA, COMPLETE
verify(exactly = 3) {
session1.sendMessage(any())
}
}

private fun connect(webSocketSession: WebSocketSession) {
val currentNrOfSessions = dgsWebsocketHandler.sessions.size

Expand Down Expand Up @@ -202,11 +255,46 @@ class DgsWebsocketHandlerTest {
dgsWebsocketHandler.handleTextMessage(webSocketSession, queryMessageWithVariable)
}

private fun startWithNullVariable(webSocketSession: WebSocketSession, nrOfResults: Int) {

every { webSocketSession.isOpen } returns true

val results = (1..nrOfResults).map {
val result1 = mockkClass(ExecutionResult::class)
every { result1.getData<Any>() } returns it
result1
}

every { executionResult.getData<Publisher<ExecutionResult>>() } returns Mono.just(results).flatMapMany { Flux.fromIterable(results) }

every { dgsQueryExecutor.execute("query HELLO(\$name: String){ hello(name:\$name) }", null) } returns executionResult

dgsWebsocketHandler.handleTextMessage(webSocketSession, queryMessageWithNullVariable)
}

private fun startWithError(webSocketSession: WebSocketSession) {
every { webSocketSession.isOpen } returns true
every { executionResult.getData<Publisher<ExecutionResult>>() } returns Mono.error(RuntimeException("That's wrong!"))
every { dgsQueryExecutor.execute("{ hello }", emptyMap()) } returns executionResult

dgsWebsocketHandler.handleTextMessage(webSocketSession, queryMessage)
}

private fun stop(webSocketSession: WebSocketSession) {
val currentNrOfSessions = dgsWebsocketHandler.sessions.size
every { webSocketSession.close() } just Runs

val textMessage = TextMessage(
"""{
"type": "$GQL_STOP",
"id": "123"
}
""".trimIndent()
)

dgsWebsocketHandler.handleTextMessage(webSocketSession, textMessage)

assertThat(dgsWebsocketHandler.sessions.size).isEqualTo(currentNrOfSessions)
assertThat(dgsWebsocketHandler.subscriptions[webSocketSession.id]?.get("123")).isNull()
}
}

0 comments on commit cc3abc5

Please sign in to comment.