Skip to content

Commit

Permalink
nly initialize all the variables once to speed up test ConsumeMessage…
Browse files Browse the repository at this point in the history
…ConcurrentlyServiceTest (#8436)
  • Loading branch information
TestBoost authored Aug 1, 2024
1 parent cd23d6d commit 2ed4ba2
Showing 1 changed file with 54 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,14 @@
import org.apache.rocketmq.remoting.exception.RemotingException;
import org.apache.rocketmq.remoting.protocol.body.ConsumeStatus;
import org.apache.rocketmq.remoting.protocol.header.PullMessageRequestHeader;
import org.junit.After;
import org.junit.Before;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
Expand All @@ -70,49 +69,54 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import org.mockito.Mockito;

@RunWith(MockitoJUnitRunner.class)
public class ConsumeMessageConcurrentlyServiceTest {
private String consumerGroup;
private String topic = "FooBar";
private String brokerName = "BrokerA";
private MQClientInstance mQClientFactory;

private static String consumerGroup;

private static String topic = "FooBar";

private static String brokerName = "BrokerA";

private static MQClientInstance mQClientFactory;

@Mock
private MQClientAPIImpl mQClientAPIImpl;
private PullAPIWrapper pullAPIWrapper;
private RebalancePushImpl rebalancePushImpl;
private DefaultMQPushConsumer pushConsumer;
private static MQClientAPIImpl mQClientAPIImpl;

private static PullAPIWrapper pullAPIWrapper;

private static RebalancePushImpl rebalancePushImpl;

private static DefaultMQPushConsumer pushConsumer;

@Before
public void init() throws Exception {
@BeforeClass
public static void init() throws Exception {
mQClientAPIImpl = Mockito.mock(MQClientAPIImpl.class);
ConcurrentMap<String, MQClientInstance> factoryTable = (ConcurrentMap<String, MQClientInstance>) FieldUtils.readDeclaredField(MQClientManager.getInstance(), "factoryTable", true);
Collection<MQClientInstance> instances = factoryTable.values();
for (MQClientInstance instance : instances) {
instance.shutdown();
}
factoryTable.clear();

consumerGroup = "FooBarGroup" + System.currentTimeMillis();
pushConsumer = new DefaultMQPushConsumer(consumerGroup);
pushConsumer.setNamesrvAddr("127.0.0.1:9876");
pushConsumer.setPullInterval(60 * 1000);

pushConsumer.registerMessageListener(new MessageListenerConcurrently() {

@Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) {
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, ConsumeConcurrentlyContext context) {
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
}
});

DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl();
rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl()));
Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl");
field.setAccessible(true);
field.set(pushConsumerImpl, rebalancePushImpl);
pushConsumer.subscribe(topic, "*");

// suppress updateTopicRouteInfoFromNameServer
pushConsumer.changeInstanceNameToPID();
mQClientFactory = MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true));
Expand All @@ -121,38 +125,32 @@ public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
field.setAccessible(true);
field.set(pushConsumerImpl, mQClientFactory);
factoryTable.put(pushConsumer.buildMQClientId(), mQClientFactory);

field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl");
field.setAccessible(true);
field.set(mQClientFactory, mQClientAPIImpl);

pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false));
field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper");
field.setAccessible(true);
field.set(pushConsumerImpl, pullAPIWrapper);

pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory);
when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))).thenAnswer(new Answer<PullResult>() {

when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<PullResult>() {
@Override
public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic);
messageClientExt.setQueueId(0);
messageClientExt.setMsgId("123");
messageClientExt.setBody(new byte[] {'a'});
messageClientExt.setOffsetMsgId("234");
messageClientExt.setBornHost(new InetSocketAddress(8080));
messageClientExt.setStoreHost(new InetSocketAddress(8080));
PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.<MessageExt>singletonList(messageClientExt));
((PullCallback) mock.getArgument(4)).onSuccess(pullResult);
return pullResult;
}
});

@Override
public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic);
messageClientExt.setQueueId(0);
messageClientExt.setMsgId("123");
messageClientExt.setBody(new byte[] { 'a' });
messageClientExt.setOffsetMsgId("234");
messageClientExt.setBornHost(new InetSocketAddress(8080));
messageClientExt.setStoreHost(new InetSocketAddress(8080));
PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.<MessageExt>singletonList(messageClientExt));
((PullCallback) mock.getArgument(4)).onSuccess(pullResult);
return pullResult;
}
});
doReturn(new FindBrokerResult("127.0.0.1:10912", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean());
doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString());
Set<MessageQueue> messageQueueSet = new HashSet<>();
Expand All @@ -162,54 +160,45 @@ public PullResult answer(InvocationOnMock mock) throws Throwable {
}

