原始工程

https://github.com/NVIDIA-AI-IOT/nanosam
NanoSAM is a Segment Anything (SAM) model variant that is capable of running in 🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with NVIDIA TensorRT.
在这里插入图片描述

Unity中实现

using UnityEngine;

/// <summary>
/// 将Unity Texture2D编码为SAM所需的image embedding格式
/// 真实项目建议用image_encoder.onnx替代本脚本
/// </summary>
public static class ImageEncoder
{
    private const int SAM_SIZE = 1024;

    // ImageNet均值/标准差
    private static readonly float[] MEAN = { 0.485f, 0.456f, 0.406f };
    private static readonly float[] STD = { 0.229f, 0.224f, 0.225f };

    /// <summary>
    /// 将Texture2D预处理为SAM输入格式 [1,3,1024,1024] CHW,RGB
    /// </summary>
    public static float[] PreprocessImage(Texture2D tex)
    {
        // 缩放至1024x1024
        var rt = RenderTexture.GetTemporary(SAM_SIZE, SAM_SIZE, 0,
                     RenderTextureFormat.ARGB32);
        Graphics.Blit(tex, rt);

        var resized = new Texture2D(SAM_SIZE, SAM_SIZE, TextureFormat.RGB24, false);
        var prev = RenderTexture.active;
        RenderTexture.active = rt;
        resized.ReadPixels(new Rect(0, 0, SAM_SIZE, SAM_SIZE), 0, 0);
        resized.Apply();
        RenderTexture.active = prev;
        RenderTexture.ReleaseTemporary(rt);

        Color[] pixels = resized.GetPixels();
        float[] data = new float[3 * SAM_SIZE * SAM_SIZE];

        for (int y = 0; y < SAM_SIZE; y++)
        {
            for (int x = 0; x < SAM_SIZE; x++)
            {
                // Unity纹理Y轴翻转
                Color c = pixels[(SAM_SIZE - 1 - y) * SAM_SIZE + x];

                int idx = y * SAM_SIZE + x;
                data[0 * SAM_SIZE * SAM_SIZE + idx] = (c.r - MEAN[0]) / STD[0];
                data[1 * SAM_SIZE * SAM_SIZE + idx] = (c.g - MEAN[1]) / STD[1];
                data[2 * SAM_SIZE * SAM_SIZE + idx] = (c.b - MEAN[2]) / STD[2];
            }
        }

        UnityEngine.Object.Destroy(resized);
        return data;
    }

    /// <summary>
    /// 如果你有image_encoder.onnx,用此方法运行编码器得到embedding
    /// </summary>
    public static float[] RunEncoderOnnx(
        Microsoft.ML.OnnxRuntime.InferenceSession encoderSession,
        float[] preprocessedImage)
    {
        var tensor = new Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<float>(
            preprocessedImage, new[] { 1, 3, SAM_SIZE, SAM_SIZE });

        var inputs = new[]
        {
            Microsoft.ML.OnnxRuntime.NamedOnnxValue
                .CreateFromTensor("image", tensor)
        };

        using var results = encoderSession.Run(inputs);
        foreach (var r in results)
        {
            if (r.Name.Contains("embed") || r.Name.Contains("feature"))
            {
                return (r.Value as
                    Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<float>)
                    .Buffer.ToArray();
            }
        }
        throw new System.Exception("未找到encoder输出");
    }
}
using System;
using System.Collections.Generic;
using UnityEngine;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;

/// <summary>
/// NanoSAM / MobileSAM Mask Decoder 推理封装
/// 输入: image embedding + 提示点 → 输出: 分割Mask
/// </summary>
public class NanoSAM : IDisposable
{
    // ─── 模型常量(与MobileSAM保持一致)───
    private const int IMAGE_SIZE = 1024;  // SAM原始图像尺寸
    private const int EMBED_DIM = 256;   // 图像embedding通道数
    private const int EMBED_H = 64;    // embedding空间高度 (1024/16)
    private const int EMBED_W = 64;    // embedding空间宽度
    private const int MASK_INPUT_SIZE = 256;   // low-res mask输入尺寸
    private const int NUM_MASK_TOKENS = 4;     // SAM mask token数量

