Skip to content

Commit ea1fcae

Browse files
author
liwang
committed
ZOOKEEPER-5015: Add admin server command to shed client connections by percentage
Author: Li Wang<liwang@apple.com>
1 parent dc2767f commit ea1fcae

File tree

6 files changed

+561
-3
lines changed

6 files changed

+561
-3
lines changed

zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxn.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ public enum DisconnectReason {
9898
AUTH_PROVIDER_NOT_FOUND("auth provider not found"),
9999
FAILED_HANDSHAKE("Unsuccessful handshake"),
100100
CLIENT_RATE_LIMIT("Client hits rate limiting threshold"),
101-
CLIENT_CNX_LIMIT("Client hits connection limiting threshold");
101+
CLIENT_CNX_LIMIT("Client hits connection limiting threshold"),
102+
SHED_CONNECTIONS_COMMAND("shed_connections_command");
102103

103104
String disconnectReason;
104105

zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxnFactory.java

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.Map;
2727
import java.util.Set;
2828
import java.util.concurrent.ConcurrentHashMap;
29+
import java.util.concurrent.ThreadLocalRandom;
2930
import java.util.function.Supplier;
3031
import javax.management.JMException;
3132
import javax.security.auth.callback.CallbackHandler;
@@ -160,15 +161,65 @@ public final void setZooKeeperServer(ZooKeeperServer zks) {
160161

161162
public abstract void closeAll(ServerCnxn.DisconnectReason reason);
162163

164+
/**
165+
* Attempts to shed approximately the specified percentage of connections.
166+
*
167+
* @param percentage [0-100] percentage of connections to shed
168+
* @return actual number of connections successfully closed (may vary due to randomness)
169+
* @throws IllegalArgumentException if percentage not in [0, 100]
170+
*/
171+
public int shedConnections(final int percentage) {
172+
if (percentage < 0 || percentage > 100) {
173+
throw new IllegalArgumentException("percentage must be between 0 and 100, got: " + percentage);
174+
}
175+
176+
final int totalConnections = cnxns.size();
177+
if (percentage == 0 || totalConnections == 0) {
178+
return 0;
179+
}
180+
181+
int actualShedCount = 0;
182+
// For 100%, close all connections deterministically
183+
if (percentage == 100) {
184+
for (final ServerCnxn cnxn : cnxns) {
185+
try {
186+
cnxn.close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);
187+
actualShedCount++;
188+
} catch (final Exception e) {
189+
LOG.warn("Failed to close connection for session 0x{}: {}",
190+
Long.toHexString(cnxn.getSessionId()), e.getMessage());
191+
}
192+
}
193+
} else {
194+
// For other percentages, use probabilistic approach
195+
final ThreadLocalRandom random = ThreadLocalRandom.current();
196+
final double probability = percentage / 100.0;
197+
198+
for (final ServerCnxn cnxn : cnxns) {
199+
if (random.nextDouble() < probability) {
200+
try {
201+
cnxn.close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);
202+
actualShedCount++;
203+
} catch (final Exception e) {
204+
LOG.warn("Failed to close connection for session 0x{}: {}",
205+
Long.toHexString(cnxn.getSessionId()), e.getMessage());
206+
}
207+
}
208+
}
209+
}
210+
LOG.info("Shed {} out of {} connections ({}%)", actualShedCount, totalConnections, percentage);
211+
return actualShedCount;
212+
}
213+
163214
public static ServerCnxnFactory createFactory() throws IOException {
164215
String serverCnxnFactoryName = System.getProperty(ZOOKEEPER_SERVER_CNXN_FACTORY);
165216
if (serverCnxnFactoryName == null) {
166217
serverCnxnFactoryName = NIOServerCnxnFactory.class.getName();
167218
}
168219
try {
169220
ServerCnxnFactory serverCnxnFactory = (ServerCnxnFactory) Class.forName(serverCnxnFactoryName)
170-
.getDeclaredConstructor()
171-
.newInstance();
221+
.getDeclaredConstructor()
222+
.newInstance();
172223
LOG.info("Using {} as server connection factory", serverCnxnFactoryName);
173224
return serverCnxnFactory;
174225
} catch (Exception e) {

zookeeper-server/src/main/java/org/apache/zookeeper/server/admin/Commands.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020

2121
import static org.apache.zookeeper.server.persistence.FileSnap.SNAPSHOT_FILE_PREFIX;
2222
import com.fasterxml.jackson.annotation.JsonProperty;
23+
import com.fasterxml.jackson.databind.JsonNode;
24+
import com.fasterxml.jackson.databind.ObjectMapper;
2325
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
2426
import java.io.File;
2527
import java.io.FileInputStream;
28+
import java.io.IOException;
2629
import java.io.InputStream;
2730
import java.net.InetSocketAddress;
2831
import java.nio.charset.StandardCharsets;
@@ -292,6 +295,7 @@ public static Command getCommand(String cmdName) {
292295
registerCommand(new RestoreCommand());
293296
registerCommand(new RuokCommand());
294297
registerCommand(new SetTraceMaskCommand());
298+
registerCommand(new ShedConnectionsCommand());
295299
registerCommand(new SnapshotCommand());
296300
registerCommand(new SrvrCommand());
297301
registerCommand(new StatCommand());
@@ -863,6 +867,77 @@ public CommandResponse runGet(ZooKeeperServer zkServer, Map<String, String> kwar
863867

864868
}
865869

870+
/**
871+
* Attempts to shed approximately the specified percentage of connections.
872+
* <p>
873+
* Request: JSON input stream containing the following required field:
874+
* - "percentage": Integer [0-100] - percentage of connections to attempt shedding
875+
* value must be between 0 (no connections) and 100 (all connections).
876+
* <p>
877+
* Response: JSON output stream containing:
878+
* - "connections_shed": Integer - actual number of connections successfully closed
879+
* may vary due to randomness.
880+
* - "percentage_requested": Integer - the percentage that was requested
881+
*/
882+
public static class ShedConnectionsCommand extends PostCommand {
883+
private static final String FIELD_PERCENTAGE = "percentage";
884+
885+
public ShedConnectionsCommand() {
886+
super(Arrays.asList("shed_connections", "shed"), true, new AuthRequest(ZooDefs.Perms.ALL, ROOT_PATH));
887+
}
888+
889+
@Override
890+
public CommandResponse runPost(final ZooKeeperServer zkServer, final InputStream inputStream) {
891+
final CommandResponse response = initializeResponse();
892+
893+
if (inputStream == null) {
894+
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
895+
response.put("error", "Request body is required");
896+
return response;
897+
}
898+
899+
try {
900+
final ObjectMapper mapper = new ObjectMapper();
901+
final JsonNode jsonNode = mapper.readTree(inputStream);
902+
903+
if (!jsonNode.has(FIELD_PERCENTAGE)) {
904+
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
905+
response.put("error", "Missing required field: " + FIELD_PERCENTAGE);
906+
return response;
907+
}
908+
909+
final int percentage = jsonNode.get(FIELD_PERCENTAGE).asInt();
910+
if (percentage < 0 || percentage > 100) {
911+
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
912+
response.put("error", "Percentage must be between 0 and 100");
913+
return response;
914+
}
915+
916+
final ServerCnxnFactory factory = zkServer.getServerCnxnFactory();
917+
final ServerCnxnFactory secureFactory = zkServer.getSecureServerCnxnFactory();
918+
919+
int connectionsShed = 0;
920+
if (percentage > 0) {
921+
if (factory != null) {
922+
connectionsShed += factory.shedConnections(percentage);
923+
}
924+
if (secureFactory != null) {
925+
connectionsShed += secureFactory.shedConnections(percentage);
926+
}
927+
}
928+
929+
response.put("connections_shed", connectionsShed);
930+
response.put("percentage_requested", percentage);
931+
932+
LOG.info("Shed {} connections ({}%)", connectionsShed, percentage);
933+
} catch (final IOException e) {
934+
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
935+
response.put("error", "Invalid JSON or failed to read request body: " + e.getMessage());
936+
}
937+
return response;
938+
}
939+
}
940+
866941
/**
867942
* Same as SrvrCommand but has extra "connections" entry.
868943
*/
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.zookeeper.server;
20+
21+
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
import static org.junit.jupiter.api.Assertions.assertThrows;
23+
import static org.mockito.Mockito.doThrow;
24+
import static org.mockito.Mockito.mock;
25+
import static org.mockito.Mockito.mockingDetails;
26+
import java.util.Arrays;
27+
import org.junit.jupiter.api.AfterEach;
28+
import org.junit.jupiter.params.ParameterizedTest;
29+
import org.junit.jupiter.params.provider.EnumSource;
30+
31+
public class ServerCnxnFactoryTest {
32+
public enum FactoryType {
33+
NIO, NETTY
34+
}
35+
36+
private ServerCnxnFactory factory;
37+
38+
@AfterEach
39+
public void tearDown() {
40+
if (factory != null) {
41+
try {
42+
factory.shutdown();
43+
} catch (Exception e) {
44+
// Ignore all shutdown exceptions in tests since factory may not be fully initialized
45+
}
46+
}
47+
}
48+
49+
@ParameterizedTest
50+
@EnumSource(FactoryType.class)
51+
public void testShedConnections_InvalidPercentage(final FactoryType factoryType) {
52+
factory = createFactory(factoryType);
53+
assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(-1));
54+
assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(101));
55+
}
56+
57+
@ParameterizedTest
58+
@EnumSource(FactoryType.class)
59+
public void testShedConnections_ValidPercentages(final FactoryType factoryType) {
60+
factory = createFactory(factoryType);
61+
62+
assertEquals(0, factory.shedConnections(0));
63+
assertEquals(0, factory.shedConnections(50));
64+
assertEquals(0, factory.shedConnections(100));
65+
}
66+
67+
@ParameterizedTest
68+
@EnumSource(FactoryType.class)
69+
public void testShedConnections_DeterministicBehavior(final FactoryType factoryType) {
70+
factory = createFactory(factoryType);
71+
72+
// Create 4 mock connections for testing deterministic edge cases
73+
final ServerCnxn[] mockCnxns = new ServerCnxn[4];
74+
for (int i = 0; i < 4; i++) {
75+
mockCnxns[i] = mock(ServerCnxn.class);
76+
factory.cnxns.add(mockCnxns[i]);
77+
}
78+
79+
// Test 0% shedding - should shed exactly 0 connections (deterministic)
80+
int shedCount = factory.shedConnections(0);
81+
assertEquals(0, shedCount, "0% shedding should shed exactly 0 connections");
82+
83+
// Verify no connections were actually closed
84+
int actualClosedCount = countConnectionsShed(mockCnxns);
85+
assertEquals(0, actualClosedCount, "No connections should be closed for 0% shedding");
86+
87+
// Test 100% shedding - should shed exactly all connections (deterministic)
88+
shedCount = factory.shedConnections(100);
89+
assertEquals(4, shedCount, "100% shedding should shed exactly all 4 connections");
90+
91+
// Verify all connections were actually closed with correct reason
92+
actualClosedCount = countConnectionsShed(mockCnxns);
93+
assertEquals(4, actualClosedCount, "All 4 connections should be closed for 100% shedding");
94+
}
95+
96+
@ParameterizedTest
97+
@EnumSource(FactoryType.class)
98+
public void testShedConnections_SmallPercentageRoundsToZero(final FactoryType factoryType) {
99+
factory = createFactory(factoryType);
100+
101+
// Add single mock connection
102+
final ServerCnxn mockCnxn = mock(ServerCnxn.class);
103+
factory.cnxns.add(mockCnxn);
104+
105+
// small percentage rounds to 0
106+
assertEquals(0, factory.shedConnections(1), "1% of 1 connection should round to 0");
107+
}
108+
109+
@ParameterizedTest
110+
@EnumSource(FactoryType.class)
111+
public void testShedConnections_ErrorHandling(final FactoryType factoryType) {
112+
factory = createFactory(factoryType);
113+
114+
// Create mock connections where one will fail to close
115+
final ServerCnxn[] mockCnxns = new ServerCnxn[4];
116+
for (int i = 0; i < 4; i++) {
117+
mockCnxns[i] = mock(ServerCnxn.class);
118+
factory.cnxns.add(mockCnxns[i]);
119+
}
120+
121+
// Make the second connection throw an exception when closed
122+
doThrow(new RuntimeException("Connection close failed"))
123+
.when(mockCnxns[1]).close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);
124+
125+
// Test 100% shedding to ensure error handling works deterministically
126+
final int shedCount = factory.shedConnections(100);
127+
128+
// Since one connection throws an exception, only 3 should be successfully closed
129+
assertEquals(3, shedCount, "Should successfully close 3 connections, 1 should fail");
130+
int actualClosedCount = countConnectionsShed(mockCnxns);
131+
assertEquals(4, actualClosedCount, "All 4 connections should have close() called, even if one throws exception");
132+
}
133+
134+
private ServerCnxnFactory createFactory(final FactoryType type) {
135+
switch (type) {
136+
case NIO:
137+
return new NIOServerCnxnFactory();
138+
case NETTY:
139+
return new NettyServerCnxnFactory();
140+
default:
141+
throw new IllegalArgumentException("Unknown factory type: " + type);
142+
}
143+
}
144+
145+
private int countConnectionsShed(final ServerCnxn[] connections) {
146+
return (int) Arrays.stream(connections)
147+
.filter(cnxn -> mockingDetails(cnxn).getInvocations().stream()
148+
.anyMatch(invocation ->
149+
invocation.getMethod().getName().equals("close")
150+
&& invocation.getArguments().length == 1
151+
&& invocation.getArguments()[0].equals(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND)
152+
))
153+
.count();
154+
}
155+
}
156+

zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,16 @@ public void testStatCommandSecureOnly() {
359359
assertThat(response.toMap().containsKey("secure_connections"), is(true));
360360
}
361361

362+
@Test
363+
public void testShedConnections() throws IOException, InterruptedException {
364+
final Map<String, String> kwargs = new HashMap<>();
365+
final InputStream inputStream = new ByteArrayInputStream("{\"percentage\": 25}".getBytes());
366+
final String authInfo = CommandAuthTest.buildAuthorizationForDigest();
367+
testCommand("shed_connections", kwargs, inputStream, authInfo, new HashMap<>(), HttpServletResponse.SC_OK,
368+
new Field("percentage_requested", Integer.class),
369+
new Field("connections_shed", Integer.class));
370+
}
371+
362372
private void testSnapshot(final boolean streaming) throws IOException, InterruptedException {
363373
System.setProperty(ADMIN_SNAPSHOT_ENABLED, "true");
364374
System.setProperty(ADMIN_RATE_LIMITER_INTERVAL, "0");

0 commit comments

Comments
 (0)