当传统的特征点,遇上深度学习的关键点:计算机视觉特征提取的演进与融合

引言:特征点在计算机视觉中的重要性

特征点(Feature Points)或关键点(Keypoints)是计算机视觉和图像处理领域的基石概念。它们代表了图像中具有独特性的局部区域,如角点、边缘交叉点、斑点等,能够抵抗光照变化、视角变化、旋转和尺度变换。从图像配准、三维重建到目标跟踪、视觉SLAM,特征点技术无处不在。

传统特征点方法基于手工设计的特征描述符,如SIFT、SURF、ORB等,这些方法在过去的二十年里取得了巨大成功。然而,随着深度学习的发展,基于神经网络的关键点检测和描述方法逐渐崭露头角,如SuperPoint、LF-Net、D2-Net等,它们在某些任务上展现出了超越传统方法的性能。

本文将深入探讨传统特征点与深度学习关键点的发展历程、技术原理、实现方法以及融合应用,通过详细的代码示例展示两者的实际应用和对比分析。

第一部分:传统特征点检测与描述

1.1 特征点的基本概念与评价标准

特征点检测的目标是找到图像中具有以下特性的点:

  • 可重复性:在不同图像(不同视角、光照等)中能够稳定检测到相同的特征点

  • 显著性:特征点周围的局部区域具有足够的独特性

  • 定位精度:能够精确定位到像素级别

  • 高效性:计算效率高,适合实时应用

传统特征点方法通常分为两个阶段:检测描述。检测阶段确定特征点的位置,描述阶段则为每个特征点生成一个描述符向量,用于特征匹配。

1.2 经典传统特征点方法

1.2.1 SIFT (Scale-Invariant Feature Transform)

SIFT由David Lowe于1999年提出,是最具影响力的特征点算法之一。其核心思想是通过尺度空间极值检测来寻找关键点,并使用局部梯度方向直方图生成描述符。

SIFT算法的主要步骤:

  1. 尺度空间极值检测:通过高斯差分金字塔检测局部极值点

  2. 关键点定位:精确定位关键点位置,去除低对比度和边缘响应点

  3. 方向分配:根据局部梯度方向为关键点分配主方向

  4. 描述符生成:计算关键点周围区域的梯度方向直方图

python

import cv2
import numpy as np
import matplotlib.pyplot as plt

