Go实现多线程分片下载文件

| |
[不指定 2024/10/21 16:39 | by 刘新修 ]

 我们在下载大文件时,通常会使用多线程下载的方式来加快下载速度。例如常用的多线程下载工具(Gopeed、Aria2、XDM等等),都是通过多线程下载技术充分利用了网络带宽,以提高下载速度。

那么多线程下载是怎么实现的呢?多个线程发送网络请求,是怎么做到同时下载一个文件呢?事实上,借助HTTP协议中的一些机制就可以实现了!

今天我们就通过使用Go语言为例,从了解HTTP请求相关的一些机制开始,实现一个多线程下载的示例。

1,多线程下载原理

事实上,多线程下载的原理很简单,主要的步骤如下:

  • 获取待下载文件大小
  • 每个线程下载文件的一部分
  • 全部下载完成后,拼接为完整文件

实现这些步骤,就涉及到HTTP协议的下列相关机制。

(1) HEAD请求 - 只获取请求头

我们通常发送HTTP请求大多数是GET或者POST类型,发送请求后我们会立即获取响应体,浏览器则会根据响应体的类型来处理内容,例如返回的是text/html就会作为网页显示,返回image/png就会解码为图片等等,响应体的类型由响应头Content-Type标识。当我们下载文件时,事实上也是发送HTTP请求,只不过服务器返回的响应体就是文件本身了!其类型则是application/octet-stream,浏览器也知道这是个文件需要下载。

当然,文件作为响应体通常比起网页、图片要大得多,在多线程下载时,我们就要先获取文件的大小,而不是立即获取文件本身,这时我们就可以向服务器发起HEAD请求而不是GET请求。

服务器收到HEAD请求后,就只会返回对应的响应头,而不会返回响应体,这样我们就可以在下载文件之前,读取响应头中的Content-Length来先获取待下载文件大小。

(2) Range请求头 - 只获取部分响应体

知道了文件大小,我们就需要让每个线程只下载一部分文件,借助HTTP的Range请求头,就可以实现只让服务端返回响应体内容的一部分,而不是返回完整的响应体。

这里我们先来借助书籍《图解HTTP》中对Range请求头的讲解,来学习一下:

XML/HTML代码
  1. Range: bytes=5001-10000  

那么服务端就只会返回响应体的第5001到第10000字节的内容部分,包含第5001和第10000字节,0表示响应体的第一个字节。

这样,在多个线程同时下载文件时,我们在每个线程的请求中使用Range请求头,就可以实现一个线程只下载文件的一部分了!

(3) 为什么多线程下载可以提升速度?

事实上,在我们客户端(下载文件的)和服务端双向网络通信情况都很好的情况下,使用单线程和多线程下载的速度是几乎没有差异的,也就是说能够跑满我们客户端的全部带宽,那么这种情况下我们使用单线程下载反而更能够节省硬件和网络资源。

但是在我们客户端和服务端之间网络波动较大的情况下,例如我们国内从Github下载文件的时候,就会发现多线程下载速度比单线程快得多,反之使用单线程完全无法充分利用我们的网络带宽。

这种现象事实上是因为TCP连接的慢启动机制导致的,众所周知HTTP是基于TCP的协议,每次我们建立HTTP连接时,包括下载文件,都是在传输层基于TCP协议进行传输。TCP慢启动机制是TCP 协议中一种拥塞控制的机制,目的是在开始数据传输时逐步探测网络的容量,避免瞬间发送大量数据而导致网络拥塞。慢启动不是字面意义上的“慢”,而是相对于立即使用最大带宽而言,它会逐渐增加传输速率。