    private InferenceSession _decoderSession;
    private bool _disposed;

    // ─── 构造 ───────────────────────────────────────────
    public NanoSAM(string decoderModelPath)
    {
        var options = new SessionOptions();
        //CPU 
        options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;

        // DML
        //options.AppendExecutionProvider_DML(0);
        _decoderSession = new InferenceSession(decoderModelPath, options);

        Debug.Log("[NanoSAM] Decoder loaded.");
        LogSessionInfo(_decoderSession, "Decoder");
    }

    // ─── 主推理接口 ──────────────────────────────────────
    /// <summary>
    /// 运行Mask Decoder推理
    /// </summary>
    /// <param name="imageEmbedding">图像embedding float数组 [1,256,64,64]</param>
    /// <param name="promptPoints">归一化坐标点列表 (x,y) ∈ [0,1]</param>
    /// <param name="promptLabels">点标签: 1=前景, 0=背景</param>
    /// <param name="outputWidth">输出mask目标宽度</param>
    /// <param name="outputHeight">输出mask目标高度</param>
    /// <returns>二值化Mask纹理</returns>
    public Texture2D RunDecoder(
        float[] imageEmbedding,
        List<Vector2> promptPoints,
        List<int> promptLabels,
        int outputWidth,
        int outputHeight)
    {
        if (promptPoints.Count == 0)
            throw new ArgumentException("至少需要一个提示点");

        // —— 构建输入Tensors ——
        var inputs = BuildDecoderInputs(imageEmbedding, promptPoints, promptLabels);

        // —— 运行推理 ——
        using var results = _decoderSession.Run(inputs);

        // —— 解析输出 ——
        // MobileSAM decoder输出: masks [1,4,256,256], iou_predictions [1,4]
        float[] masks = null;
        float[] iouPreds = null;
        int[] masksShape = null;

        foreach (var r in results)
        {
            if (r.Name.Contains("mask") && !r.Name.Contains("iou"))
            {
                var t = r.Value as DenseTensor<float>;
                masks = t.Buffer.ToArray();
                masksShape = new int[] {
                    (int)t.Dimensions[0],
                    (int)t.Dimensions[1],
                    (int)t.Dimensions[2],
                    (int)t.Dimensions[3]
                };
            }
            else if (r.Name.Contains("iou"))
            {
                iouPreds = (r.Value as DenseTensor<float>).Buffer.ToArray();
            }
        }

        if (masks == null)
            throw new Exception("未找到masks输出,请检查模型输出节点名称");

        // —— 选择最高IOU的Mask ——
        int bestIdx = SelectBestMask(iouPreds);
        Debug.Log($"[NanoSAM] Best mask index: {bestIdx}, IOU: {iouPreds[bestIdx]:F3}");

        // —— 提取并上采样Mask ——
        int mH = masksShape[2];
        int mW = masksShape[3];
        float[] bestMask = ExtractMask(masks, bestIdx, mH, mW);

        return BuildMaskTexture(bestMask, mW, mH, outputWidth, outputHeight);
    }

