替换原始sam3d的sam为fastsam即为fastsam3d,通过onnxruntime或tensorrt框架部署后,推理速度能够大幅提升。

python

import os
import cv2
import numpy as np
import open3d as o3d
import torch
import json
import pointops
import argparse
import time
from os.path import join
from util import *

import onnxruntime
from fastsam import *
onnx_session = onnxruntime.InferenceSession("FastSAM-s.onnx", providers=['CUDAExecutionProvider'])


def get_seg(image):
    input = letterbox(image, input_shape)
    input = input[:, :, ::-1].transpose(2, 0, 1)[np.newaxis, :]  
    input = input / 255.0

    inputs = {}
    inputs['images'] = input.astype(np.float32)
    outputs = onnx_session.run(None, inputs)

    output = np.squeeze(outputs[0]).astype(dtype=np.float32)      
    mask = output[..., 4] > score_threshold 
    scores = output[mask, 4]
    output[..., :4] = xywh2xyxy(output[..., :4])
    box = output[mask, :4]
    boxes = np.concatenate([box, scores[..., np.newaxis]], axis=1)
    preds = output[mask, :]

    indices = nms(boxes, scores, score_threshold, nms_threshold)
    boxes = boxes[indices] 

    masks_in = np.array(preds)[indices][..., -32:]
    proto = np.squeeze(outputs[1]).astype(dtype=np.float32)
    c, mh, mw = proto.shape 
    masks = (masks_in @ proto.reshape(c, -1)).reshape(-1, mh, mw)    
    downsampled_bboxes = boxes.copy()
    downsampled_bboxes[:, 0] *= mw / input_shape[0]
    downsampled_bboxes[:, 2] *= mw / input_shape[0]
    downsampled_bboxes[:, 3] *= mh / input_shape[1]
    downsampled_bboxes[:, 1] *= mh / input_shape[1]
    masks = crop_mask(masks, downsampled_bboxes)
    resized_masks = []
    for mask in masks:
        mask = cv2.resize(mask, input_shape, cv2.INTER_LINEAR)
        mask = scale_mask(mask, input_shape, image.shape)
        resized_masks.append(mask)
    resized_masks = np.array(resized_masks)
    resized_masks = resized_masks > 0      

    masks = []
    for i in range(resized_masks.shape[0]):
        masks.append({"segmentation":resized_masks[i, :, :]})
    group_ids = np.full((image.shape[0], image.shape[1]), -1, dtype=int)
    group_counter = 0
    for i in range(len(masks)):
        group_ids[masks[i]["segmentation"]] = group_counter
        group_counter += 1
    return group_ids    #(480, 640)


def get_pcd(color_name, rgb_path, save_2dmask_path):
    intrinsic_path = join(rgb_path, 'intrinsics', 'intrinsic_depth.txt')
    depth_intrinsic = np.loadtxt(intrinsic_path)

    pose = join(rgb_path, 'pose', color_name[0:-4] + '.txt')
    depth = join(rgb_path, 'depth', color_name[0:-4] + '.png')
    color = join(rgb_path, 'color', color_name)

    depth_img = cv2.imread(depth, -1) # read 16bit grayscale image
    mask = (depth_img != 0)
    color_image = cv2.imread(color)
    color_image = cv2.resize(color_image, (640, 480))

    group_ids = get_seg(color_image)

    group_ids = group_ids[mask] #(480, 640)->276497

    color_image = np.reshape(color_image[mask], [-1,3])
    colors = np.zeros_like(color_image)
    colors[:,0] = color_image[:,2]
    colors[:,1] = color_image[:,1]
    colors[:,2] = color_image[:,0]

    pose = np.loadtxt(pose)
    x,y = np.meshgrid(np.linspace(0,depth_img.shape[1]-1,depth_img.shape[1]), np.linspace(0,depth_img.shape[0]-1,depth_img.shape[0]))
    uv_depth = np.zeros((depth_img.shape[0], depth_img.shape[1], 3))
    uv_depth[:,:,0] = x
    uv_depth[:,:,1] = y
    uv_depth[:,:,2] = depth_img / 1000.0
    uv_depth = np.reshape(uv_depth, [-1,3])
    uv_depth = uv_depth[np.where(uv_depth[:,2]!=0),:].squeeze()
    
    fx = depth_intrinsic[0,0]
    fy = depth_intrinsic[1,1]
    cx = depth_intrinsic[0,2]
    cy = depth_intrinsic[1,2]
    bx = depth_intrinsic[0,3]
    by = depth_intrinsic[1,3]
    n = uv_depth.shape[0]
    points = np.ones((n,4))
    points[:,0] = (uv_depth[:,0]-cx)*uv_depth[:,2]/fx + bx
    points[:,1] = (uv_depth[:,1]-cy)*uv_depth[:,2]/fy + by
    points[:,2] = uv_depth[:,2]
    points_world = np.dot(points, np.transpose(pose))
    group_ids = num_to_natural(group_ids)
    save_dict = dict(coord=points_world[:,:3], color=colors, group=group_ids)
    return save_dict