@Test
public void testPullMessage_ConsumeSuccess() throws InterruptedException, RemotingException, MQBrokerException, NoSuchFieldException,Exception {
public void testPullMessage_ConsumeSuccess() throws InterruptedException, RemotingException, MQBrokerException, NoSuchFieldException, Exception {
final CountDownLatch countDownLatch = new CountDownLatch(1);
final AtomicReference<MessageExt> messageAtomic = new AtomicReference<>();
ConsumeMessageConcurrentlyService normalServie = new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {

ConsumeMessageConcurrentlyService normalServie = new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {
@Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) {
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, ConsumeConcurrentlyContext context) {
messageAtomic.set(msgs.get(0));
countDownLatch.countDown();
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
}
});
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(normalServie);

PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest());
countDownLatch.await();

Thread.sleep(1000);

ConsumeStatus stats = normalServie.getConsumerStatsManager().consumeStatus(pushConsumer.getDefaultMQPushConsumerImpl().groupName(),topic);

ConsumerStatsManager mgr = normalServie.getConsumerStatsManager();

ConsumeStatus stats = normalServie.getConsumerStatsManager().consumeStatus(pushConsumer.getDefaultMQPushConsumerImpl().groupName(), topic);
ConsumerStatsManager mgr = normalServie.getConsumerStatsManager();
Field statItmeSetField = mgr.getClass().getDeclaredField("topicAndGroupConsumeOKTPS");
statItmeSetField.setAccessible(true);

StatsItemSet itemSet = (StatsItemSet)statItmeSetField.get(mgr);
StatsItemSet itemSet = (StatsItemSet) statItmeSetField.get(mgr);
StatsItem item = itemSet.getAndCreateStatsItem(topic + "@" + pushConsumer.getDefaultMQPushConsumerImpl().groupName());

assertThat(item.getValue().sum()).isGreaterThan(0L);
MessageExt msg = messageAtomic.get();
assertThat(msg).isNotNull();
assertThat(msg.getTopic()).isEqualTo(topic);
assertThat(msg.getBody()).isEqualTo(new byte[] {'a'});
assertThat(msg.getBody()).isEqualTo(new byte[] { 'a' });
}

@After
public void terminate() {
@AfterClass
public static void terminate() {
pushConsumer.shutdown();
}

private PullRequest createPullRequest() {
private static PullRequest createPullRequest() {
PullRequest pullRequest = new PullRequest();
pullRequest.setConsumerGroup(consumerGroup);
pullRequest.setNextOffset(1024);

MessageQueue messageQueue = new MessageQueue();
messageQueue.setBrokerName(brokerName);
messageQueue.setQueueId(0);
Expand All @@ -219,12 +208,10 @@ private PullRequest createPullRequest() {
processQueue.setLocked(true);
processQueue.setLastLockTimestamp(System.currentTimeMillis());
pullRequest.setProcessQueue(processQueue);

return pullRequest;
}

private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus,
List<MessageExt> messageExtList) throws Exception {
private static PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, List<MessageExt> messageExtList) throws Exception {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
for (MessageExt messageExt : messageExtList) {
outputStream.write(MessageDecoder.encode(messageExt, false));
Expand All @@ -236,23 +223,21 @@ private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, P
public void testConsumeThreadName() throws Exception {
final CountDownLatch countDownLatch = new CountDownLatch(1);
final AtomicReference<String> consumeThreadName = new AtomicReference<>();

StringBuilder consumeGroup2 = new StringBuilder();
for (int i = 0; i < 101; i++) {
consumeGroup2.append(i).append("#");
}
pushConsumer.setConsumerGroup(consumeGroup2.toString());
ConsumeMessageConcurrentlyService normalServie = new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {
ConsumeMessageConcurrentlyService normalServie = new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {

@Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) {
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, ConsumeConcurrentlyContext context) {
consumeThreadName.set(Thread.currentThread().getName());
countDownLatch.countDown();
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
}
});
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(normalServie);

PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest());
countDownLatch.await();
Expand Down

0 comments on commit 2ed4ba2

Please sign in to comment.