慢启动机制的过程简要概括如下:

  • 一开始建立连接:当一个新的TCP连接建立后,发送方并不知道当前网络的拥塞情况。因此,发送方不会马上发送大量数据,而是会使用慢启动机制来逐步增加数据传输的速率,在TCP中使用阻塞窗口cwnd来限制发送的数据量,也就是说一开始cwnd是非常小的
  • 拥塞窗口增长:在建立连接后,每当接收到一个确认ACK包时,cwnd指数级增长,直到达到网络的带宽限制或者某个拥塞控制的阈值(称为慢启动阈值ssthresh),这个过程会一直持续,直到发送方探测到网络出现了拥塞(比如丢包或者确认延迟变长),或者cwnd达到了某个预定义的慢启动阈值ssthresh
  • 慢启动的终止:慢启动机制会在以下情况终止:
    • 达到慢启动阈值ssthresh:当拥塞窗口cwnd增长到慢启动阈值ssthresh时,慢启动机制停止,此时TCP会进入另一种拥塞控制机制,称为拥塞避免,这时cwnd增长变为线性而非指数级
    • 发生拥塞(如丢包或超时):如果发送方检测到数据包丢失(例如没有收到确认),它会认为网络已经出现拥塞,此时ssthresh会被调整为当前cwnd的一半,然后cwnd会重置为1 MSS,重新进入慢启动阶段

可见TCP连接使用cwnd限制两者发送的数据量的大小,并逐步“试探”两者传输数据速率的上限并增加传输的数据量。

在我们下载文件时,事实上是服务端在向我们发送文件,如果网络波动较大、不稳定,TCP连接机会一直将cwnd限制在一个较小的值,在单位时间内,服务端也无法向我们发送更大的数据量。

此时,如果我们使用多线程下载,和服务端建立多个TCP连接,这样即使每个TCP连接的cwnd较小,所有TCP连接加起来传输的数据量仍然可以占满我们的带宽。

2,Go代码实现

知道了HTTP的上述几个机制,相信大家就知道如何实现一个简单的多线程下载了!我们可以总结主要步骤如下:

  • 发送HEAD类型请求,通过Content-Length请求头获取待下载文件大小
  • 根据给定的线程数量,结合待下载文件大小,确定每个线程下载的范围部分,也就是每个线程的Range请求头字节范围
  • 启动所有线程,使得每个线程下载它们对应的部分文件,并等待全部线程下载完成
  • 合并每个线程下载的部分为最终文件
  • 清理每个线程下载的文件部分

这里分别设计下列类(结构体),用于存放多线程下载时的传入参数和状态量:

 

上述ShardTask类表示一个线程的下载任务,其中会完成一个分片(文件的一部分)的下载请求操作,它有如下作为参数的属性:

  • Url 下载的文件地址
  • Order 分片序号
  • ShardFilePath 这个分片文件的保存路径
  • RangeStartRangeEnd 下载的文件起始范围和结束范围,用于设定Range请求头

此外,还有作为下载状态的属性:

  • DownloadSize 下载任务进行时,这个线程已下载的文件部分大小
  • TaskDone 这个线程的下载任务是否完成

该类的成员方法如下:

  • DoShardGet 执行分片下载任务,在其中会根据RangeStartRangeEnd设定对应的HTTP请求头,发送请求并下载对应的文件部分

然后就是ParallelGetTask类,表示一整个多线程下载任务,其中包含了一个多线程下载任务的参数和状态量,并且实现了多线程下载的每个步骤,它有如下作为参数的属性:

  • Url 文件的下载链接
  • FilePath 文件下载完成后的保存位置
  • Concurrent 下载并发数,即同时下载的分片数量
  • TempFolder 临时分片文件的保存文件夹

此外还有作为状态的属性:

  • TotalSize 待下载文件的总大小
  • ShardTaskList 存储所有分片任务对象指针的列表

该类中的方法主要是分片下载的一些步骤如下:

  • getLength 发送HEAD请求获取Content-Length以获取文件大小,获取后将其设定到TotalSize属性
  • allocateTask 根据给定的线程数和获取到的文件大小,计算每个线程下载的文件内容范围,并创建对应的ShardTask结构体放入ShardTaskList
  • downloadShard 为每一个ShardTask对象创建一个线程(Goroutine)并在新的线程中调用ShardTask对象的下载分片方法,以启动所有线程的下载任务,并通过sync.WaitGroup来等待全部线程完成
  • mergeFile 下载完成后,合并每个分片为最终文件
  • cleanShard 合并完成后,清理下载的每个分片文件
  • printTotalProcess 这是一个附加的辅助方法,用于实时输出下载进度
  • Run 启动整个多线程下载任务,该函数是暴露的公开函数,其中对上述每个步骤函数进行了组织,按顺序调用执行

下面,我们来看一下它们的代码实现。

(1) ShardTask - 一个线程的下载任务