# SIFT特征点检测与匹配示例
def sift_feature_demo():
    # 读取图像
    img1 = cv2.imread('image1.jpg', cv2.IMREAD_GRAYSCALE)
    img2 = cv2.imread('image2.jpg', cv2.IMREAD_GRAYSCALE)
    
    # 初始化SIFT检测器
    sift = cv2.SIFT_create()
    
    # 检测关键点和计算描述符
    keypoints1, descriptors1 = sift.detectAndCompute(img1, None)
    keypoints2, descriptors2 = sift.detectAndCompute(img2, None)
    
    # 可视化关键点
    img1_keypoints = cv2.drawKeypoints(img1, keypoints1, None, 
                                      flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    img2_keypoints = cv2.drawKeypoints(img2, keypoints2, None,
                                      flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    
    # 使用FLANN匹配器进行特征匹配
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(descriptors1, descriptors2, k=2)
    
    # 应用Lowe's比率测试筛选好的匹配
    good_matches = []
    for m, n in matches:
        if m.distance < 0.7 * n.distance:
            good_matches.append(m)
    
    # 绘制匹配结果
    img_matches = cv2.drawMatches(img1, keypoints1, img2, keypoints2, 
                                 good_matches, None, flags=2)
    
    # 显示结果
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes[0, 0].imshow(img1, cmap='gray')
    axes[0, 0].set_title('Image 1')
    axes[0, 1].imshow(img2, cmap='gray')
    axes[0, 1].set_title('Image 2')
    axes[1, 0].imshow(img1_keypoints, cmap='gray')
    axes[1, 0].set_title(f'Image 1 Keypoints: {len(keypoints1)}')
    axes[1, 1].imshow(img_matches)
    axes[1, 1].set_title(f'Feature Matches: {len(good_matches)}')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'keypoints1': keypoints1,
        'keypoints2': keypoints2,
        'descriptors1': descriptors1,
        'descriptors2': descriptors2,
        'matches': good_matches
    }

# 执行SIFT示例
# 注意:需要准备image1.jpg和image2.jpg图像文件
# result = sift_feature_demo()
1.2.2 SURF (Speeded-Up Robust Features)

SURF是SIFT的加速版本,使用盒式滤波器和积分图像加速计算,同时保持了良好的旋转和尺度不变性。

python

def surf_feature_demo():
    # 读取图像
    img1 = cv2.imread('image1.jpg', cv2.IMREAD_GRAYSCALE)
    img2 = cv2.imread('image2.jpg', cv2.IMREAD_GRAYSCALE)
    
    # 初始化SURF检测器
    # 注意:OpenCV中SURF已移至xfeatures2d模块,部分版本需要contrib
    surf = cv2.xfeatures2d.SURF_create(hessianThreshold=400)
    
    # 检测关键点和计算描述符
    keypoints1, descriptors1 = surf.detectAndCompute(img1, None)
    keypoints2, descriptors2 = surf.detectAndCompute(img2, None)
    
    # 可视化关键点
    img1_keypoints = cv2.drawKeypoints(img1, keypoints1, None,
                                      flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    img2_keypoints = cv2.drawKeypoints(img2, keypoints2, None,
                                      flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    
    # 使用BFMatcher进行特征匹配
    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
    matches = bf.match(descriptors1, descriptors2)
    
    # 按距离排序
    matches = sorted(matches, key=lambda x: x.distance)
    
    # 绘制前50个匹配
    img_matches = cv2.drawMatches(img1, keypoints1, img2, keypoints2,
                                 matches[:50], None, flags=2)
    
    # 显示结果
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes[0, 0].imshow(img1, cmap='gray')
    axes[0, 0].set_title('Image 1')
    axes[0, 1].imshow(img2, cmap='gray')
    axes[0, 1].set_title('Image 2')
    axes[1, 0].imshow(img1_keypoints, cmap='gray')
    axes[1, 0].set_title(f'Image 1 Keypoints: {len(keypoints1)}')
    axes[1, 1].imshow(img_matches)
    axes[1, 1].set_title(f'Feature Matches: {len(matches)}')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'keypoints1': keypoints1,
        'keypoints2': keypoints2,
        'descriptors1': descriptors1,
        'descriptors2': descriptors2,
        'matches': matches
    }
1.2.3 ORB (Oriented FAST and Rotated BRIEF)

ORB结合了FAST关键点检测器和BRIEF描述符,并添加了方向不变性和尺度不变性,是一种高效且免费的特征点算法。

python

def orb_feature_demo():
    # 读取图像
    img1 = cv2.imread('image1.jpg', cv2.IMREAD_GRAYSCALE)
    img2 = cv2.imread('image2.jpg', cv2.IMREAD_GRAYSCALE)
    
    # 初始化ORB检测器
    orb = cv2.ORB_create(nfeatures=1000)
    
    # 检测关键点和计算描述符
    keypoints1, descriptors1 = orb.detectAndCompute(img1, None)
    keypoints2, descriptors2 = orb.detectAndCompute(img2, None)
    
    # 可视化关键点
    img1_keypoints = cv2.drawKeypoints(img1, keypoints1, None, 
                                      color=(0, 255, 0), flags=0)
    img2_keypoints = cv2.drawKeypoints(img2, keypoints2, None,
                                      color=(0, 255, 0), flags=0)
    
    # 使用BFMatcher进行特征匹配(ORB使用汉明距离)
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    matches = bf.match(descriptors1, descriptors2)
    
    # 按距离排序
    matches = sorted(matches, key=lambda x: x.distance)
    
    # 计算单应性矩阵进行几何验证
    if len(matches) > 10:
        src_pts = np.float32([keypoints1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([keypoints2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
        
        # 使用RANSAC计算单应性矩阵
        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        matches_mask = mask.ravel().tolist()
        
        # 绘制经过几何验证的匹配
        draw_params = dict(matchColor=(0, 255, 0),
                          singlePointColor=None,
                          matchesMask=matches_mask,
                          flags=2)
        
        img_matches = cv2.drawMatches(img1, keypoints1, img2, keypoints2,
                                     matches, None, **draw_params)
    else:
        img_matches = cv2.drawMatches(img1, keypoints1, img2, keypoints2,
                                     matches[:30], None, flags=2)
    
    # 显示结果
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes[0, 0].imshow(img1, cmap='gray')
    axes[0, 0].set_title('Image 1')
    axes[0, 1].imshow(img2, cmap='gray')
    axes[0, 1].set_title('Image 2')
    axes[1, 0].imshow(img1_keypoints, cmap='gray')
    axes[1, 0].set_title(f'Image 1 Keypoints: {len(keypoints1)}')
    axes[1, 1].imshow(img_matches)
    axes[1, 1].set_title(f'Feature Matches: {len(matches)}')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'keypoints1': keypoints1,
        'keypoints2': keypoints2,
        'descriptors1': descriptors1,
        'descriptors2': descriptors2,
        'matches': matches
    }

1.3 传统特征点的性能分析

为了全面评估传统特征点算法的性能,我们设计了一个综合测试:

python

import time
from scipy.spatial.distance import cdist

def evaluate_feature_detectors(images):
    """
    评估不同特征点检测器的性能
    
    参数:
        images: 图像列表
        
    返回:
        评估结果字典
    """
    detectors = {
        'SIFT': cv2.SIFT_create(),
        'SURF': cv2.xfeatures2d.SURF_create(hessianThreshold=400),
        'ORB': cv2.ORB_create(nfeatures=1000),
        'AKAZE': cv2.AKAZE_create(),
        'BRISK': cv2.BRISK_create()
    }
    
    results = {}
    
    for name, detector in detectors.items():
        print(f"测试 {name} 检测器...")
        
        detector_results = {
            'detection_time': [],
            'keypoint_counts': [],
            'repeatability': [],
            'matching_score': []
        }
        
        for i, img in enumerate(images):
            # 检测关键点和计算描述符
            start_time = time.time()
            keypoints, descriptors = detector.detectAndCompute(img, None)
            detection_time = time.time() - start_time
            
            detector_results['detection_time'].append(detection_time)
            detector_results['keypoint_counts'].append(len(keypoints))
            
            # 如果有多张图像,计算重复性和匹配分数
            if i > 0:
                # 计算特征匹配
                if name in ['SIFT', 'SURF']:
                    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
                else:
                    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
                
                matches = bf.match(descriptors_prev, descriptors)
                detector_results['matching_score'].append(len(matches))
            
            # 保存当前描述符供下一次使用
            descriptors_prev = descriptors
        
        # 计算平均指标
        results[name] = {
            'avg_detection_time': np.mean(detector_results['detection_time']),
            'avg_keypoints': np.mean(detector_results['keypoint_counts']),
            'avg_matching_score': np.mean(detector_results['matching_score']) if detector_results['matching_score'] else 0
        }
    
    return results

def plot_evaluation_results(results):
    """
    绘制评估结果
    """
    names = list(results.keys())
    detection_times = [results[name]['avg_detection_time'] for name in names]
    keypoint_counts = [results[name]['avg_keypoints'] for name in names]
    matching_scores = [results[name]['avg_matching_score'] for name in names]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # 检测时间柱状图
    axes[0].bar(names, detection_times, color='skyblue')
    axes[0].set_title('平均检测时间')
    axes[0].set_ylabel('时间(秒)')
    axes[0].tick_params(axis='x', rotation=45)
    
    # 关键点数量柱状图
    axes[1].bar(names, keypoint_counts, color='lightgreen')
    axes[1].set_title('平均关键点数量')
    axes[1].set_ylabel('数量')
    axes[1].tick_params(axis='x', rotation=45)
    
    # 匹配分数柱状图
    axes[2].bar(names, matching_scores, color='lightcoral')
    axes[2].set_title('平均匹配分数')
    axes[2].set_ylabel('匹配对数')
    axes[2].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()

# 示例使用
# 假设我们有一组测试图像
# test_images = [img1, img2, img3, ...]
# eval_results = evaluate_feature_detectors(test_images)
# plot_evaluation_results(eval_results)

第二部分:深度学习关键点检测与描述

2.1 深度学习关键点检测的基本原理

与传统手工设计的特征点不同,深度学习关键点检测通过神经网络自动学习图像中的显著特征。这些方法通常使用卷积神经网络(CNN)来预测关键点位置和描述符,能够更好地适应复杂场景和变化。

深度学习关键点检测的主要优势:

  1. 更强的表示能力:神经网络能够学习复杂的特征表示

  2. 端到端学习:可以直接从数据中学习最优的特征表示

  3. 更好的鲁棒性:对光照变化、视角变化等具有更好的适应性

  4. 上下文感知:能够利用周围区域的上下文信息

2.2 经典深度学习关键点方法

2.2.1 SuperPoint:自监督关键点检测与描述

SuperPoint是一种自监督学习的关键点检测和描述方法,使用单应性适应进行训练,能够在无需人工标注的情况下学习关键点检测和描述。

python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SuperPointNet(nn.Module):
    """
    SuperPoint网络结构
    参考论文: "SuperPoint: Self-Supervised Interest Point Detection and Description"
    """
    def __init__(self):
        super(SuperPointNet, self).__init__()
        
        # 共享编码器
        self.conv1a = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.conv1b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        
        self.conv2a = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv2b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        
        self.conv3a = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3b = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        
        self.conv4a = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv4b = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 检测头
        self.convPa = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.convPb = nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
        
        # 描述头
        self.convDa = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.convDb = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        """
        前向传播
        
        参数:
            x: 输入图像 [B, 1, H, W]
            
        返回:
            semi: 关键点检测得分 [B, 65, H/8, W/8]
            desc: 描述符 [B, 256, H/8, W/8]
        """
        # 共享编码器
        x = F.relu(self.conv1a(x))
        x = F.relu(self.conv1b(x))
        x = self.pool(x)  # H/2, W/2
        
        x = F.relu(self.conv2a(x))
        x = F.relu(self.conv2b(x))
        x = self.pool(x)  # H/4, W/4
        
        x = F.relu(self.conv3a(x))
        x = F.relu(self.conv3b(x))
        x = self.pool(x)  # H/8, W/8
        
        x = F.relu(self.conv4a(x))
        x = F.relu(self.conv4b(x))
        
        # 关键点检测头
        cPa = F.relu(self.convPa(x))
        semi = self.convPb(cPa)  # [B, 65, H/8, W/8]
        
        # 描述符头
        cDa = F.relu(self.convDa(x))
        desc = self.convDb(cDa)  # [B, 256, H/8, W/8]
        
        # 描述符L2归一化
        dn = torch.norm(desc, p=2, dim=1)  # 计算L2范数
        desc = desc.div(torch.unsqueeze(dn, 1))  # 归一化
        
        return semi, desc
    
def superpoint_detector_and_descriptor(image, model, device='cpu'):
    """
    使用SuperPoint模型检测关键点和计算描述符
    
    参数:
        image: 输入图像 [H, W]
        model: SuperPoint模型
        device: 设备
        
    返回:
        keypoints: 关键点列表
        descriptors: 描述符矩阵
    """
    # 将图像转换为张量
    if len(image.shape) == 2:
        image = np.expand_dims(image, axis=0)  # [1, H, W]
    
    image_tensor = torch.from_numpy(image).unsqueeze(0).float()  # [1, 1, H, W]
    image_tensor = image_tensor.to(device)
    
    # 前向传播
    model.eval()
    with torch.no_grad():
        semi, desc = model(image_tensor)
    
    # 将semi转换为关键点概率
    dense = semi.exp()
    dense = dense / (torch.sum(dense, dim=1, keepdim=True) + 1e-8)
    
    # 移除dustbin通道
    nodust = dense[:, :-1, :, :]
    
    # 重塑为热图
    Hc = nodust.shape[2]
    Wc = nodust.shape[3]
    nodust = nodust.permute(0, 2, 3, 1)  # [B, Hc, Wc, 64]
    heatmap = nodust.reshape(-1, Hc, Wc, 8, 8)
    heatmap = heatmap.permute(0, 1, 3, 2, 4)  # [B, Hc, 8, Wc, 8]
    heatmap = heatmap.reshape(-1, Hc*8, Wc*8)  # [B, H, W]
    
    # 提取关键点
    heatmap_np = heatmap[0].cpu().numpy()
    keypoints = []
    
    # 使用非极大值抑制提取关键点
    from scipy.ndimage import maximum_filter
    
    # 非极大值抑制
    neighborhood_size = 5
    threshold = 0.015
    
    data_max = maximum_filter(heatmap_np, neighborhood_size)
    maxima = (heatmap_np == data_max)
    
    # 应用阈值
    maxima[heatmap_np < threshold] = 0
    
    # 获取关键点坐标
    yx = np.argwhere(maxima)
    
    for y, x in yx:
        score = heatmap_np[y, x]
        keypoints.append(cv2.KeyPoint(x, y, 1, -1, score))
    
    # 提取描述符
    desc_np = desc[0].cpu().numpy()  # [256, Hc, Wc]
    
    # 在关键点位置插值描述符
    descriptors = []
    for kp in keypoints:
        x = kp.pt[0] / 8.0
        y = kp.pt[1] / 8.0
        
        # 双线性插值
        x0, y0 = int(np.floor(x)), int(np.floor(y))
        x1, y1 = x0 + 1, y0 + 1
        
        # 边界检查
        x0 = max(0, min(x0, desc_np.shape[2]-1))
        x1 = max(0, min(x1, desc_np.shape[2]-1))
        y0 = max(0, min(y0, desc_np.shape[1]-1))
        y1 = max(0, min(y1, desc_np.shape[1]-1))
        
        # 权重
        wa = (x1 - x) * (y1 - y)
        wb = (x1 - x) * (y - y0)
        wc = (x - x0) * (y1 - y)
        wd = (x - x0) * (y - y0)
        
        # 插值描述符
        desc_a = desc_np[:, y0, x0]
        desc_b = desc_np[:, y1, x0]
        desc_c = desc_np[:, y0, x1]
        desc_d = desc_np[:, y1, x1]
        
        descriptor = wa * desc_a + wb * desc_b + wc * desc_c + wd * desc_d
        descriptors.append(descriptor)
    
    descriptors = np.array(descriptors)
    
    return keypoints, descriptors
2.2.2 D2-Net:联合检测和描述的特征点

D2-Net提出了一种联合检测和描述的方法,使用单个网络同时进行特征点检测和描述符计算。

python

class D2Net(nn.Module):
    """
    D2-Net网络结构
    参考论文: "D2-Net: A Trainable CNN for Joint Detection and Description of Local Features"
    """
    def __init__(self, use_relu=True):
        super(D2Net, self).__init__()
        
        # 特征提取骨干网络
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU(),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True) if use_relu else nn.ReLU()
        )
        
        # 描述符输出层
        self.descriptor = nn.Conv2d(512, 128, kernel_size=1)
        
    def forward(self, x):
        """
        前向传播
        
        参数:
            x: 输入图像 [B, 3, H, W]
            
        返回:
            features: 特征图 [B, 512, H/8, W/8]
            descriptors: 描述符 [B, 128, H/8, W/8]
        """
        # 提取特征
        features = self.features(x)
        
        # 计算描述符
        descriptors = self.descriptor(features)
        
        # L2归一化
        descriptors = F.normalize(descriptors, p=2, dim=1)
        
        return features, descriptors
    
def d2net_detector_and_descriptor(image, model, device='cpu', detection_threshold=0.015):
    """
    使用D2-Net模型检测关键点和计算描述符
    
    参数:
        image: 输入图像 [H, W, 3]
        model: D2-Net模型
        device: 设备
        detection_threshold: 检测阈值
        
    返回:
        keypoints: 关键点列表
        descriptors: 描述符矩阵
    """
    # 预处理图像
    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    
    # 调整大小为8的倍数
    H, W = image.shape[:2]
    H_new = (H // 8) * 8
    W_new = (W // 8) * 8
    image_resized = cv2.resize(image, (W_new, H_new))
    
    # 转换为张量
    image_tensor = torch.from_numpy(image_resized).permute(2, 0, 1).unsqueeze(0).float()
    image_tensor = image_tensor.to(device)
    
    # 前向传播
    model.eval()
    with torch.no_grad():
        features, descriptors = model(image_tensor)
    
    # 转换为numpy
    features_np = features[0].cpu().numpy()  # [512, H/8, W/8]
    descriptors_np = descriptors[0].cpu().numpy()  # [128, H/8, W/8]
    
    # 检测关键点
    keypoints = []
    
    # D2-Net的检测策略:在特征维度上寻找最大值
    # 对于每个空间位置,计算特征维度上的L2范数
    detection_map = np.linalg.norm(features_np, axis=0)  # [H/8, W/8]
    
    # 应用非极大值抑制
    from scipy.ndimage import maximum_filter
    
    neighborhood_size = 3
    data_max = maximum_filter(detection_map, neighborhood_size)
    maxima = (detection_map == data_max)
    
    # 应用阈值
    maxima[detection_map < detection_threshold] = 0
    
    # 获取关键点坐标
    yx = np.argwhere(maxima)
    
    # 转换为原始图像坐标
    scale_factor = 8
    for y, x in yx:
        # 原始坐标
        x_orig = x * scale_factor
        y_orig = y * scale_factor
        
        # 调整到原始图像尺寸
        x_orig = x_orig * (W / W_new)
        y_orig = y_orig * (H / H_new)
        
        score = detection_map[y, x]
        keypoints.append(cv2.KeyPoint(x_orig, y_orig, 1, -1, score))
    
    # 提取描述符
    desc_list = []
    for kp in keypoints:
        # 在特征图上的坐标
        x_feat = kp.pt[0] * (W_new / W) / scale_factor
        y_feat = kp.pt[1] * (H_new / H) / scale_factor
        
        # 双线性插值获取描述符
        x0, y0 = int(np.floor(x_feat)), int(np.floor(y_feat))
        x1, y1 = x0 + 1, y0 + 1
        
        # 边界检查
        x0 = max(0, min(x0, descriptors_np.shape[2]-1))
        x1 = max(0, min(x1, descriptors_np.shape[2]-1))
        y0 = max(0, min(y0, descriptors_np.shape[1]-1))
        y1 = max(0, min(y1, descriptors_np.shape[1]-1))
        
        # 权重
        wa = (x1 - x_feat) * (y1 - y_feat)
        wb = (x1 - x_feat) * (y_feat - y0)
        wc = (x_feat - x0) * (y1 - y_feat)
        wd = (x_feat - x0) * (y_feat - y0)
        
        # 插值描述符
        desc_a = descriptors_np[:, y0, x0]
        desc_b = descriptors_np[:, y1, x0]
        desc_c = descriptors_np[:, y0, x1]
        desc_d = descriptors_np[:, y1, x1]
        
        descriptor = wa * desc_a + wb * desc_b + wc * desc_c + wd * desc_d
        desc_list.append(descriptor)
    
    descriptors = np.array(desc_list)
    
    return keypoints, descriptors

2.3 深度学习关键点训练策略

深度学习关键点检测的训练通常需要特殊策略,因为关键点位置难以直接标注。以下是一些常用的训练方法:

python

class KeypointDetectionLoss(nn.Module):
    """
    关键点检测的损失函数
    """
    def __init__(self, detector_loss_weight=1.0, descriptor_loss_weight=1.0):
        super(KeypointDetectionLoss, self).__init__()
        self.detector_loss_weight = detector_loss_weight
        self.descriptor_loss_weight = descriptor_loss_weight
        
    def detector_loss(self, pred_heatmaps, gt_heatmaps):
        """
        检测器损失:均方误差损失
        
        参数:
            pred_heatmaps: 预测的热图 [B, 1, H, W]
            gt_heatmaps: 真实的热图 [B, 1, H, W]
            
        返回:
            loss: 检测器损失
        """
        return F.mse_loss(pred_heatmaps, gt_heatmaps)
    
    def descriptor_loss(self, pred_descriptors1, pred_descriptors2, matches):
        """
        描述符损失:对比损失
        
        参数:
            pred_descriptors1: 图像1的描述符 [B, D, H, W]
            pred_descriptors2: 图像2的描述符 [B, D, H, W]
            matches: 匹配对列表
            
        返回:
            loss: 描述符损失
        """
        # 提取匹配位置的描述符
        loss = 0
        pos_pairs = 0
        
        for match in matches:
            # 获取匹配的关键点位置
            y1, x1 = match[0]  # 图像1中的位置
            y2, x2 = match[1]  # 图像2中的位置
            
            # 提取描述符
            desc1 = pred_descriptors1[:, :, y1, x1]
            desc2 = pred_descriptors2[:, :, y2, x2]
            
            # 计算余弦相似度
            sim = F.cosine_similarity(desc1, desc2)
            
            # 正样本应该相似度高,使用1 - similarity作为损失
            loss += (1 - sim).mean()
            pos_pairs += 1
        
        if pos_pairs > 0:
            loss = loss / pos_pairs
        
        return loss
    
    def forward(self, pred_heatmaps1, pred_heatmaps2, 
                pred_descriptors1, pred_descriptors2,
                gt_heatmaps1, gt_heatmaps2, matches):
        """
        总损失
        
        参数:
            pred_heatmaps1: 图像1的预测热图
            pred_heatmaps2: 图像2的预测热图
            pred_descriptors1: 图像1的预测描述符
            pred_descriptors2: 图像2的预测描述符
            gt_heatmaps1: 图像1的真实热图
            gt_heatmaps2: 图像2的真实热图
            matches: 匹配对
            
        返回:
            total_loss: 总损失
        """
        # 检测器损失
        detector_loss1 = self.detector_loss(pred_heatmaps1, gt_heatmaps1)
        detector_loss2 = self.detector_loss(pred_heatmaps2, gt_heatmaps2)
        detector_loss = (detector_loss1 + detector_loss2) / 2
        
        # 描述符损失
        descriptor_loss = self.descriptor_loss(pred_descriptors1, pred_descriptors2, matches)
        
        # 总损失
        total_loss = (self.detector_loss_weight * detector_loss + 
                     self.descriptor_loss_weight * descriptor_loss)
        
        return total_loss, detector_loss, descriptor_loss

class HomographyAdaptation:
    """
    单应性适应:用于自监督关键点训练的数据增强方法
    """
    def __init__(self, num_samples=10, perspective=True, scaling=True, rotation=True):
        self.num_samples = num_samples
        self.perspective = perspective
        self.scaling = scaling
        self.rotation = rotation
        
    def generate_homography(self, H, W):
        """
        生成随机单应性变换
        
        参数:
            H: 图像高度
            W: 图像宽度
            
        返回:
            H_mat: 单应性矩阵 [3, 3]
        """
        # 基础变换:恒等变换
        H_mat = np.eye(3)
        
        # 添加透视变换
        if self.perspective:
            perspective_scale = 0.0005
            H_mat[2, 0] = np.random.uniform(-perspective_scale, perspective_scale)
            H_mat[2, 1] = np.random.uniform(-perspective_scale, perspective_scale)
        
        # 添加缩放
        if self.scaling:
            scale = np.random.uniform(0.8, 1.2)
            H_mat[0, 0] *= scale
            H_mat[1, 1] *= scale
        
        # 添加旋转
        if self.rotation:
            angle = np.random.uniform(-30, 30) * np.pi / 180
            cos_a = np.cos(angle)
            sin_a = np.sin(angle)
            
            R = np.array([[cos_a, -sin_a, 0],
                         [sin_a, cos_a, 0],
                         [0, 0, 1]])
            
            H_mat = R @ H_mat
        
        # 添加平移
        translate_x = np.random.uniform(-0.1, 0.1) * W
        translate_y = np.random.uniform(-0.1, 0.1) * H
        
        T = np.array([[1, 0, translate_x],
                     [0, 1, translate_y],
                     [0, 0, 1]])
        
        H_mat = T @ H_mat
        
        return H_mat
    
    def apply_homography(self, image, H_mat):
        """
        应用单应性变换到图像
        
        参数:
            image: 输入图像 [H, W, C]
            H_mat: 单应性矩阵 [3, 3]
            
        返回:
            warped_image: 变换后的图像
            valid_mask: 有效区域掩码
        """
        H, W = image.shape[:2]
        
        # 应用透视变换
        warped_image = cv2.warpPerspective(image, H_mat, (W, H))
        
        # 创建有效区域掩码
        valid_mask = np.ones((H, W), dtype=np.uint8)
        valid_mask = cv2.warpPerspective(valid_mask, H_mat, (W, H))
        
        return warped_image, valid_mask
    
    def adapt_image(self, image):
        """
        对图像进行单应性适应
        
        参数:
            image: 输入图像 [H, W, C]
            
        返回:
            warped_images: 变换后的图像列表
            homographies: 单应性矩阵列表
            valid_masks: 有效区域掩码列表
        """
        H, W = image.shape[:2]
        
        warped_images = []
        homographies = []
        valid_masks = []
        
        for _ in range(self.num_samples):
            # 生成随机单应性变换
            H_mat = self.generate_homography(H, W)
            
            # 应用变换
            warped_image, valid_mask = self.apply_homography(image, H_mat)
            
            warped_images.append(warped_image)
            homographies.append(H_mat)
            valid_masks.append(valid_mask)
        
        return warped_images, homographies, valid_masks

第三部分:传统方法与深度学习方法的比较

3.1 性能对比实验

为了全面比较传统特征点方法和深度学习关键点方法,我们设计了一个综合评估实验:

python

def compare_traditional_vs_deeplearning(images, traditional_methods, deeplearning_methods):
    """
    比较传统方法和深度学习方法的性能
    
    参数:
        images: 测试图像列表
        traditional_methods: 传统方法字典 {名称: 检测器}
        deeplearning_methods: 深度学习方法字典 {名称: (模型, 处理函数)}
        
    返回:
        comparison_results: 比较结果
    """
    comparison_results = {
        'traditional': {},
        'deeplearning': {}
    }
    
    # 评估传统方法
    print("评估传统方法...")
    for name, detector in traditional_methods.items():
        print(f"  测试 {name}...")
        
        method_results = {
            'detection_time': [],
            'keypoint_counts': [],
            'repeatability': [],
            'matching_accuracy': []
        }
        
        for i in range(len(images)-1):
            img1 = images[i]
            img2 = images[i+1]
            
            # 检测关键点和描述符
            start_time = time.time()
            kp1, desc1 = detector.detectAndCompute(img1, None)
            kp2, desc2 = detector.detectAndCompute(img2, None)
            detection_time = time.time() - start_time
            
            method_results['detection_time'].append(detection_time)
            method_results['keypoint_counts'].append((len(kp1), len(kp2)))
            
            # 特征匹配
            if desc1 is not None and desc2 is not None and len(desc1) > 0 and len(desc2) > 0:
                if name in ['SIFT', 'SURF']:
                    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
                else:
                    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
                
                matches = bf.match(desc1, desc2)
                method_results['matching_accuracy'].append(len(matches))
        
        # 计算平均指标
        if method_results['detection_time']:
            comparison_results['traditional'][name] = {
                'avg_detection_time': np.mean(method_results['detection_time']),
                'avg_keypoints': np.mean([k[0] for k in method_results['keypoint_counts']]),
                'avg_matching_accuracy': np.mean(method_results['matching_accuracy']) if method_results['matching_accuracy'] else 0
            }
    
    # 评估深度学习方法
    print("\n评估深度学习方法...")
    for name, (model, processor) in deeplearning_methods.items():
        print(f"  测试 {name}...")
        
        method_results = {
            'detection_time': [],
            'keypoint_counts': [],
            'repeatability': [],
            'matching_accuracy': []
        }
        
        for i in range(len(images)-1):
            img1 = images[i]
            img2 = images[i+1]
            
            # 检测关键点和描述符
            start_time = time.time()
            kp1, desc1 = processor(img1, model)
            kp2, desc2 = processor(img2, model)
            detection_time = time.time() - start_time
            
            method_results['detection_time'].append(detection_time)
            method_results['keypoint_counts'].append((len(kp1), len(kp2)))
            
            # 特征匹配
            if desc1 is not None and desc2 is not None and len(desc1) > 0 and len(desc2) > 0:
                # 深度学习描述符通常使用L2距离
                bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
                matches = bf.match(desc1, desc2)
                method_results['matching_accuracy'].append(len(matches))
        
        # 计算平均指标
        if method_results['detection_time']:
            comparison_results['deeplearning'][name] = {
                'avg_detection_time': np.mean(method_results['detection_time']),
                'avg_keypoints': np.mean([k[0] for k in method_results['keypoint_counts']]),
                'avg_matching_accuracy': np.mean(method_results['matching_accuracy']) if method_results['matching_accuracy'] else 0
            }
    
    return comparison_results

def plot_comparison_results(comparison_results):
    """
    绘制比较结果
    """
    # 提取数据
    traditional_names = list(comparison_results['traditional'].keys())
    deeplearning_names = list(comparison_results['deeplearning'].keys())
    
    all_names = traditional_names + deeplearning_names
    categories = ['传统'] * len(traditional_names) + ['深度学习'] * len(deeplearning_names)
    
    detection_times = []
    keypoint_counts = []
    matching_accuracies = []
    
    for name in traditional_names:
        detection_times.append(comparison_results['traditional'][name]['avg_detection_time'])
        keypoint_counts.append(comparison_results['traditional'][name]['avg_keypoints'])
        matching_accuracies.append(comparison_results['traditional'][name]['avg_matching_accuracy'])
    
    for name in deeplearning_names:
        detection_times.append(comparison_results['deeplearning'][name]['avg_detection_time'])
        keypoint_counts.append(comparison_results['deeplearning'][name]['avg_keypoints'])
        matching_accuracies.append(comparison_results['deeplearning'][name]['avg_matching_accuracy'])
    
    # 创建子图
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 检测时间对比
    colors = ['skyblue' if cat == '传统' else 'lightcoral' for cat in categories]
    axes[0, 0].bar(all_names, detection_times, color=colors)
    axes[0, 0].set_title('平均检测时间对比')
    axes[0, 0].set_ylabel('时间(秒)')
    axes[0, 0].tick_params(axis='x', rotation=45)
    axes[0, 0].axhline(y=np.mean(detection_times[:len(traditional_names)]), 
                      color='blue', linestyle='--', alpha=0.5, label='传统方法平均')
    axes[0, 0].axhline(y=np.mean(detection_times[len(traditional_names):]), 
                      color='red', linestyle='--', alpha=0.5, label='深度学习方法平均')
    axes[0, 0].legend()
    
    # 关键点数量对比
    axes[0, 1].bar(all_names, keypoint_counts, color=colors)
    axes[0, 1].set_title('平均关键点数量对比')
    axes[0, 1].set_ylabel('数量')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # 匹配精度对比
    axes[1, 0].bar(all_names, matching_accuracies, color=colors)
    axes[1, 0].set_title('平均匹配精度对比')
    axes[1, 0].set_ylabel('匹配对数')
    axes[1, 0].tick_params(axis='x', rotation=45)
    
    # 综合雷达图
    axes[1, 1].axis('off')
    
    # 为每个方法创建归一化的雷达图数据
    radar_fig, radar_ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))
    
    # 归一化数据
    norm_detection_times = 1 - (detection_times - np.min(detection_times)) / (np.max(detection_times) - np.min(detection_times) + 1e-8)
    norm_keypoint_counts = (keypoint_counts - np.min(keypoint_counts)) / (np.max(keypoint_counts) - np.min(keypoint_counts) + 1e-8)
    norm_matching_accuracies = (matching_accuracies - np.min(matching_accuracies)) / (np.max(matching_accuracies) - np.min(matching_accuracies) + 1e-8)
    
    # 角度
    angles = np.linspace(0, 2*np.pi, 3, endpoint=False).tolist()
    angles += angles[:1]  # 闭合图形
    
    # 绘制每个方法的雷达图
    for idx, name in enumerate(all_names):
        values = [norm_detection_times[idx], norm_keypoint_counts[idx], norm_matching_accuracies[idx]]
        values += values[:1]  # 闭合图形
        
        color = 'blue' if categories[idx] == '传统' else 'red'
        radar_ax.plot(angles, values, 'o-', linewidth=2, label=name, color=color, alpha=0.7)
        radar_ax.fill(angles, values, alpha=0.1, color=color)
    
    radar_ax.set_xticks(angles[:-1])
    radar_ax.set_xticklabels(['速度(逆)', '关键点数量', '匹配精度'])
    radar_ax.set_title('方法性能雷达图')
    radar_ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
    
    plt.tight_layout()
    plt.show()

3.2 鲁棒性测试

为了测试不同方法在不同变换下的鲁棒性,我们设计了以下实验:

python

def robustness_test(image, methods):
    """
    测试不同方法在各种图像变换下的鲁棒性
    
    参数:
        image: 原始图像
        methods: 方法字典 {名称: 处理函数}
        
    返回:
        robustness_results: 鲁棒性测试结果
    """
    transformations = {
        '原始': lambda img: img,
        '旋转30度': lambda img: cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE),
        '缩放0.5倍': lambda img: cv2.resize(img, None, fx=0.5, fy=0.5),
        '高斯噪声': lambda img: add_gaussian_noise(img, mean=0, sigma=25),
        '亮度变化': lambda img: adjust_brightness(img, factor=1.5),
        '模糊': lambda img: cv2.GaussianBlur(img, (5, 5), 1.5),
        'JPEG压缩': lambda img: jpeg_compression(img, quality=50)
    }
    
    robustness_results = {}
    
    for method_name, method_func in methods.items():
        print(f"测试方法: {method_name}")
        
        method_results = {}
        
        # 在原始图像上检测关键点
        orig_keypoints, orig_descriptors = method_func(image)
        
        for transform_name, transform_func in transformations.items():
            print(f"  变换: {transform_name}")
            
            # 应用变换
            transformed_image = transform_func(image.copy())
            
            # 检测关键点
            trans_keypoints, trans_descriptors = method_func(transformed_image)
            
            # 计算重复性
            repeatability = calculate_repeatability(orig_keypoints, trans_keypoints, 
                                                   orig_image=image, trans_image=transformed_image)
            
            # 计算描述符匹配率
            if orig_descriptors is not None and trans_descriptors is not None:
                match_rate = calculate_match_rate(orig_descriptors, trans_descriptors)
            else:
                match_rate = 0
            
            method_results[transform_name] = {
                'repeatability': repeatability,
                'match_rate': match_rate,
                'keypoint_count': len(trans_keypoints)
            }
        
        robustness_results[method_name] = method_results
    
    return robustness_results

def add_gaussian_noise(image, mean=0, sigma=25):
    """添加高斯噪声"""
    gauss = np.random.normal(mean, sigma, image.shape).astype(np.float32)
    noisy_image = image.astype(np.float32) + gauss
    return np.clip(noisy_image, 0, 255).astype(np.uint8)

def adjust_brightness(image, factor=1.5):
    """调整亮度"""
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) if len(image.shape) == 3 else image
    hsv = hsv.astype(np.float32)
    hsv[..., 2] = hsv[..., 2] * factor
    hsv = np.clip(hsv, 0, 255).astype(np.uint8)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) if len(image.shape) == 3 else hsv

