diff --git a/code/api/executor-api/src/main/java/nu/marginalia/executor/client/ExecutorClient.java b/code/api/executor-api/src/main/java/nu/marginalia/executor/client/ExecutorClient.java index 7f301261..658f6b37 100644 --- a/code/api/executor-api/src/main/java/nu/marginalia/executor/client/ExecutorClient.java +++ b/code/api/executor-api/src/main/java/nu/marginalia/executor/client/ExecutorClient.java @@ -4,7 +4,7 @@ import com.google.inject.Inject; import com.google.inject.Singleton; import nu.marginalia.client.AbstractDynamicClient; import nu.marginalia.client.Context; -import nu.marginalia.client.grpc.GrpcStubPool; +import nu.marginalia.client.grpc.GrpcChannelPool; import nu.marginalia.executor.api.*; import nu.marginalia.executor.api.ExecutorApiGrpc.ExecutorApiBlockingStub; import nu.marginalia.executor.model.ActorRunState; @@ -35,14 +35,14 @@ import java.util.concurrent.TimeUnit; @Singleton public class ExecutorClient extends AbstractDynamicClient { - private final GrpcStubPool stubPool; + private final GrpcChannelPool channelPool; private static final Logger logger = LoggerFactory.getLogger(ExecutorClient.class); @Inject public ExecutorClient(ServiceDescriptors descriptors, NodeConfigurationService nodeConfigurationService) { super(descriptors.forId(ServiceId.Executor), GsonFactory::get); - stubPool = new GrpcStubPool<>(ServiceId.Executor) { + channelPool = new GrpcChannelPool<>(ServiceId.Executor) { @Override public ExecutorApiBlockingStub createStub(ManagedChannel channel) { return ExecutorApiGrpc.newBlockingStub(channel); @@ -59,7 +59,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void startFsm(int node, String actorName) { - stubPool.apiForNode(node).startFsm( + channelPool.apiForNode(node).startFsm( RpcFsmName.newBuilder() .setActorName(actorName) .build() @@ -67,7 +67,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void stopFsm(int node, String actorName) { - stubPool.apiForNode(node).stopFsm( + channelPool.apiForNode(node).stopFsm( RpcFsmName.newBuilder() .setActorName(actorName) .build() @@ -75,7 +75,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void stopProcess(int node, String id) { - stubPool.apiForNode(node).stopProcess( + channelPool.apiForNode(node).stopProcess( RpcProcessId.newBuilder() .setProcessId(id) .build() @@ -83,7 +83,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void triggerCrawl(int node, FileStorageId fid) { - stubPool.apiForNode(node).triggerCrawl( + channelPool.apiForNode(node).triggerCrawl( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() @@ -91,7 +91,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void triggerRecrawl(int node, FileStorageId fid) { - stubPool.apiForNode(node).triggerRecrawl( + channelPool.apiForNode(node).triggerRecrawl( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() @@ -99,7 +99,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void triggerConvert(int node, FileStorageId fid) { - stubPool.apiForNode(node).triggerConvert( + channelPool.apiForNode(node).triggerConvert( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() @@ -107,7 +107,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void triggerConvertAndLoad(int node, FileStorageId fid) { - stubPool.apiForNode(node).triggerConvertAndLoad( + channelPool.apiForNode(node).triggerConvertAndLoad( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() @@ -115,7 +115,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void loadProcessedData(int node, List ids) { - stubPool.apiForNode(node).loadProcessedData( + channelPool.apiForNode(node).loadProcessedData( RpcFileStorageIds.newBuilder() .addAllFileStorageIds(ids.stream().map(FileStorageId::id).toList()) .build() @@ -123,11 +123,11 @@ public class ExecutorClient extends AbstractDynamicClient { } public void calculateAdjacencies(int node) { - stubPool.apiForNode(node).calculateAdjacencies(Empty.getDefaultInstance()); + channelPool.apiForNode(node).calculateAdjacencies(Empty.getDefaultInstance()); } public void sideloadEncyclopedia(int node, Path sourcePath, String baseUrl) { - stubPool.apiForNode(node).sideloadEncyclopedia( + channelPool.apiForNode(node).sideloadEncyclopedia( RpcSideloadEncyclopedia.newBuilder() .setBaseUrl(baseUrl) .setSourcePath(sourcePath.toString()) @@ -136,21 +136,21 @@ public class ExecutorClient extends AbstractDynamicClient { } public void sideloadDirtree(int node, Path sourcePath) { - stubPool.apiForNode(node).sideloadDirtree( + channelPool.apiForNode(node).sideloadDirtree( RpcSideloadDirtree.newBuilder() .setSourcePath(sourcePath.toString()) .build() ); } public void sideloadReddit(int node, Path sourcePath) { - stubPool.apiForNode(node).sideloadReddit( + channelPool.apiForNode(node).sideloadReddit( RpcSideloadReddit.newBuilder() .setSourcePath(sourcePath.toString()) .build() ); } public void sideloadWarc(int node, Path sourcePath) { - stubPool.apiForNode(node).sideloadWarc( + channelPool.apiForNode(node).sideloadWarc( RpcSideloadWarc.newBuilder() .setSourcePath(sourcePath.toString()) .build() @@ -158,7 +158,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void sideloadStackexchange(int node, Path sourcePath) { - stubPool.apiForNode(node).sideloadStackexchange( + channelPool.apiForNode(node).sideloadStackexchange( RpcSideloadStackexchange.newBuilder() .setSourcePath(sourcePath.toString()) .build() @@ -166,7 +166,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void createCrawlSpecFromDownload(int node, String description, String url) { - stubPool.apiForNode(node).createCrawlSpecFromDownload( + channelPool.apiForNode(node).createCrawlSpecFromDownload( RpcCrawlSpecFromDownload.newBuilder() .setDescription(description) .setUrl(url) @@ -175,14 +175,14 @@ public class ExecutorClient extends AbstractDynamicClient { } public void exportAtags(int node, FileStorageId fid) { - stubPool.apiForNode(node).exportAtags( + channelPool.apiForNode(node).exportAtags( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() ); } public void exportSampleData(int node, FileStorageId fid, int size, String name) { - stubPool.apiForNode(node).exportSampleData( + channelPool.apiForNode(node).exportSampleData( RpcExportSampleData.newBuilder() .setFileStorageId(fid.id()) .setSize(size) @@ -192,14 +192,14 @@ public class ExecutorClient extends AbstractDynamicClient { } public void exportRssFeeds(int node, FileStorageId fid) { - stubPool.apiForNode(node).exportRssFeeds( + channelPool.apiForNode(node).exportRssFeeds( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() ); } public void exportTermFrequencies(int node, FileStorageId fid) { - stubPool.apiForNode(node).exportTermFrequencies( + channelPool.apiForNode(node).exportTermFrequencies( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() @@ -207,7 +207,7 @@ public class ExecutorClient extends AbstractDynamicClient { } public void downloadSampleData(int node, String sampleSet) { - stubPool.apiForNode(node).downloadSampleData( + channelPool.apiForNode(node).downloadSampleData( RpcDownloadSampleData.newBuilder() .setSampleSet(sampleSet) .build() @@ -215,11 +215,11 @@ public class ExecutorClient extends AbstractDynamicClient { } public void exportData(int node) { - stubPool.apiForNode(node).exportData(Empty.getDefaultInstance()); + channelPool.apiForNode(node).exportData(Empty.getDefaultInstance()); } public void restoreBackup(int node, FileStorageId fid) { - stubPool.apiForNode(node).restoreBackup( + channelPool.apiForNode(node).restoreBackup( RpcFileStorageId.newBuilder() .setFileStorageId(fid.id()) .build() @@ -228,7 +228,7 @@ public class ExecutorClient extends AbstractDynamicClient { public ActorRunStates getActorStates(int node) { try { - var rs = stubPool.apiForNode(node).getActorStates(Empty.getDefaultInstance()); + var rs = channelPool.apiForNode(node).getActorStates(Empty.getDefaultInstance()); var states = rs.getActorRunStatesList().stream() .map(r -> new ActorRunState( r.getActorName(), @@ -252,7 +252,7 @@ public class ExecutorClient extends AbstractDynamicClient { public UploadDirContents listSideloadDir(int node) { try { - var rs = stubPool.apiForNode(node).listSideloadDir(Empty.getDefaultInstance()); + var rs = channelPool.apiForNode(node).listSideloadDir(Empty.getDefaultInstance()); var items = rs.getEntriesList().stream() .map(i -> new UploadDirItem(i.getName(), i.getLastModifiedTime(), i.getIsDirectory(), i.getSize())) .toList(); @@ -268,7 +268,7 @@ public class ExecutorClient extends AbstractDynamicClient { public FileStorageContent listFileStorage(int node, FileStorageId fileId) { try { - var rs = stubPool.apiForNode(node).listFileStorage( + var rs = channelPool.apiForNode(node).listFileStorage( RpcFileStorageId.newBuilder() .setFileStorageId(fileId.id()) .build() diff --git a/code/common/service-client/src/main/java/nu/marginalia/client/grpc/GrpcStubPool.java b/code/common/service-client/src/main/java/nu/marginalia/client/grpc/GrpcChannelPool.java similarity index 76% rename from code/common/service-client/src/main/java/nu/marginalia/client/grpc/GrpcStubPool.java rename to code/common/service-client/src/main/java/nu/marginalia/client/grpc/GrpcChannelPool.java index 0692c002..d2508dfa 100644 --- a/code/common/service-client/src/main/java/nu/marginalia/client/grpc/GrpcStubPool.java +++ b/code/common/service-client/src/main/java/nu/marginalia/client/grpc/GrpcChannelPool.java @@ -10,11 +10,12 @@ import java.util.concurrent.*; import java.util.function.Function; import java.util.stream.Stream; +import static io.grpc.ConnectivityState.SHUTDOWN; + /** A pool of gRPC stubs for a service, with a separate stub for each node. * Manages broadcast-style request. */ -public abstract class GrpcStubPool { - public GrpcStubPool(String serviceName) { - +public abstract class GrpcChannelPool { + public GrpcChannelPool(String serviceName) { this.serviceName = serviceName; } @@ -25,25 +26,34 @@ public abstract class GrpcStubPool { } private final Map channels = new ConcurrentHashMap<>(); - private final Map apis = new ConcurrentHashMap<>(); private final ExecutorService virtualExecutorService = Executors.newVirtualThreadPerTaskExecutor(); private final String serviceName; - public GrpcStubPool(ServiceId serviceId) { + public GrpcChannelPool(ServiceId serviceId) { this.serviceName = serviceId.serviceName; } /** Get an API stub for the given node */ public STUB apiForNode(int node) { - var san = new ServiceAndNode(serviceName, node); - return apis.computeIfAbsent(san, n -> - createStub(channels.computeIfAbsent(san, this::createChannel)) + return createStub( + channels.compute( + new ServiceAndNode(serviceName, node), + this::refreshChannel) ); } + private ManagedChannel refreshChannel(ServiceAndNode serviceAndNode, ManagedChannel old) { + if (old == null || old.getState(true) != SHUTDOWN) { + return createChannel(serviceAndNode); + } + return old; + } + protected ManagedChannel createChannel(ServiceAndNode serviceAndNode) { - return ManagedChannelBuilder.forAddress(serviceAndNode.getHostName(), 81).usePlaintext().build(); + return ManagedChannelBuilder.forAddress(serviceAndNode.getHostName(), 81) + .usePlaintext() + .build(); } /** Invoke a function on each node, returning a list of futures in a terminal state, as per diff --git a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCDomainLinksService.java b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCDomainLinksService.java index 2eb4c01c..a2b6b780 100644 --- a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCDomainLinksService.java +++ b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCDomainLinksService.java @@ -3,7 +3,7 @@ package nu.marginalia.query; import com.google.inject.Inject; import io.grpc.ManagedChannel; import io.grpc.stub.StreamObserver; -import nu.marginalia.client.grpc.GrpcStubPool; +import nu.marginalia.client.grpc.GrpcChannelPool; import nu.marginalia.index.api.IndexDomainLinksApiGrpc; import nu.marginalia.index.api.RpcDomainIdCount; import nu.marginalia.index.api.RpcDomainIdList; @@ -17,11 +17,11 @@ import java.util.List; public class QueryGRPCDomainLinksService extends IndexDomainLinksApiGrpc.IndexDomainLinksApiImplBase { private static final Logger logger = LoggerFactory.getLogger(QueryGRPCDomainLinksService.class); - private final GrpcStubPool stubPool; + private final GrpcChannelPool channelPool; @Inject public QueryGRPCDomainLinksService(NodeConfigurationWatcher nodeConfigurationWatcher) { - stubPool = new GrpcStubPool<>(ServiceId.Index) { + channelPool = new GrpcChannelPool<>(ServiceId.Index) { @Override public IndexDomainLinksApiGrpc.IndexDomainLinksApiBlockingStub createStub(ManagedChannel channel) { return IndexDomainLinksApiGrpc.newBlockingStub(channel); @@ -37,7 +37,7 @@ public class QueryGRPCDomainLinksService extends IndexDomainLinksApiGrpc.IndexDo @Override public void getAllLinks(nu.marginalia.index.api.Empty request, StreamObserver responseObserver) { - stubPool.callEachSequential(stub -> stub.getAllLinks(request)) + channelPool.callEachSequential(stub -> stub.getAllLinks(request)) .forEach( iter -> iter.forEachRemaining(responseObserver::onNext) ); @@ -50,7 +50,7 @@ public class QueryGRPCDomainLinksService extends IndexDomainLinksApiGrpc.IndexDo StreamObserver responseObserver) { var rspBuilder = RpcDomainIdList.newBuilder(); - stubPool.callEachSequential(stub -> stub.getLinksFromDomain(request)) + channelPool.callEachSequential(stub -> stub.getLinksFromDomain(request)) .map(RpcDomainIdList::getDomainIdList) .forEach(rspBuilder::addAllDomainId); @@ -63,7 +63,7 @@ public class QueryGRPCDomainLinksService extends IndexDomainLinksApiGrpc.IndexDo StreamObserver responseObserver) { var rspBuilder = RpcDomainIdList.newBuilder(); - stubPool.callEachSequential(stub -> stub.getLinksToDomain(request)) + channelPool.callEachSequential(stub -> stub.getLinksToDomain(request)) .map(RpcDomainIdList::getDomainIdList) .forEach(rspBuilder::addAllDomainId); @@ -75,7 +75,7 @@ public class QueryGRPCDomainLinksService extends IndexDomainLinksApiGrpc.IndexDo public void countLinksFromDomain(nu.marginalia.index.api.RpcDomainId request, StreamObserver responseObserver) { - int sum = stubPool.callEachSequential(stub -> stub.countLinksFromDomain(request)) + int sum = channelPool.callEachSequential(stub -> stub.countLinksFromDomain(request)) .mapToInt(RpcDomainIdCount::getIdCount) .sum(); @@ -89,7 +89,7 @@ public class QueryGRPCDomainLinksService extends IndexDomainLinksApiGrpc.IndexDo public void countLinksToDomain(nu.marginalia.index.api.RpcDomainId request, io.grpc.stub.StreamObserver responseObserver) { - int sum = stubPool.callEachSequential(stub -> stub.countLinksToDomain(request)) + int sum = channelPool.callEachSequential(stub -> stub.countLinksToDomain(request)) .mapToInt(RpcDomainIdCount::getIdCount) .sum(); diff --git a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java index ef253253..bd2c0452 100644 --- a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java +++ b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java @@ -4,7 +4,7 @@ import com.google.inject.Inject; import io.grpc.ManagedChannel; import io.prometheus.client.Histogram; import lombok.SneakyThrows; -import nu.marginalia.client.grpc.GrpcStubPool; +import nu.marginalia.client.grpc.GrpcChannelPool; import nu.marginalia.db.DomainBlacklist; import nu.marginalia.index.api.*; import nu.marginalia.model.id.UrlIdCodec; @@ -27,7 +27,7 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { .help("QS-side query time (GRPC endpoint)") .register(); - private final GrpcStubPool stubPool; + private final GrpcChannelPool channelPool; private final QueryFactory queryFactory; private final DomainBlacklist blacklist; @@ -40,7 +40,7 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { this.queryFactory = queryFactory; this.blacklist = blacklist; - stubPool = new GrpcStubPool<>(ServiceId.Index) { + channelPool = new GrpcChannelPool<>(ServiceId.Index) { @Override public IndexApiGrpc.IndexApiBlockingStub createStub(ManagedChannel channel) { return IndexApiGrpc.newBlockingStub(channel); @@ -89,7 +89,7 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { @SneakyThrows private List executeQueries(RpcIndexQuery indexRequest, int totalSize) { - return stubPool.invokeAll(stub -> new QueryTask(stub, indexRequest)) + return channelPool.invokeAll(stub -> new QueryTask(stub, indexRequest)) .stream() .filter(f -> f.state() == Future.State.SUCCESS) .map(Future::resultNow)