C#代码
  1. package model  
  2.   
  3. import (  
  4.     "bufio"  
  5.     "fmt"  
  6.     "github.com/fatih/color"  
  7.     "io"  
  8.     "net/http"  
  9.     "os"  
  10.     "sync"  
  11. )  
  12.   
  13. // 全局HTTP客户端  
  14. var httpClient = http.Client{  
  15.     Transport: &http.Transport{  
  16.         // 关闭keep-alive确保一个线程就使用一个TCP连接  
  17.         DisableKeepAlives: true,  
  18.     },  
  19. }  
  20.   
  21. // ShardTask 单个分片下载任务的任务参数和状态量  
  22. type ShardTask struct {  
  23.     // 下载链接  
  24.     Url string  
  25.     // 分片序号,从1开始  
  26.     Order int  
  27.     // 这个分片文件的路径  
  28.     ShardFilePath string  
  29.     // 分片的起始范围(字节,包含)  
  30.     RangeStart int64  
  31.     // 分片的结束范围(字节,包含)  
  32.     RangeEnd int64  
  33.     // 已下载的部分(字节)  
  34.     DownloadSize int64  
  35.     // 该任务是否完成  
  36.     TaskDone bool  
  37. }  
  38.   
  39. // NewShardTask 构造函数  
  40. func NewShardTask(url string, order int, shardFilePath string, rangeStart int64, rangeEnd int64) *ShardTask {  
  41.     return &ShardTask{  
  42.         // 设定任务参数  
  43.         Url:           url,  
  44.         Order:         order,  
  45.         ShardFilePath: shardFilePath,  
  46.         RangeStart:    rangeStart,  
  47.         RangeEnd:      rangeEnd,  
  48.         // 初始化状态量  
  49.         DownloadSize: 0,  
  50.         TaskDone:     false,  
  51.     }  
  52. }  
  53.   
  54. // DoShardGet 开始下载这个分片(该方法在goroutine中执行)  
  55. func (task *ShardTask) DoShardGet(waitGroup *sync.WaitGroup) {  
  56.     // 创建文件  
  57.     file, e := os.OpenFile(task.ShardFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755)  
  58.     if e != nil {  
  59.         color.Red("任务%d创建文件失败!", task.Order)  
  60.         color.HiRed("%s", e)  
  61.         return  
  62.     }  
  63.     // 准备请求  
  64.     request, e := http.NewRequest("GET", task.Url, nil)  
  65.     if e != nil {  
  66.         color.Red("任务%d创建请求出错!", task.Order)  
  67.         color.HiRed("%s", e)  
  68.         return  
  69.     }  
  70.     // 设定请求头  
  71.     request.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", task.RangeStart, task.RangeEnd))  
  72.     // 发送请求  
  73.     response, e := httpClient.Do(request)  
  74.     if e != nil {  
  75.         color.Red("任务%d发送下载请求出错!", task.Order)  
  76.         color.HiRed("%s", e)  
  77.         return  
  78.     }  
  79.     // 读取请求体  
  80.     body := response.Body  
  81.     // 读取缓冲区  
  82.     buffer := make([]byte, 8092)  
  83.     // 准备写入文件  
  84.     writer := bufio.NewWriter(file)  
  85.     for {  
  86.         // 读取一次内容至缓冲区  
  87.         readSize, readError := body.Read(buffer)  
  88.         if readError != nil {  
  89.             // 如果读取完毕则退出循环  
  90.             if readError == io.EOF {  
  91.                 break  
  92.             } else {  
  93.                 color.Red("任务%d读取响应错误!", task.Order)  
  94.                 color.HiRed("%s", readError)  
  95.                 return  
  96.             }  
  97.         }  
  98.         // 把缓冲区内容追加至文件  
  99.         _, writeError := writer.Write(buffer[0:readSize])  
  100.         if writeError != nil {  
  101.             color.Red("任务%d写入文件时出现错误!", task.Order)  
  102.             color.HiRed("%s", writeError)  
  103.             return  
  104.         }  
  105.         _ = writer.Flush()  
  106.         // 记录下载进度  
  107.         task.DownloadSize += int64(readSize)  
  108.     }  
  109.     // 关闭全部资源  
  110.     _ = body.Close()  
  111.     _ = file.Close()  
  112.     // 标记任务完成  
  113.     task.TaskDone = true  
  114.     // 使线程组中计数器-1  
  115.     waitGroup.Done()  
  116. }  

