为什么不要在Spark中执行这个操作,详解DataFrame collect源码流程

大家好,我是Tim。

相信很多Spark新手经常听到过这个劝告:”当你在编写Spark SQL DataFrame 时尽量不要使用collect()函数”。

因为有时可能会由于这个可有可无的语句,使得整个Spark程序跑着跑着挂掉或者执行超慢。

执行collect()导致程序变慢的原因可能是多个的,这里我们首先给出一些原因和解决办法:

  1. 数据量过大:如果 DataFrame 中的数据量非常大,collect 方法会将所有数据汇集到 Driver 程序中,这可能导致内存不足或网络传输延迟增加,从而导致慢速执行。建议使用其他方式处理大规模数据,如分布式计算或对数据进行采样。

  2. 内存不足:如果 Driver 程序的内存不足以容纳整个 DataFrame,collect 方法可能会导致 OutOfMemoryError。可以通过增加 Driver 程序的内存限制(例如,通过调整 spark.driver.memory 参数)来尝试解决此问题。

  3. 网络传输延迟:当 DataFrame 的数据分布在多个 Executor 节点上时,collect 方法需要将数据从各个节点传输到 Driver 程序,这可能会受到网络传输延迟的影响。可以考虑使用其他操作,如使用分布式计算框架进行并行处理,以减少网络传输开销。

  4. 过多的分区数:如果 DataFrame 的分区数过多,则 collect 方法需要在不同的分区之间进行数据传输,这可能导致较高的网络开销和延迟。可以尝试减少分区数,通过 repartition 或 coalesce 方法将分区数减少到合理的范围。

  5. 数据倾斜:如果 DataFrame 的数据分布不均匀,即某些分区中的数据量远远超过其他分区,这可能导致 collect 方法执行缓慢。可以尝试使用 Spark 的数据倾斜处理技术,如 repartitionByRange 或使用自定义的数据倾斜处理逻辑来解决此问题。

那么collect()函数的执行细节到底是如何呢,你可以简单的描述清楚吗?

接下来,我们从源码的解读进行分析DataFrame collect的执行流程。

Dataset.collect源码解析

// 初始化 Spark: SparkSession 
val Spark = SparkSession.builder().getOrCreate()
// 创建一个 DataFrame 或 Dataset[Row] 
val df = Spark.sql("select sex, count(1) as count from user_table group by sex")
df.collect()

如上所示,这是我们经常运行的SparkSQL程序,当我们执行df.collect()后,Spark程序是如何执行的呢?

当我们点开collect的源码,找到Dataset中collect函数,如下所示 。

def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)

当我们执行collect函数时,其首先会执行withAction函数,并会分别传入"collect"和queryExecution和传入collectFromPlan进行执行,这是一个柯里化的函数,它会被分别传参并连续的执行。

withAction是一个Dataset的包装函数,在其内部依赖于SQLExecution.withNewExecutionId函数。

private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
    SQLExecution.withNewExecutionId(qe, Some(name)) {
      qe.executedPlan.resetMetrics()
      action(qe.executedPlan)
    }
  }

withNewExecutionId主要的作用是执行”queryExecution”的操作,并跟踪“body”中的所有Spark Job任务,并将他们与执行链接起来。

“body” 中的内容,就是函数主体中括号内的内容,就是将要执行的Spark Job。

可以看到这里只有两行:

  1. qe.executedPlan.resetMetrics() 重置执行计划的metrics信息

  2. action(qe.executedPlan) 将executedPlan传给action函数并执行,而action就是传入的第二个参数collectFromPlan。

那么下面我们就来看看Dataset.collectFromPlan的源码是怎样。

private def collectFromPlan(plan: SparkPlan): Array[T] = {
    val fromRow = resolvedEnc.createDeserializer()
    plan.executeCollect().map(fromRow)
  }

从源码可以看出,其首先执行反序列化程序将SparkSQL row转换,然后执行executeCollect()函数。