def jpeg_compression(image, quality=50):
    """JPEG压缩"""
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
    result, encimg = cv2.imencode('.jpg', image, encode_param)
    return cv2.imdecode(encimg, 1)

def calculate_repeatability(keypoints1, keypoints2, orig_image=None, trans_image=None, threshold=3):
    """
    计算关键点重复性
    
    参数:
        keypoints1: 第一幅图像的关键点
        keypoints2: 第二幅图像的关键点
        threshold: 距离阈值(像素)
        
    返回:
        repeatability: 重复性分数
    """
    if not keypoints1 or not keypoints2:
        return 0
    
    # 提取关键点坐标
    pts1 = np.array([kp.pt for kp in keypoints1])
    pts2 = np.array([kp.pt for kp in keypoints2])
    
    # 如果有图像,考虑变换
    if orig_image is not None and trans_image is not None:
        # 这里可以添加几何变换的考虑
        pass
    
    # 计算最近邻距离
    from scipy.spatial import KDTree
    if len(pts2) > 0:
        tree = KDTree(pts2)
        distances, _ = tree.query(pts1)
        
        # 计算重复关键点数量
        repeatable_count = np.sum(distances < threshold)
        repeatability = repeatable_count / len(keypoints1)
    else:
        repeatability = 0
    
    return repeatability

def calculate_match_rate(descriptors1, descriptors2, ratio_threshold=0.8):
    """
    计算描述符匹配率
    
    参数:
        descriptors1: 第一幅图像的描述符
        descriptors2: 第二幅图像的描述符
        ratio_threshold: 比率测试阈值
        
    返回:
        match_rate: 匹配率
    """
    if len(descriptors1) == 0 or len(descriptors2) == 0:
        return 0
    
    # 使用FLANN匹配器
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(descriptors1, descriptors2, k=2)
    
    # 应用Lowe's比率测试
    good_matches = []
    for m, n in matches:
        if m.distance < ratio_threshold * n.distance:
            good_matches.append(m)
    
    match_rate = len(good_matches) / min(len(descriptors1), len(descriptors2))
    
    return match_rate

def plot_robustness_results(robustness_results):
    """
    绘制鲁棒性测试结果
    """
    method_names = list(robustness_results.keys())
    transform_names = list(robustness_results[method_names[0]].keys())
    
    # 创建子图
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 重复性热图
    repeatability_matrix = np.zeros((len(method_names), len(transform_names)))
    for i, method in enumerate(method_names):
        for j, transform in enumerate(transform_names):
            repeatability_matrix[i, j] = robustness_results[method][transform]['repeatability']
    
    im1 = axes[0, 0].imshow(repeatability_matrix, cmap='RdYlGn', vmin=0, vmax=1)
    axes[0, 0].set_xticks(range(len(transform_names)))
    axes[0, 0].set_xticklabels(transform_names, rotation=45, ha='right')
    axes[0, 0].set_yticks(range(len(method_names)))
    axes[0, 0].set_yticklabels(method_names)
    axes[0, 0].set_title('关键点重复性热图')
    plt.colorbar(im1, ax=axes[0, 0])
    
    # 匹配率热图
    match_rate_matrix = np.zeros((len(method_names), len(transform_names)))
    for i, method in enumerate(method_names):
        for j, transform in enumerate(transform_names):
            match_rate_matrix[i, j] = robustness_results[method][transform]['match_rate']
    
    im2 = axes[0, 1].imshow(match_rate_matrix, cmap='RdYlGn', vmin=0, vmax=1)
    axes[0, 1].set_xticks(range(len(transform_names)))
    axes[0, 1].set_xticklabels(transform_names, rotation=45, ha='right')
    axes[0, 1].set_yticks(range(len(method_names)))
    axes[0, 1].set_yticklabels(method_names)
    axes[0, 1].set_title('描述符匹配率热图')
    plt.colorbar(im2, ax=axes[0, 1])
    
    # 关键点数量柱状图
    keypoint_counts = {}
    for method in method_names:
        keypoint_counts[method] = [robustness_results[method][t]['keypoint_count'] for t in transform_names]
    
    x = np.arange(len(transform_names))
    width = 0.8 / len(method_names)
    
    for i, method in enumerate(method_names):
        offset = (i - len(method_names)/2 + 0.5) * width
        axes[1, 0].bar(x + offset, keypoint_counts[method], width, label=method)
    
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(transform_names, rotation=45, ha='right')
    axes[1, 0].set_ylabel('关键点数量')
    axes[1, 0].set_title('不同变换下的关键点数量')
    axes[1, 0].legend()
    
    # 综合性能雷达图
    axes[1, 1].axis('off')
    
    # 创建雷达图
    radar_fig, radar_ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))
    
    angles = np.linspace(0, 2*np.pi, len(transform_names), endpoint=False).tolist()
    angles += angles[:1]  # 闭合图形
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(method_names)))
    
    for idx, method in enumerate(method_names):
        values = [robustness_results[method][t]['repeatability'] for t in transform_names]
        values += values[:1]  # 闭合图形
        
        radar_ax.plot(angles, values, 'o-', linewidth=2, label=method, color=colors[idx], alpha=0.7)
        radar_ax.fill(angles, values, alpha=0.1, color=colors[idx])
    
    radar_ax.set_xticks(angles[:-1])
    radar_ax.set_xticklabels(transform_names)
    radar_ax.set_title('方法鲁棒性雷达图')
    radar_ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
    
    plt.tight_layout()
    plt.show()

第四部分:传统方法与深度学习方法的融合

4.1 混合特征点系统

将传统方法的效率与深度学习方法的鲁棒性相结合,可以构建更强大的特征点系统:

python

class HybridFeatureDetector:
    """
    混合特征检测器:结合传统方法和深度学习方法
    """
    def __init__(self, traditional_detector, deeplearning_detector, fusion_strategy='weighted'):
        """
        初始化混合特征检测器
        
        参数:
            traditional_detector: 传统特征检测器
            deeplearning_detector: 深度学习特征检测器
            fusion_strategy: 融合策略 ('weighted', 'adaptive', 'cascade')
        """
        self.traditional_detector = traditional_detector
        self.deeplearning_detector = deeplearning_detector
        self.fusion_strategy = fusion_strategy
        
    def detect_and_compute(self, image):
        """
        检测关键点和计算描述符
        
        参数:
            image: 输入图像
            
        返回:
            keypoints: 融合后的关键点
            descriptors: 融合后的描述符
        """
        # 传统方法检测
        trad_keypoints, trad_descriptors = self.traditional_detector.detectAndCompute(image, None)
        
        # 深度学习方法检测
        if isinstance(self.deeplearning_detector, tuple):
            # 如果是(模型, 处理函数)的元组
            model, processor = self.deeplearning_detector
            dl_keypoints, dl_descriptors = processor(image, model)
        else:
            # 如果是直接可调用的函数
            dl_keypoints, dl_descriptors = self.deeplearning_detector(image)
        
        # 根据融合策略合并结果
        if self.fusion_strategy == 'weighted':
            keypoints, descriptors = self._weighted_fusion(trad_keypoints, trad_descriptors,
                                                         dl_keypoints, dl_descriptors)
        elif self.fusion_strategy == 'adaptive':
            keypoints, descriptors = self._adaptive_fusion(trad_keypoints, trad_descriptors,
                                                         dl_keypoints, dl_descriptors, image)
        elif self.fusion_strategy == 'cascade':
            keypoints, descriptors = self._cascade_fusion(trad_keypoints, trad_descriptors,
                                                        dl_keypoints, dl_descriptors)
        else:
            raise ValueError(f"未知的融合策略: {self.fusion_strategy}")
        
        return keypoints, descriptors
    
    def _weighted_fusion(self, trad_kps, trad_descs, dl_kps, dl_descs, trad_weight=0.4):
        """
        加权融合策略
        
        参数:
            trad_kps: 传统方法关键点
            trad_descs: 传统方法描述符
            dl_kps: 深度学习方法关键点
            dl_descs: 深度学习方法描述符
            trad_weight: 传统方法权重
            
        返回:
            fused_keypoints: 融合后的关键点
            fused_descriptors: 融合后的描述符
        """
        # 如果没有关键点,直接返回空
        if not trad_kps and not dl_kps:
            return [], None
        
        # 加权选择关键点
        fused_keypoints = []
        fused_descriptors = []
        
        # 传统方法关键点
        if trad_kps and trad_descs is not None:
            for i, kp in enumerate(trad_kps):
                # 根据权重随机选择
                if np.random.rand() < trad_weight:
                    fused_keypoints.append(kp)
                    fused_descriptors.append(trad_descs[i])
        
        # 深度学习方法关键点
        if dl_kps and dl_descs is not None:
            for i, kp in enumerate(dl_kps):
                # 根据权重随机选择
                if np.random.rand() < (1 - trad_weight):
                    fused_keypoints.append(kp)
                    fused_descriptors.append(dl_descs[i])
        
        # 如果都没有选择到,至少选择一个
        if not fused_keypoints:
            if trad_kps:
                fused_keypoints.append(trad_kps[0])
                if trad_descs is not None:
                    fused_descriptors.append(trad_descs[0])
            elif dl_kps:
                fused_keypoints.append(dl_kps[0])
                if dl_descs is not None:
                    fused_descriptors.append(dl_descs[0])
        
        if fused_descriptors:
            fused_descriptors = np.array(fused_descriptors)
        else:
            fused_descriptors = None
        
        return fused_keypoints, fused_descriptors
    
    def _adaptive_fusion(self, trad_kps, trad_descs, dl_kps, dl_descs, image):
        """
        自适应融合策略:根据图像特性选择融合方式
        
        参数:
            trad_kps: 传统方法关键点
            trad_descs: 传统方法描述符
            dl_kps: 深度学习方法关键点
            dl_descs: 深度学习方法描述符
            image: 输入图像
            
        返回:
            fused_keypoints: 融合后的关键点
            fused_descriptors: 融合后的描述符
        """
        # 分析图像特性
        image_characteristics = self._analyze_image_characteristics(image)
        
        # 根据图像特性调整融合策略
        if image_characteristics['low_contrast']:
            # 低对比度图像,更依赖深度学习方法
            trad_weight = 0.2
        elif image_characteristics['high_frequency']:
            # 高频纹理丰富,更依赖传统方法
            trad_weight = 0.7
        elif image_characteristics['blurry']:
            # 模糊图像,更依赖深度学习方法
            trad_weight = 0.3
        else:
            # 正常情况,均衡融合
            trad_weight = 0.5
        
        return self._weighted_fusion(trad_kps, trad_descs, dl_kps, dl_descs, trad_weight)
    
    def _analyze_image_characteristics(self, image):
        """
        分析图像特性
        
        参数:
            image: 输入图像
            
        返回:
            characteristics: 图像特性字典
        """
        characteristics = {}
        
        # 计算图像对比度
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image
        
        # 计算对比度(标准差)
        contrast = np.std(gray)
        characteristics['low_contrast'] = contrast < 30
        
        # 计算高频成分(通过拉普拉斯算子)
        laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
        characteristics['high_frequency'] = laplacian_var > 1000
        
        # 估计模糊程度
        characteristics['blurry'] = laplacian_var < 100
        
        return characteristics
    
    def _cascade_fusion(self, trad_kps, trad_descs, dl_kps, dl_descs):
        """
        级融合策略:先用传统方法,如果不够再用深度学习方法补充
        
        参数:
            trad_kps: 传统方法关键点
            trad_descs: 传统方法描述符
            dl_kps: 深度学习方法关键点
            dl_descs: 深度学习方法描述符
            
        返回:
            fused_keypoints: 融合后的关键点
            fused_descriptors: 融合后的描述符
        """
        fused_keypoints = list(trad_kps)
        fused_descriptors = []
        
        # 添加传统方法描述符
        if trad_descs is not None:
            for i in range(len(trad_kps)):
                fused_descriptors.append(trad_descs[i])
        
        # 如果传统方法关键点太少,用深度学习方法补充
        min_keypoints = 50
        if len(fused_keypoints) < min_keypoints and dl_kps:
            # 选择深度学习方法的Top-N关键点
            num_to_add = min_keypoints - len(fused_keypoints)
            
            # 按响应值排序
            if hasattr(dl_kps[0], 'response'):
                sorted_indices = np.argsort([-kp.response for kp in dl_kps])
            else:
                sorted_indices = range(min(num_to_add, len(dl_kps)))
            
            for idx in sorted_indices[:num_to_add]:
                kp = dl_kps[idx]
                # 检查是否与现有关键点太近
                too_close = False
                for existing_kp in fused_keypoints:
                    dist = np.sqrt((kp.pt[0] - existing_kp.pt[0])**2 + 
                                 (kp.pt[1] - existing_kp.pt[1])**2)
                    if dist < 10:  # 10像素阈值
                        too_close = True
                        break
                
                if not too_close:
                    fused_keypoints.append(kp)
                    if dl_descs is not None:
                        fused_descriptors.append(dl_descs[idx])
        
        if fused_descriptors:
            fused_descriptors = np.array(fused_descriptors)
        else:
            fused_descriptors = None
        
        return fused_keypoints, fused_descriptors