构造函数NewShardTask负责完成ShardTask的参数传入和状态量初始化,而DoShardGet方法实现了下载一个文件分片的完整步骤,从创建文件准备写入,到设定请求头,发出请求,最后读取响应体保存到文件。

此外,可见这里的http.Client对象中,我们将其DisableKeepAlives设为了true即关闭keep-alive,这是因为默认情况下Go语言的HTTP客户端会复用TCP连接,即使你多个线程发起请求,也会使用一个TCP连接进行

而多线程下载需要每个线程持有一个单独的TCP连接来达到突破cwnd的限制,因此这里关闭keep-alive实现每个线程发起请求时,使用单独的TCP连接。

(2) ParallelGetTask - 一整个多线程下载任务

C#代码
  1. package model  
  2.   
  3. import (  
  4.     "bufio"  
  5.     "fmt"  
  6.     "gitee.com/swsk33/shard-download-demo/util"  
  7.     "github.com/fatih/color"  
  8.     "io"  
  9.     "net/http"  
  10.     "os"  
  11.     "path/filepath"  
  12.     "strconv"  
  13.     "sync"  
  14.     "time"  
  15. )  
  16.   
  17. // ParallelGetTask 多线程下载任务类,存放一个多线程下载任务的参数和状态量  
  18. type ParallelGetTask struct {  
  19.     // 文件的下载链接  
  20.     Url string  
  21.     // 文件的最终保存位置  
  22.     FilePath string  
  23.     // 下载并发数  
  24.     Concurrent int  
  25.     // 下载的分片临时文件保存文件夹  
  26.     TempFolder string  
  27.     // 下载文件的总大小  
  28.     TotalSize int64  
  29.     // 全部的下载分片任务参数列表  
  30.     ShardTaskList []*ShardTask  
  31. }  
  32.   
  33. // NewParallelGetTask 构造函数  
  34. func NewParallelGetTask(url string, filePath string, concurrent int, tempFolder string) *ParallelGetTask {  
  35.     return &ParallelGetTask{  
  36.         // 参数赋值  
  37.         Url:        url,  
  38.         FilePath:   filePath,  
  39.         Concurrent: concurrent,  
  40.         TempFolder: tempFolder,  
  41.         // 初始化状态量  
  42.         TotalSize:     0,  
  43.         ShardTaskList: make([]*ShardTask, 0),  
  44.     }  
  45. }  
  46.   
  47. // 发送HEAD请求获取待下载文件的大小  
  48. func (task *ParallelGetTask) getLength() error {  
  49.     // 发送请求  
  50.     response, e := http.Head(task.Url)  
  51.     if e != nil {  
  52.         color.Red("发送HEAD请求出错!")  
  53.         return e  
  54.     }  
  55.     // 读取并设定长度  
  56.     task.TotalSize = response.ContentLength  
  57.     return nil  
  58. }  
  59.   
  60. // 根据待下载文件的大小和设定的并发数,创建每个分片任务对象  
  61. func (task *ParallelGetTask) allocateTask() {  
  62.     // 如果并发数大于总大小,则进行调整  
  63.     if int64(task.Concurrent) > task.TotalSize {  
  64.         task.Concurrent = int(task.TotalSize)  
  65.     }  
  66.     // 开始计算每个分片的下载范围  
  67.     eachSize := task.TotalSize / int64(task.Concurrent)  
  68.     // 创建任务对象  
  69.     for i := 0; i < task.Concurrent; i++ {  
  70.         task.ShardTaskList = append(task.ShardTaskList, NewShardTask(task.Url, i+1, filepath.Join(task.TempFolder, strconv.Itoa(i+1)), int64(i)*eachSize, int64(i+1)*eachSize-1))  
  71.     }  
  72.     // 处理末尾部分  
  73.     if task.TotalSize%int64(task.Concurrent) != 0 {  
  74.         task.ShardTaskList[task.Concurrent-1].RangeEnd = task.TotalSize - 1  
  75.     }  
  76. }  
  77.   
  78. // 根据任务列表进行多线程分片下载操作  
  79. func (task *ParallelGetTask) downloadShard() {  
  80.     // 创建线程组  
  81.     waitGroup := &sync.WaitGroup{}  
  82.     // 开始执行全部分片下载线程  
  83.     for _, task := range task.ShardTaskList {  
  84.         go task.DoShardGet(waitGroup)  
  85.         waitGroup.Add(1)  
  86.     }  
  87.     // 等待全部下载完成  
  88.     waitGroup.Wait()  
  89. }  
  90.   
  91. // 下载完成后,合并分片文件  
  92. func (task *ParallelGetTask) mergeFile() error {  
  93.     // 创建目的文件  
  94.     targetFile, e := os.OpenFile(task.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0755)  
  95.     if e != nil {  
  96.         color.Red("创建目标文件出错!")  
  97.         return e  
  98.     }  
  99.     // 创建写入器  
  100.     writer := bufio.NewWriter(targetFile)  
  101.     // 准备读取每个分片文件  
  102.     for _, shard := range task.ShardTaskList {  
  103.         shardFile, e := os.OpenFile(shard.ShardFilePath, os.O_RDONLY, 0755)  
  104.         if e != nil {  
  105.             color.Red("读取分片文件出错!")  
  106.             return e  
  107.         }  
  108.         reader := bufio.NewReader(shardFile)  
  109.         readBuffer := make([]byte, 1024*1024)  
  110.         for {  
  111.             // 读取每个分片文件,一次读取1KB  
  112.             readSize, readError := reader.Read(readBuffer)  
  113.             // 处理结束或错误  
  114.             if readError != nil {  
  115.                 if readError == io.EOF {  
  116.                     break  
  117.                 } else {  
  118.                     color.Red("读取分片文件出错!")  
  119.                     return readError  
  120.                 }  
  121.             }  
  122.             // 写入到最终合并的文件  
  123.             _, writeError := writer.Write(readBuffer[0:readSize])  
  124.             if writeError != nil {  
  125.                 color.Red("写入合并文件出错!")  
  126.                 return writeError  
  127.             }  
  128.             _ = writer.Flush()  
  129.         }  
  130.         // 关闭分片文件资源  
  131.         _ = shardFile.Close()  
  132.     }  
  133.     // 关闭目的文件资源  
  134.     _ = targetFile.Close()  
  135.     return nil  
  136. }  
  137.   
  138. // 删除分片临时文件  
  139. func (task *ParallelGetTask) cleanShard() error {  
  140.     for _, shard := range task.ShardTaskList {  
  141.         e := os.Remove(shard.ShardFilePath)  
  142.         if e != nil {  
  143.             color.Red("删除分片临时文件%s出错!", shard.ShardFilePath)  
  144.             return e  
  145.         }  
  146.     }  
  147.     return nil  
  148. }  
  149.   
  150. // 在一个新线程中,实时输出每个分片的下载进度和总进度  
  151. func (task *ParallelGetTask) printTotalProcess() {  
  152.     go func() {  
  153.         // 上一次统计时的已下载大小,用于计算速度  
  154.         var lastDownloadSize int64 = 0  
  155.         for {  
  156.             // 如果全部任务完成则结束输出,并统计并发数  
  157.             allDone := true  
  158.             // 当前并发数  
  159.             currentTaskCount := 0  
  160.             for _, shardTask := range task.ShardTaskList {  
  161.                 if !shardTask.TaskDone {  
  162.                     allDone = false  
  163.                     currentTaskCount += 1  
  164.                 }  
  165.             }  
  166.             if allDone {  
  167.                 break  
  168.             }  
  169.             // 统计所有分片已下载大小之和  
  170.             var totalDownloadSize int64 = 0  
  171.             for _, shardTask := range task.ShardTaskList {  
  172.                 totalDownloadSize += shardTask.DownloadSize  
  173.             }  
  174.             // 计算速度  
  175.             currentDownload := totalDownloadSize - lastDownloadSize  
  176.             lastDownloadSize = totalDownloadSize  
  177.             speedString := util.ComputeSpeed(currentDownload, 300)  
  178.             // 输出到控制台  
  179.             fmt.Printf("\r当前并发数:%3d 速度:%s 总进度:%3.2f%%", currentTaskCount, speedString, float32(totalDownloadSize)/float32(task.TotalSize)*100)  
  180.             // 等待300ms  
  181.             time.Sleep(300 * time.Millisecond)  
  182.         }  
  183.     }()  
  184. }  
  185.   
  186. // Run 开始执行整个分片多线程下载任务  
  187. func (task *ParallelGetTask) Run() error {  
  188.     // 获取文件大小  
  189.     e := task.getLength()  
  190.     if e != nil {  
  191.         color.Red("%s", e)  
  192.         return e  
  193.     }  
  194.     color.HiYellow("已获取到下载文件大小:%d字节", task.TotalSize)  
  195.     // 分配任务  
  196.     task.allocateTask()  
  197.     color.HiYellow("已完成分片任务分配,共计%d个任务", len(task.ShardTaskList))  
  198.     // 开启进度输出  
  199.     task.printTotalProcess()  
  200.     // 开始下载分片  
  201.     task.downloadShard()  
  202.     color.HiYellow("\n所有分片已下载完成!")  
  203.     // 开始合并文件  
  204.     e = task.mergeFile()  
  205.     if e != nil {  
  206.         color.Red("%s", e)  
  207.         return e  
  208.     }  
  209.     color.HiYellow("合并分片完成!")  
  210.     // 清理临时分片文件  
  211.     e = task.cleanShard()  
  212.     if e != nil {  
  213.         color.Red("%s", e)  
  214.         return e  
  215.     }  
  216.     color.HiYellow("清理分片临时文件完成!")  
  217.     color.Green("分片下载任务完成!")  
  218.     return nil  
  219. }  

