python多线程方法详解

处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明。

1、模块安装

pip install multiprocessing
pip install joblib

2、以分块计算NDVI为例

首先导入需要的包

import numpy as np
from osgeo import gdal
import time
from multiprocessing import cpu_count
from multiprocessing import Pool
from joblib import Parallel, delayed

定义GdalUtil类,以读取遥感数据

class GdalUtil:
    def __init__(self):
        pass
    @staticmethod
    def read_file(raster_file, read_band=None):
        """读取栅格数据"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
        # 打开输入图像
        dataset = gdal.Open(raster_file, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_file)
        # 列
        raster_width = dataset.RasterXSize
        # 行
        raster_height = dataset.RasterYSize
        # 读取数据
        if read_band == None:
            data_array = dataset.ReadAsArray(0, 0, raster_width, raster_height)
        else:
            band = dataset.GetRasterBand(read_band)
            data_array = band.ReadAsArray(0, 0, raster_width, raster_height)
        return data_array

    @staticmethod
    def read_block_data(dataset, band_num, cols_read, rows_read, start_col=0, start_row=0):
        band = dataset.GetRasterBand(band_num)
        res_data = band.ReadAsArray(start_col, start_row, cols_read, rows_read)
        return res_data

    @staticmethod
    def get_raster_band(raster_path):
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
        raster_band = dataset.RasterCount
        return raster_band

    @staticmethod
    def get_file_size(raster_path):
        """获取栅格仿射变换参数"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')

        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
        # 列
        raster_width = dataset.RasterXSize
        # 行
        raster_height = dataset.RasterYSize
        return raster_width, raster_height

    @staticmethod
    def get_file_geotransform(raster_path):
        """获取栅格仿射变换参数"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')

        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)

        # 获取输入图像仿射变换参数
        input_geotransform = dataset.GetGeoTransform()
        return input_geotransform

    @staticmethod
    def get_file_proj(raster_path):
        """获取栅格图像空间参考"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')

        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)

        # 获取输入图像空间参考
        input_project = dataset.GetProjection()
        return input_project

    @staticmethod
    def write_file(dataset, geotransform, project, output_path, out_format='GTiff', eType=gdal.GDT_Float32):
        """写入栅格"""
        if np.ndim(dataset) == 3:
            out_band, out_rows, out_cols = dataset.shape
        else:
            out_band = 1
            out_rows, out_cols = dataset.shape

        # 创建指定输出格式的驱动
        out_driver = gdal.GetDriverByName(out_format)
        if out_driver == None:
            print('格式%s 不支持Creat()方法.\n', out_format)
            return

        out_dataset = out_driver.Create(output_path, xsize=out_cols,
                                        ysize=out_rows, bands=out_band,
                                        eType=eType)
        # 设置输出图像的仿射参数
        out_dataset.SetGeoTransform(geotransform)

        # 设置输出图像的投影参数
        out_dataset.SetProjection(project)

        # 写出数据
        if out_band == 1:
            out_dataset.GetRasterBand(1).WriteArray(dataset)
        else:
            for i in range(out_band):
                out_dataset.GetRasterBand(i + 1).WriteArray(dataset[i])
        del out_dataset

定义计算NDVI的函数

def cal_ndvi(multi):
    '''
    计算高分NDVI
    :param multi:格式为列表,依次包含[遥感文件路径,开始行号,开始列号,待读的行数,待读的列数]
    :return: NDVI数组
    '''
    input_file, start_col, start_row, cols_step, rows_step = multi
    dataset = gdal.Open(input_file, gdal.GA_ReadOnly)
    nir_data = GdalUtil.read_block_data(dataset, 4, cols_step, rows_step, start_col=start_col, start_row=start_row)
    red_data = GdalUtil.read_block_data(dataset, 3, cols_step, rows_step, start_col=start_col, start_row=start_row)
    ndvi = (nir_data - red_data) / (nir_data + red_data)
    ndvi[(ndvi > 1.5) | (ndvi < -1)] = 0
    return ndvi