4.2 深度学习增强的传统特征点

使用深度学习来增强传统特征点的检测和描述能力:

python

class DeepEnhancedORB:
    """
    深度学习增强的ORB特征点
    """
    def __init__(self, enhancement_model=None):
        """
        初始化深度学习增强的ORB
        
        参数:
            enhancement_model: 用于增强的深度学习模型
        """
        self.orb = cv2.ORB_create(nfeatures=1000)
        self.enhancement_model = enhancement_model
        
    def detectAndCompute(self, image, mask=None):
        """
        检测关键点和计算描述符
        
        参数:
            image: 输入图像
            mask: 掩码
            
        返回:
            keypoints: 关键点
            descriptors: 描述符
        """
        # 使用传统ORB检测关键点
        keypoints = self.orb.detect(image, mask)
        
        # 如果有增强模型,优化关键点位置
        if self.enhancement_model is not None:
            keypoints = self._enhance_keypoints(image, keypoints)
        
        # 计算描述符
        keypoints, descriptors = self.orb.compute(image, keypoints)
        
        # 如果有增强模型,优化描述符
        if self.enhancement_model is not None:
            descriptors = self._enhance_descriptors(image, keypoints, descriptors)
        
        return keypoints, descriptors
    
    def _enhance_keypoints(self, image, keypoints):
        """
        使用深度学习模型增强关键点
        
        参数:
            image: 输入图像
            keypoints: 原始关键点
            
        返回:
            enhanced_keypoints: 增强后的关键点
        """
        enhanced_keypoints = []
        
        # 将图像转换为模型输入格式
        if self.enhancement_model is not None:
            # 这里假设enhancement_model是一个可以预测关键点位置的模型
            # 实际实现需要根据具体模型调整
            pass
        
        # 如果没有模型或处理失败,返回原始关键点
        if not enhanced_keypoints:
            return keypoints
        
        return enhanced_keypoints
    
    def _enhance_descriptors(self, image, keypoints, descriptors):
        """
        使用深度学习模型增强描述符
        
        参数:
            image: 输入图像
            keypoints: 关键点
            descriptors: 原始描述符
            
        返回:
            enhanced_descriptors: 增强后的描述符
        """
        if self.enhancement_model is not None:
            # 这里可以添加深度学习增强描述符的代码
            # 例如,使用CNN提取更强大的特征表示
            pass
        
        # 暂时返回原始描述符
        return descriptors

class NeuralFeatureRefiner:
    """
    神经网络特征优化器:使用CNN优化传统特征点
    """
    def __init__(self):
        # 初始化神经网络
        self.model = self._build_refinement_network()
        
    def _build_refinement_network(self):
        """
        构建特征优化网络
        
        返回:
            model: 优化网络模型
        """
        model = nn.Sequential(
            # 输入: 局部图像块
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # 输出: 关键点位置偏移和置信度
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 3)  # (dx, dy, confidence)
        )
        
        return model
    
    def refine_keypoints(self, image, keypoints, patch_size=32):
        """
        优化关键点位置
        
        参数:
            image: 输入图像
            keypoints: 原始关键点
            patch_size: 图像块大小
            
        返回:
            refined_keypoints: 优化后的关键点
        """
        refined_keypoints = []
        
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image
        
        gray_tensor = torch.from_numpy(gray).float().unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
        
        for kp in keypoints:
            x, y = int(kp.pt[0]), int(kp.pt[1])
            
            # 提取局部图像块
            half_size = patch_size // 2
            x_start = max(0, x - half_size)
            x_end = min(gray.shape[1], x + half_size)
            y_start = max(0, y - half_size)
            y_end = min(gray.shape[0], y + half_size)
            
            # 如果关键点太靠近边缘,跳过
            if (x_end - x_start) < patch_size or (y_end - y_start) < patch_size:
                refined_keypoints.append(kp)
                continue
            
            patch = gray[y_start:y_end, x_start:x_end]
            
            # 调整大小为固定尺寸
            patch_resized = cv2.resize(patch, (patch_size, patch_size))
            patch_tensor = torch.from_numpy(patch_resized).float().unsqueeze(0).unsqueeze(0)  # [1, 1, patch_size, patch_size]
            
            # 使用模型预测偏移
            with torch.no_grad():
                output = self.model(patch_tensor)
                dx, dy, confidence = output[0].numpy()
            
            # 应用偏移(根据置信度加权)
            scale_factor = 2.0  # 偏移缩放因子
            new_x = x + dx * scale_factor * confidence
            new_y = y + dy * scale_factor * confidence
            
            # 创建新的关键点
            refined_kp = cv2.KeyPoint(new_x, new_y, kp.size, kp.angle, kp.response * confidence, kp.octave, kp.class_id)
            refined_keypoints.append(refined_kp)
        
        return refined_keypoints

4.3 应用示例:混合特征点的图像拼接

python

def hybrid_feature_matching(img1, img2, hybrid_detector):
    """
    使用混合特征检测器进行图像匹配
    
    参数:
        img1: 第一幅图像
        img2: 第二幅图像
        hybrid_detector: 混合特征检测器
        
    返回:
        matches: 特征匹配
        homography: 单应性矩阵
        matched_image: 匹配结果可视化
    """
    # 检测关键点和描述符
    kp1, desc1 = hybrid_detector.detectAndCompute(img1, None)
    kp2, desc2 = hybrid_detector.detectAndCompute(img2, None)
    
    # 特征匹配
    if desc1 is not None and desc2 is not None and len(desc1) > 0 and len(desc2) > 0:
        # 根据描述符类型选择匹配器
        if desc1.dtype == np.uint8:  # 二进制描述符
            bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
            matches = bf.match(desc1, desc2)
        else:  # 浮点描述符
            bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
            matches = bf.match(desc1.astype(np.float32), desc2.astype(np.float32))
        
        # 按距离排序
        matches = sorted(matches, key=lambda x: x.distance)
        
        # 计算单应性矩阵
        if len(matches) > 10:
            src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
            dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
            
            # 使用RANSAC计算单应性矩阵
            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            
            # 绘制匹配结果
            matches_mask = mask.ravel().tolist()
            draw_params = dict(matchColor=(0, 255, 0),
                             singlePointColor=None,
                             matchesMask=matches_mask,
                             flags=2)
            
            matched_image = cv2.drawMatches(img1, kp1, img2, kp2, 
                                          matches, None, **draw_params)
            
            return matches, M, matched_image
    
    return None, None, None

def image_stitching(images, hybrid_detector):
    """
    使用混合特征进行图像拼接
    
    参数:
        images: 图像列表
        hybrid_detector: 混合特征检测器
        
    返回:
        panorama: 拼接后的全景图
    """
    if len(images) < 2:
        return images[0] if images else None
    
    # 将第一幅图像作为基准
    panorama = images[0].copy()
    
    for i in range(1, len(images)):
        # 特征匹配
        matches, homography, _ = hybrid_feature_matching(panorama, images[i], hybrid_detector)
        
        if homography is not None:
            # 计算新全景图的尺寸
            h1, w1 = panorama.shape[:2]
            h2, w2 = images[i].shape[:2]
            
            # 计算变换后图像的角点
            corners1 = np.float32([[0, 0], [0, h1], [w1, h1], [w1, 0]]).reshape(-1, 1, 2)
            corners2 = np.float32([[0, 0], [0, h2], [w2, h2], [w2, 0]]).reshape(-1, 1, 2)
            corners2_transformed = cv2.perspectiveTransform(corners2, homography)
            
            # 合并角点
            all_corners = np.concatenate((corners1, corners2_transformed), axis=0)
            
            # 计算新图像尺寸
            [x_min, y_min] = np.int32(all_corners.min(axis=0).ravel() - 0.5)
            [x_max, y_max] = np.int32(all_corners.max(axis=0).ravel() + 0.5)
            
            # 计算平移矩阵
            translation = np.array([[1, 0, -x_min],
                                   [0, 1, -y_min],
                                   [0, 0, 1]])
            
            # 应用平移
            panorama_transformed = cv2.warpPerspective(panorama, translation, 
                                                      (x_max - x_min, y_max - y_min))
            
            # 变换第二幅图像
            img2_transformed = cv2.warpPerspective(images[i], translation.dot(homography),
                                                  (x_max - x_min, y_max - y_min))
            
            # 图像融合
            mask1 = panorama_transformed > 0
            mask2 = img2_transformed > 0
            
            # 简单叠加(实际应用中可以使用更复杂的融合方法)
            panorama_result = panorama_transformed.copy()
            panorama_result[mask2] = img2_transformed[mask2]
            
            # 重叠区域混合
            overlap = mask1 & mask2
            if np.any(overlap):
                # 简单平均混合
                panorama_result[overlap] = (panorama_transformed[overlap].astype(np.float32) * 0.5 + 
                                          img2_transformed[overlap].astype(np.float32) * 0.5).astype(np.uint8)
            
            panorama = panorama_result
    
    return panorama

# 使用示例
def demo_hybrid_feature_stitching():
    """
    演示混合特征图像拼接
    """
    # 读取图像
    images = []
    for i in range(1, 4):  # 假设有3张图像
        img = cv2.imread(f'image{i}.jpg')
        if img is not None:
            images.append(img)
    
    if len(images) < 2:
        print("需要至少2张图像进行拼接")
        return
    
    # 创建混合特征检测器
    traditional_detector = cv2.ORB_create(nfeatures=1000)
    
    # 注意:这里需要实际加载深度学习模型
    # deeplearning_detector = load_deeplearning_model()
    # 为了演示,我们使用一个简单的替代
    deeplearning_detector = None
    
    hybrid_detector = HybridFeatureDetector(traditional_detector, deeplearning_detector)
    
    # 图像拼接
    panorama = image_stitching(images, hybrid_detector)
    
    # 显示结果
    plt.figure(figsize=(15, 10))
    
    # 显示原始图像
    for i, img in enumerate(images):
        plt.subplot(2, len(images), i+1)
        plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.title(f'Image {i+1}')
        plt.axis('off')
    
    # 显示拼接结果
    plt.subplot(2, 1, 2)
    if panorama is not None:
        plt.imshow(cv2.cvtColor(panorama, cv2.COLOR_BGR2RGB))
        plt.title('拼接结果')
    else:
        plt.text(0.5, 0.5, '拼接失败', horizontalalignment='center',
                verticalalignment='center', transform=plt.gca().transAxes)
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # 保存结果
    if panorama is not None:
        cv2.imwrite('panorama_result.jpg', panorama)
        print("拼接结果已保存为 panorama_result.jpg")

第五部分:未来发展趋势与总结

5.1 特征点技术的发展趋势

  1. 更强的表示能力:未来的特征点方法将具有更强的表示能力,能够处理更复杂的场景和变换。

  2. 更高的效率:随着硬件的发展和算法优化,特征点检测和匹配将更加高效,满足实时应用需求。

  3. 更好的鲁棒性:对光照变化、视角变化、遮挡等具有更好的鲁棒性。

  4. 多模态融合:结合RGB、深度、红外等多种传感器信息,提高特征点的可靠性和准确性。

  5. 语义理解:结合语义分割和场景理解,提取更具语义意义的特征点。

5.2 传统与深度学习方法的融合趋势

  1. 优势互补:传统方法的高效性与深度学习方法的强表示能力相结合。

  2. 自适应选择:根据场景特性自动选择最合适的特征点方法。

  3. 联合优化:端到端的联合优化传统方法和深度学习方法。

  4. 知识蒸馏:使用深度学习模型指导传统方法的优化。

5.3 总结

传统特征点方法和深度学习关键点方法各有优势和局限性:

传统方法的优势

  • 计算效率高,适合实时应用

  • 原理明确,可解释性强

  • 不需要大量训练数据

  • 在特定条件下性能稳定

深度学习方法的优势

  • 表示能力强,能处理复杂场景

  • 端到端学习,可以直接优化目标任务

  • 对光照变化、视角变化等具有更好的鲁棒性

  • 可以结合上下文信息

融合方向

  1. 使用深度学习优化传统特征点的检测和描述

  2. 结合传统方法的效率和深度学习方法的鲁棒性

  3. 根据场景特性自适应选择特征点方法

  4. 构建多尺度、多模态的特征点系统

5.4 实际应用建议

  1. 实时应用:优先考虑传统方法(如ORB)或轻量级深度学习方法。

  2. 高精度要求:考虑使用深度学习方法(如SuperPoint、D2-Net)。

  3. 复杂场景:使用混合方法或自适应选择策略。

  4. 资源受限:考虑传统方法或模型压缩后的深度学习方法。

  5. 特定领域:根据领域特点选择或定制特征点方法。

5.5 代码资源与进一步学习

  1. OpenCV:提供了丰富的传统特征点实现。

  2. PyTorch/ TensorFlow:深度学习框架,用于实现和训练关键点检测模型。

  3. Kornia:基于PyTorch的计算机视觉库,包含特征点检测模块。

  4. GitHub项目:SuperPoint、D2-Net、LF-Net等开源实现。

  5. 学术论文:关注CVPR、ICCV、ECCV等顶级会议的最新研究成果。

特征点技术作为计算机视觉的基础,仍在不断发展和演进。传统方法与深度学习方法的融合将为计算机视觉应用带来新的可能性,推动图像匹配、三维重建、视觉定位等领域的进一步发展。


参考文献

  1. Lowe, D. G. (2004). Distinctive image features from scale-invariant keypoints. International journal of computer vision, 60(2), 91-110.

  2. Bay, H., Tuytelaars, T., & Van Gool, L. (2006). Surf: Speeded up robust features. In European conference on computer vision (pp. 404-417).

  3. Rublee, E., Rabaud, V., Konolige, K., & Bradski, G. (2011). ORB: An efficient alternative to SIFT or SURF. In International conference on computer vision (pp. 2564-2571).

  4. DeTone, D., Malisiewicz, T., & Rabinovich, A. (2018). Superpoint: Self-supervised interest point detection and description. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops (pp. 224-236).

  5. Dusmanu, M., Rocco, I., Pajdla, T., Pollefeys, M., Sivic, J., Torii, A., & Sattler, T. (2019). D2-net: A trainable cnn for joint detection and description of local features. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 8092-8101).

注意:本文中的代码示例需要适当的环境配置和依赖库,部分深度学习模型需要预训练权重。实际应用中需要根据具体需求进行调整和优化。

完整增強版代碼 一個文件完成

我们將创建一个完整的增强版代码文件,包含以下内容:

  1. 导入必要的库

  2. 传统特征点检测器(SIFT, SURF, ORB, AKAZE, BRISK)的封装

  3. 深度学习关键点检测器(SuperPoint, D2-Net)的模型定义和封装(由于预训练模型需要下载,这里将提供模型定义和加载代码,但实际使用需要下载权重)

  4. 混合特征点检测器的实现

  5. 评估函数和可视化函数

  6. 示例使用代码

注意:由于深度学习模型需要预训练权重,我们将提供模型定义,但权重需要用户自行下载。同时,我们将使用PyTorch来实现深度学习模型。

由于代码量很大,我们将尽量保持结构清晰,并添加详细注释。

我们将创建一个类,名为FeatureDetector,它可以根据配置选择不同的特征检测方法,包括传统方法和深度学习方法。

另外,我们还将提供一些工具函数,用于评估和可视化特征点。

由于篇幅限制,我们可能不会实现所有细节,但会提供一个完整的框架。

注意:SURF在OpenCV的默认版本中可能不可用,需要opencv-contrib-python。

开始编写代码:

完整增強版代碼:傳統與深度學習特徵點融合系統

python

"""
特徵點檢測與匹配完整系統:傳統方法與深度學習的融合
版本:2.0
作者:計算機視覺工程師
日期:2023年10月
"""

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from scipy.spatial import KDTree
from scipy.ndimage import maximum_filter, gaussian_filter
from typing import List, Tuple, Dict, Any, Optional, Union
import time
import warnings
import logging
from dataclasses import dataclass
from enum import Enum
import json
from pathlib import Path

# 深度學習相關導入
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    warnings.warn("PyTorch未安裝,深度學習功能將不可用")

# 設置日誌
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 枚舉定義
class FeatureType(Enum):
    """特徵點類型枚舉"""
    SIFT = "SIFT"
    SURF = "SURF"
    ORB = "ORB"
    AKAZE = "AKAZE"
    BRISK = "BRISK"
    SUPERPOINT = "SuperPoint"
    D2NET = "D2-Net"
    HYBRID = "Hybrid"
    CUSTOM = "Custom"

class FusionStrategy(Enum):
    """融合策略枚舉"""
    WEIGHTED = "weighted"
    ADAPTIVE = "adaptive"
    CASCADE = "cascade"
    ENSEMBLE = "ensemble"