    // ─── 构建Decoder输入 ────────────────────────────────
    private List<NamedOnnxValue> BuildDecoderInputs(
        float[] imageEmbedding,
        List<Vector2> points,
        List<int> labels)
    {
        int numPoints = points.Count;

        // MobileSAM decoder期望的输入节点(请根据实际模型调整节点名):
        // 1. image_embeddings     [1, 256, 64, 64]
        // 2. point_coords         [1, N+1, 2]    (多一个padding点)
        // 3. point_labels         [1, N+1]       (padding标签=-1)
        // 4. mask_input           [1, 1, 256, 256]
        // 5. has_mask_input       [1]

        // —— 1. image_embeddings ——
        var embTensor = new DenseTensor<float>(
            imageEmbedding,
            new[] { 1, EMBED_DIM, EMBED_H, EMBED_W });

        // —— 2. point_coords: SAM坐标系(原图像素坐标,非归一化) ——
        int totalPts = numPoints + 1; // +1 padding
        float[] coordData = new float[1 * totalPts * 2];
        for (int i = 0; i < numPoints; i++)
        {
            // 将归一化坐标转为SAM图像坐标 [0, IMAGE_SIZE]
            coordData[i * 2] = points[i].x * IMAGE_SIZE;
            coordData[i * 2 + 1] = points[i].y * IMAGE_SIZE;
        }
        // padding点放在左上角
        coordData[numPoints * 2] = 0f;
        coordData[numPoints * 2 + 1] = 0f;

        var coordTensor = new DenseTensor<float>(
            coordData,
            new[] { 1, totalPts, 2 });

        // —— 3. point_labels ——
        float[] labelData = new float[1 * totalPts];
        for (int i = 0; i < numPoints; i++)
            labelData[i] = labels[i];
        labelData[numPoints] = -1f; // padding

        var labelTensor = new DenseTensor<float>(
            labelData,
            new[] { 1, totalPts });

        // —— 4. mask_input (全零,表示无先验mask) ——
        float[] maskInputData = new float[1 * 1 * MASK_INPUT_SIZE * MASK_INPUT_SIZE];
        var maskInputTensor = new DenseTensor<float>(
            maskInputData,
            new[] { 1, 1, MASK_INPUT_SIZE, MASK_INPUT_SIZE });

        // —— 5. has_mask_input ——
        var hasMaskTensor = new DenseTensor<float>(
            new float[] { 0f },
            new[] { 1 });

        // ⚠️ 以下节点名基于标准MobileSAM导出,如不匹配请用LogSessionInfo打印实际名称
        return new List<NamedOnnxValue>
        {
            NamedOnnxValue.CreateFromTensor("image_embeddings", embTensor),
            NamedOnnxValue.CreateFromTensor("point_coords",     coordTensor),
            NamedOnnxValue.CreateFromTensor("point_labels",     labelTensor),
            NamedOnnxValue.CreateFromTensor("mask_input",       maskInputTensor),
            NamedOnnxValue.CreateFromTensor("has_mask_input",   hasMaskTensor),
        };
    }

    // ─── 工具方法 ────────────────────────────────────────
    private int SelectBestMask(float[] iouPreds)
    {
        int best = 0;
        for (int i = 1; i < iouPreds.Length; i++)
            if (iouPreds[i] > iouPreds[best]) best = i;
        return best;
    }

    private float[] ExtractMask(float[] allMasks, int maskIdx, int h, int w)
    {
        float[] result = new float[h * w];
        int offset = maskIdx * h * w;
        Array.Copy(allMasks, offset, result, 0, h * w);
        return result;
    }

    /// <summary>
    /// 将logits mask转为Texture2D,双线性上采样到目标尺寸
    /// </summary>
    private Texture2D BuildMaskTexture(float[] mask, int srcW, int srcH,
                                        int dstW, int dstH)
    {
        var tex = new Texture2D(dstW, dstH, TextureFormat.RGBA32, false);
        Color[] pixels = new Color[dstW * dstH];

        float scaleX = (float)srcW / dstW;
        float scaleY = (float)srcH / dstH;

        for (int y = 0; y < dstH; y++)
        {
            for (int x = 0; x < dstW; x++)
            {
                // 双线性采样
                float sx = (x + 0.5f) * scaleX - 0.5f;
                float sy = (y + 0.5f) * scaleY - 0.5f;
                float val = BilinearSample(mask, srcW, srcH, sx, sy);

                // logits > 0 视为前景
                float alpha = val > 0f ? 1f : 0f;
                // Unity纹理Y轴翻转
                pixels[(dstH - 1 - y) * dstW + x] =
                    new Color(0.2f, 0.8f, 1f, alpha * 0.6f);
            }
        }

        tex.SetPixels(pixels);
        tex.Apply();
        return tex;
    }

    private float BilinearSample(float[] data, int w, int h, float x, float y)
    {
        int x0 = Mathf.Clamp(Mathf.FloorToInt(x), 0, w - 1);
        int y0 = Mathf.Clamp(Mathf.FloorToInt(y), 0, h - 1);
        int x1 = Mathf.Clamp(x0 + 1, 0, w - 1);
        int y1 = Mathf.Clamp(y0 + 1, 0, h - 1);

        float fx = x - x0;
        float fy = y - y0;

        float v00 = data[y0 * w + x0];
        float v10 = data[y0 * w + x1];
        float v01 = data[y1 * w + x0];
        float v11 = data[y1 * w + x1];

        return Mathf.Lerp(
            Mathf.Lerp(v00, v10, fx),
            Mathf.Lerp(v01, v11, fx), fy);
    }