def make_open3d_point_cloud(input_dict, th):
    input_dict["group"] = remove_small_group(input_dict["group"], th)
    xyz = input_dict["coord"]
    if np.isnan(xyz).any():
        return None
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz)
    return pcd


def cal_group(input_dict, new_input_dict, match_inds, ratio=0.5):
    group_0 = input_dict["group"]
    group_1 = new_input_dict["group"]
    group_1[group_1 != -1] += group_0.max() + 1
    
    unique_groups, group_0_counts = np.unique(group_0, return_counts=True)
    group_0_counts = dict(zip(unique_groups, group_0_counts))
    unique_groups, group_1_counts = np.unique(group_1, return_counts=True)
    group_1_counts = dict(zip(unique_groups, group_1_counts))

    # Calculate the group number correspondence of overlapping points
    group_overlap = {}
    for i, j in match_inds:
        group_i = group_1[i]
        group_j = group_0[j]
        if group_i == -1:
            group_1[i] = group_0[j]
            continue
        if group_j == -1:
            continue
        if group_i not in group_overlap:
            group_overlap[group_i] = {}
        if group_j not in group_overlap[group_i]:
            group_overlap[group_i][group_j] = 0
        group_overlap[group_i][group_j] += 1

    # Update group information for point cloud 1
    for group_i, overlap_count in group_overlap.items():
        max_index = np.argmax(np.array(list(overlap_count.values())))
        group_j = list(overlap_count.keys())[max_index]
        count = list(overlap_count.values())[max_index]
        total_count = min(group_0_counts[group_j], group_1_counts[group_i]).astype(np.float32)
        if count / total_count >= ratio:
            group_1[group_1 == group_i] = group_j
    return group_1


def cal_2_scenes(pcd_list, index, voxel_size, voxelize, th=50):
    if len(index) == 1:
        return(pcd_list[index[0]])
    input_dict_0 = pcd_list[index[0]]
    input_dict_1 = pcd_list[index[1]]
    pcd0 = make_open3d_point_cloud(input_dict_0, th)
    pcd1 = make_open3d_point_cloud(input_dict_1, th)

    # Cal Dul-overlap
    match_inds = get_matching_indices(pcd1, pcd0, 1.5 * voxel_size, 1)
    pcd1_new_group = cal_group(input_dict_0, input_dict_1, match_inds)

    match_inds = get_matching_indices(pcd0, pcd1, 1.5 * voxel_size, 1)
    input_dict_1["group"] = pcd1_new_group
    pcd0_new_group = cal_group(input_dict_1, input_dict_0, match_inds)

    pcd_new_group = np.concatenate((pcd0_new_group, pcd1_new_group), axis=0)
    pcd_new_group = num_to_natural(pcd_new_group)
    pcd_new_coord = np.concatenate((input_dict_0["coord"], input_dict_1["coord"]), axis=0)
    pcd_new_color = np.concatenate((input_dict_0["color"], input_dict_1["color"]), axis=0)
    pcd_dict = dict(coord=pcd_new_coord, color=pcd_new_color, group=pcd_new_group)

    pcd_dict = voxelize(pcd_dict)
    return pcd_dict


