目录

核心步骤概览

第一步:准备标注数据 (这是最耗时但最关键的一步)

第二步:搭建数据集和数据加载器

第三步:构建 U-Net 模型

第四步:编写训练脚本 (train.py)

第五步:训练模型

第六步:对新影像进行预测 (predict.py)


核心步骤概览

  1. 准备标注数据 (最关键!)
  2. 搭建数据集和数据加载器
  3. 构建 U-Net 模型
  4. 编写训练脚本
  5. 训练模型
  6. 对新影像进行预测

第一步:准备标注数据 (这是最耗时但最关键的一步)

只有影像,但没有标注,模型是无法学习的。需要为这些 TIF 影像创建对应的像素级标注掩码 (Mask)。

  • 标注工具:
    • QGIS: 免费、强大的开源地理信息系统软件。可以加载 TIF 影像,然后创建新的矢量图层(多边形),手动绘制建筑物、水系、工程车辆的边界。绘制完成后,需要将矢量图层栅格化 (Rasterize) 成与原始影像分辨率、范围完全一致的 GeoTIFF 或 PNG 文件。这是最专业但也最耗时的方法。
    • Labelbox / Supervisely / VGG Image Annotator (VIA): 这些是在线或桌面的图像标注平台,专门为机器学习设计。它们通常提供多边形、多边形套索等工具,操作比 QGIS 更直观。标注完成后,导出为 PNG 格式的掩码文件(每个像素值代表一个类别)。
    • ArcGIS: 商业软件,功能强大,类似 QGIS。
  • 标注规范 (非常重要!):
    • 定义类别: 明确类别 ID。
      • 0: 背景 (Background)
      • 1: 建筑物 (Building)
      • 2: 水系 (Water)
      • 3: 工程车辆 (Construction Vehicle)
    • 掩码格式: 推荐使用 单通道 PNG 文件。文件名应与原始 TIF 影像对应(如 image_001.tif 对应 image_001_mask.png)。
    • 精度: 尽量精确地描绘边界,尤其是建筑物的直角和水系的蜿蜒轮廓。
    • 工程车辆: 由于目标小,标注时要特别仔细,确保不遗漏。
  • 数据集划分:

将标注好的数据划分为:

    • 训练集 (Training Set): ~70-80% 的数据,用于训练模型。
    • 验证集 (Validation Set): ~10-15% 的数据,用于在训练过程中监控模型性能,防止过拟合。
    • 测试集 (Test Set): ~10-15% 的数据,用于最终评估模型性能,在整个训练过程中绝对不能使用。

第二步:搭建数据集和数据加载器

创建一个 Python 脚本 dataset.py。

import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

class RemoteSensingDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        """
        Args:
            image_dir (str): 存放原始遥感影像 (.tif) 的目录路径。
            mask_dir (str): 存放标注掩码 (.png) 的目录路径。
            transform (callable, optional): 数据增强和预处理的转换函数。
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        # 获取所有影像文件名 (假设 .tif 和 .png 同名)
        self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.tiff'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # 获取文件名
        img_name = self.images[idx]
        mask_name = img_name.replace('.tif', '.png').replace('.tiff', '.png') # 假设掩码是png
        
        # 构建完整路径
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        # 加载影像 (PIL Image)
        # 注意:TIF 可能有多个波段,这里假设是标准的 RGB 3波段
        image = Image.open(img_path).convert("RGB") # 转为 RGB 模式
        image = np.array(image) # 转为 numpy array (H, W, C)

        # 加载掩码 (PIL Image)
        mask = Image.open(mask_path)
        mask = np.array(mask) # 转为 numpy array (H, W) 单通道
        # 确保 mask 的值是 0, 1, 2, 3 (你的类别ID)

        # 应用数据增强和预处理
        if self.transform is not None:
            # Albumentations 的 transform 接受字典 {'image': image, 'mask': mask}
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        else:
            # 如果没有 transform,手动进行基本处理
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask = torch.from_numpy(mask).long()

        return image, mask

# --- 定义数据增强和预处理 ---
# 强烈建议使用数据增强来提高模型鲁棒性
def get_transforms(train=True):
    if train:
        return A.Compose([
            A.Resize(512, 512), # U-Net 通常需要固定尺寸输入,或使用能处理任意尺寸的变体
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            # A.GaussNoise(var_limit=(10.0, 50.0), p=0.2), # 可选,模拟噪声
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet 标准化,常用
            ToTensorV2(), # 将 numpy array 转为 torch tensor, 并归一化到 [0,1]
        ])
    else:
        return A.Compose([
            A.Resize(512, 512),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

# --- 创建数据加载器 ---
# 假设你的数据目录结构如下:
# data/
#   ├── train/
#   │   ├── images/
#   │   └── masks/
#   ├── val/
#   │   ├── images/
#   │   └── masks/
#   └── test/
#       ├── images/
#       └── masks/

train_dataset = RemoteSensingDataset(
    image_dir="data/train/images",
    mask_dir="data/train/masks",
    transform=get_transforms(train=True)
)

val_dataset = RemoteSensingDataset(
    image_dir="data/val/images",
    mask_dir="data/val/masks",
    transform=get_transforms(train=False) # 验证集通常不做强增强
)

# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

安装依赖:

pip install pillow numpy albumentations

第三步:构建 U-Net 模型

创建 model.py。我们可以自己实现一个简单的 U-Net,但更推荐使用 segmentation_models_pytorch。

安装依赖

pip install segmentation-models-pytorch

生成model.py文件

# model.py
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

def create_unet_model(num_classes=4, encoder_name='resnet34', encoder_weights='imagenet'):
    """
    使用 smp 库创建 U-Net 模型。
    Args:
        num_classes: 分类数 (4: 背景, 建筑物, 水系, 工程车辆)
        encoder_name: 骨干网络名称,如 'resnet34', 'resnet50', 'efficientnet-b0' 等。
        encoder_weights: 预训练权重,'imagenet' 表示使用在 ImageNet 上预训练的权重。
    Returns:
        PyTorch 模型
    """
    model = smp.Unet(
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=3, # 输入是 RGB 3波段
        classes=num_classes,
        activation=None # 让损失函数 (如 CrossEntropyLoss) 处理
    )
    return model

# --- 创建模型 ---
model = create_unet_model(num_classes=4, encoder_name='resnet34', encoder_weights='imagenet')

第四步:编写训练脚本 (train.py)

# train.py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from dataset import RemoteSensingDataset, get_transforms, train_loader, val_loader # 假设 dataset.py 已定义
from model import create_unet_model # 或 UNet

# --- 1. 设备 ---
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# --- 2. 超参数 ---
num_classes = 4
lr = 1e-4
batch_size = 8 # 根据你的显存调整,MPS 可能支持 8-16
num_epochs = 100
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# --- 3. 模型、损失、优化器 ---
model = create_unet_model(num_classes=num_classes).to(device)
# model = UNet(n_classes=num_classes).to(device) # 如果使用手动实现

criterion = nn.CrossEntropyLoss() # 适用于多类别分割
optimizer = Adam(model.parameters(), lr=lr)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True) # 学习率调度

# --- 4. 训练循环 ---
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # --- 训练阶段 ---
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device) # [B, H, W]

        optimizer.zero_grad()
        outputs = model(images) # [B, num_classes, H, W]
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # --- 验证阶段 ---
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f} "
          f"Val Loss: {val_loss:.4f}")

    # --- 学习率调度 ---
    scheduler.step(val_loss)

    # --- 保存最佳模型 ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pth"))
        print(f"  --> Best model saved at epoch {epoch+1}")

    # 保存每个 epoch 的模型 (可选)
    # torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth"))

print("Training completed! Best model saved as 'best_model.pth'")

第五步:训练模型

  1. 组织好数据目录,确保 data/train/images, data/train/masks 等路径正确。
  2. 运行训练脚本:
python train.py
  1. 监控训练过程,观察训练损失和验证损失是否下降。如果验证损失不再下降甚至上升,说明可能过拟合。

第六步:对新影像进行预测 (predict.py)

# predict.py
import torch
import numpy as np
from PIL import Image
import os
from model import create_unet_model
from albumentations import Compose, Resize, Normalize, ToTensorV2
from albumentations.pytorch import ToTensorV2

def load_model(model_path, num_classes=4, device='cpu'):
    model = create_unet_model(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def preprocess_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    image = np.array(image)
    # 应用与训练时相同的 transform (Resize, Normalize, ToTensor)
    transformed = transform(image=image)
    image = transformed['image'].unsqueeze(0) # 添加 batch 维度 [1, C, H, W]
    return image

def postprocess_mask(mask_tensor):
    # mask_tensor shape: [1, num_classes, H, W]
    mask = mask_tensor.argmax(dim=1).squeeze(0) # 取最大概率的类别,[H, W]
    return mask.cpu().numpy().astype(np.uint8)

def save_prediction(mask, output_path):
    # 将预测结果保存为 PNG
    result = Image.fromarray(mask)
    result.save(output_path)

# --- 预测流程 ---
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = load_model("checkpoints/best_model.pth", num_classes=4, device=device)

# 定义预处理 transform (与训练时验证集相同)
transform = Compose([
    Resize(512, 512),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# 预测单张影像
image_path = "path/to/your/new_image.tif"
output_path = "predictions/prediction.png"

image_tensor = preprocess_image(image_path, transform).to(device)

with torch.no_grad():
    output = model(image_tensor) # [1, 4, H, W]
    predicted_mask = postprocess_mask(output)

save_prediction(predicted_mask, output_path)
print(f"Prediction saved to {output_path}")

Logo

立足具身智能前沿赛道,致力于搭建全球化、开源化、全栈式技术交流与实践共创平台。

更多推荐