    private void LogSessionInfo(InferenceSession session, string name)
    {
        Debug.Log($"[{name}] ── 输入节点 ──");
        foreach (var kv in session.InputMetadata)
            Debug.Log($"  {kv.Key}: [{string.Join(",", kv.Value.Dimensions)}] {kv.Value.ElementType}");

        Debug.Log($"[{name}] ── 输出节点 ──");
        foreach (var kv in session.OutputMetadata)
            Debug.Log($"  {kv.Key}: [{string.Join(",", kv.Value.Dimensions)}] {kv.Value.ElementType}");
    }

    public void Dispose()
    {
        if (!_disposed)
        {
            _decoderSession?.Dispose();
            _disposed = true;
        }
    }
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using System.IO;

/// <summary>
/// resnet18_image_encoder.onnx https://github.com/NVIDIA-AI-IOT/nanosam/issues/41
/// mobile_sam_mask_decoder.onnx https://huggingface.co/dragonSwing/nanosam/tree/main
/// </summary>
public class NanoSAMDemo : MonoBehaviour
{
    [Header("模型路径(相对StreamingAssets)")]
    public string decoderModelName = "models/NanoSAM/mobile_sam_mask_decoder.onnx";
    public string encoderModelName = "models/NanoSAM/resnet18_image_encoder.onnx"; // 可选

    [Header("UI引用")]
    public RawImage targetImage;      // 显示原始图像
    public RawImage maskOverlay;      // 显示分割Mask(叠加在targetImage上)
    public Text statusText;

    [Header("测试图像")]
    public Texture2D testTexture;

    // ─── 私有成员 ───
    private NanoSAM _segmenter;
    private Microsoft.ML.OnnxRuntime.InferenceSession _encoderSession;
    private float[] _cachedEmbedding;   // 缓存当前图像的embedding
    private Texture2D _currentMaskTex;

    private List<Vector2> _promptPoints = new List<Vector2>();
    private List<int> _promptLabels = new List<int>();

    // ─── 生命周期 ────────────────────────────────────────
    void Start()
    {
        if (testTexture != null)
        {
            targetImage.texture = testTexture;
            maskOverlay.texture = testTexture;
        }
        StartCoroutine(InitializeModels());
    }

    IEnumerator InitializeModels()
    {
        SetStatus("正在加载模型...");
        yield return null;

        string decoderPath = Path.Combine(Application.streamingAssetsPath, decoderModelName);

        if (!File.Exists(decoderPath))
        {
            SetStatus($"❌ 模型未找到: {decoderPath}");
            yield break;
        }

        _segmenter = new NanoSAM(decoderPath);

        // 尝试加载encoder(可选)
        string encoderPath = Path.Combine(Application.streamingAssetsPath, encoderModelName);
        if (File.Exists(encoderPath))
        {
            _encoderSession = new Microsoft.ML.OnnxRuntime.InferenceSession(encoderPath);
            Debug.Log("[Demo] Encoder loaded.");
        }

        SetStatus("✅ 模型加载完成。点击图像进行分割(左键=前景,右键=背景)");

        // 显示测试图像并预计算embedding
        if (testTexture != null)
        {
            targetImage.texture = testTexture;
            yield return StartCoroutine(ComputeEmbedding(testTexture));
        }
    }

    // ─── 计算图像Embedding ───────────────────────────────
    IEnumerator ComputeEmbedding(Texture2D tex)
    {
        SetStatus("正在计算图像Embedding...");
        yield return null;

        if (_encoderSession != null)
        {
            // 使用真实encoder
            float[] preprocessed = ImageEncoder.PreprocessImage(tex);
            _cachedEmbedding = ImageEncoder.RunEncoderOnnx(_encoderSession, preprocessed);
            Debug.Log($"[Demo] Embedding computed, size={_cachedEmbedding.Length}");
        }
        else
        {
            // 没有encoder时使用随机embedding(仅调试用)
            Debug.LogWarning("[Demo] 未找到encoder,使用随机embedding(结果无意义,仅测试流程)");
            _cachedEmbedding = new float[256 * 64 * 64];
            var rng = new System.Random(42);
            for (int i = 0; i < _cachedEmbedding.Length; i++)
                _cachedEmbedding[i] = (float)(rng.NextDouble() * 0.1);
        }

        SetStatus("✅ 准备好。左键点击=前景点,右键=背景点,Space=运行分割,R=重置");
    }