上述printTotalProcess函数中,util.ComputeSpeed函数用于计算下载速度并自动转换为可读单位,代码如下:

C#代码
  1. package util  
  2.   
  3. import (  
  4.     "fmt"  
  5.     "math"  
  6. )  
  7.   
  8. // 关于单位的实用工具函数  
  9.   
  10. // ComputeSpeed 计算网络速度  
  11. // size 一段时间内下载的数据大小,单位字节  
  12. // timeElapsed 经过的时间长度,单位毫秒  
  13. // 返回计算得到的网速,会自动换算单位  
  14. func ComputeSpeed(size int64, timeElapsed intstring {  
  15.     bytePerSecond := size / int64(timeElapsed) * 1000  
  16.     if 0 <= bytePerSecond && bytePerSecond <= 1024 {  
  17.         return fmt.Sprintf("%4d Byte/s", bytePerSecond)  
  18.     }  
  19.     if bytePerSecond > 1024 && bytePerSecond <= int64(math.Pow(1024, 2)) {  
  20.         return fmt.Sprintf("%6.2f KB/s", float64(bytePerSecond)/1024)  
  21.     }  
  22.     if bytePerSecond > 1024*1024 && bytePerSecond <= int64(math.Pow(1024, 3)) {  
  23.         return fmt.Sprintf("%6.2f MB/s", float64(bytePerSecond)/math.Pow(1024, 2))  
  24.     }  
  25.     return fmt.Sprintf("%6.2f GB/s", float64(bytePerSecond)/math.Pow(1024, 3))  
  26. }  

 

可见通过构造函数NewParallelGetTask完成参数传递和状态量设定后,其它每个私有函数都对应我们多线程下载中的一个步骤,最后由公开函数Run统筹组织起所有的步骤,完成整个多线程下载任务。

3,实现效果

现在我们在main函数中创建一个ParallelGetTask对象,设定好参数后调用其Run方法即可开始多线程下载文件的任务:

C#代码
  1. package main  
  2.   
  3. import (  
  4.     "gitee.com/swsk33/shard-download-demo/model"  
  5. )  
  6.   
  7. func main() {  
  8.     // 创建任务  
  9.     task := model.NewParallelGetTask(  
  10.         "https://github.com/jgraph/drawio-desktop/releases/download/v24.7.17/draw.io-24.7.17-windows-installer.exe",  
  11.         "C:\\Users\\swsk33\\Downloads\\draw.io.exe",  
  12.         64,  
  13.         "C:\\Users\\swsk33\\Downloads\\temp",  
  14.     )  
  15.     // 执行任务  
  16.     _ = task.Run()  
  17. }  
 
PHP/Java | 评论(0) | 引用(0) | 阅读(12)