本地模型推理#

import依赖#

引入需要的包依赖

[ ]:
from aiearth.predict.checkpoint import ModelCheckpoint
from aiearth.predict.predictors import TensorrtPredictor
from aiearth.predict.processors import (
    Chw2HwcProcessor,
    MeanNormProcessor,
    Hwc2ChwProcessor,
    SqueezeNdimProcessor,
    MaskBinarizationProcessor,
)
from aiearth.predict.pipelines import GeoSegmentationPredictPipeline

from aiearth.predict.logging import root_logger as logger

初始化checkpoint#

从本地模型的onnx路径初始化ModelCheckpoint,image_size参数是模型输入图像的大小,bound参数是模型预测需要忽略的边界大小,不计入最后结果

[ ]:
ckpt = ModelCheckpoint.from_local_path(
    "/path/to/your/onnx",
    image_size=1024,
    bound=128,
)

运行的时候需要将 /path/to/your/onnx 路径替换为本地onnx模型的绝对路径,并将image_size和bound修改为模型对应的参数

初始化pipeline#

初始化遥感分割推理任务的pipeline:

  • model_checkpoint:模型checkpoint

  • predictor_cls:模型推理类,通过model_checkpoint进行延迟初始化,可根据需求自定义predictor

  • pre_processors:模型前处理,参数类型为List,可传递多个处理算子,在pipeline里面会将算子组合成一个调用链

  • post_processors:模型后处理,参数类型为List,可传递多个处理算子,在pipeline里面会将算子组合成一个调用链

[ ]:
pipe = GeoSegmentationPredictPipeline(
    model_checkpoint=ckpt,
    predictor_cls=TensorrtPredictor,
    pre_processors=[
        Chw2HwcProcessor(["image"]),
        MeanNormProcessor(
            ["image"], [123.675, 116.28, 103.53], [0.01712475, 0.017507, 0.01742919]
        ),
        Hwc2ChwProcessor(["image"]),
    ],
    post_processors=[
        SqueezeNdimProcessor(["image"]),
        MaskBinarizationProcessor(["image"], 127.5),
    ],
)

示例这里predictor推理类使用的是TensorrtPredictor,可使用TensorRT转换的trt engine文件进行推理加速

模型前处理使用的是Chw2HwcProcessor->MeanNormProcessor->Hwc2ChwProcessor,Chw2HwcProcessor将rasterio读取的CHW图像格式转换为HWC格式,MeanNormProcessor将图像进行归一化处理,计算逻辑为(X-mean)*norm,Hwc2ChwProcessor再将HWC格式转换为CHW格式作为推理的输入

模型后处理使用的是SqueezeNdimProcessor->MaskBinarizationProcessor,推理的batch_size设置为1,SqueezeNdimProcessor为去掉batch维度,MaskBinarizationProcessor通过mask_threshold阈值将推理结果转换为mask灰度图

运行pipeline#

运行pipeline,pipeline运行完成之后会在当前目录输出结果shape文件

[ ]:
logger.info("run pipeline")
pipe(
    "/path/to/your/tiff",
)

运行的时候需要将/path/to/your/tiff路径替换为本地tiff影像的绝对路径