    // ─── 输入处理 ────────────────────────────────────────
    void Update()
    {
        if (_cachedEmbedding == null) return;

        // 鼠标点击收集提示点
        if (Input.GetMouseButtonDown(0) || Input.GetMouseButtonDown(1))
        {
            if (TryGetNormalizedClickPos(out Vector2 normPos))
            {
                int label = Input.GetMouseButtonDown(0) ? 1 : 0;
                _promptPoints.Add(normPos);
                _promptLabels.Add(label);
                Debug.Log($"[Demo] Point added: {normPos}, label={label}");
                DrawPointMarker(normPos, label == 1 ? Color.green : Color.red);
            }
        }

        // Space运行分割
        if (Input.GetKeyDown(KeyCode.Space) && _promptPoints.Count > 0)
        {
            StartCoroutine(RunSegmentation());
        }

        // R重置
        if (Input.GetKeyDown(KeyCode.R))
        {
            ResetPrompts();
        }
    }

    IEnumerator RunSegmentation()
    {
        SetStatus("分割中...");
        yield return null;

        try
        {
            int w = testTexture != null ? testTexture.width : 512;
            int h = testTexture != null ? testTexture.height : 512;

            var maskTex = _segmenter.RunDecoder(
                _cachedEmbedding,
                _promptPoints,
                _promptLabels,
                w, h);

            if (_currentMaskTex != null)
                Destroy(_currentMaskTex);
            _currentMaskTex = maskTex;

            maskOverlay.texture = maskTex;
            SetStatus($"✅ 分割完成 ({_promptPoints.Count}个提示点)");
        }
        catch (System.Exception e)
        {
            SetStatus($"❌ 分割失败: {e.Message}");
            Debug.LogException(e);
        }
    }

    // ─── 工具 ────────────────────────────────────────────
    bool TryGetNormalizedClickPos(out Vector2 normPos)
    {
        normPos = Vector2.zero;
        if (targetImage == null) return false;

        var rectT = targetImage.rectTransform;
        if (!RectTransformUtility.ScreenPointToLocalPointInRectangle(
                rectT, Input.mousePosition,
                null, out Vector2 localPos)) return false;

        Rect rect = rectT.rect;
        normPos = new Vector2(
            (localPos.x - rect.x) / rect.width,
            (localPos.y - rect.y) / rect.height);

        // 过滤越界点
        if (normPos.x < 0 || normPos.x > 1 || normPos.y < 0 || normPos.y > 1)
            return false;

        return true;
    }

    void DrawPointMarker(Vector2 normPos, Color color)
    {
        // 简单在maskOverlay上标记(可扩展为UI点标记)
        Debug.Log($"[Demo] Marker at {normPos} color={color}");
    }

    void ResetPrompts()
    {
        _promptPoints.Clear();
        _promptLabels.Clear();
        if (maskOverlay != null)
        {
            //maskOverlay.texture = null;
        }
        SetStatus("已重置。重新点击添加提示点");
    }

    void SetStatus(string msg)
    {
        Debug.Log("[Demo] " + msg);
        if (statusText != null) statusText.text = msg;
    }

    // ─── 清理 ────────────────────────────────────────────
    void OnDestroy()
    {
        _segmenter?.Dispose();
        _encoderSession?.Dispose();
        if (_currentMaskTex != null) Destroy(_currentMaskTex);
    }
}

效果图

在这里插入图片描述
在这里插入图片描述

模型文件

resnet18_image_encoder.onnx
https://github.com/NVIDIA-AI-IOT/nanosam/issues/41

mobile_sam_mask_decoder.onnx https://huggingface.co/dragonSwing/nanosam/tree/main

最后是工程地址

https://github.com/xue-fei/onnxruntime-unity-samples.git

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