Skip to content

Commit a81ac2c

Browse files
Add ratelimiter framework for grpc server interceptor
1 parent 65e4f61 commit a81ac2c

File tree

10 files changed

+523
-0
lines changed

10 files changed

+523
-0
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
plugins {
2+
`java-library`
3+
jacoco
4+
id("org.hypertrace.publish-plugin")
5+
id("org.hypertrace.jacoco-report-plugin")
6+
}
7+
8+
dependencies {
9+
10+
api(platform("io.grpc:grpc-bom:1.68.3"))
11+
api("io.grpc:grpc-api")
12+
api(project(":grpc-context-utils"))
13+
14+
implementation("org.slf4j:slf4j-api:1.7.36")
15+
implementation("com.google.guava:guava:32.0.1-jre")
16+
implementation("com.bucket4j:bucket4j-core:8.7.0")
17+
18+
annotationProcessor("org.projectlombok:lombok:1.18.24")
19+
compileOnly("org.projectlombok:lombok:1.18.24")
20+
21+
testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
22+
testImplementation("org.mockito:mockito-core:5.8.0")
23+
testImplementation("org.mockito:mockito-junit-jupiter:5.8.0")
24+
}
25+
26+
tasks.test {
27+
useJUnitPlatform()
28+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package org.hypertrace.ratelimiter.grpcutils;
2+
3+
public interface RateLimiter {
4+
default boolean tryAcquire(String key, RateLimiterConfiguration.RateLimit rateLimit) {
5+
return tryAcquire(key, 1, rateLimit);
6+
} // default single token
7+
8+
boolean tryAcquire(
9+
String key, int permits, RateLimiterConfiguration.RateLimit rateLimit); // new: batch tokens
10+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.hypertrace.ratelimiter.grpcutils;
2+
3+
import java.util.Map;
4+
import java.util.function.BiFunction;
5+
import lombok.Builder;
6+
import lombok.Value;
7+
import org.hypertrace.core.grpcutils.context.RequestContext;
8+
9+
@Value
10+
@Builder
11+
public class RateLimiterConfiguration {
12+
boolean enabled;
13+
String method;
14+
// Attributes to match like tenant_id -> traceable
15+
Map<String, String> matchAttributes;
16+
17+
// Extract attributes from gRPC request
18+
BiFunction<RequestContext, Object, Map<String, String>> attributeExtractor;
19+
20+
// Token cost evaluator (can be static 1 or dynamic based on message)
21+
@Builder.Default BiFunction<RequestContext, Object, Integer> tokenCostFunction = (ctx, req) -> 1;
22+
RateLimit rateLimit;
23+
24+
@Value
25+
@Builder
26+
public static class RateLimit {
27+
int tokens;
28+
int refreshPeriodSeconds;
29+
}
30+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package org.hypertrace.ratelimiter.grpcutils;
2+
3+
public interface RateLimiterFactory {
4+
RateLimiter getRateLimiter(RateLimiterConfiguration rateLimiterConfiguration);
5+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.hypertrace.ratelimiter.grpcutils;
2+
3+
import org.hypertrace.ratelimiter.grpcutils.bucket4j.Bucket4jRateLimiterFactory;
4+
5+
public final class RateLimiterFactoryProvider {
6+
7+
private RateLimiterFactoryProvider() {
8+
// Prevent instantiation
9+
}
10+
11+
public static Bucket4jRateLimiterFactory bucket4j() {
12+
return new Bucket4jRateLimiterFactory();
13+
}
14+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package org.hypertrace.ratelimiter.grpcutils;
2+
3+
import io.grpc.ForwardingServerCallListener;
4+
import io.grpc.Metadata;
5+
import io.grpc.ServerCall;
6+
import io.grpc.ServerCallHandler;
7+
import io.grpc.ServerInterceptor;
8+
import io.grpc.Status;
9+
import java.util.List;
10+
import java.util.Map;
11+
import java.util.Objects;
12+
import java.util.stream.Collectors;
13+
import org.hypertrace.core.grpcutils.context.RequestContext;
14+
15+
public class RateLimiterInterceptor implements ServerInterceptor {
16+
17+
private final List<RateLimiterConfiguration>
18+
rateLimitConfigs; // Provided via config or dynamic update
19+
private final RateLimiterFactory rateLimiterFactory;
20+
21+
public RateLimiterInterceptor(
22+
List<RateLimiterConfiguration> rateLimitConfigs, RateLimiterFactory factory) {
23+
this.rateLimitConfigs = rateLimitConfigs;
24+
this.rateLimiterFactory = factory;
25+
}
26+
27+
@Override
28+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
29+
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
30+
31+
String method = call.getMethodDescriptor().getFullMethodName();
32+
33+
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(
34+
next.startCall(call, headers)) {
35+
@Override
36+
public void onMessage(ReqT message) {
37+
RequestContext requestContext = RequestContext.fromMetadata(headers);
38+
for (RateLimiterConfiguration config : rateLimitConfigs) {
39+
if (!config.getMethod().equals(method)) continue;
40+
41+
Map<String, String> attributes =
42+
config.getAttributeExtractor().apply(requestContext, message);
43+
44+
if (!matches(config.getMatchAttributes(), attributes)) continue;
45+
int tokens = config.getTokenCostFunction().apply(requestContext, message);
46+
String key = buildRateLimitKey(method, config.getMatchAttributes(), attributes);
47+
boolean allowed =
48+
rateLimiterFactory
49+
.getRateLimiter(config)
50+
.tryAcquire(key, tokens, config.getRateLimit());
51+
if (!allowed) {
52+
call.close(Status.RESOURCE_EXHAUSTED.withDescription("Rate limit exceeded"), headers);
53+
return;
54+
}
55+
}
56+
super.onMessage(message);
57+
}
58+
};
59+
}
60+
61+
private boolean matches(Map<String, String> match, Map<String, String> actual) {
62+
return match.entrySet().stream()
63+
.allMatch(e -> Objects.equals(actual.get(e.getKey()), e.getValue()));
64+
}
65+
66+
private String buildRateLimitKey(
67+
String method, Map<String, String> keys, Map<String, String> attrs) {
68+
return method
69+
+ "::"
70+
+ keys.keySet().stream()
71+
.map(k -> attrs.getOrDefault(k, "null"))
72+
.collect(Collectors.joining(":"));
73+
}
74+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.hypertrace.ratelimiter.grpcutils.bucket4j;
2+
3+
import com.google.common.cache.Cache;
4+
import com.google.common.cache.CacheBuilder;
5+
import io.github.bucket4j.Bucket;
6+
import io.grpc.Status;
7+
import java.time.Duration;
8+
import java.util.concurrent.ExecutionException;
9+
import org.hypertrace.ratelimiter.grpcutils.RateLimiter;
10+
import org.hypertrace.ratelimiter.grpcutils.RateLimiterConfiguration;
11+
import org.hypertrace.ratelimiter.grpcutils.RateLimiterFactory;
12+
13+
public class Bucket4jRateLimiterFactory implements RateLimiterFactory {
14+
15+
private final Cache<String, Bucket> limiterCache =
16+
CacheBuilder.newBuilder().maximumSize(10_000).build();
17+
18+
@Override
19+
public RateLimiter getRateLimiter(RateLimiterConfiguration rule) {
20+
return (key, tokens, limit) -> {
21+
try {
22+
Bucket bucket = limiterCache.get(key, () -> createBucket(limit));
23+
return bucket.tryConsume(tokens);
24+
} catch (ExecutionException e) {
25+
throw Status.INTERNAL
26+
.withDescription("Failed to create rate limiter bucket for key: " + key)
27+
.withCause(e)
28+
.asRuntimeException();
29+
}
30+
};
31+
}
32+
33+
private Bucket createBucket(RateLimiterConfiguration.RateLimit limit) {
34+
return Bucket.builder()
35+
.addLimit(
36+
bandwidth ->
37+
bandwidth
38+
.capacity(limit.getTokens())
39+
.refillGreedy(
40+
limit.getTokens(), Duration.ofSeconds(limit.getRefreshPeriodSeconds())))
41+
.build();
42+
}
43+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package org.hypertrace.ratelimiter.grpcutils;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import java.util.Map;
6+
import java.util.function.BiFunction;
7+
import org.hypertrace.core.grpcutils.context.RequestContext;
8+
import org.junit.jupiter.api.Test;
9+
10+
class RateLimiterConfigurationTest {
11+
12+
@Test
13+
void testDefaultValues() {
14+
// Build a minimal RateLimiterConfiguration using only required fields
15+
RateLimiterConfiguration configuration =
16+
RateLimiterConfiguration.builder()
17+
.method("testMethod")
18+
.rateLimit(
19+
RateLimiterConfiguration.RateLimit.builder()
20+
.tokens(10)
21+
.refreshPeriodSeconds(30)
22+
.build())
23+
.build();
24+
25+
// Assertions for default values
26+
assertFalse(configuration.isEnabled()); // `enabled` defaults to true
27+
assertEquals("testMethod", configuration.getMethod());
28+
assertNull(configuration.getMatchAttributes()); // `matchAttributes` defaults to null
29+
assertNotNull(configuration.getTokenCostFunction()); // Ensure tokenCostFunction is initialized
30+
assertEquals(
31+
1, configuration.getTokenCostFunction().apply(null, null)); // Default token cost value
32+
}
33+
34+
@Test
35+
void testCustomMatchAttributes() {
36+
// Create matchAttributes map and build configuration
37+
Map<String, String> matchAttributes = Map.of("tenant_id", "traceable");
38+
RateLimiterConfiguration configuration =
39+
RateLimiterConfiguration.builder()
40+
.method("testMethod")
41+
.matchAttributes(matchAttributes)
42+
.rateLimit(
43+
RateLimiterConfiguration.RateLimit.builder()
44+
.tokens(100)
45+
.refreshPeriodSeconds(60)
46+
.build())
47+
.build();
48+
49+
// Verify matchAttributes are correctly set
50+
assertEquals(matchAttributes, configuration.getMatchAttributes());
51+
}
52+
53+
@Test
54+
void testCustomTokenCostFunction() {
55+
// Define a custom tokenCostFunction
56+
BiFunction<RequestContext, Object, Integer> customTokenCostFunction =
57+
(ctx, req) -> (req instanceof Integer) ? (Integer) req : 5;
58+
59+
// Build a configuration using the custom token cost function
60+
RateLimiterConfiguration configuration =
61+
RateLimiterConfiguration.builder()
62+
.method("computeTokenCost")
63+
.tokenCostFunction(customTokenCostFunction)
64+
.rateLimit(
65+
RateLimiterConfiguration.RateLimit.builder()
66+
.tokens(50)
67+
.refreshPeriodSeconds(15)
68+
.build())
69+
.build();
70+
71+
// Verify behavior of the custom token cost function
72+
assertEquals(10, configuration.getTokenCostFunction().apply(null, 10)); // Dynamic cost
73+
assertEquals(
74+
5, configuration.getTokenCostFunction().apply(null, "randomObject")); // Default cost
75+
}
76+
77+
@Test
78+
void testAttributeExtractor() {
79+
// Define an attributeExtractor that extracts specific attributes from the request
80+
BiFunction<RequestContext, Object, Map<String, String>> customAttributeExtractor =
81+
(ctx, req) -> Map.of("attributeKey", "attributeValue");
82+
83+
// Build the configuration with the custom attribute extractor
84+
RateLimiterConfiguration configuration =
85+
RateLimiterConfiguration.builder()
86+
.method("extractAttributes")
87+
.attributeExtractor(customAttributeExtractor)
88+
.rateLimit(
89+
RateLimiterConfiguration.RateLimit.builder()
90+
.tokens(20)
91+
.refreshPeriodSeconds(45)
92+
.build())
93+
.build();
94+
95+
// Verify the custom attribute extractor
96+
assertEquals(
97+
Map.of("attributeKey", "attributeValue"),
98+
configuration.getAttributeExtractor().apply(null, null));
99+
}
100+
101+
@Test
102+
void testRateLimitConfiguration() {
103+
// Build a simple RateLimit configuration
104+
RateLimiterConfiguration.RateLimit rateLimit =
105+
RateLimiterConfiguration.RateLimit.builder().tokens(500).refreshPeriodSeconds(300).build();
106+
107+
// Verify RateLimit configuration values
108+
assertEquals(500, rateLimit.getTokens());
109+
assertEquals(300, rateLimit.getRefreshPeriodSeconds());
110+
}
111+
112+
@Test
113+
void testFullCustomConfiguration() {
114+
// Define custom token cost function and attribute extractor
115+
BiFunction<RequestContext, Object, Integer> tokenCostFunction = (ctx, req) -> 2;
116+
BiFunction<RequestContext, Object, Map<String, String>> attributeExtractor =
117+
(ctx, req) -> Map.of("tenant_id", "12345");
118+
119+
// Build a complete custom configuration
120+
RateLimiterConfiguration configuration =
121+
RateLimiterConfiguration.builder()
122+
.method("fullCustomMethod")
123+
.enabled(false)
124+
.matchAttributes(Map.of("region", "us-west"))
125+
.attributeExtractor(attributeExtractor)
126+
.tokenCostFunction(tokenCostFunction)
127+
.rateLimit(
128+
RateLimiterConfiguration.RateLimit.builder()
129+
.tokens(1000)
130+
.refreshPeriodSeconds(60)
131+
.build())
132+
.build();
133+
134+
// Verify all custom configurations
135+
assertFalse(configuration.isEnabled()); // Verify `enabled` value
136+
assertEquals("fullCustomMethod", configuration.getMethod()); // Verify method
137+
assertEquals(
138+
Map.of("region", "us-west"), configuration.getMatchAttributes()); // Verify matchAttributes
139+
assertEquals(
140+
Map.of("tenant_id", "12345"),
141+
configuration.getAttributeExtractor().apply(null, null)); // Custom extractor
142+
assertEquals(2, configuration.getTokenCostFunction().apply(null, null)); // Custom token cost
143+
assertEquals(1000, configuration.getRateLimit().getTokens()); // RateLimit tokens
144+
assertEquals(
145+
60, configuration.getRateLimit().getRefreshPeriodSeconds()); // RateLimit refresh period
146+
}
147+
}

0 commit comments

Comments
 (0)