def seg_pcd(scene_name, rgb_path, data_path, voxel_size, voxelize, th, save_2dmask_path):
    pcd_list = []
    color_names = sorted(os.listdir(join(scene_name, 'color')), key=lambda a: int(os.path.basename(a).split('.')[0]))
    for color_name in color_names:
        start = time.time()
        pcd_dict = get_pcd(color_name, rgb_path, save_2dmask_path)
        if len(pcd_dict["coord"]) == 0:
            continue
        pcd_dict = voxelize(pcd_dict)
        pcd_list.append(pcd_dict)
        print(color_name, time.time()-start, flush=True)
    
    while len(pcd_list) !=1:
        print(len(pcd_list), flush=True)
        new_pcd_list = []
        for indice in pairwise_indices(len(pcd_list)):
            pcd_frame = cal_2_scenes(pcd_list, indice, voxel_size=voxel_size, voxelize=voxelize)
            if pcd_frame is not None:
                new_pcd_list.append(pcd_frame)
        pcd_list = new_pcd_list
    seg_dict = pcd_list[0]
    seg_dict["group"] = num_to_natural(remove_small_group(seg_dict["group"], th))

    data_dict = torch.load(data_path)
    # np.savetxt("coord.txt", data_dict["coord"], fmt="%.6f")
    scene_coord = torch.tensor(data_dict["coord"]).cuda().contiguous()
    new_offset = torch.tensor(scene_coord.shape[0]).cuda()
    gen_coord = torch.tensor(seg_dict["coord"]).cuda().contiguous().float()
    offset = torch.tensor(gen_coord.shape[0]).cuda()
    gen_group = seg_dict["group"]
    indices, dis = pointops.knn_query(1, gen_coord, offset, scene_coord, new_offset)
    indices = indices.cpu().numpy()
    group = gen_group[indices.reshape(-1)].astype(np.int16)
    mask_dis = dis.reshape(-1).cpu().numpy() > 0.6
    group[mask_dis] = -1
    np.savetxt(os.path.basename(scene_name) + ".txt", num_to_natural(group), fmt="%d")


def pcd_ensemble(org_path, new_path, data_path, vis_path):
    new_pcd = np.loadtxt(new_path).astype(np.int16)
    new_pcd = num_to_natural(remove_small_group(new_pcd, 20))
    with open(org_path) as f:
        segments = json.load(f)
        org_pcd = np.array(segments['segIndices'])
    match_inds = [(i, i) for i in range(len(new_pcd))]
    new_group = cal_group(dict(group=new_pcd), dict(group=org_pcd), match_inds)
    print(new_group.shape)
    data = torch.load(data_path)
    visualize_partition(data["coord"], new_group, vis_path)


def get_args():
    '''Command line arguments.'''
    parser = argparse.ArgumentParser(description='Segment Anything on ScanNet.')
    parser.add_argument('--rgb_path', type=str, default='/data1/tfy/scannet_images/scene0000_00', help='the path of rgb data')
    parser.add_argument('--data_path', type=str, default='/data1/tfy/scannet_processed/train/scene0000_00.pth', help='the path of pointcload data')
    parser.add_argument('--save_2dmask_path', type=str, default='./', help='Where to save 2D segmentation result from SAM')
    parser.add_argument('--sam_checkpoint_path', type=str, default='/data1/tfy/scannet/sam_vit_h_4b8939.pth', help='the path of checkpoint for SAM')
    parser.add_argument('--img_size', default=[640, 480])
    parser.add_argument('--voxel_size', default=0.05)
    parser.add_argument('--th', default=50, help='threshold of ignoring small groups to avoid noise pixel')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    print(args)
    voxelize = Voxelize(voxel_size=args.voxel_size, mode="train", keys=("coord", "color", "group"))
    seg_pcd(args.rgb_path, args.rgb_path, args.data_path, args.voxel_size, voxelize, args.th, args.save_2dmask_path)
     
    pcd_ensemble(org_path="/data1/tfy/scannet/scans/scene0000_00/scene0000_00_vh_clean_2.0.010000.segs.json", 
                new_path="build/scene0000_00.txt", 
                data_path="/data1/tfy/scannet_processed/train/scene0000_00.pth",
                vis_path="pred/cloud.ply")

cpp