# 數據類定義
@dataclass
class FeaturePoint:
    """特徵點數據類"""
    x: float
    y: float
    size: float = 1.0
    angle: float = -1.0
    response: float = 0.0
    octave: int = 0
    class_id: int = -1
    descriptor: Optional[np.ndarray] = None
    
    def to_cv_keypoint(self) -> cv2.KeyPoint:
        """轉換為OpenCV KeyPoint對象"""
        return cv2.KeyPoint(
            x=self.x,
            y=self.y,
            size=self.size,
            angle=self.angle,
            response=self.response,
            octave=self.octave,
            class_id=self.class_id
        )
    
    @classmethod
    def from_cv_keypoint(cls, kp: cv2.KeyPoint, descriptor: Optional[np.ndarray] = None):
        """從OpenCV KeyPoint創建"""
        return cls(
            x=kp.pt[0],
            y=kp.pt[1],
            size=kp.size,
            angle=kp.angle,
            response=kp.response,
            octave=kp.octave,
            class_id=kp.class_id,
            descriptor=descriptor
        )

@dataclass
class MatchResult:
    """匹配結果數據類"""
    matches: List[cv2.DMatch]
    homography: Optional[np.ndarray] = None
    inlier_mask: Optional[np.ndarray] = None
    match_score: float = 0.0
    processing_time: float = 0.0

# ==================== 圖像預處理模塊 ====================

class ImagePreprocessor:
    """圖像預處理器"""
    
    @staticmethod
    def load_image(path: str, grayscale: bool = True) -> np.ndarray:
        """加載圖像"""
        try:
            if grayscale:
                img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            else:
                img = cv2.imread(path, cv2.IMREAD_COLOR)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if img is None:
                raise ValueError(f"無法加載圖像: {path}")
            return img
        except Exception as e:
            logger.error(f"加載圖像失敗: {e}")
            raise
    
    @staticmethod
    def resize_image(image: np.ndarray, max_dim: int = 1024) -> np.ndarray:
        """調整圖像大小,保持長寬比"""
        h, w = image.shape[:2]
        
        if max(h, w) <= max_dim:
            return image
        
        scale = max_dim / max(h, w)
        new_w = int(w * scale)
        new_h = int(h * scale)
        
        return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
    
    @staticmethod
    def normalize_image(image: np.ndarray) -> np.ndarray:
        """圖像歸一化"""
        if image.dtype == np.uint8:
            image = image.astype(np.float32) / 255.0
        
        # 去均值,除方差
        if len(image.shape) == 3:
            mean = np.mean(image, axis=(0, 1), keepdims=True)
            std = np.std(image, axis=(0, 1), keepdims=True)
        else:
            mean = np.mean(image)
            std = np.std(image)
        
        return (image - mean) / (std + 1e-8)
    
    @staticmethod
    def enhance_contrast(image: np.ndarray, method: str = 'clahe') -> np.ndarray:
        """增強圖像對比度"""
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        if method == 'clahe':
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            enhanced = clahe.apply(gray.astype(np.uint8))
        elif method == 'histogram':
            enhanced = cv2.equalizeHist(gray.astype(np.uint8))
        elif method == 'gamma':
            # Gamma校正
            gamma = 1.5
            enhanced = np.power(gray.astype(np.float32) / 255.0, gamma) * 255
            enhanced = enhanced.astype(np.uint8)
        else:
            enhanced = gray
        
        return enhanced
    
    @staticmethod
    def prepare_pyramid(image: np.ndarray, levels: int = 3, scale_factor: float = 0.5) -> List[np.ndarray]:
        """創建圖像金字塔"""
        pyramid = [image]
        for i in range(1, levels):
            h, w = pyramid[-1].shape[:2]
            new_w = int(w * scale_factor)
            new_h = int(h * scale_factor)
            resized = cv2.resize(pyramid[-1], (new_w, new_h), interpolation=cv2.INTER_AREA)
            pyramid.append(resized)
        return pyramid
    
    @staticmethod
    def extract_patches(image: np.ndarray, keypoints: List[FeaturePoint], patch_size: int = 32) -> List[np.ndarray]:
        """提取關鍵點周圍的圖像塊"""
        patches = []
        half_size = patch_size // 2
        
        for kp in keypoints:
            x, y = int(kp.x), int(kp.y)
            
            # 確保坐標在圖像範圍內
            x_start = max(0, x - half_size)
            x_end = min(image.shape[1], x + half_size)
            y_start = max(0, y - half_size)
            y_end = min(image.shape[0], y + half_size)
            
            # 提取圖像塊
            patch = image[y_start:y_end, x_start:x_end]
            
            # 如果圖像塊太小,跳過
            if patch.shape[0] < patch_size or patch.shape[1] < patch_size:
                # 調整大小到目標尺寸
                patch = cv2.resize(patch, (patch_size, patch_size), interpolation=cv2.INTER_AREA)
            
            patches.append(patch)
        
        return patches

# ==================== 傳統特徵點檢測器模塊 ====================