def executeCollect(): Array[InternalRow] = {
    val byteArrayRdd = getByteArrayRdd()

    val results = ArrayBuffer[InternalRow]()
    byteArrayRdd.collect().foreach { countAndBytes =>
      decodeUnsafeRows(countAndBytes._2).foreach(results.+=)
    }
    results.toArray
  }

从上面可以看出,在executeCollect中,首先通过getByteArrayRdd将RDD[UnSafeRow]转换为字节数组RDD,然后调用RDD.collect(),最后解析并返回结果。

下面我们再来分析下RDD.collect是如何执行的。

RDD.collect源码解析

def collect(): Array[T] = withScope {
    val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
    Array.concat(results: _*)
  }

RDD是Spark中最基础的抽象,Spark SQL程序最终都会转换为RDD进行执行。上述为RDD中的collect()函数。

在其中执行了sc.runJob,可见其是Action算子,并将合并多个分区的结果。

下面我们看下sc.runJob是如何执行的。

  1. SparkContext.runJob方法

def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      resultHandler: (Int, U) => Unit): Unit = {
    if (stopped.get()) {
      throw new IllegalStateException("SparkContext has been shutdown")
    }
    val callSite = getCallSite
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite.shortForm)
    if (conf.getBoolean("spark.logLineage"false)) {
      logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
    }
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
    progressBar.foreach(_.finishAll())
    rdd.doCheckpoint()
  }
  1. 准备 callSite,以便出问题知道是哪里代码出错了

  2. 通过 DAGScheduler.runJob提交作业

  3. progressBar: 命令行里 stage的进度条显示

  4. doCheckpoint 将 RDD的中间和最后结果缓存下来

真正的代码执行在dagScheduler.runJob中,下面我们展示下源码。

  1. DAGScheduler.runJob方法

def runJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): Unit = {
    val start = System.nanoTime
    val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
    ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf)
    waiter.completionFuture.value.get match {
      case scala.util.Success(_) =>
        logInfo("Job %d finished: %s, took %f s".format
          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      case scala.util.Failure(exception) =>
        logInfo("Job %d failed: %s, took %f s".format
          (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
        // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
        val callerStackTrace = Thread.currentThread().getStackTrace.tail
        exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
        throw exception
    }
  }

对于 DAGScheduler 而言,Stage是其最小的调度单元,其主要功能为:

  • 给Job生成以Stage为调度单位的DAG图

  • 追踪RDD和Stage的输出状态,比如哪些已经被物化,并基于这些信息提供一个最优的调度方案

  • 提交Stage,以TaskSet的形式提交给 TasksetManager

submitJob方法中为其核心的实现方法:

def submitJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): JobWaiter[U] = {
    // Check to make sure we are not launching a task on a partition that does not exist.
    val maxPartitions = rdd.partitions.length
    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
      throw new IllegalArgumentException(
        "Attempting to access a non-existent partition: " + p + ". " +
          "Total number of partitions: " + maxPartitions)
    }

    val jobId = nextJobId.getAndIncrement()
    if (partitions.isEmpty) {
      val clonedProperties = Utils.cloneProperties(properties)
      if (sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) == null) {
        clonedProperties.setProperty(SparkContext.SPARK_JOB_DESCRIPTION, callSite.shortForm)
      }
      val time = clock.getTimeMillis()
      listenerBus.post(
        SparkListenerJobStart(jobId, time, Seq.empty, clonedProperties))
      listenerBus.post(
        SparkListenerJobEnd(jobId, time, JobSucceeded))
      // Return immediately if the job is running 0 tasks
      return new JobWaiter[U](this, jobId, 0, resultHandler)
    }

在上述源码中,我们可以看出DAGScheduler 的runJob的是围绕

DAGSchedulerEventProcessLoop 展开的。

这是一个经典的EventLoop使用场景。runJob 方法的执行流程如下:

  1. 提交任务本质上是向 EventLoop 发送一个 JobSubmitted 事件

  2. 通过一个JobWaiter对象等待结果