定义主函数
if __name__ == "__main__":
    input_file = r'D:\originalData\GF1\namucuo2021.tif'
    output_file = r'D:\originalData\GF1\namucuo2021_ndvi.tif'
    method = 'joblib'
    # method = 'multiprocessing'
    # 获取文件主要信息
    raster_cols, raster_rows = GdalUtil.get_file_size(input_file)
    geotransform = GdalUtil.get_file_geotransform(input_file)
    project = GdalUtil.get_file_proj(input_file)
    # 定义分块大小
    rows_block_size = 50
    cols_block_size = 50
    multi = []
    for j in range(0, raster_rows, rows_block_size):
        for i in range(0, raster_cols, cols_block_size):
            if j + rows_block_size < raster_rows:
                rows_step = rows_block_size
            else:
                rows_step = raster_rows - j
            # 数据横向步长
            if i + cols_block_size < raster_cols:
                cols_step = cols_block_size
            else:
                cols_step = raster_cols - i
            temp_multi = [input_file, i, j, cols_step, rows_step]
            multi.append(temp_multi)

    t1 = time.time()
    if method == 'multiprocessing':
        # multiprocessing方法
        pool = Pool(processes=cpu_count()-1)
        # 注意map函数中传入的参数应该是可迭代对象,如list;返回值为list
        res = pool.map(cal_ndvi, multi)
        pool.close()
        pool.join()
    else:
        # joblib方法
        res = Parallel(n_jobs=-1)(delayed(cal_ndvi)(input_list) for input_list in multi)

    t2 = time.time()
    print("Total time:" + (t2 - t1).__str__())

    # 将multiprocessing中的结果提取出来,放回对应的矩阵位置中
    out_data = np.zeros([raster_rows, raster_cols], dtype='float')
    for result, input_multi in zip(res, multi):
        start_col = input_multi[1]
        start_row = input_multi[2]
        cols_step = input_multi[3]
        rows_step = input_multi[4]
        out_data[start_row:start_row + rows_step, start_col:start_col + cols_step] = result

    GdalUtil.write_file(out_data, geotransform, project, output_file)

双重for循环时,两层for循环都使用multiprocessing时会报错,这时可以外层for循环使用joblib方法,内层for循环改为multiprocessing方法,不会报错