class TraditionalFeatureDetector:
    """傳統特徵點檢測器基類"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.detector = None
        self.initialized = False
        
    def initialize(self):
        """初始化檢測器"""
        raise NotImplementedError
    
    def detect_and_compute(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        """檢測關鍵點並計算描述符"""
        raise NotImplementedError
    
    def get_config_summary(self) -> Dict[str, Any]:
        """獲取配置摘要"""
        return {
            "type": "Traditional",
            "name": self.__class__.__name__,
            "config": self.config
        }

class SIFTDetector(TraditionalFeatureDetector):
    """SIFT特徵點檢測器"""
    
    def initialize(self):
        try:
            self.detector = cv2.SIFT_create(
                nfeatures=self.config.get('nfeatures', 0),
                nOctaveLayers=self.config.get('nOctaveLayers', 3),
                contrastThreshold=self.config.get('contrastThreshold', 0.04),
                edgeThreshold=self.config.get('edgeThreshold', 10),
                sigma=self.config.get('sigma', 1.6)
            )
            self.initialized = True
            logger.info("SIFT檢測器初始化成功")
        except Exception as e:
            logger.error(f"SIFT檢測器初始化失敗: {e}")
    
    def detect_and_compute(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        if not self.initialized:
            self.initialize()
        
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        # 檢測關鍵點和描述符
        keypoints, descriptors = self.detector.detectAndCompute(gray, mask)
        
        # 轉換為FeaturePoint對象
        feature_points = [FeaturePoint.from_cv_keypoint(kp) for kp in keypoints]
        
        # 為每個特徵點分配描述符
        if descriptors is not None:
            for i, fp in enumerate(feature_points):
                fp.descriptor = descriptors[i]
        
        return feature_points, descriptors
    
    def get_config_summary(self) -> Dict[str, Any]:
        summary = super().get_config_summary()
        summary.update({
            "algorithm": "SIFT",
            "descriptor_size": 128,
            "descriptor_type": "float32"
        })
        return summary

class SURFDetector(TraditionalFeatureDetector):
    """SURF特徵點檢測器"""
    
    def initialize(self):
        try:
            # 檢查是否有SURF支持
            if hasattr(cv2, 'xfeatures2d'):
                self.detector = cv2.xfeatures2d.SURF_create(
                    hessianThreshold=self.config.get('hessianThreshold', 100),
                    nOctaves=self.config.get('nOctaves', 4),
                    nOctaveLayers=self.config.get('nOctaveLayers', 3),
                    extended=self.config.get('extended', False),
                    upright=self.config.get('upright', False)
                )
                self.initialized = True
                logger.info("SURF檢測器初始化成功")
            else:
                logger.warning("OpenCV未包含SURF支持,請安裝opencv-contrib-python")
                self.detector = None
        except Exception as e:
            logger.error(f"SURF檢測器初始化失敗: {e}")
    
    def detect_and_compute(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        if not self.initialized:
            self.initialize()
        
        if self.detector is None:
            return [], None
        
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        # 檢測關鍵點和描述符
        keypoints, descriptors = self.detector.detectAndCompute(gray, mask)
        
        # 轉換為FeaturePoint對象
        feature_points = [FeaturePoint.from_cv_keypoint(kp) for kp in keypoints]
        
        # 為每個特徵點分配描述符
        if descriptors is not None:
            for i, fp in enumerate(feature_points):
                fp.descriptor = descriptors[i]
        
        return feature_points, descriptors
    
    def get_config_summary(self) -> Dict[str, Any]:
        summary = super().get_config_summary()
        summary.update({
            "algorithm": "SURF",
            "descriptor_size": 64 if not self.config.get('extended', False) else 128,
            "descriptor_type": "float32"
        })
        return summary

class ORBDetector(TraditionalFeatureDetector):
    """ORB特徵點檢測器"""
    
    def initialize(self):
        try:
            self.detector = cv2.ORB_create(
                nfeatures=self.config.get('nfeatures', 500),
                scaleFactor=self.config.get('scaleFactor', 1.2),
                nlevels=self.config.get('nlevels', 8),
                edgeThreshold=self.config.get('edgeThreshold', 31),
                firstLevel=self.config.get('firstLevel', 0),
                WTA_K=self.config.get('WTA_K', 2),
                scoreType=self.config.get('scoreType', cv2.ORB_HARRIS_SCORE),
                patchSize=self.config.get('patchSize', 31),
                fastThreshold=self.config.get('fastThreshold', 20)
            )
            self.initialized = True
            logger.info("ORB檢測器初始化成功")
        except Exception as e:
            logger.error(f"ORB檢測器初始化失敗: {e}")
    
    def detect_and_compute(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        if not self.initialized:
            self.initialize()
        
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        # 檢測關鍵點和描述符
        keypoints, descriptors = self.detector.detectAndCompute(gray, mask)
        
        # 轉換為FeaturePoint對象
        feature_points = [FeaturePoint.from_cv_keypoint(kp) for kp in keypoints]
        
        # 為每個特徵點分配描述符
        if descriptors is not None:
            for i, fp in enumerate(feature_points):
                fp.descriptor = descriptors[i]
        
        return feature_points, descriptors
    
    def get_config_summary(self) -> Dict[str, Any]:
        summary = super().get_config_summary()
        summary.update({
            "algorithm": "ORB",
            "descriptor_size": 32,
            "descriptor_type": "uint8"
        })
        return summary

class AKAZEDetector(TraditionalFeatureDetector):
    """AKAZE特徵點檢測器"""
    
    def initialize(self):
        try:
            self.detector = cv2.AKAZE_create(
                descriptor_type=self.config.get('descriptor_type', cv2.AKAZE_DESCRIPTOR_MLDB),
                descriptor_size=self.config.get('descriptor_size', 0),
                descriptor_channels=self.config.get('descriptor_channels', 3),
                threshold=self.config.get('threshold', 0.001),
                nOctaves=self.config.get('nOctaves', 4),
                nOctaveLayers=self.config.get('nOctaveLayers', 4),
                diffusivity=self.config.get('diffusivity', cv2.KAZE_DIFF_PM_G2)
            )
            self.initialized = True
            logger.info("AKAZE檢測器初始化成功")
        except Exception as e:
            logger.error(f"AKAZE檢測器初始化失敗: {e}")
    
    def detect_and_compute(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        if not self.initialized:
            self.initialize()
        
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        # 檢測關鍵點和描述符
        keypoints, descriptors = self.detector.detectAndCompute(gray, mask)
        
        # 轉換為FeaturePoint對象
        feature_points = [FeaturePoint.from_cv_keypoint(kp) for kp in keypoints]
        
        # 為每個特徵點分配描述符
        if descriptors is not None:
            for i, fp in enumerate(feature_points):
                fp.descriptor = descriptors[i]
        
        return feature_points, descriptors
    
    def get_config_summary(self) -> Dict[str, Any]:
        summary = super().get_config_summary()
        summary.update({
            "algorithm": "AKAZE",
            "descriptor_size": 61,
            "descriptor_type": "uint8"
        })
        return summary

# ==================== 深度學習特徵點檢測器模塊 ====================

if TORCH_AVAILABLE:
    
    class BaseFeatureNetwork(nn.Module):
        """基礎特徵網絡"""
        
        def __init__(self, input_channels: int = 1, descriptor_dim: int = 256):
            super().__init__()
            self.input_channels = input_channels
            self.descriptor_dim = descriptor_dim
            
        def forward(self, x):
            raise NotImplementedError
        
        def detect_and_compute(self, image: np.ndarray, device: str = 'cpu'):
            """檢測關鍵點並計算描述符"""
            raise NotImplementedError
        
        def save_model(self, path: str):
            """保存模型"""
            torch.save({
                'model_state_dict': self.state_dict(),
                'config': self.get_config()
            }, path)
            logger.info(f"模型已保存到: {path}")
        
        def load_model(self, path: str, device: str = 'cpu'):
            """加載模型"""
            checkpoint = torch.load(path, map_location=device)
            self.load_state_dict(checkpoint['model_state_dict'])
            self.eval()
            logger.info(f"模型已從加載: {path}")
        
        def get_config(self) -> Dict[str, Any]:
            """獲取模型配置"""
            return {
                'input_channels': self.input_channels,
                'descriptor_dim': self.descriptor_dim
            }
    
    class SuperPoint(BaseFeatureNetwork):
        """SuperPoint網絡實現"""
        
        def __init__(self, input_channels: int = 1, descriptor_dim: int = 256):
            super().__init__(input_channels, descriptor_dim)
            
            # 共享編碼器
            self.encoder = nn.Sequential(
                nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            )
            
            # 檢測頭
            self.detector = nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
            )
            
            # 描述符頭
            self.descriptor = nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, descriptor_dim, kernel_size=1, stride=1, padding=0)
            )
            
        def forward(self, x):
            """前向傳播"""
            features = self.encoder(x)
            detector_output = self.detector(features)
            descriptor_output = self.descriptor(features)
            
            # 描述符L2歸一化
            descriptor_output = F.normalize(descriptor_output, p=2, dim=1)
            
            return detector_output, descriptor_output
        
        def detect_and_compute(self, image: np.ndarray, device: str = 'cpu', 
                             detection_threshold: float = 0.015, 
                             nms_size: int = 5, max_keypoints: int = 1000):
            """檢測關鍵點並計算描述符"""
            self.eval()
            
            # 預處理圖像
            if len(image.shape) == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            else:
                gray = image
            
            # 調整大小為8的倍數
            h, w = gray.shape
            new_h = (h // 8) * 8
            new_w = (w // 8) * 8
            if new_h != h or new_w != w:
                gray = cv2.resize(gray, (new_w, new_h))
            
            # 歸一化
            gray_tensor = torch.from_numpy(gray).float() / 255.0
            gray_tensor = gray_tensor.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
            gray_tensor = gray_tensor.to(device)
            
            with torch.no_grad():
                detector_output, descriptor_output = self.forward(gray_tensor)
            
            # 處理檢測器輸出
            detector_output = detector_output.squeeze(0)  # [65, H/8, W/8]
            
            # 轉換為概率
            prob = F.softmax(detector_output, dim=0)
            
            # 移除dustbin通道
            prob = prob[:-1, :, :]  # [64, H/8, W/8]
            
            # 重塑為熱圖
            hc, wc = prob.shape[1], prob.shape[2]
            prob = prob.permute(1, 2, 0).contiguous()  # [H/8, W/8, 64]
            heatmap = prob.view(hc, wc, 8, 8).permute(0, 2, 1, 3).contiguous()
            heatmap = heatmap.view(hc * 8, wc * 8)  # [H, W]
            
            heatmap_np = heatmap.cpu().numpy()
            
            # 非極大值抑制
            data_max = maximum_filter(heatmap_np, size=nms_size)
            maxima = (heatmap_np == data_max)
            
            # 應用閾值
            maxima[heatmap_np < detection_threshold] = 0
            
            # 獲取關鍵點坐標
            yx = np.argwhere(maxima)
            
            # 按響應值排序
            responses = heatmap_np[yx[:, 0], yx[:, 1]]
            if len(responses) > max_keypoints:
                idx = np.argsort(responses)[::-1][:max_keypoints]
                yx = yx[idx]
                responses = responses[idx]
            
            # 轉換為原始圖像坐標
            scale_h = h / heatmap_np.shape[0]
            scale_w = w / heatmap_np.shape[1]
            
            keypoints = []
            for (y, x), response in zip(yx, responses):
                x_orig = x * scale_w
                y_orig = y * scale_h
                keypoints.append(FeaturePoint(x=x_orig, y=y_orig, response=response))
            
            # 提取描述符
            descriptor_output = descriptor_output.squeeze(0)  # [D, H/8, W/8]
            descriptor_np = descriptor_output.cpu().numpy()
            
            descriptors = []
            for kp in keypoints:
                # 在特徵圖上的坐標
                x_feat = kp.x * (wc * 8 / w) / 8
                y_feat = kp.y * (hc * 8 / h) / 8
                
                # 雙線性插值
                x0, y0 = int(np.floor(x_feat)), int(np.floor(y_feat))
                x1, y1 = x0 + 1, y0 + 1
                
                # 邊界檢查
                x0 = max(0, min(x0, descriptor_np.shape[2] - 1))
                x1 = max(0, min(x1, descriptor_np.shape[2] - 1))
                y0 = max(0, min(y0, descriptor_np.shape[1] - 1))
                y1 = max(0, min(y1, descriptor_np.shape[1] - 1))
                
                # 權重
                wa = (x1 - x_feat) * (y1 - y_feat)
                wb = (x1 - x_feat) * (y_feat - y0)
                wc = (x_feat - x0) * (y1 - y_feat)
                wd = (x_feat - x0) * (y_feat - y0)
                
                # 插值描述符
                desc_a = descriptor_np[:, y0, x0]
                desc_b = descriptor_np[:, y1, x0]
                desc_c = descriptor_np[:, y0, x1]
                desc_d = descriptor_np[:, y1, x1]
                
                descriptor = wa * desc_a + wb * desc_b + wc * desc_c + wd * desc_d
                kp.descriptor = descriptor
                descriptors.append(descriptor)
            
            descriptors = np.array(descriptors) if descriptors else None
            
            return keypoints, descriptors
        
    class D2Net(BaseFeatureNetwork):
        """D2-Net網絡實現"""
        
        def __init__(self, input_channels: int = 3, descriptor_dim: int = 128):
            super().__init__(input_channels, descriptor_dim)
            
            # 特徵提取網絡
            self.features = nn.Sequential(
                nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
            )
            
            # 描述符輸出層
            self.descriptor = nn.Conv2d(512, descriptor_dim, kernel_size=1)
            
        def forward(self, x):
            """前向傳播"""
            features = self.features(x)
            descriptors = self.descriptor(features)
            
            # L2歸一化
            descriptors = F.normalize(descriptors, p=2, dim=1)
            
            return features, descriptors
        
        def detect_and_compute(self, image: np.ndarray, device: str = 'cpu',
                             detection_threshold: float = 0.015,
                             nms_size: int = 3, max_keypoints: int = 1000):
            """檢測關鍵點並計算描述符"""
            self.eval()
            
            # 確保圖像是RGB
            if len(image.shape) == 2:
                image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            elif image.shape[2] == 4:
                image_rgb = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
            else:
                image_rgb = image
            
            h, w = image_rgb.shape[:2]
            
            # 調整大小為8的倍數
            new_h = (h // 8) * 8
            new_w = (w // 8) * 8
            if new_h != h or new_w != w:
                image_rgb = cv2.resize(image_rgb, (new_w, new_h))
            
            # 轉換為張量
            image_tensor = torch.from_numpy(image_rgb).float().permute(2, 0, 1) / 255.0
            image_tensor = image_tensor.unsqueeze(0)  # [1, 3, H, W]
            image_tensor = image_tensor.to(device)
            
            with torch.no_grad():
                features, descriptors = self.forward(image_tensor)
            
            # 提取特徵和描述符
            features_np = features.squeeze(0).cpu().numpy()  # [512, H/8, W/8]
            descriptors_np = descriptors.squeeze(0).cpu().numpy()  # [D, H/8, W/8]
            
            # D2-Net檢測策略:在特徵維度上尋找最大值
            detection_map = np.linalg.norm(features_np, axis=0)  # [H/8, W/8]
            
            # 非極大值抑制
            data_max = maximum_filter(detection_map, size=nms_size)
            maxima = (detection_map == data_max)
            
            # 應用閾值
            maxima[detection_map < detection_threshold] = 0
            
            # 獲取關鍵點坐標
            yx = np.argwhere(maxima)
            
            # 按響應值排序
            responses = detection_map[yx[:, 0], yx[:, 1]]
            if len(responses) > max_keypoints:
                idx = np.argsort(responses)[::-1][:max_keypoints]
                yx = yx[idx]
                responses = responses[idx]
            
            # 轉換為原始圖像坐標
            scale_h = h / (detection_map.shape[0] * 8)
            scale_w = w / (detection_map.shape[1] * 8)
            
            keypoints = []
            for (y, x), response in zip(yx, responses):
                x_orig = x * 8 * scale_w
                y_orig = y * 8 * scale_h
                keypoints.append(FeaturePoint(x=x_orig, y=y_orig, response=response))
            
            # 提取描述符
            descriptors_list = []
            for kp in keypoints:
                # 在特徵圖上的坐標
                x_feat = kp.x * (descriptors_np.shape[2] * 8 / w) / 8
                y_feat = kp.y * (descriptors_np.shape[1] * 8 / h) / 8
                
                # 雙線性插值
                x0, y0 = int(np.floor(x_feat)), int(np.floor(y_feat))
                x1, y1 = x0 + 1, y0 + 1
                
                # 邊界檢查
                x0 = max(0, min(x0, descriptors_np.shape[2] - 1))
                x1 = max(0, min(x1, descriptors_np.shape[2] - 1))
                y0 = max(0, min(y0, descriptors_np.shape[1] - 1))
                y1 = max(0, min(y1, descriptors_np.shape[1] - 1))
                
                # 權重
                wa = (x1 - x_feat) * (y1 - y_feat)
                wb = (x1 - x_feat) * (y_feat - y0)
                wc = (x_feat - x0) * (y1 - y_feat)
                wd = (x_feat - x0) * (y_feat - y0)
                
                # 插值描述符
                desc_a = descriptors_np[:, y0, x0]
                desc_b = descriptors_np[:, y1, x0]
                desc_c = descriptors_np[:, y0, x1]
                desc_d = descriptors_np[:, y1, x1]
                
                descriptor = wa * desc_a + wb * desc_b + wc * desc_c + wd * desc_d
                kp.descriptor = descriptor
                descriptors_list.append(descriptor)
            
            descriptors = np.array(descriptors_list) if descriptors_list else None
            
            return keypoints, descriptors

# ==================== 混合特徵點檢測器模塊 ====================

class HybridFeatureDetector:
    """混合特徵點檢測器:結合傳統和深度學習方法"""
    
    def __init__(self, 
                 traditional_detector: TraditionalFeatureDetector,
                 deeplearning_detector: Optional[Any] = None,
                 fusion_strategy: FusionStrategy = FusionStrategy.ADAPTIVE,
                 config: Dict[str, Any] = None):
        """
        初始化混合檢測器
        
        參數:
            traditional_detector: 傳統檢測器
            deeplearning_detector: 深度學習檢測器(可選)
            fusion_strategy: 融合策略
            config: 配置字典
        """
        self.traditional_detector = traditional_detector
        self.deeplearning_detector = deeplearning_detector
        self.fusion_strategy = fusion_strategy
        self.config = config or {}
        
        # 初始化權重
        self.traditional_weight = self.config.get('traditional_weight', 0.5)
        self.deeplearning_weight = self.config.get('deeplearning_weight', 0.5)
        
        # 性能統計
        self.stats = {
            'detection_time': [],
            'keypoint_counts': {'traditional': 0, 'deeplearning': 0, 'total': 0},
            'quality_scores': {'traditional': 0.0, 'deeplearning': 0.0}
        }
        
        logger.info(f"初始化混合檢測器,融合策略: {fusion_strategy.value}")
    
    def analyze_image_characteristics(self, image: np.ndarray) -> Dict[str, Any]:
        """分析圖像特性以自適應調整融合策略"""
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        characteristics = {}
        
        # 計算對比度(標準差)
        contrast = np.std(gray)
        characteristics['contrast'] = contrast
        characteristics['low_contrast'] = contrast < 30
        
        # 計算高頻成分(拉普拉斯方差)
        laplacian = cv2.Laplacian(gray, cv2.CV_64F)
        laplacian_var = laplacian.var()
        characteristics['laplacian_variance'] = laplacian_var
        characteristics['high_frequency'] = laplacian_var > 1000
        characteristics['blurry'] = laplacian_var < 100
        
        # 計算熵(紋理複雜度)
        hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
        hist = hist / hist.sum() + 1e-8
        entropy = -np.sum(hist * np.log2(hist))
        characteristics['entropy'] = entropy
        characteristics['texture_rich'] = entropy > 6.0
        
        # 計算梯度幅度
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        gradient_magnitude = np.sqrt(sobelx**2 + sobely**2)
        characteristics['gradient_mean'] = np.mean(gradient_magnitude)
        
        return characteristics
    
    def adaptive_weight_calculation(self, image: np.ndarray) -> float:
        """自適應計算傳統方法權重"""
        characteristics = self.analyze_image_characteristics(image)
        
        # 基礎權重
        base_weight = 0.5
        
        # 根據圖像特性調整權重
        if characteristics['low_contrast']:
            # 低對比度圖像,更依賴深度學習方法
            base_weight -= 0.3
        elif characteristics['texture_rich']:
            # 紋理豐富,更依賴傳統方法
            base_weight += 0.2
        
        if characteristics['blurry']:
            # 模糊圖像,更依賴深度學習方法
            base_weight -= 0.2
        elif characteristics['high_frequency']:
            # 高頻圖像,更依賴傳統方法
            base_weight += 0.1
        
        # 確保權重在[0, 1]範圍內
        base_weight = max(0.1, min(0.9, base_weight))
        
        return base_weight
    
    def weighted_fusion(self, trad_kps: List[FeaturePoint], trad_descs: Optional[np.ndarray],
                       dl_kps: List[FeaturePoint], dl_descs: Optional[np.ndarray],
                       trad_weight: float) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        """加權融合策略"""
        if not trad_kps and not dl_kps:
            return [], None
        
        fused_keypoints = []
        fused_descriptors = []
        
        # 傳統方法關鍵點
        if trad_kps and trad_descs is not None:
            for i, kp in enumerate(trad_kps):
                # 根據權重隨機選擇
                if np.random.rand() < trad_weight:
                    fused_keypoints.append(kp)
                    if trad_descs is not None and i < len(trad_descs):
                        fused_descriptors.append(trad_descs[i])
        
        # 深度學習方法關鍵點
        if dl_kps and dl_descs is not None:
            for i, kp in enumerate(dl_kps):
                # 根據權重隨機選擇
                if np.random.rand() < (1 - trad_weight):
                    fused_keypoints.append(kp)
                    if dl_descs is not None and i < len(dl_descs):
                        fused_descriptors.append(dl_descs[i])
        
        # 如果都沒有選擇到,至少選擇一些關鍵點
        if len(fused_keypoints) < 10:
            # 從兩種方法中各選擇一些
            num_to_add = max(5, 10 - len(fused_keypoints))
            
            if trad_kps:
                for kp in trad_kps[:num_to_add//2]:
                    if kp not in fused_keypoints:
                        fused_keypoints.append(kp)
            
            if dl_kps:
                for kp in dl_kps[:num_to_add//2]:
                    if kp not in fused_keypoints:
                        fused_keypoints.append(kp)
        
        # 構建描述符矩陣
        if fused_descriptors:
            # 檢查描述符類型是否一致
            first_desc = fused_descriptors[0]
            desc_type = first_desc.dtype
            
            # 轉換為相同類型
            converted_descriptors = []
            for desc in fused_descriptors:
                if desc.dtype != desc_type:
                    if desc_type == np.float32:
                        desc = desc.astype(np.float32)
                    else:
                        desc = desc.astype(np.uint8)
                converted_descriptors.append(desc)
            
            fused_descriptors = np.array(converted_descriptors)
        else:
            fused_descriptors = None
        
        return fused_keypoints, fused_descriptors
    
    def cascade_fusion(self, trad_kps: List[FeaturePoint], trad_descs: Optional[np.ndarray],
                      dl_kps: List[FeaturePoint], dl_descs: Optional[np.ndarray]) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        """級聯融合策略"""
        fused_keypoints = list(trad_kps)
        fused_descriptors = []
        
        # 添加傳統方法描述符
        if trad_descs is not None:
            for i in range(len(trad_kps)):
                fused_descriptors.append(trad_descs[i])
        
        # 如果傳統方法關鍵點太少,用深度學習方法補充
        min_keypoints = self.config.get('min_keypoints', 50)
        if len(fused_keypoints) < min_keypoints and dl_kps:
            # 按響應值排序
            dl_kps_sorted = sorted(dl_kps, key=lambda kp: kp.response, reverse=True)
            
            num_to_add = min(min_keypoints - len(fused_keypoints), len(dl_kps_sorted))
            
            for i in range(num_to_add):
                kp = dl_kps_sorted[i]
                
                # 檢查是否與現有關鍵點太近
                too_close = False
                for existing_kp in fused_keypoints:
                    dist = np.sqrt((kp.x - existing_kp.x)**2 + (kp.y - existing_kp.y)**2)
                    if dist < self.config.get('min_distance', 10):
                        too_close = True
                        break
                
                if not too_close:
                    fused_keypoints.append(kp)
                    if dl_descs is not None and i < len(dl_descs):
                        fused_descriptors.append(dl_descs[i])
        
        # 構建描述符矩陣
        if fused_descriptors:
            fused_descriptors = np.array(fused_descriptors)
        else:
            fused_descriptors = None
        
        return fused_keypoints, fused_descriptors
    
    def ensemble_fusion(self, trad_kps: List[FeaturePoint], trad_descs: Optional[np.ndarray],
                       dl_kps: List[FeaturePoint], dl_descs: Optional[np.ndarray]) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        """集成融合策略:保留所有關鍵點,但去重"""
        all_keypoints = list(trad_kps) + list(dl_kps)
        all_descriptors = []
        
        # 收集所有描述符
        if trad_descs is not None:
            all_descriptors.extend(trad_descs)
        if dl_descs is not None:
            all_descriptors.extend(dl_descs)
        
        # 空間聚類去重
        if len(all_keypoints) > 0:
            positions = np.array([[kp.x, kp.y] for kp in all_keypoints])
            
            # 使用KDTree查找近鄰
            tree = KDTree(positions)
            
            # 標記要保留的關鍵點
            keep_mask = np.ones(len(all_keypoints), dtype=bool)
            min_distance = self.config.get('min_distance', 5)
            
            for i in range(len(all_keypoints)):
                if keep_mask[i]:
                    # 查找所有距離小於閾值的點
                    indices = tree.query_ball_point(positions[i], min_distance)
                    
                    # 在同類點中保留響應值最高的
                    if len(indices) > 1:
                        responses = [all_keypoints[idx].response for idx in indices]
                        max_idx = indices[np.argmax(responses)]
                        
                        # 標記其他點為刪除
                        for idx in indices:
                            if idx != max_idx:
                                keep_mask[idx] = False
            
            # 應用篩選
            fused_keypoints = [all_keypoints[i] for i in range(len(all_keypoints)) if keep_mask[i]]
            
            if all_descriptors:
                fused_descriptors = [all_descriptors[i] for i in range(len(all_keypoints)) if keep_mask[i]]
                fused_descriptors = np.array(fused_descriptors)
            else:
                fused_descriptors = None
        else:
            fused_keypoints = []
            fused_descriptors = None
        
        return fused_keypoints, fused_descriptors
    
    def detect_and_compute(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[List[FeaturePoint], Optional[np.ndarray]]:
        """檢測關鍵點並計算描述符"""
        start_time = time.time()
        
        # 傳統方法檢測
        trad_kps, trad_descs = self.traditional_detector.detect_and_compute(image, mask)
        trad_time = time.time()
        
        # 深度學習方法檢測(如果可用)
        dl_kps, dl_descs = [], None
        if self.deeplearning_detector is not None and TORCH_AVAILABLE:
            try:
                dl_kps, dl_descs = self.deeplearning_detector.detect_and_compute(image)
            except Exception as e:
                logger.warning(f"深度學習檢測失敗: {e}")
                dl_kps, dl_descs = [], None
        
        dl_time = time.time()
        
        # 更新統計信息
        self.stats['detection_time'].append({
            'traditional': trad_time - start_time,
            'deeplearning': dl_time - trad_time if dl_kps else 0,
            'total': time.time() - start_time
        })
        
        self.stats['keypoint_counts']['traditional'] = len(trad_kps)
        self.stats['keypoint_counts']['deeplearning'] = len(dl_kps)
        
        # 根據融合策略合併結果
        if self.fusion_strategy == FusionStrategy.WEIGHTED:
            # 使用固定權重
            trad_weight = self.traditional_weight
            fused_kps, fused_descs = self.weighted_fusion(trad_kps, trad_descs, dl_kps, dl_descs, trad_weight)
            
        elif self.fusion_strategy == FusionStrategy.ADAPTIVE:
            # 自適應權重
            trad_weight = self.adaptive_weight_calculation(image)
            fused_kps, fused_descs = self.weighted_fusion(trad_kps, trad_descs, dl_kps, dl_descs, trad_weight)
            
        elif self.fusion_strategy == FusionStrategy.CASCADE:
            # 級聯融合
            fused_kps, fused_descs = self.cascade_fusion(trad_kps, trad_descs, dl_kps, dl_descs)
            
        elif self.fusion_strategy == FusionStrategy.ENSEMBLE:
            # 集成融合
            fused_kps, fused_descs = self.ensemble_fusion(trad_kps, trad_descs, dl_kps, dl_descs)
            
        else:
            # 默認使用傳統方法
            fused_kps, fused_descs = trad_kps, trad_descs
        
        self.stats['keypoint_counts']['total'] = len(fused_kps)
        
        # 計算質量分數
        if trad_kps:
            trad_responses = [kp.response for kp in trad_kps]
            self.stats['quality_scores']['traditional'] = np.mean(trad_responses) if trad_responses else 0
        
        if dl_kps:
            dl_responses = [kp.response for kp in dl_kps]
            self.stats['quality_scores']['deeplearning'] = np.mean(dl_responses) if dl_responses else 0
        
        logger.info(f"混合檢測完成: 傳統{len(trad_kps)}個, 深度學習{len(dl_kps)}個, 融合後{len(fused_kps)}個關鍵點")
        
        return fused_kps, fused_descs
    
    def get_stats(self) -> Dict[str, Any]:
        """獲取統計信息"""
        if self.stats['detection_time']:
            avg_times = {
                'traditional': np.mean([t['traditional'] for t in self.stats['detection_time']]),
                'deeplearning': np.mean([t['deeplearning'] for t in self.stats['detection_time']]),
                'total': np.mean([t['total'] for t in self.stats['detection_time']])
            }
        else:
            avg_times = {'traditional': 0, 'deeplearning': 0, 'total': 0}
        
        return {
            'detection_times': avg_times,
            'keypoint_counts': self.stats['keypoint_counts'],
            'quality_scores': self.stats['quality_scores'],
            'fusion_strategy': self.fusion_strategy.value
        }

# ==================== 特徵匹配器模塊 ====================

class FeatureMatcher:
    """特徵匹配器"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.matcher_type = self.config.get('matcher_type', 'bf')
        self.distance_threshold = self.config.get('distance_threshold', 0.75)
        self.ransac_threshold = self.config.get('ransac_threshold', 5.0)
        self.min_matches = self.config.get('min_matches', 10)
        
    def match_features(self, 
                      kp1: List[FeaturePoint], desc1: np.ndarray,
                      kp2: List[FeaturePoint], desc2: np.ndarray) -> MatchResult:
        """匹配特徵點"""
        start_time = time.time()
        
        # 檢查輸入
        if desc1 is None or desc2 is None or len(desc1) == 0 or len(desc2) == 0:
            return MatchResult([], processing_time=time.time() - start_time)
        
        # 轉換為OpenCV格式
        cv_kp1 = [kp.to_cv_keypoint() for kp in kp1]
        cv_kp2 = [kp.to_cv_keypoint() for kp in kp2]
        
        matches = []
        homography = None
        inlier_mask = None
        
        try:
            # 根據描述符類型選擇匹配器
            if desc1.dtype == np.uint8 or desc2.dtype == np.uint8:
                # 二進制描述符(ORB, AKAZE等)
                if self.matcher_type == 'bf':
                    matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
                    matches = matcher.match(desc1, desc2)
                elif self.matcher_type == 'flann':
                    # FLANN對於二進制描述符需要特殊索引
                    FLANN_INDEX_LSH = 6
                    index_params = dict(algorithm=FLANN_INDEX_LSH,
                                       table_number=6,
                                       key_size=12,
                                       multi_probe_level=1)
                    search_params = dict(checks=50)
                    
                    flann = cv2.FlannBasedMatcher(index_params, search_params)
                    knn_matches = flann.knnMatch(desc1, desc2, k=2)
                    
                    # 應用Lowe's比率測試
                    matches = []
                    for m, n in knn_matches:
                        if m.distance < self.distance_threshold * n.distance:
                            matches.append(m)
            else:
                # 浮點描述符(SIFT, SURF等)
                if self.matcher_type == 'bf':
                    matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
                    knn_matches = matcher.knnMatch(desc1, desc2, k=2)
                    
                    # 應用Lowe's比率測試
                    matches = []
                    for m, n in knn_matches:
                        if m.distance < self.distance_threshold * n.distance:
                            matches.append(m)
                elif self.matcher_type == 'flann':
                    # FLANN對於浮點描述符
                    FLANN_INDEX_KDTREE = 1
                    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
                    search_params = dict(checks=50)
                    
                    flann = cv2.FlannBasedMatcher(index_params, search_params)
                    knn_matches = flann.knnMatch(desc1, desc2, k=2)
                    
                    # 應用Lowe's比率測試
                    matches = []
                    for m, n in knn_matches:
                        if m.distance < self.distance_threshold * n.distance:
                            matches.append(m)
            
            # 按距離排序
            matches = sorted(matches, key=lambda x: x.distance)
            
            # 計算單應性矩陣(如果有足夠的匹配)
            if len(matches) >= self.min_matches:
                src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
                dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
                
                # 使用RANSAC計算單應性矩陣
                homography, inlier_mask = cv2.findHomography(
                    src_pts, dst_pts, cv2.RANSAC, self.ransac_threshold
                )
                
                # 計算匹配分數
                if inlier_mask is not None:
                    inlier_count = np.sum(inlier_mask)
                    match_score = inlier_count / max(len(kp1), len(kp2))
                else:
                    match_score = len(matches) / max(len(kp1), len(kp2))
            else:
                match_score = len(matches) / max(len(kp1), len(kp2)) if max(len(kp1), len(kp2)) > 0 else 0
            
        except Exception as e:
            logger.error(f"特徵匹配失敗: {e}")
            matches = []
            match_score = 0
        
        processing_time = time.time() - start_time
        
        return MatchResult(
            matches=matches,
            homography=homography,
            inlier_mask=inlier_mask,
            match_score=match_score,
            processing_time=processing_time
        )
    
    def geometric_verification(self, kp1: List[FeaturePoint], kp2: List[FeaturePoint],
                              matches: List[cv2.DMatch], method: str = 'homography') -> Tuple[np.ndarray, np.ndarray]:
        """幾何驗證"""
        if len(matches) < 4:
            return None, None
        
        src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 2)
        
        if method == 'homography':
            # 計算單應性矩陣
            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        elif method == 'fundamental':
            # 計算基礎矩陣
            M, mask = cv2.findFundamentalMat(src_pts, dst_pts, cv2.FM_RANSAC)
        elif method == 'affine':
            # 計算仿射變換
            M, mask = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC)
            if M is not None:
                # 將仿射矩陣轉換為3x3齊次矩陣
                M_homog = np.eye(3)
                M_homog[:2, :] = M
                M = M_homog
        else:
            M, mask = None, None
        
        return M, mask

