Unity使用NanoSAM实现图像分割
NanoSAM是一个实时运行的Segment Anything (SAM)模型变体,可在NVIDIA Jetson Orin平台上通过TensorRT加速运行。本文展示了如何在Unity中实现NanoSAM的图像预处理和解码器功能。ImageEncoder类负责将Unity的Texture2D转换为SAM所需的1024x1024输入格式,并进行标准化处理。NanoSAM类封装了Mask解码器推理功
·
原始工程
https://github.com/NVIDIA-AI-IOT/nanosam
NanoSAM is a Segment Anything (SAM) model variant that is capable of running in 🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with NVIDIA TensorRT.
Unity中实现
using UnityEngine;
/// <summary>
/// 将Unity Texture2D编码为SAM所需的image embedding格式
/// 真实项目建议用image_encoder.onnx替代本脚本
/// </summary>
public static class ImageEncoder
{
private const int SAM_SIZE = 1024;
// ImageNet均值/标准差
private static readonly float[] MEAN = { 0.485f, 0.456f, 0.406f };
private static readonly float[] STD = { 0.229f, 0.224f, 0.225f };
/// <summary>
/// 将Texture2D预处理为SAM输入格式 [1,3,1024,1024] CHW,RGB
/// </summary>
public static float[] PreprocessImage(Texture2D tex)
{
// 缩放至1024x1024
var rt = RenderTexture.GetTemporary(SAM_SIZE, SAM_SIZE, 0,
RenderTextureFormat.ARGB32);
Graphics.Blit(tex, rt);
var resized = new Texture2D(SAM_SIZE, SAM_SIZE, TextureFormat.RGB24, false);
var prev = RenderTexture.active;
RenderTexture.active = rt;
resized.ReadPixels(new Rect(0, 0, SAM_SIZE, SAM_SIZE), 0, 0);
resized.Apply();
RenderTexture.active = prev;
RenderTexture.ReleaseTemporary(rt);
Color[] pixels = resized.GetPixels();
float[] data = new float[3 * SAM_SIZE * SAM_SIZE];
for (int y = 0; y < SAM_SIZE; y++)
{
for (int x = 0; x < SAM_SIZE; x++)
{
// Unity纹理Y轴翻转
Color c = pixels[(SAM_SIZE - 1 - y) * SAM_SIZE + x];
int idx = y * SAM_SIZE + x;
data[0 * SAM_SIZE * SAM_SIZE + idx] = (c.r - MEAN[0]) / STD[0];
data[1 * SAM_SIZE * SAM_SIZE + idx] = (c.g - MEAN[1]) / STD[1];
data[2 * SAM_SIZE * SAM_SIZE + idx] = (c.b - MEAN[2]) / STD[2];
}
}
UnityEngine.Object.Destroy(resized);
return data;
}
/// <summary>
/// 如果你有image_encoder.onnx,用此方法运行编码器得到embedding
/// </summary>
public static float[] RunEncoderOnnx(
Microsoft.ML.OnnxRuntime.InferenceSession encoderSession,
float[] preprocessedImage)
{
var tensor = new Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<float>(
preprocessedImage, new[] { 1, 3, SAM_SIZE, SAM_SIZE });
var inputs = new[]
{
Microsoft.ML.OnnxRuntime.NamedOnnxValue
.CreateFromTensor("image", tensor)
};
using var results = encoderSession.Run(inputs);
foreach (var r in results)
{
if (r.Name.Contains("embed") || r.Name.Contains("feature"))
{
return (r.Value as
Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<float>)
.Buffer.ToArray();
}
}
throw new System.Exception("未找到encoder输出");
}
}
using System;
using System.Collections.Generic;
using UnityEngine;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
/// <summary>
/// NanoSAM / MobileSAM Mask Decoder 推理封装
/// 输入: image embedding + 提示点 → 输出: 分割Mask
/// </summary>
public class NanoSAM : IDisposable
{
// ─── 模型常量(与MobileSAM保持一致)───
private const int IMAGE_SIZE = 1024; // SAM原始图像尺寸
private const int EMBED_DIM = 256; // 图像embedding通道数
private const int EMBED_H = 64; // embedding空间高度 (1024/16)
private const int EMBED_W = 64; // embedding空间宽度
private const int MASK_INPUT_SIZE = 256; // low-res mask输入尺寸
private const int NUM_MASK_TOKENS = 4; // SAM mask token数量
private InferenceSession _decoderSession;
private bool _disposed;
// ─── 构造 ───────────────────────────────────────────
public NanoSAM(string decoderModelPath)
{
var options = new SessionOptions();
//CPU
options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
// DML
//options.AppendExecutionProvider_DML(0);
_decoderSession = new InferenceSession(decoderModelPath, options);
Debug.Log("[NanoSAM] Decoder loaded.");
LogSessionInfo(_decoderSession, "Decoder");
}
// ─── 主推理接口 ──────────────────────────────────────
/// <summary>
/// 运行Mask Decoder推理
/// </summary>
/// <param name="imageEmbedding">图像embedding float数组 [1,256,64,64]</param>
/// <param name="promptPoints">归一化坐标点列表 (x,y) ∈ [0,1]</param>
/// <param name="promptLabels">点标签: 1=前景, 0=背景</param>
/// <param name="outputWidth">输出mask目标宽度</param>
/// <param name="outputHeight">输出mask目标高度</param>
/// <returns>二值化Mask纹理</returns>
public Texture2D RunDecoder(
float[] imageEmbedding,
List<Vector2> promptPoints,
List<int> promptLabels,
int outputWidth,
int outputHeight)
{
if (promptPoints.Count == 0)
throw new ArgumentException("至少需要一个提示点");
// —— 构建输入Tensors ——
var inputs = BuildDecoderInputs(imageEmbedding, promptPoints, promptLabels);
// —— 运行推理 ——
using var results = _decoderSession.Run(inputs);
// —— 解析输出 ——
// MobileSAM decoder输出: masks [1,4,256,256], iou_predictions [1,4]
float[] masks = null;
float[] iouPreds = null;
int[] masksShape = null;
foreach (var r in results)
{
if (r.Name.Contains("mask") && !r.Name.Contains("iou"))
{
var t = r.Value as DenseTensor<float>;
masks = t.Buffer.ToArray();
masksShape = new int[] {
(int)t.Dimensions[0],
(int)t.Dimensions[1],
(int)t.Dimensions[2],
(int)t.Dimensions[3]
};
}
else if (r.Name.Contains("iou"))
{
iouPreds = (r.Value as DenseTensor<float>).Buffer.ToArray();
}
}
if (masks == null)
throw new Exception("未找到masks输出,请检查模型输出节点名称");
// —— 选择最高IOU的Mask ——
int bestIdx = SelectBestMask(iouPreds);
Debug.Log($"[NanoSAM] Best mask index: {bestIdx}, IOU: {iouPreds[bestIdx]:F3}");
// —— 提取并上采样Mask ——
int mH = masksShape[2];
int mW = masksShape[3];
float[] bestMask = ExtractMask(masks, bestIdx, mH, mW);
return BuildMaskTexture(bestMask, mW, mH, outputWidth, outputHeight);
}
// ─── 构建Decoder输入 ────────────────────────────────
private List<NamedOnnxValue> BuildDecoderInputs(
float[] imageEmbedding,
List<Vector2> points,
List<int> labels)
{
int numPoints = points.Count;
// MobileSAM decoder期望的输入节点(请根据实际模型调整节点名):
// 1. image_embeddings [1, 256, 64, 64]
// 2. point_coords [1, N+1, 2] (多一个padding点)
// 3. point_labels [1, N+1] (padding标签=-1)
// 4. mask_input [1, 1, 256, 256]
// 5. has_mask_input [1]
// —— 1. image_embeddings ——
var embTensor = new DenseTensor<float>(
imageEmbedding,
new[] { 1, EMBED_DIM, EMBED_H, EMBED_W });
// —— 2. point_coords: SAM坐标系(原图像素坐标,非归一化) ——
int totalPts = numPoints + 1; // +1 padding
float[] coordData = new float[1 * totalPts * 2];
for (int i = 0; i < numPoints; i++)
{
// 将归一化坐标转为SAM图像坐标 [0, IMAGE_SIZE]
coordData[i * 2] = points[i].x * IMAGE_SIZE;
coordData[i * 2 + 1] = points[i].y * IMAGE_SIZE;
}
// padding点放在左上角
coordData[numPoints * 2] = 0f;
coordData[numPoints * 2 + 1] = 0f;
var coordTensor = new DenseTensor<float>(
coordData,
new[] { 1, totalPts, 2 });
// —— 3. point_labels ——
float[] labelData = new float[1 * totalPts];
for (int i = 0; i < numPoints; i++)
labelData[i] = labels[i];
labelData[numPoints] = -1f; // padding
var labelTensor = new DenseTensor<float>(
labelData,
new[] { 1, totalPts });
// —— 4. mask_input (全零,表示无先验mask) ——
float[] maskInputData = new float[1 * 1 * MASK_INPUT_SIZE * MASK_INPUT_SIZE];
var maskInputTensor = new DenseTensor<float>(
maskInputData,
new[] { 1, 1, MASK_INPUT_SIZE, MASK_INPUT_SIZE });
// —— 5. has_mask_input ——
var hasMaskTensor = new DenseTensor<float>(
new float[] { 0f },
new[] { 1 });
// ⚠️ 以下节点名基于标准MobileSAM导出,如不匹配请用LogSessionInfo打印实际名称
return new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("image_embeddings", embTensor),
NamedOnnxValue.CreateFromTensor("point_coords", coordTensor),
NamedOnnxValue.CreateFromTensor("point_labels", labelTensor),
NamedOnnxValue.CreateFromTensor("mask_input", maskInputTensor),
NamedOnnxValue.CreateFromTensor("has_mask_input", hasMaskTensor),
};
}
// ─── 工具方法 ────────────────────────────────────────
private int SelectBestMask(float[] iouPreds)
{
int best = 0;
for (int i = 1; i < iouPreds.Length; i++)
if (iouPreds[i] > iouPreds[best]) best = i;
return best;
}
private float[] ExtractMask(float[] allMasks, int maskIdx, int h, int w)
{
float[] result = new float[h * w];
int offset = maskIdx * h * w;
Array.Copy(allMasks, offset, result, 0, h * w);
return result;
}
/// <summary>
/// 将logits mask转为Texture2D,双线性上采样到目标尺寸
/// </summary>
private Texture2D BuildMaskTexture(float[] mask, int srcW, int srcH,
int dstW, int dstH)
{
var tex = new Texture2D(dstW, dstH, TextureFormat.RGBA32, false);
Color[] pixels = new Color[dstW * dstH];
float scaleX = (float)srcW / dstW;
float scaleY = (float)srcH / dstH;
for (int y = 0; y < dstH; y++)
{
for (int x = 0; x < dstW; x++)
{
// 双线性采样
float sx = (x + 0.5f) * scaleX - 0.5f;
float sy = (y + 0.5f) * scaleY - 0.5f;
float val = BilinearSample(mask, srcW, srcH, sx, sy);
// logits > 0 视为前景
float alpha = val > 0f ? 1f : 0f;
// Unity纹理Y轴翻转
pixels[(dstH - 1 - y) * dstW + x] =
new Color(0.2f, 0.8f, 1f, alpha * 0.6f);
}
}
tex.SetPixels(pixels);
tex.Apply();
return tex;
}
private float BilinearSample(float[] data, int w, int h, float x, float y)
{
int x0 = Mathf.Clamp(Mathf.FloorToInt(x), 0, w - 1);
int y0 = Mathf.Clamp(Mathf.FloorToInt(y), 0, h - 1);
int x1 = Mathf.Clamp(x0 + 1, 0, w - 1);
int y1 = Mathf.Clamp(y0 + 1, 0, h - 1);
float fx = x - x0;
float fy = y - y0;
float v00 = data[y0 * w + x0];
float v10 = data[y0 * w + x1];
float v01 = data[y1 * w + x0];
float v11 = data[y1 * w + x1];
return Mathf.Lerp(
Mathf.Lerp(v00, v10, fx),
Mathf.Lerp(v01, v11, fx), fy);
}
private void LogSessionInfo(InferenceSession session, string name)
{
Debug.Log($"[{name}] ── 输入节点 ──");
foreach (var kv in session.InputMetadata)
Debug.Log($" {kv.Key}: [{string.Join(",", kv.Value.Dimensions)}] {kv.Value.ElementType}");
Debug.Log($"[{name}] ── 输出节点 ──");
foreach (var kv in session.OutputMetadata)
Debug.Log($" {kv.Key}: [{string.Join(",", kv.Value.Dimensions)}] {kv.Value.ElementType}");
}
public void Dispose()
{
if (!_disposed)
{
_decoderSession?.Dispose();
_disposed = true;
}
}
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using System.IO;
/// <summary>
/// resnet18_image_encoder.onnx https://github.com/NVIDIA-AI-IOT/nanosam/issues/41
/// mobile_sam_mask_decoder.onnx https://huggingface.co/dragonSwing/nanosam/tree/main
/// </summary>
public class NanoSAMDemo : MonoBehaviour
{
[Header("模型路径(相对StreamingAssets)")]
public string decoderModelName = "models/NanoSAM/mobile_sam_mask_decoder.onnx";
public string encoderModelName = "models/NanoSAM/resnet18_image_encoder.onnx"; // 可选
[Header("UI引用")]
public RawImage targetImage; // 显示原始图像
public RawImage maskOverlay; // 显示分割Mask(叠加在targetImage上)
public Text statusText;
[Header("测试图像")]
public Texture2D testTexture;
// ─── 私有成员 ───
private NanoSAM _segmenter;
private Microsoft.ML.OnnxRuntime.InferenceSession _encoderSession;
private float[] _cachedEmbedding; // 缓存当前图像的embedding
private Texture2D _currentMaskTex;
private List<Vector2> _promptPoints = new List<Vector2>();
private List<int> _promptLabels = new List<int>();
// ─── 生命周期 ────────────────────────────────────────
void Start()
{
if (testTexture != null)
{
targetImage.texture = testTexture;
maskOverlay.texture = testTexture;
}
StartCoroutine(InitializeModels());
}
IEnumerator InitializeModels()
{
SetStatus("正在加载模型...");
yield return null;
string decoderPath = Path.Combine(Application.streamingAssetsPath, decoderModelName);
if (!File.Exists(decoderPath))
{
SetStatus($"❌ 模型未找到: {decoderPath}");
yield break;
}
_segmenter = new NanoSAM(decoderPath);
// 尝试加载encoder(可选)
string encoderPath = Path.Combine(Application.streamingAssetsPath, encoderModelName);
if (File.Exists(encoderPath))
{
_encoderSession = new Microsoft.ML.OnnxRuntime.InferenceSession(encoderPath);
Debug.Log("[Demo] Encoder loaded.");
}
SetStatus("✅ 模型加载完成。点击图像进行分割(左键=前景,右键=背景)");
// 显示测试图像并预计算embedding
if (testTexture != null)
{
targetImage.texture = testTexture;
yield return StartCoroutine(ComputeEmbedding(testTexture));
}
}
// ─── 计算图像Embedding ───────────────────────────────
IEnumerator ComputeEmbedding(Texture2D tex)
{
SetStatus("正在计算图像Embedding...");
yield return null;
if (_encoderSession != null)
{
// 使用真实encoder
float[] preprocessed = ImageEncoder.PreprocessImage(tex);
_cachedEmbedding = ImageEncoder.RunEncoderOnnx(_encoderSession, preprocessed);
Debug.Log($"[Demo] Embedding computed, size={_cachedEmbedding.Length}");
}
else
{
// 没有encoder时使用随机embedding(仅调试用)
Debug.LogWarning("[Demo] 未找到encoder,使用随机embedding(结果无意义,仅测试流程)");
_cachedEmbedding = new float[256 * 64 * 64];
var rng = new System.Random(42);
for (int i = 0; i < _cachedEmbedding.Length; i++)
_cachedEmbedding[i] = (float)(rng.NextDouble() * 0.1);
}
SetStatus("✅ 准备好。左键点击=前景点,右键=背景点,Space=运行分割,R=重置");
}
// ─── 输入处理 ────────────────────────────────────────
void Update()
{
if (_cachedEmbedding == null) return;
// 鼠标点击收集提示点
if (Input.GetMouseButtonDown(0) || Input.GetMouseButtonDown(1))
{
if (TryGetNormalizedClickPos(out Vector2 normPos))
{
int label = Input.GetMouseButtonDown(0) ? 1 : 0;
_promptPoints.Add(normPos);
_promptLabels.Add(label);
Debug.Log($"[Demo] Point added: {normPos}, label={label}");
DrawPointMarker(normPos, label == 1 ? Color.green : Color.red);
}
}
// Space运行分割
if (Input.GetKeyDown(KeyCode.Space) && _promptPoints.Count > 0)
{
StartCoroutine(RunSegmentation());
}
// R重置
if (Input.GetKeyDown(KeyCode.R))
{
ResetPrompts();
}
}
IEnumerator RunSegmentation()
{
SetStatus("分割中...");
yield return null;
try
{
int w = testTexture != null ? testTexture.width : 512;
int h = testTexture != null ? testTexture.height : 512;
var maskTex = _segmenter.RunDecoder(
_cachedEmbedding,
_promptPoints,
_promptLabels,
w, h);
if (_currentMaskTex != null)
Destroy(_currentMaskTex);
_currentMaskTex = maskTex;
maskOverlay.texture = maskTex;
SetStatus($"✅ 分割完成 ({_promptPoints.Count}个提示点)");
}
catch (System.Exception e)
{
SetStatus($"❌ 分割失败: {e.Message}");
Debug.LogException(e);
}
}
// ─── 工具 ────────────────────────────────────────────
bool TryGetNormalizedClickPos(out Vector2 normPos)
{
normPos = Vector2.zero;
if (targetImage == null) return false;
var rectT = targetImage.rectTransform;
if (!RectTransformUtility.ScreenPointToLocalPointInRectangle(
rectT, Input.mousePosition,
null, out Vector2 localPos)) return false;
Rect rect = rectT.rect;
normPos = new Vector2(
(localPos.x - rect.x) / rect.width,
(localPos.y - rect.y) / rect.height);
// 过滤越界点
if (normPos.x < 0 || normPos.x > 1 || normPos.y < 0 || normPos.y > 1)
return false;
return true;
}
void DrawPointMarker(Vector2 normPos, Color color)
{
// 简单在maskOverlay上标记(可扩展为UI点标记)
Debug.Log($"[Demo] Marker at {normPos} color={color}");
}
void ResetPrompts()
{
_promptPoints.Clear();
_promptLabels.Clear();
if (maskOverlay != null)
{
//maskOverlay.texture = null;
}
SetStatus("已重置。重新点击添加提示点");
}
void SetStatus(string msg)
{
Debug.Log("[Demo] " + msg);
if (statusText != null) statusText.text = msg;
}
// ─── 清理 ────────────────────────────────────────────
void OnDestroy()
{
_segmenter?.Dispose();
_encoderSession?.Dispose();
if (_currentMaskTex != null) Destroy(_currentMaskTex);
}
}
效果图


模型文件
resnet18_image_encoder.onnx
https://github.com/NVIDIA-AI-IOT/nanosam/issues/41
mobile_sam_mask_decoder.onnx https://huggingface.co/dragonSwing/nanosam/tree/main
最后是工程地址
更多推荐
所有评论(0)