Spark source learning-built-in RPC framework (3)

Posted May 25, 20209 min read

RPC Client Factory TransportClientFactory

TransportClientFactory is a factory class that creates TransportClient. TransportContext's createClientFactory method can create an instance of TransportClientFactory

/**
* Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
* a new Client. Bootstraps will be executed synchronously, and must run successfully in order
* to create a Client.
* /
public TransportClientFactory createClientFactory(List bootstraps) {
return new TransportClientFactory(this, bootstraps);
}

  public TransportClientFactory createClientFactory() {
    return createClientFactory(Lists. <TransportClientBootstrap> newArrayList());
  }

As you can see, there are two overloaded createClientFactory methods in TransportContext, and they will eventually pass two parameters when constructing TransportClientFactory:TransportContext and TransportClientBootstrap list. The implementation of the TransportClientFactory constructor is shown in the code.

 public TransportClientFactory(
      TransportContext context,
      List <TransportClientBootstrap> clientBootstraps) {
    this.context = Preconditions.checkNotNull(context);
    this.conf = context.getConf();
    this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
    this.connectionPool = new ConcurrentHashMap <>();
    this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
    this.rand = new Random();

    IOMode ioMode = IOMode.valueOf(conf.ioMode());
    this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
    //TODO:Make thread pool name configurable.
    this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
    this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
      conf.preferDirectBufs(), false/* allowCache * /, conf.clientThreads());
  }

The variables in the TransportClientFactory constructor are as follows:

context:Reference to the TransportContext passed by the parameter.

conf:Refers to TransportConf, obtained here by calling getConf of TransportContext.

clientBootstraps:List of TransportClientBootstrap passed by parameters.

connectionPool:Cache for the connection pool ClientPool for each Socket address.
The data structure of connectionPool is more complicated. For the convenience of readers, the data structure of connectionPool is shown in the figure.

image.png

numConnectionsPerPeer:The key obtained from TransportConf is the attribute value of "spark. + module name + .io.num-ConnectionsPerPeer" The value of this attribute is used to specify the number of connections between peer nodes. The module name here is actually the Module field of TransportConf. Many components of Spark are built using the RPC framework, and they are distinguished according to the module name. For example, the key of the RPC module is "spark.rpc.io.numConnectionsPerPeer".

#TransportConf getConfKey method to get parameters
private String getConfKey(String suffix) {
    return "spark." + module + "." + suffix;
  }

rand:cached in the connection pool ClientPool corresponding to the Socket address
TransportClient performs random selection and load balances each connection.

ioMode:IO mode, that is, the attribute value whose key is "spark. + module name + .io.mode" is obtained from TransportConf. The default value is NIO, Spark also supports EPOLL.

socketChannelClass:The class used when the client Channel is created, matched by ioMode, the default is NioSocketChannel, Spark also supports EpollEventLoopGroup.

workerGroup:According to the Netty specification, the client only has a worker group, so a worker-Group is created here. The actual type of workerGroup is NioEventLoopGroup.

pooledAllocator:An allocator that aggregates ByteBuf but is disabled for local thread caching.

Client bootloader TransportClientBootstrap

The clientBootstraps property of TransportClientFactory is a list of TransportClientBootstrap. Transport ClientBootstrap is a client boot program that is executed on TransportClient. It mainly prepares for initialization(such as authentication and encryption) when the connection is established. The operations done by TransportClientBootstrap are often expensive, but fortunately the established connection can be reused. The interface definition of TransportClientBootstrap is shown in Listing 3-10:

import io.netty.channel.Channel;

/**
* A bootstrap which is executed on a TransportClient before it is returned to the user.
* This enables an initial exchange of information(e.g., SASL authentication tokens) on a once-per-
* connection basis.
*
* Since connections(and TransportClients) are reused as much as possible, it is generally
* reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
* the JVM itself.
* /
public interface TransportClientBootstrap {
/** Performs the bootstrapping operation, throwing an exception on failure. * /
void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
}

TransportClientBootstrap has two implementation classes:EncryptionDisablerBootstrap and SaslClientBootstrap.

Create RPC client TransportClient

With TransportClientFactory, various modules of Spark can use it to create RPC client TransportClient. Each TransportClient instance can only communicate with one remote RPC service, so if a component in Spark wants to communicate with multiple RPC services, it needs to hold multiple TransportClient instances. The method of creating TransportClient is shown in the code(actually, getting TransportClient from the cache).

/**
* Create a {@link TransportClient} connecting to the given remote host/port.
*
* We maintains an array of clients(size determined by spark.shuffle.io.numConnectionsPerPeer)
* and randomly picks one to use. If no client was previously created in the randomly selected
* spot, this function creates a new client and places it there.
*
* Prior to the creation of a new TransportClient, we will execute all
* {@link TransportClientBootstrap} s that are registered with this factory.
*
* This blocks until a connection is successfully established and fully bootstrapped.
*
* Concurrency:This method is safe to call from multiple threads.
* /
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
//Get connection from the connection pool first.
//If it is not found or not active, create a new one.
//Use unresolved address here to avoid DNS resolution each time we creates a client.
final InetSocketAddress unresolvedAddress =
InetSocketAddress.createUnresolved(remoteHost, remotePort);

    //Create the ClientPool if we don't have it yet.
    ClientPool clientPool = connectionPool.get(unresolvedAddress);
    if(clientPool == null) {
      connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
      clientPool = connectionPool.get(unresolvedAddress);
    }

    int clientIndex = rand.nextInt(numConnectionsPerPeer);
    TransportClient cachedClient = clientPool.clients [clientIndex];

    if(cachedClient! = null && cachedClient.isActive()) {
      //Make sure that the channel will not timeout by updating the last use time of the
      //handler. Then check that the client is still alive, in case it timed out before
      //this code was able to update things.
      TransportChannelHandler handler = cachedClient.getChannel(). Pipeline()
        .get(TransportChannelHandler.class);
      synchronized(handler) {
        handler.getResponseHandler(). updateTimeOfLastRequest();
      }

      if(cachedClient.isActive()) {
        logger.trace("Returning cached connection to {}:{}",
          cachedClient.getSocketAddress(), cachedClient);
        return cachedClient;
      }
    }

    //If we reach here, we don't have an existing connection open. Let's create a new one.
    //Multiple threads might race here to create new connections. Keep only one of them active.
    final long preResolveHost = System.nanoTime();
    final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
    final long hostResolveTimeMs =(System.nanoTime()-preResolveHost)/1000000;
    if(hostResolveTimeMs> 2000) {
      logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
    } else {
      logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
    }

    synchronized(clientPool.locks [clientIndex]) {
      cachedClient = clientPool.clients [clientIndex];

      if(cachedClient! = null) {
        if(cachedClient.isActive()) {
          logger.trace("Returning cached connection to {}:{}", resolvedAddress, cachedClient);
          return cachedClient;
        } else {
          logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
        }
      }
      clientPool.clients [clientIndex]= createClient(resolvedAddress);
      return clientPool.clients [clientIndex];
    }
  }

From the code, the steps to create TransportClient are as follows.

1) Call the static method createUnresolved of InetSocketAddress to build InetSocketAddress(in this way to create InetSocketAddress, you can avoid unnecessary domain name resolution when there is already TransportClient in the cache), and then obtain the ClientPool corresponding to this address from the connectionPool, if not, you need Create a new ClientPool and put it in the cache connectionPool.

2) According to the size of numConnectionsPerPeer(using "spark. + Module name + .io.numConnections-PerPeer" property configuration), randomly select a TransportClient from ClientPool.

3) If there is no TransportClient in the randomly generated index position in the Clients array of ClientPool or TransportClient is not activated, then go to step 5, otherwise check the TransportClient in step 4.

4) Update the last usage time of the TransportChannelHandler configured in the TransportClient's channel to ensure that the channel has not timed out, then check whether the TransportClient is activated, and finally return this TransportClient to the caller.

5) Since there is no TransportClient available in the cache, the constructor of InetSocketAddress is called to create the InetSocketAddress object(using the constructor of InetSocketAddress directly to create the InetSocketAddress will perform domain name resolution). At this step, multiple threads may generate race conditions(due to no synchronization Processing, so multiple threads are very likely to be executed here at the same time, and found that no TransportClient is available in the cache, so all use the InetSocketAddress constructor to create InetSocketAddress).

6) If the race conditions generated in the process of creating InetSocketAddress in step 5 are not properly handled, there will be a thread safety problem, so it is time for the Locks array of ClientPool to play a role. According to the randomly generated array index, the lock objects in the locks array can be synchronized one-to-one with the TransportClient in the clients array. Even if the race condition was generated before, only one thread can enter the critical section at this step. In the critical zone, the first-entered thread calls the overloaded createClient method to create the TransportClient object and put it in the Clients array of ClientPool. When the thread that first enters the critical section exits the critical section, other threads can enter. At this time, it is found that the TransportClient object already exists in the Clients array of ClientPool, then the TransportClient will no longer be created, but used directly.

The entire execution process of the following code actually solves the thread safety problems of the use of the TransportClient cache and the createClient method, and does not involve the implementation of the TransportClient. The creation process of TransportClient is implemented in the overloaded createClient method.

/** Create a completely new {@link TransportClient} to the remote address. * /
  private TransportClient createClient(InetSocketAddress address) throws IOException {
    logger.debug("Creating new connection to {}", address);

    Bootstrap bootstrap = new Bootstrap();
    bootstrap.group(workerGroup)
      .channel(socketChannelClass)
      //Disable Nagle's Algorithm since we don't want packets to wait
      .option(ChannelOption.TCP_NODELAY, true)
      .option(ChannelOption.SO_KEEPALIVE, true)
      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
      .option(ChannelOption.ALLOCATOR, pooledAllocator);

    final AtomicReference <TransportClient> clientRef = new AtomicReference <>();
    final AtomicReference <Channel> channelRef = new AtomicReference <>();

    bootstrap.handler(new ChannelInitializer <SocketChannel>() {
      @Override
      public void initChannel(SocketChannel ch) {
        TransportChannelHandler clientHandler = context.initializePipeline(ch);
        clientRef.set(clientHandler.getClient());
        channelRef.set(ch);
      }
    });

    //Connect to the remote server
    long preConnect = System.nanoTime();
    ChannelFuture cf = bootstrap.connect(address);
    if(! cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
      throw new IOException(
        String.format("Connecting to%s timed out(%s ms)", address, conf.connectionTimeoutMs()));
    } else if(cf.cause()! = null) {
      throw new IOException(String.format("Failed to connect to%s", address), cf.cause());
    }

    TransportClient client = clientRef.get();
    Channel channel = channelRef.get();
    assert client! = null:"Channel future completed successfully with null client";

    //Execute any client bootstraps synchronously before marking the Client as successful.
    long preBootstrap = System.nanoTime();
    logger.debug("Connection to {} successful, running bootstraps ...", address);
    try {
      for(TransportClientBootstrap clientBootstrap:clientBootstraps) {
        clientBootstrap.doBootstrap(client, channel);
      }
    } catch(Exception e) {//catch non-RuntimeExceptions too as bootstrap may be written in Scala
      long bootstrapTimeMs =(System.nanoTime()-preBootstrap)/1000000;
      logger.error("Exception while bootstrapping client after" + bootstrapTimeMs + "ms", e);
      client.close();
      throw Throwables.propagate(e);
    }
    long postBootstrap = System.nanoTime();

    logger.info("Successfully created connection to {} after {} ms({} ms spent in bootstraps)",
      address,(postBootstrap-preConnect)/1000000,(postBootstrap-preBootstrap)/1000000);

    return client;
  }

From the code, the steps to create TransportClient are as follows.

1) Build and configure the root bootloader Bootstrap.

2) Set the pipeline initialization callback function for the root bootloader. This callback function will call the InitializePipeline method of TransportContext to initialize the pipeline of the Channel.

3) Use the root bootloader to connect to the remote server. When the connection is successfully initialized to the pipeline, the initialization callback function will be called back to set the TransportClient and Channel objects to the atomic references clientRef and channelRef, respectively.

4) Set the client boot program for TransportClient, that is, set the Transport-ClientBootstrap list in TransportClientFactory.

5) Return this TransportClient object.

                           The blog is based on the book "The Art of Spark Kernel Design:Architecture Design and Implementation"