# ==================== 可視化模塊 ====================

class Visualization:
    """可視化工具"""
    
    @staticmethod
    def draw_keypoints(image: np.ndarray, keypoints: List[FeaturePoint], 
                      color: Tuple[int, int, int] = (0, 255, 0),
                      radius: int = 3, thickness: int = 1) -> np.ndarray:
        """繪製關鍵點"""
        if len(image.shape) == 2:
            vis = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        else:
            vis = image.copy()
        
        for kp in keypoints:
            center = (int(kp.x), int(kp.y))
            cv2.circle(vis, center, radius, color, thickness)
            
            # 繪製方向(如果有)
            if kp.angle >= 0:
                angle_rad = kp.angle * np.pi / 180.0
                end_x = int(kp.x + radius * 2 * np.cos(angle_rad))
                end_y = int(kp.y + radius * 2 * np.sin(angle_rad))
                cv2.line(vis, center, (end_x, end_y), color, 1)
        
        return vis
    
    @staticmethod
    def draw_matches(img1: np.ndarray, kp1: List[FeaturePoint],
                    img2: np.ndarray, kp2: List[FeaturePoint],
                    matches: List[cv2.DMatch],
                    inlier_mask: Optional[np.ndarray] = None,
                    max_matches: int = 50) -> np.ndarray:
        """繪製匹配結果"""
        # 轉換為OpenCV格式
        cv_kp1 = [kp.to_cv_keypoint() for kp in kp1]
        cv_kp2 = [kp.to_cv_keypoint() for kp in kp2]
        
        # 限制匹配數量
        if len(matches) > max_matches:
            matches = matches[:max_matches]
        
        # 創建掩碼
        if inlier_mask is not None and len(inlier_mask) >= len(matches):
            matches_mask = inlier_mask[:len(matches)].ravel().tolist()
        else:
            matches_mask = None
        
        # 繪製匹配
        draw_params = dict(
            matchColor=(0, 255, 0),
            singlePointColor=(255, 0, 0),
            matchesMask=matches_mask,
            flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS
        )
        
        return cv2.drawMatches(img1, cv_kp1, img2, cv_kp2, matches, None, **draw_params)
    
    @staticmethod
    def plot_feature_comparison(images: List[np.ndarray], 
                               keypoints_list: List[List[FeaturePoint]],
                               titles: List[str]) -> plt.Figure:
        """繪製特徵點比較圖"""
        fig, axes = plt.subplots(2, len(images), figsize=(5*len(images), 10))
        
        if len(images) == 1:
            axes = axes.reshape(2, 1)
        
        for i, (img, kps, title) in enumerate(zip(images, keypoints_list, titles)):
            # 原始圖像
            if len(img.shape) == 2:
                axes[0, i].imshow(img, cmap='gray')
            else:
                axes[0, i].imshow(img)
            axes[0, i].set_title(f"{title}\n原始圖像")
            axes[0, i].axis('off')
            
            # 特徵點圖像
            vis_img = Visualization.draw_keypoints(img, kps)
            axes[1, i].imshow(vis_img)
            axes[1, i].set_title(f"{title}\n特徵點數量: {len(kps)}")
            axes[1, i].axis('off')
        
        plt.tight_layout()
        return fig
    
    @staticmethod
    def plot_matching_results(img1: np.ndarray, kp1: List[FeaturePoint],
                             img2: np.ndarray, kp2: List[FeaturePoint],
                             match_result: MatchResult) -> plt.Figure:
        """繪製匹配結果"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # 圖像1特徵點
        vis1 = Visualization.draw_keypoints(img1, kp1)
        axes[0, 0].imshow(vis1)
        axes[0, 0].set_title(f"圖像1 - 特徵點: {len(kp1)}")
        axes[0, 0].axis('off')
        
        # 圖像2特徵點
        vis2 = Visualization.draw_keypoints(img2, kp2)
        axes[0, 1].imshow(vis2)
        axes[0, 1].set_title(f"圖像2 - 特徵點: {len(kp2)}")
        axes[0, 1].axis('off')
        
        # 匹配結果
        if match_result.matches:
            matches_img = Visualization.draw_matches(img1, kp1, img2, kp2, match_result.matches, match_result.inlier_mask)
            axes[1, 0].imshow(matches_img)
            
            # 統計信息
            total_matches = len(match_result.matches)
            inlier_count = np.sum(match_result.inlier_mask) if match_result.inlier_mask is not None else total_matches
            
            axes[1, 0].set_title(f"特徵匹配\n總匹配: {total_matches}, 內點: {inlier_count}, 分數: {match_result.match_score:.3f}")
            axes[1, 0].axis('off')
        else:
            axes[1, 0].text(0.5, 0.5, "無匹配", ha='center', va='center', transform=axes[1, 0].transAxes)
            axes[1, 0].axis('off')
        
        # 性能信息
        axes[1, 1].axis('off')
        info_text = f"匹配性能\n"
        info_text += f"處理時間: {match_result.processing_time:.3f}秒\n"
        info_text += f"匹配分數: {match_result.match_score:.3f}\n"
        
        if match_result.homography is not None:
            info_text += f"單應性矩陣計算成功\n"
        
        axes[1, 1].text(0.1, 0.9, info_text, ha='left', va='top', transform=axes[1, 1].transAxes, fontsize=10)
        
        plt.tight_layout()
        return fig
    
    @staticmethod
    def plot_performance_comparison(results: Dict[str, Dict[str, Any]]) -> plt.Figure:
        """繪製性能比較圖"""
        methods = list(results.keys())
        
        # 提取數據
        detection_times = [results[m]['detection_time'] for m in methods]
        keypoint_counts = [results[m]['keypoint_count'] for m in methods]
        match_scores = [results[m]['match_score'] for m in methods]
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # 檢測時間
        axes[0].bar(methods, detection_times, color='skyblue')
        axes[0].set_title('平均檢測時間')
        axes[0].set_ylabel('時間(秒)')
        axes[0].tick_params(axis='x', rotation=45)
        
        # 關鍵點數量
        axes[1].bar(methods, keypoint_counts, color='lightgreen')
        axes[1].set_title('平均關鍵點數量')
        axes[1].set_ylabel('數量')
        axes[1].tick_params(axis='x', rotation=45)
        
        # 匹配分數
        axes[2].bar(methods, match_scores, color='lightcoral')
        axes[2].set_title('平均匹配分數')
        axes[2].set_ylabel('分數')
        axes[2].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        return fig

# ==================== 評估模塊 ====================

class Evaluator:
    """評估器"""
    
    @staticmethod
    def calculate_repeatability(kp1: List[FeaturePoint], kp2: List[FeaturePoint],
                               homography: Optional[np.ndarray] = None,
                               threshold: float = 3.0) -> float:
        """計算重複性"""
        if not kp1 or not kp2:
            return 0.0
        
        # 轉換為坐標數組
        pts1 = np.array([[kp.x, kp.y] for kp in kp1])
        pts2 = np.array([[kp.x, kp.y] for kp in kp2])
        
        # 如果有單應性矩陣,變換點集
        if homography is not None:
            # 將pts1變換到圖像2的坐標系
            pts1_homo = np.hstack([pts1, np.ones((len(pts1), 1))])
            pts1_transformed = (homography @ pts1_homo.T).T
            pts1_transformed = pts1_transformed[:, :2] / pts1_transformed[:, 2:3]
            pts1_to_use = pts1_transformed
        else:
            pts1_to_use = pts1
        
        # 使用KDTree查找最近鄰
        tree = KDTree(pts2)
        distances, _ = tree.query(pts1_to_use)
        
        # 計算重複關鍵點數量
        repeatable_count = np.sum(distances < threshold)
        repeatability = repeatable_count / len(kp1)
        
        return repeatability
    
    @staticmethod
    def calculate_precision_recall(matches: List[cv2.DMatch], 
                                  inlier_mask: Optional[np.ndarray] = None,
                                  total_possible: int = 0) -> Tuple[float, float]:
        """計算精確率和召回率"""
        if not matches:
            return 0.0, 0.0
        
        if inlier_mask is not None:
            true_positives = np.sum(inlier_mask)
            false_positives = len(matches) - true_positives
        else:
            # 如果沒有內點掩碼,假設所有匹配都是真陽性
            true_positives = len(matches)
            false_positives = 0
        
        precision = true_positives / (true_positives + false_positives + 1e-8)
        
        if total_possible > 0:
            recall = true_positives / total_possible
        else:
            recall = precision  # 近似值
        
        return precision, recall
    
    @staticmethod
    def evaluate_detector(detector, images: List[np.ndarray], 
                         ground_truth_homographies: Optional[List[np.ndarray]] = None) -> Dict[str, Any]:
        """評估檢測器性能"""
        results = {
            'detection_times': [],
            'keypoint_counts': [],
            'repeatabilities': [],
            'match_scores': []
        }
        
        for i in range(len(images) - 1):
            img1 = images[i]
            img2 = images[i + 1]
            
            # 檢測特徵點
            start_time = time.time()
            kp1, desc1 = detector.detect_and_compute(img1)
            kp2, desc2 = detector.detect_and_compute(img2)
            detection_time = time.time() - start_time
            
            # 匹配特徵點
            matcher = FeatureMatcher()
            match_result = matcher.match_features(kp1, desc1, kp2, desc2)
            
            # 計算重複性
            repeatability = 0
            if ground_truth_homographies and i < len(ground_truth_homographies):
                homography = ground_truth_homographies[i]
                repeatability = Evaluator.calculate_repeatability(kp1, kp2, homography)
            
            # 記錄結果
            results['detection_times'].append(detection_time)
            results['keypoint_counts'].append((len(kp1), len(kp2)))
            results['repeatabilities'].append(repeatability)
            results['match_scores'].append(match_result.match_score)
        
        # 計算平均值
        summary = {
            'avg_detection_time': np.mean(results['detection_times']) if results['detection_times'] else 0,
            'avg_keypoint_count': np.mean([k[0] for k in results['keypoint_counts']]) if results['keypoint_counts'] else 0,
            'avg_repeatability': np.mean(results['repeatabilities']) if results['repeatabilities'] else 0,
            'avg_match_score': np.mean(results['match_scores']) if results['match_scores'] else 0
        }
        
        return summary

# ==================== 主系統類 ====================

class FeaturePointSystem:
    """特徵點系統主類"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.detectors = {}
        self.matchers = {}
        self.preprocessor = ImagePreprocessor()
        self.visualizer = Visualization()
        self.evaluator = Evaluator()
        
        # 初始化默認檢測器
        self._init_default_detectors()
        
        logger.info("特徵點系統初始化完成")
    
    def _init_default_detectors(self):
        """初始化默認檢測器"""
        # SIFT
        sift_config = {
            'nfeatures': self.config.get('sift_nfeatures', 0),
            'contrastThreshold': self.config.get('sift_contrastThreshold', 0.04)
        }
        self.detectors['sift'] = SIFTDetector(sift_config)
        
        # ORB
        orb_config = {
            'nfeatures': self.config.get('orb_nfeatures', 1000),
            'scaleFactor': self.config.get('orb_scaleFactor', 1.2)
        }
        self.detectors['orb'] = ORBDetector(orb_config)
        
        # AKAZE
        akaze_config = {
            'descriptor_type': self.config.get('akaze_descriptor_type', cv2.AKAZE_DESCRIPTOR_MLDB),
            'threshold': self.config.get('akaze_threshold', 0.001)
        }
        self.detectors['akaze'] = AKAZEDetector(akaze_config)
        
        # 嘗試初始化SURF
        try:
            surf_config = {
                'hessianThreshold': self.config.get('surf_hessianThreshold', 100),
                'nOctaves': self.config.get('surf_nOctaves', 4)
            }
            self.detectors['surf'] = SURFDetector(surf_config)
        except:
            logger.warning("SURF檢測器初始化失敗,可能缺少opencv-contrib-python")
        
        # 深度學習檢測器(如果可用)
        if TORCH_AVAILABLE:
            try:
                self.detectors['superpoint'] = SuperPoint()
                logger.info("SuperPoint檢測器初始化成功")
            except Exception as e:
                logger.warning(f"SuperPoint檢測器初始化失敗: {e}")
            
            try:
                self.detectors['d2net'] = D2Net()
                logger.info("D2-Net檢測器初始化成功")
            except Exception as e:
                logger.warning(f"D2-Net檢測器初始化失敗: {e}")
    
    def create_hybrid_detector(self, 
                              traditional_name: str = 'orb',
                              deeplearning_name: Optional[str] = None,
                              fusion_strategy: Union[str, FusionStrategy] = FusionStrategy.ADAPTIVE,
                              config: Dict[str, Any] = None) -> HybridFeatureDetector:
        """創建混合檢測器"""
        # 獲取傳統檢測器
        if traditional_name not in self.detectors:
            raise ValueError(f"未知的傳統檢測器: {traditional_name}")
        traditional_detector = self.detectors[traditional_name]
        
        # 獲取深度學習檢測器
        deeplearning_detector = None
        if deeplearning_name:
            if deeplearning_name not in self.detectors:
                raise ValueError(f"未知的深度學習檢測器: {deeplearning_name}")
            deeplearning_detector = self.detectors[deeplearning_name]
        
        # 轉換融合策略
        if isinstance(fusion_strategy, str):
            fusion_strategy = FusionStrategy(fusion_strategy)
        
        # 創建混合檢測器
        hybrid_config = config or {}
        hybrid_detector = HybridFeatureDetector(
            traditional_detector=traditional_detector,
            deeplearning_detector=deeplearning_detector,
            fusion_strategy=fusion_strategy,
            config=hybrid_config
        )
        
        # 註冊到檢測器列表
        hybrid_name = f"hybrid_{traditional_name}_{deeplearning_name or 'only'}_{fusion_strategy.value}"
        self.detectors[hybrid_name] = hybrid_detector
        
        return hybrid_detector
    
    def process_image_pair(self, 
                          img1_path: str, 
                          img2_path: str,
                          detector_name: str = 'sift',
                          preprocess: bool = True) -> Dict[str, Any]:
        """處理圖像對"""
        # 加載圖像
        img1 = self.preprocessor.load_image(img1_path, grayscale=False)
        img2 = self.preprocessor.load_image(img2_path, grayscale=False)
        
        # 預處理
        if preprocess:
            img1_gray = self.preprocessor.enhance_contrast(img1, method='clahe')
            img2_gray = self.preprocessor.enhance_contrast(img2, method='clahe')
        else:
            img1_gray = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY) if len(img1.shape) == 3 else img1
            img2_gray = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY) if len(img2.shape) == 3 else img2
        
        # 獲取檢測器
        if detector_name not in self.detectors:
            raise ValueError(f"未知的檢測器: {detector_name}")
        detector = self.detectors[detector_name]
        
        # 檢測特徵點
        kp1, desc1 = detector.detect_and_compute(img1_gray)
        kp2, desc2 = detector.detect_and_compute(img2_gray)
        
        # 匹配特徵點
        matcher = FeatureMatcher()
        match_result = matcher.match_features(kp1, desc1, kp2, desc2)
        
        # 準備結果
        result = {
            'images': {
                'img1': img1,
                'img2': img2,
                'img1_gray': img1_gray,
                'img2_gray': img2_gray
            },
            'keypoints': {
                'kp1': kp1,
                'kp2': kp2
            },
            'descriptors': {
                'desc1': desc1,
                'desc2': desc2
            },
            'match_result': match_result,
            'detector_info': detector.get_config_summary() if hasattr(detector, 'get_config_summary') else {},
            'detector_stats': detector.get_stats() if hasattr(detector, 'get_stats') else {}
        }
        
        return result
    
    def compare_detectors(self, 
                         img_paths: List[str],
                         detector_names: List[str],
                         save_results: bool = False,
                         output_dir: str = 'results') -> Dict[str, Dict[str, Any]]:
        """比較多個檢測器"""
        results = {}
        
        for detector_name in detector_names:
            if detector_name not in self.detectors:
                logger.warning(f"跳過未知檢測器: {detector_name}")
                continue
            
            logger.info(f"測試檢測器: {detector_name}")
            detector_results = []
            
            for i in range(len(img_paths) - 1):
                try:
                    # 處理圖像對
                    result = self.process_image_pair(
                        img_paths[i], img_paths[i+1], detector_name
                    )
                    
                    detector_results.append({
                        'image_pair': (img_paths[i], img_paths[i+1]),
                        'keypoint_counts': (len(result['keypoints']['kp1']), len(result['keypoints']['kp2'])),
                        'match_score': result['match_result'].match_score,
                        'detection_time': result['match_result'].processing_time
                    })
                    
                    # 可視化並保存
                    if save_results:
                        self._save_detector_results(result, detector_name, i, output_dir)
                        
                except Exception as e:
                    logger.error(f"處理圖像對失敗 ({img_paths[i]}, {img_paths[i+1]}): {e}")
            
            # 計算平均性能
            if detector_results:
                avg_keypoints = np.mean([r['keypoint_counts'][0] for r in detector_results])
                avg_match_score = np.mean([r['match_score'] for r in detector_results])
                avg_detection_time = np.mean([r['detection_time'] for r in detector_results])
                
                results[detector_name] = {
                    'avg_keypoint_count': avg_keypoints,
                    'avg_match_score': avg_match_score,
                    'avg_detection_time': avg_detection_time,
                    'num_tests': len(detector_results),
                    'detailed_results': detector_results
                }
        
        return results
    
    def _save_detector_results(self, result: Dict[str, Any], detector_name: str, 
                              pair_idx: int, output_dir: str):
        """保存檢測器結果"""
        import os
        os.makedirs(output_dir, exist_ok=True)
        
        # 保存匹配圖像
        match_img = self.visualizer.draw_matches(
            result['images']['img1'], result['keypoints']['kp1'],
            result['images']['img2'], result['keypoints']['kp2'],
            result['match_result'].matches,
            result['match_result'].inlier_mask
        )
        
        match_img_path = os.path.join(output_dir, f"{detector_name}_pair{pair_idx}_matches.jpg")
        cv2.imwrite(match_img_path, cv2.cvtColor(match_img, cv2.COLOR_RGB2BGR))
        
        # 保存統計信息
        stats = {
            'detector': detector_name,
            'image_pair': pair_idx,
            'keypoint_counts': {
                'img1': len(result['keypoints']['kp1']),
                'img2': len(result['keypoints']['kp2'])
            },
            'match_score': float(result['match_result'].match_score),
            'num_matches': len(result['match_result'].matches),
            'num_inliers': int(np.sum(result['match_result'].inlier_mask)) if result['match_result'].inlier_mask is not None else 0,
            'processing_time': float(result['match_result'].processing_time)
        }
        
        stats_path = os.path.join(output_dir, f"{detector_name}_pair{pair_idx}_stats.json")
        with open(stats_path, 'w') as f:
            json.dump(stats, f, indent=2)
    
    def create_performance_report(self, comparison_results: Dict[str, Dict[str, Any]]) -> str:
        """創建性能報告"""
        report_lines = []
        report_lines.append("=" * 60)
        report_lines.append("特徵點檢測器性能比較報告")
        report_lines.append("=" * 60)
        report_lines.append("")
        
        # 收集所有數據
        detectors = list(comparison_results.keys())
        
        # 表頭
        report_lines.append(f"{'檢測器':<20} {'關鍵點數量':<15} {'匹配分數':<15} {'檢測時間(秒)':<15}")
        report_lines.append("-" * 65)
        
        # 數據行
        for detector in detectors:
            result = comparison_results[detector]
            report_lines.append(
                f"{detector:<20} {result['avg_keypoint_count']:<15.1f} "
                f"{result['avg_match_score']:<15.3f} {result['avg_detection_time']:<15.3f}"
            )
        
        report_lines.append("")
        
        # 找出最佳檢測器
        best_by_keypoints = max(detectors, key=lambda d: comparison_results[d]['avg_keypoint_count'])
        best_by_score = max(detectors, key=lambda d: comparison_results[d]['avg_match_score'])
        best_by_speed = min(detectors, key=lambda d: comparison_results[d]['avg_detection_time'])
        
        report_lines.append("最佳檢測器:")
        report_lines.append(f"  關鍵點數量: {best_by_keypoints} ({comparison_results[best_by_keypoints]['avg_keypoint_count']:.1f})")
        report_lines.append(f"  匹配分數: {best_by_score} ({comparison_results[best_by_score]['avg_match_score']:.3f})")
        report_lines.append(f"  檢測速度: {best_by_speed} ({comparison_results[best_by_speed]['avg_detection_time']:.3f}秒)")
        report_lines.append("")
        report_lines.append("=" * 60)
        
        return "\n".join(report_lines)