在 EventLoop 的另一端,onReceive 接收到 JobSubmitted事件,交给成员函数 handleJobSubmitted 处理该事件。

JobWaiter 内部有一个 Promise 对象,它会不停接收到 taskSucceeded,增加计数,知道成功task的数量等于task的总数量,将promise置为成功。

  1. DAGSchedulerEventProcessLoop.onReceive方法

在DAGSchedulerEventProcessLoop的onReceive方法负责接收各类事件,并分发给特定的 handler 函数处理。

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
    case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

当我们发送JobSubmitted任务时,在接受后会调用 handleJobSubmitted方法,这里就是启动任务的核心实现,限于篇幅我们简单进行介绍:

private[scheduler] def handleJobSubmitted(jobId: Int,
      finalRDD: RDD[_],
      func: (TaskContext, Iterator[_])
 
=> _,
      partitions: Array[Int],
      callSite: CallSite,
      listener: JobListener,
      properties: Properties): Unit = {
    var finalStage: ResultStage = null
    try {
      // 1. 创建Stage 
      finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
    } catch {
      ...
    }
    ...
    // 创建Job
    val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
    // 注册job和stage
    jobIdToActiveJob(jobId) = job
    activeJobs += job
    finalStage.setActiveJob(job)
    val stageIds = jobIdToStageIds(jobId).toArray
    val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
    // 提交Job
    listenerBus.post(
      SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
    submitStage(finalStage)
  }

从上述的源码可以看出,其主要做了以下几个事情:

  1. 创建Stage:递归式地创建,先创建parent stage, 并注册

  2. 创建Job,并注册

  3. 按拓扑顺序提交Stage

由于Stage Graph是有向无环图,因此Stage的创建和提交应遵循拓扑顺序。为了证明这一点,我们将对DAGSCheduler.createResultStage做一些解释。

  1. DAGScheduler.createResultStage方

在 SparkPlan 对象调用 execute 时,会递归地生成 RDD,从而构成了 RDD Lineage Graph,它是一个有向无环图。

那么在 RDD Lineage 上如何切分 stage 呢?

有两种类型的阶段:

  • ShuffleMapStage:执行 DAG 中为 shuffle 生成数据的中间阶段

  • ResultStage:在 RDD 的某些分区上应用函数来计算操作的结果

因此,一个 ResultStage 依赖于一个或多个 ShuffleMapStage,而一个 ShuffleMapStage 依赖于任意数量的 ShuffleMapStages 或 None。

依赖关系主要有两类:

  • ShuffleDependency:发生在JOIN、GROUP-BY、REPARTITION等中

  • NarrowDependency:出现在 SELECT、WHERE、COALESCE 等中

ShuffleDependency 将两个连续的 RDD 分成两个阶段。

总结

  1. 当调用collect()函数时,首先会执行DataFrame的collect函数,其内部首先调用executeCollect()函数,该函数将DataFrame转换为RDD,并将其结果收集到Driver端。

  2. 在RDD的collect()函数中,实际上是通过sc.runJob()方法来执行作业的。sc.runJob()方法会将作业提交给DAGScheduler,并最终由DAGScheduler负责作业的调度和执行。

  3. DAGScheduler会根据RDD Lineage图构建作业的Stage Graph,并按照拓扑顺序依次提交各个Stage。在每个Stage中,会将作业分成多个TaskSet,并提交给TaskScheduler进行具体的任务调度和执行。

  4. 在整个过程中,会不断地监控作业的执行情况,并将相关的事件发送给SparkListener进行监听和处理。一旦作业执行完成,会将执行结果返回给collect()函数。

总的来说,collect()函数的执行流程涉及到DataFrame、RDD、DAGScheduler、TaskScheduler等多个组件的协同工作,其中涉及到了作业的划分、调度和执行过程,基本上包含了Spark SQL执行算子的基本过程。



如果觉得这篇文章对你有所帮助,
请点一下或者在看,是对我的肯定和支持~


请使用浏览器的分享功能分享到微信等