#include "util.h"
#include "logger.h"
#include "fastsam.h"


int main(int argc, char** argv)
{
    TRTLogger logger;
	nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
    auto engine_data = load_file("FastSAM-s.engine");
	nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());
	if (engine == nullptr)
	{
		printf("Deserialize cuda engine failed.\n");
		return -1;
	}

	nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();

    float* output0_h = nullptr;
    cudaMallocHost(&output0_h, sizeof(float) * OUTPUT0_SIZE);
    float* output1_h = nullptr;
    cudaMallocHost(&output1_h, sizeof(float) * OUTPUT1_SIZE);

    uint8_t* input_d = nullptr;
    cudaMalloc(&input_d, sizeof(uint8_t) * 3 * 640 * 480);
    float* input0_d = nullptr;
    cudaMalloc(&input0_d, sizeof(float) * INPUT_SIZE);
    float* output0_d = nullptr;
    cudaMalloc(&output0_d, sizeof(float) * OUTPUT0_SIZE);
    float* output1_d = nullptr;
    cudaMalloc(&output1_d, sizeof(float) * OUTPUT1_SIZE);

    cv::Mat depth_intrinsic = loadPoseFromTxt("/data1/tfy/scannet_images/scene0000_00/intrinsics/intrinsic_depth.txt");
    float fx = depth_intrinsic.at<float>(0, 0);
    float fy = depth_intrinsic.at<float>(1, 1);
    float cx = depth_intrinsic.at<float>(0, 2);
    float cy = depth_intrinsic.at<float>(1, 2);
    float bx = depth_intrinsic.at<float>(0, 3);
    float by = depth_intrinsic.at<float>(1, 3);

    std::vector<pcd_tuple> pcd_list(279);
    auto start = std::chrono::steady_clock::now();
    for(int k=0; k<pcd_list.size();k++)
    {
        //auto start = std::chrono::steady_clock::now();
        cv::Mat color_image = cv::imread("/data1/tfy/scannet_images/scene0000_00/color/" + std::to_string(k*20) + ".jpg");
        cv::Mat depth_image = cv::imread("/data1/tfy/scannet_images/scene0000_00/depth/" + std::to_string(k*20) + ".png", cv::IMREAD_UNCHANGED);
        cv::Mat pose = loadPoseFromTxt("/data1/tfy/scannet_images/scene0000_00/pose/" + std::to_string(k*20) + ".txt");

        cv::Mat mask;
        cv::compare(depth_image, 0, mask, cv::CMP_NE); 

        cv::Mat group_id(color_image.rows, color_image.cols, CV_32S, cv::Scalar(-1));
        cudaMemcpy(input_d, color_image.data, sizeof(uint8_t) * 3 * 640 * 480, cudaMemcpyHostToDevice);
        get_seg(color_image, group_id, execution_context, input_d, input0_d, output0_d, output1_d);

        std::vector<int> group_ids;
        int* group_ptr = (int*)group_id.data;
        uchar* ptr = mask.data;
        for (int y = 0; y < mask.rows; ++y)
         {
            for (int x = 0; x < mask.cols; ++x) 
            {
                int id = (mask.cols * y + x);
                if (int(ptr[id]) != 0)
                    group_ids.push_back(group_ptr[id]);
            }
        }
           
        std::vector<cv::Vec3b> colors;
        uchar* p = color_image.data;
        ptr = mask.data;
        for (int y = 0; y < mask.rows; ++y)
         {
            for (int x = 0; x < mask.cols; ++x) 
            {
                int id = (mask.cols * y + x);
                if (int(ptr[id]) != 0)
                { 
                    uchar b = p[id * 3];
                    uchar g = p[id * 3 + 1];
                    uchar r = p[id * 3 + 2];
                    colors.push_back(cv::Vec3b(r, g, b)); 
                }
            }
        }

        cv::Mat uv_depth(depth_image.rows, depth_image.cols, CV_32FC3);
        ushort *pu = (ushort*)depth_image.data;
        float *pf = (float*)uv_depth.data;
        for (int y = 0; y < depth_image.rows; ++y)
         {
            for (int x = 0; x < depth_image.cols; ++x) 
            {
                int id = depth_image.cols * y + x;
                pf[id * 3] = x;
                pf[id * 3 + 1] = y;
                pf[id * 3 + 2] = (float) pu[id] / 1000.0f;
            }
        }

        std::vector<cv::Vec3f> valid_uv_depth;
        pf = (float*)uv_depth.data;
        for (int y = 0; y < depth_image.rows; ++y)
         {
            for (int x = 0; x < depth_image.cols; ++x) 
            {
                int id = (depth_image.cols * y + x) * 3;
                float u = pf[id];
                float v = pf[id + 1];
                float d = pf[id + 2];
                if (d != 0) 
                    valid_uv_depth.push_back(cv::Vec3f(u, v, d));
            }
        }

        cv::Mat points = cv::Mat::ones(valid_uv_depth.size(), 4, CV_32F); 
        pf = (float*)points.data;
        for (int i = 0; i < valid_uv_depth.size(); ++i) 
        {
            float u = valid_uv_depth[i][0];
            float v = valid_uv_depth[i][1];
            float z = valid_uv_depth[i][2];
            
            pf[i * 4 + 0] = (u - cx) * z / fx + bx;
            pf[i * 4 + 1] = (v - cy) * z / fy + by;
            pf[i * 4 + 2] = z;
        }

        cv::Mat points_world = points * pose.t(); 
        cv::Mat coord = points_world.colRange(0, 3);
        
        std::vector<cv::Point3f> coords(coord.rows);
        for (int i = 0; i < coord.rows; ++i) 
        {
            const float* row = coord.ptr<float>(i);
            coords[i] = cv::Point3f(row[0], row[1], row[2]);
        }

        std::vector<int> group_ids_natural = num_to_natural(group_ids);

        pcd_tuple save_dict;
        save_dict.coord = coords;
        save_dict.color = colors;
        save_dict.group = group_ids_natural;     
        pcd_list[k] = voxelize(save_dict);

        // auto end = std::chrono::steady_clock::now();
        // std::chrono::duration<double> duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);		
        // std::cout <<k <<" "<< duration.count() << "s" << std::endl;
    }
    auto end = std::chrono::steady_clock::now();
    std::chrono::duration<double> duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);		
    std::cout <<" "<< duration.count() << "s" << std::endl;

    while(pcd_list.size() !=1)
    {
        std::cout << "pcd_list.size() = " << pcd_list.size() << std::endl;
        std::vector<pcd_tuple> new_pcd_list;
        for(auto indice : pairwise_indices(pcd_list.size()))
        {
            pcd_tuple pcd_frame = cal_2_scenes(pcd_list, indice, 0.05, 50);
            new_pcd_list.push_back(pcd_frame);
        }
        pcd_list = new_pcd_list;
    }

    pcd_tuple seg_dict = pcd_list[0];

    std::vector<cv::Point3f> pcd = seg_dict.coord;
    seg_dict.group = num_to_natural(remove_small_group(seg_dict.group, 50));

    auto scene_coord = readPointsFromFile("../coord.txt");;
    auto gen_coord = seg_dict.coord;
    auto gen_group = seg_dict.group;
    int *indices; float *dis;
    cudaMallocHost(&indices, scene_coord.size() * sizeof(int));
    cudaMallocHost(&dis, scene_coord.size() * sizeof(float));
    knn_query(scene_coord, gen_coord, indices, dis, 0.05f, 1);

    std::vector<int> group(scene_coord.size());
    for(int i=0; i<scene_coord.size(); i++) 
    {
         group[i] = gen_group[indices[i]];
        if(sqrt(dis[i]) > 0.6)
            group[i] = -1;
    }

    std::vector<int> group_natural = num_to_natural(group);
    std::ofstream file_out("scene0000_00.txt");
    for( int i = 0; i < group_natural.size(); i++)
        file_out << group_natural[i] << std::endl;

    cudaFree(input_d);
    cudaFree(input0_d);
    cudaFree(output0_d);
    cudaFree(output1_d);
    cudaFreeHost(indices);
    cudaFreeHost(dis);

	return 0;
}

分割效果图:
在这里插入图片描述

完整代码工程见:https://github.com/taifyang/fastsam3d

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