# ==================== 應用示例模塊 ====================

class ApplicationExamples:
    """應用示例"""
    
    @staticmethod
    def demo_basic_feature_matching(system: FeaturePointSystem):
        """演示基礎特徵匹配"""
        print("=" * 60)
        print("基礎特徵匹配演示")
        print("=" * 60)
        
        # 創建示例圖像(如果不存在)
        img1 = np.random.randint(0, 255, (400, 600, 3), dtype=np.uint8)
        img2 = cv2.warpAffine(img1, np.float32([[0.9, 0.1, 50], [0.1, 0.9, 30]]), (600, 400))
        
        # 保存示例圖像
        cv2.imwrite('demo_img1.jpg', cv2.cvtColor(img1, cv2.COLOR_RGB2BGR))
        cv2.imwrite('demo_img2.jpg', cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
        
        # 測試不同檢測器
        detector_names = ['sift', 'orb', 'akaze']
        
        for detector_name in detector_names:
            print(f"\n測試檢測器: {detector_name}")
            
            try:
                result = system.process_image_pair('demo_img1.jpg', 'demo_img2.jpg', detector_name)
                
                print(f"  圖像1特徵點: {len(result['keypoints']['kp1'])}")
                print(f"  圖像2特徵點: {len(result['keypoints']['kp2'])}")
                print(f"  匹配數量: {len(result['match_result'].matches)}")
                print(f"  匹配分數: {result['match_result'].match_score:.3f}")
                
                # 可視化
                fig = system.visualizer.plot_matching_results(
                    result['images']['img1'], result['keypoints']['kp1'],
                    result['images']['img2'], result['keypoints']['kp2'],
                    result['match_result']
                )
                plt.savefig(f'demo_{detector_name}_matches.png', dpi=150, bbox_inches='tight')
                plt.close(fig)
                print(f"  結果已保存到: demo_{detector_name}_matches.png")
                
            except Exception as e:
                print(f"  錯誤: {e}")
        
        print("\n演示完成!")
    
    @staticmethod
    def demo_hybrid_detector(system: FeaturePointSystem):
        """演示混合檢測器"""
        print("=" * 60)
        print("混合檢測器演示")
        print("=" * 60)
        
        # 創建混合檢測器
        hybrid_detector = system.create_hybrid_detector(
            traditional_name='orb',
            deeplearning_name='superpoint' if 'superpoint' in system.detectors else None,
            fusion_strategy='adaptive'
        )
        
        # 創建示例圖像
        img = np.random.randint(0, 255, (400, 600, 3), dtype=np.uint8)
        cv2.imwrite('demo_hybrid.jpg', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        
        # 測試混合檢測器
        print("\n測試混合檢測器...")
        
        try:
            kp, desc = hybrid_detector.detect_and_compute(img)
            stats = hybrid_detector.get_stats()
            
            print(f"  檢測到特徵點: {len(kp)}")
            print(f"  傳統方法特徵點: {stats['keypoint_counts']['traditional']}")
            print(f"  深度學習方法特徵點: {stats['keypoint_counts']['deeplearning']}")
            print(f"  融合策略: {stats['fusion_strategy']}")
            print(f"  總檢測時間: {stats['detection_times']['total']:.3f}秒")
            
            # 可視化特徵點
            vis_img = system.visualizer.draw_keypoints(img, kp)
            cv2.imwrite('demo_hybrid_keypoints.jpg', cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR))
            print("  特徵點可視化已保存到: demo_hybrid_keypoints.jpg")
            
        except Exception as e:
            print(f"  錯誤: {e}")
        
        print("\n演示完成!")
    
    @staticmethod
    def demo_performance_comparison(system: FeaturePointSystem):
        """演示性能比較"""
        print("=" * 60)
        print("性能比較演示")
        print("=" * 60)
        
        # 創建測試圖像集
        images = []
        for i in range(5):
            img = np.random.randint(0, 255, (300 + i*50, 400 + i*50, 3), dtype=np.uint8)
            # 添加一些變換
            if i > 0:
                angle = i * 5
                M = cv2.getRotationMatrix2D((img.shape[1]//2, img.shape[0]//2), angle, 1.0)
                img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
            images.append(img)
            cv2.imwrite(f'test_img_{i}.jpg', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        
        # 測試不同檢測器
        detector_names = ['sift', 'orb', 'akaze']
        
        # 如果有深度學習檢測器,也測試
        if 'superpoint' in system.detectors:
            detector_names.append('superpoint')
        
        print("\n比較檢測器性能...")
        
        # 運行比較
        comparison_results = system.compare_detectors(
            img_paths=[f'test_img_{i}.jpg' for i in range(5)],
            detector_names=detector_names,
            save_results=True,
            output_dir='comparison_results'
        )
        
        # 生成報告
        report = system.create_performance_report(comparison_results)
        print(report)
        
        # 保存報告
        with open('performance_report.txt', 'w') as f:
            f.write(report)
        
        # 可視化比較結果
        fig = system.visualizer.plot_performance_comparison(comparison_results)
        plt.savefig('performance_comparison.png', dpi=150, bbox_inches='tight')
        plt.close(fig)
        
        print("性能比較完成!結果已保存到文件。")

# ==================== 主函數 ====================

def main():
    """主函數"""
    print("特徵點檢測系統 - 傳統與深度學習融合")
    print("版本: 2.0")
    print("=" * 60)
    
    # 初始化系統
    config = {
        'sift_nfeatures': 0,
        'orb_nfeatures': 1000,
        'akaze_threshold': 0.001
    }
    
    system = FeaturePointSystem(config)
    
    # 運行演示
    print("\n1. 基礎特徵匹配演示")
    ApplicationExamples.demo_basic_feature_matching(system)
    
    print("\n2. 混合檢測器演示")
    ApplicationExamples.demo_hybrid_detector(system)
    
    print("\n3. 性能比較演示")
    ApplicationExamples.demo_performance_comparison(system)
    
    print("\n" + "=" * 60)
    print("所有演示完成!")
    print("生成的結果文件:")
    print("  - demo_*.jpg/png: 演示結果圖像")
    print("  - test_img_*.jpg: 測試圖像")
    print("  - comparison_results/: 比較結果文件夾")
    print("  - performance_report.txt: 性能報告")
    print("  - performance_comparison.png: 性能比較圖")
    print("=" * 60)

if __name__ == "__main__":
    main()

系統說明

主要特點

  1. 模塊化設計:系統分為多個模塊,每個模塊負責特定功能

  2. 多種特徵點算法:支持SIFT、SURF、ORB、AKAZE等傳統算法

  3. 深度學習支持:集成SuperPoint和D2-Net深度學習方法

  4. 智能融合:提供多種融合策略,可自適應選擇最佳方法

  5. 完整評估系統:包含性能評估、可視化和比較功能

  6. 易用性:提供簡單的API和演示示例

使用示例

基本使用

python

# 初始化系統
system = FeaturePointSystem()

# 處理圖像對
result = system.process_image_pair('img1.jpg', 'img2.jpg', detector_name='sift')

# 獲取結果
keypoints1 = result['keypoints']['kp1']
keypoints2 = result['keypoints']['kp2']
matches = result['match_result'].matches

# 可視化
fig = system.visualizer.plot_matching_results(
    result['images']['img1'], keypoints1,
    result['images']['img2'], keypoints2,
    result['match_result']
)
plt.show()
創建混合檢測器

python

# 創建ORB+SuperPoint混合檢測器
hybrid_detector = system.create_hybrid_detector(
    traditional_name='orb',
    deeplearning_name='superpoint',
    fusion_strategy='adaptive'
)

# 使用混合檢測器
keypoints, descriptors = hybrid_detector.detect_and_compute(image)
性能比較

python

# 比較多個檢測器
results = system.compare_detectors(
    img_paths=['img1.jpg', 'img2.jpg', 'img3.jpg'],
    detector_names=['sift', 'orb', 'akaze', 'hybrid_orb_superpoint_adaptive']
)

# 生成報告
report = system.create_performance_report(results)
print(report)

輸出文件

系統運行後會生成以下文件:

  1. 演示圖像和結果

  2. 性能比較圖表

  3. 詳細的性能報告

  4. 匹配結果可視化

系統要求

  1. Python 3.7+

  2. OpenCV 4.0+

  3. NumPy

  4. Matplotlib

  5. SciPy

  6. (可選) PyTorch 1.7+(用於深度學習功能)

擴展性

系統設計易於擴展:

  1. 添加新的傳統檢測器:繼承TraditionalFeatureDetector

  2. 添加新的深度學習檢測器:繼承BaseFeatureNetwork

  3. 添加新的融合策略:擴展FusionStrategy枚舉和HybridFeatureDetector

  4. 添加新的評估指標:擴展Evaluator

這個完整系統提供了從傳統特徵點到深度學習關鍵點的完整解決方案,並通過智能融合策略結合了兩者的優勢,適用於各種計算機視覺應用。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