(service/grpc) Reduce thread count

Netty and GRPC by default spawns an incredible number of threads on high-core CPUs, which amount to a fair bit of RAM usage.

Add custom executors that throttle this behavior.
This commit is contained in:
Viktor Lofgren 2024-02-27 15:04:34 +01:00
parent dbf64b0987
commit eaf836dc66
8 changed files with 62 additions and 36 deletions

View File

@ -10,7 +10,12 @@ import nu.marginalia.service.discovery.property.PartitionTraits;
import nu.marginalia.service.discovery.property.ServiceEndpoint.InstanceAddress; import nu.marginalia.service.discovery.property.ServiceEndpoint.InstanceAddress;
import nu.marginalia.service.discovery.property.ServiceKey; import nu.marginalia.service.discovery.property.ServiceKey;
import nu.marginalia.service.discovery.property.ServicePartition; import nu.marginalia.service.discovery.property.ServicePartition;
import org.jetbrains.annotations.NotNull;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function; import java.util.function.Function;
@Singleton @Singleton
@ -18,6 +23,16 @@ public class GrpcChannelPoolFactory {
private final NodeConfigurationWatcher nodeConfigurationWatcher; private final NodeConfigurationWatcher nodeConfigurationWatcher;
private final ServiceRegistryIf serviceRegistryIf; private final ServiceRegistryIf serviceRegistryIf;
private static final Executor executor = Executors.newFixedThreadPool(
Math.clamp(Runtime.getRuntime().availableProcessors() / 2, 2, 16), new ThreadFactory() {
static final AtomicInteger threadNumber = new AtomicInteger(1);
@Override
public Thread newThread(@NotNull Runnable r) {
var thread = new Thread(r, STR."gRPC-Channel-Pool[\{threadNumber.getAndIncrement()}]");
thread.setDaemon(true);
return thread;
}
});
@Inject @Inject
public GrpcChannelPoolFactory(NodeConfigurationWatcher nodeConfigurationWatcher, public GrpcChannelPoolFactory(NodeConfigurationWatcher nodeConfigurationWatcher,
@ -49,6 +64,7 @@ public class GrpcChannelPoolFactory {
var mc = ManagedChannelBuilder var mc = ManagedChannelBuilder
.forAddress(route.host(), route.port()) .forAddress(route.host(), route.port())
.executor(executor)
.usePlaintext() .usePlaintext()
.build(); .build();

View File

@ -2,6 +2,8 @@ package nu.marginalia.service.server;
import io.grpc.*; import io.grpc.*;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.channel.nio.NioEventLoopGroup;
import io.grpc.netty.shaded.io.netty.channel.socket.nio.NioServerSocketChannel;
import io.prometheus.client.Counter; import io.prometheus.client.Counter;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import nu.marginalia.mq.inbox.*; import nu.marginalia.mq.inbox.*;
@ -19,6 +21,11 @@ import spark.Spark;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
public class Service { public class Service {
private final Logger logger = LoggerFactory.getLogger(getClass()); private final Logger logger = LoggerFactory.getLogger(getClass());
@ -121,8 +128,16 @@ public class Service {
int port = params.serviceRegistry.requestPort(config.externalAddress(), new ServiceKey.Grpc<>("-", partition)); int port = params.serviceRegistry.requestPort(config.externalAddress(), new ServiceKey.Grpc<>("-", partition));
int nThreads = Math.clamp(Runtime.getRuntime().availableProcessors() / 2, 2, 8);
// Start the gRPC server // Start the gRPC server
var grpcServerBuilder = NettyServerBuilder.forAddress(new InetSocketAddress(config.bindAddress(), port)); var grpcServerBuilder = NettyServerBuilder.forAddress(new InetSocketAddress(config.bindAddress(), port))
.executor(namedExecutor("nettyExecutor", nThreads))
.workerEventLoopGroup(new NioEventLoopGroup(nThreads, namedExecutor("Worker-ELG", nThreads)))
.bossEventLoopGroup(new NioEventLoopGroup(nThreads, namedExecutor("Boss-ELG", nThreads)))
.channelType(NioServerSocketChannel.class);
for (var grpcService : grpcServices) { for (var grpcService : grpcServices) {
var svc = grpcService.bindService(); var svc = grpcService.bindService();
@ -138,6 +153,20 @@ public class Service {
} }
} }
private ExecutorService namedExecutor(String name, int limit) {
return Executors.newFixedThreadPool(
limit,
new ThreadFactory() {
static final AtomicInteger threadNumber = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
var thread = new Thread(r, STR."\{name}[\{threadNumber.getAndIncrement()}]");
thread.setDaemon(true);
return thread;
}
});
}
public Service(BaseServiceParams params, public Service(BaseServiceParams params,
ServicePartition partition, ServicePartition partition,
List<BindableService> grpcServices) { List<BindableService> grpcServices) {

View File

@ -30,7 +30,7 @@ public class FeedlotClient {
this.gson = gson; this.gson = gson;
httpClient = HttpClient.newBuilder() httpClient = HttpClient.newBuilder()
.executor(Executors.newVirtualThreadPerTaskExecutor()) .executor(Executors.newCachedThreadPool())
.connectTimeout(connectTimeout) .connectTimeout(connectTimeout)
.build(); .build();
this.requestTimeout = requestTimeout; this.requestTimeout = requestTimeout;

View File

@ -19,7 +19,7 @@ public class DomainInfoClient {
private static final Logger logger = LoggerFactory.getLogger(DomainInfoClient.class); private static final Logger logger = LoggerFactory.getLogger(DomainInfoClient.class);
private final GrpcSingleNodeChannelPool<DomainInfoAPIGrpc.DomainInfoAPIBlockingStub> channelPool; private final GrpcSingleNodeChannelPool<DomainInfoAPIGrpc.DomainInfoAPIBlockingStub> channelPool;
private final ExecutorService virtualExecutorService = Executors.newVirtualThreadPerTaskExecutor(); private final ExecutorService executor = Executors.newWorkStealingPool(8);
@Inject @Inject
public DomainInfoClient(GrpcChannelPoolFactory factory) { public DomainInfoClient(GrpcChannelPoolFactory factory) {
@ -30,21 +30,21 @@ public class DomainInfoClient {
public Future<List<SimilarDomain>> similarDomains(int domainId, int count) { public Future<List<SimilarDomain>> similarDomains(int domainId, int count) {
return channelPool.call(DomainInfoAPIGrpc.DomainInfoAPIBlockingStub::getSimilarDomains) return channelPool.call(DomainInfoAPIGrpc.DomainInfoAPIBlockingStub::getSimilarDomains)
.async(virtualExecutorService) .async(executor)
.run(DomainsProtobufCodec.DomainQueries.createRequest(domainId, count)) .run(DomainsProtobufCodec.DomainQueries.createRequest(domainId, count))
.thenApply(DomainsProtobufCodec.DomainQueries::convertResponse); .thenApply(DomainsProtobufCodec.DomainQueries::convertResponse);
} }
public Future<List<SimilarDomain>> linkedDomains(int domainId, int count) { public Future<List<SimilarDomain>> linkedDomains(int domainId, int count) {
return channelPool.call(DomainInfoAPIGrpc.DomainInfoAPIBlockingStub::getLinkingDomains) return channelPool.call(DomainInfoAPIGrpc.DomainInfoAPIBlockingStub::getLinkingDomains)
.async(virtualExecutorService) .async(executor)
.run(DomainsProtobufCodec.DomainQueries.createRequest(domainId, count)) .run(DomainsProtobufCodec.DomainQueries.createRequest(domainId, count))
.thenApply(DomainsProtobufCodec.DomainQueries::convertResponse); .thenApply(DomainsProtobufCodec.DomainQueries::convertResponse);
} }
public Future<DomainInformation> domainInformation(int domainId) { public Future<DomainInformation> domainInformation(int domainId) {
return channelPool.call(DomainInfoAPIGrpc.DomainInfoAPIBlockingStub::getDomainInfo) return channelPool.call(DomainInfoAPIGrpc.DomainInfoAPIBlockingStub::getDomainInfo)
.async(virtualExecutorService) .async(executor)
.run(DomainsProtobufCodec.DomainInfo.createRequest(domainId)) .run(DomainsProtobufCodec.DomainInfo.createRequest(domainId))
.thenApply(DomainsProtobufCodec.DomainInfo::convertResponse); .thenApply(DomainsProtobufCodec.DomainInfo::convertResponse);
} }

View File

@ -24,7 +24,8 @@ public class MathClient {
private static final Logger logger = LoggerFactory.getLogger(MathClient.class); private static final Logger logger = LoggerFactory.getLogger(MathClient.class);
private final GrpcSingleNodeChannelPool<MathApiGrpc.MathApiBlockingStub> channelPool; private final GrpcSingleNodeChannelPool<MathApiGrpc.MathApiBlockingStub> channelPool;
private final ExecutorService virtualExecutorService = Executors.newVirtualThreadPerTaskExecutor(); private final ExecutorService executor = Executors.newWorkStealingPool(8);
@Inject @Inject
public MathClient(GrpcChannelPoolFactory factory) { public MathClient(GrpcChannelPoolFactory factory) {
this.channelPool = factory.createSingle( this.channelPool = factory.createSingle(
@ -35,7 +36,7 @@ public class MathClient {
public Future<DictionaryResponse> dictionaryLookup(String word) { public Future<DictionaryResponse> dictionaryLookup(String word) {
return channelPool.call(MathApiGrpc.MathApiBlockingStub::dictionaryLookup) return channelPool.call(MathApiGrpc.MathApiBlockingStub::dictionaryLookup)
.async(virtualExecutorService) .async(executor)
.run(DictionaryLookup.createRequest(word)) .run(DictionaryLookup.createRequest(word))
.thenApply(DictionaryLookup::convertResponse); .thenApply(DictionaryLookup::convertResponse);
} }
@ -43,7 +44,7 @@ public class MathClient {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public Future<List<String>> spellCheck(String word) { public Future<List<String>> spellCheck(String word) {
return channelPool.call(MathApiGrpc.MathApiBlockingStub::spellCheck) return channelPool.call(MathApiGrpc.MathApiBlockingStub::spellCheck)
.async(virtualExecutorService) .async(executor)
.run(SpellCheck.createRequest(word)) .run(SpellCheck.createRequest(word))
.thenApply(SpellCheck::convertResponse); .thenApply(SpellCheck::convertResponse);
} }
@ -52,7 +53,7 @@ public class MathClient {
List<RpcSpellCheckRequest> requests = words.stream().map(SpellCheck::createRequest).toList(); List<RpcSpellCheckRequest> requests = words.stream().map(SpellCheck::createRequest).toList();
var future = channelPool.call(MathApiGrpc.MathApiBlockingStub::spellCheck) var future = channelPool.call(MathApiGrpc.MathApiBlockingStub::spellCheck)
.async(virtualExecutorService) .async(executor)
.runFor(requests); .runFor(requests);
try { try {
@ -70,14 +71,14 @@ public class MathClient {
public Future<String> unitConversion(String value, String from, String to) { public Future<String> unitConversion(String value, String from, String to) {
return channelPool.call(MathApiGrpc.MathApiBlockingStub::unitConversion) return channelPool.call(MathApiGrpc.MathApiBlockingStub::unitConversion)
.async(virtualExecutorService) .async(executor)
.run(UnitConversion.createRequest(from, to, value)) .run(UnitConversion.createRequest(from, to, value))
.thenApply(UnitConversion::convertResponse); .thenApply(UnitConversion::convertResponse);
} }
public Future<String> evalMath(String expression) { public Future<String> evalMath(String expression) {
return channelPool.call(MathApiGrpc.MathApiBlockingStub::evalMath) return channelPool.call(MathApiGrpc.MathApiBlockingStub::evalMath)
.async(virtualExecutorService) .async(executor)
.run(EvalMath.createRequest(expression)) .run(EvalMath.createRequest(expression))
.thenApply(EvalMath::convertResponse); .thenApply(EvalMath::convertResponse);
} }

View File

@ -22,7 +22,7 @@ import java.util.concurrent.Executors;
public class IndexClient { public class IndexClient {
private static final Logger logger = LoggerFactory.getLogger(IndexClient.class); private static final Logger logger = LoggerFactory.getLogger(IndexClient.class);
private final GrpcMultiNodeChannelPool<IndexApiGrpc.IndexApiBlockingStub> channelPool; private final GrpcMultiNodeChannelPool<IndexApiGrpc.IndexApiBlockingStub> channelPool;
private static final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor(); private static final ExecutorService executor = Executors.newFixedThreadPool(8);
@Inject @Inject
public IndexClient(GrpcChannelPoolFactory channelPoolFactory) { public IndexClient(GrpcChannelPoolFactory channelPoolFactory) {
this.channelPool = channelPoolFactory.createMulti( this.channelPool = channelPoolFactory.createMulti(

View File

@ -84,7 +84,7 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
private final String nodeName; private final String nodeName;
private final int indexValuationThreads = Integer.getInteger("index.valuationThreads", 8); private static final int indexValuationThreads = Integer.getInteger("index.valuationThreads", 8);
@Inject @Inject
public IndexGrpcService(ServiceConfiguration serviceConfiguration, public IndexGrpcService(ServiceConfiguration serviceConfiguration,
@ -227,7 +227,7 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
* and finally the best results are returned. * and finally the best results are returned.
*/ */
private class QueryExecution { private class QueryExecution {
private static final Executor workerPool = Executors.newCachedThreadPool(); private static final Executor workerPool = Executors.newWorkStealingPool(indexValuationThreads*4);
private final ArrayBlockingQueue<CombinedDocIdList> resultCandidateQueue private final ArrayBlockingQueue<CombinedDocIdList> resultCandidateQueue
= new ArrayBlockingQueue<>(8); = new ArrayBlockingQueue<>(8);

View File

@ -5,7 +5,7 @@ import gnu.trove.map.hash.TLongIntHashMap;
import nu.marginalia.api.searchquery.model.results.SearchResultItem; import nu.marginalia.api.searchquery.model.results.SearchResultItem;
public class IndexResultDomainDeduplicator { public class IndexResultDomainDeduplicator {
final TLongIntMap resultsByDomainId = CachedObjects.getMap(); final TLongIntMap resultsByDomainId = new TLongIntHashMap(2048, 0.5f, -1, 0);
final int limitByDomain; final int limitByDomain;
public IndexResultDomainDeduplicator(int limitByDomain) { public IndexResultDomainDeduplicator(int limitByDomain) {
@ -23,25 +23,5 @@ public class IndexResultDomainDeduplicator {
return resultsByDomainId.get(key); return resultsByDomainId.get(key);
} }
private static class CachedObjects {
private static final ThreadLocal<TLongIntHashMap> mapCache = ThreadLocal.withInitial(() ->
new TLongIntHashMap(2048, 0.5f, -1, 0)
);
private static TLongIntHashMap getMap() {
var ret = mapCache.get();
ret.clear();
return ret;
}
public static void clear() {
mapCache.remove();
}
}
static void clearCachedObjects() {
CachedObjects.clear();
}
} }