到此这篇关于python多线程方法详解的文章就介绍到这了,更多相关python多线程内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • python多线程中的定时器你了解吗

    定时器 定时器:每隔一段时间启动一次线程 threading.Timer创建的是一个线程!定时器基本上都是在线程中执行 创建定时器: threading.Timer(interval, function, args=None, kwargs=None) interval — 定时器间隔,间隔多少秒之后启动定时器任务(单位:秒):function — 线程函数:args — 线程参数,可以传递元组类型数据,默认为空(缺省参数): kwargs — 线程参数,可以传递字典类型数据,默认为空(缺省参数

  • 深入了解Python的多线程基础

    目录 线程 多线程 Python多线程 创建线程 GIL锁 线程池 总结 线程 线程(Thread),有时也被称为轻量级进程(Lightweight Process,LWP),是操作系统独⽴调度和分派的基本单位,本质上就是一串指令的集合. ⼀个标准的线程由线程id.当前指令指针(PC),寄存器集合和堆栈组成,它是进程中的⼀个实体,线程本身不拥有系统资源,只拥有⼀点⼉在运⾏中必不可少的资源(如程序计数器.寄存器.栈),但它可与同属⼀个进程的其它线程共享进程所拥有的全部资源.线程不能够独⽴执⾏,必须

  • python的多线程原来可以这样解

    目录 多线程 创建线程 总结 多线程 多线程类似于同时执行多个不同程序,多线程运行有如下优点: 使用线程可以把占据长时间的程序中的任务放到后台去处理.用户界面可以更加吸引人,比如用户点击了一个按钮去触发某些事件的处理,可以弹出一个进度条来显示处理的进度.程序的运行速度可能加快.在一些等待的任务实现上如用户输入.文件读写和网络收发数据等,线程就比较有用了.在这种情况下我们可以释放一些珍贵的资源如内存占用等等. 创建线程 一个进程里面必然有一个主线程. 创建线程的两种方法 继承Thread类,并重写

  • Python线程之多线程展示详解

    目录 什么多线程? 获取活跃线程相关数据 总结 什么多线程? 多线程,就是多个独立的运行单位,同时执行同样的事情. 想想一下,文章发布后同时被很多读者阅读,这些读者在做的事情'阅读'就是一个一个的线程. 多线程就是多个读者同时阅读这篇文章.重点是:同时有多个读者在做阅读这件事情. 如果是多个读者,分时间阅读,最后任意时刻只有一个读者在阅读,虽然是多个读者,但还是单线程. 我们再拿前面分享的代码:关注和点赞. def dianzan_guanzhu(): now = datetime.dateti

  • Python多线程即相关理念详解

    目录 一.什么是线程? 二.开启线程的两种方式 1.方式1 2.方式2 三.线程对象的jion方法() 四. 补充小案例 五.守护线程 六.线程互斥锁 七.GTL-全局解释器 八.验证多线程与多线程运用场景 总结: 一.什么是线程? 线程顾名思义,就是一条流水线工作的过程,一条流水线必须属于一个车间,一个车间的工作过程是一个进程.车间负责把资源整合到一起,是一个资源单位,而一个车间内至少有一个流水线.所以,进程只是用来把资源集中到一起(进程只是一个资源单位,或者说资源集合),而线程才是cpu上的

  • python多线程方法详解

    处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明. 1.模块安装 pip install multiprocessing pip install joblib 2.以分块计算NDVI为例 首先导入需要的包 import numpy as np from osgeo import gdal import time from multiprocessing import cpu_c

  • Python 多线程实例详解

    Python 多线程实例详解 多线程通常是新开一个后台线程去处理比较耗时的操作,Python做后台线程处理也是很简单的,今天从官方文档中找到了一个Demo. 实例代码: import threading, zipfile class AsyncZip(threading.Thread): def __init__(self, infile, outfile): threading.Thread.__init__(self) self.infile = infile self.outfile =

  • TF-IDF算法解析与Python实现方法详解

    TF-IDF(term frequency–inverse document frequency)是一种用于信息检索(information retrieval)与文本挖掘(text mining)的常用加权技术.比较容易理解的一个应用场景是当我们手头有一些文章时,我们希望计算机能够自动地进行关键词提取.而TF-IDF就是可以帮我们完成这项任务的一种统计方法.它能够用于评估一个词语对于一个文集或一个语料库中的其中一份文档的重要程度. 在一份给定的文件里,词频 (term frequency, T

  • json跨域调用python的方法详解

    本文实例讲述了json跨域调用python的方法.分享给大家供大家参考,具体如下: 客户端: <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns="http://www.w3.org/1999/xhtml">

  • Python魔术方法详解

    介绍 此教程为我的数篇文章中的一个重点.主题是魔术方法. 什么是魔术方法?他们是面向对象的Python的一切.他们是可以给你的类增加"magic"的特殊方法.他们总是被双下划线所包围(e.g. __init__ 或者 __lt__).然而他们的文档却远没有提供应该有的内容.Python中所有的魔术方法均在Python官方文档中有相应描述,但是对于他们的描述比较混乱而且组织比较松散.很难找到有一个例子(也许他们原本打算的很好,在开始语言参考中有描述很详细,然而随之而来的确是枯燥的语法描述

  • 决策树剪枝算法的python实现方法详解

    本文实例讲述了决策树剪枝算法的python实现方法.分享给大家供大家参考,具体如下: 决策树是一种依托决策而建立起来的一种树.在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值.决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出. ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用

  • 机器学习之KNN算法原理及Python实现方法详解

    本文实例讲述了机器学习之KNN算法原理及Python实现方法.分享给大家供大家参考,具体如下: 文中代码出自<机器学习实战>CH02,可参考本站: 机器学习实战 (Peter Harrington著) 中文版 机器学习实战 (Peter Harrington著) 英文原版 [附源代码] KNN算法介绍 KNN是一种监督学习算法,通过计算新数据与训练数据特征值之间的距离,然后选取K(K>=1)个距离最近的邻居进行分类判(投票法)或者回归.若K=1,新数据被简单分配给其近邻的类. KNN算法

  • Python魔法方法详解

    据说,Python 的对象天生拥有一些神奇的方法,它们总被双下划线所包围,他们是面向对象的 Python 的一切. 他们是可以给你的类增加魔力的特殊方法,如果你的对象实现(重载)了这些方法中的某一个,那么这个方法就会在特殊的情况下被 Python 所调用,你可以定义自己想要的行为,而这一切都是自动发生的. Python 的魔术方法非常强大,然而随之而来的则是责任.了解正确的方法去使用非常重要! 魔法方法 含义 基本的魔法方法 __new__(cls[, ...]) new 是在一个对象实例化的时

  • Anconda环境下Vscode安装Python的方法详解

    这里使用的操作系统为win7/10,安装环境是使用Anconda搭建Python环境,然后在Vscode编辑器中安装Python插件,最终能够在Vscode环境下使用Python. 一.Anconda软件的安装 Anaconda is a completely free Python distribution (including for commercial use and redistribution). It includes over 195 of the most popularPyt

  • 对Python 多线程统计所有csv文件的行数方法详解

    如下所示: #统计某文件夹下的所有csv文件的行数(多线程) import threading import csv import os class MyThreadLine(threading.Thread): #用于统计csv文件的行数的线程类 def __init__(self,path): threading.Thread.__init__(self) #父类初始化 self.path=path #路径 self.line=-1 #统计行数 def run(self): reader =

随机推荐