协议转换 (tensormsg)

相关源文件

以下文件用作生成此 wiki 页面的上下文:

tensormsg 包作为 ROS 2 消息流与机器学习张量表示之间的**双向协议转换器**。它使 LeRobot 策略能够消费机器人观测并产生动作,无需手动编写序列化代码,使用契约驱动规范系统确保数据收集、训练和部署之间的一致性。

有关契约如何定义和加载的信息,请参阅 契约定义。有关使用 tensormsg 的推理管道详情,请参阅 推理管道

来源: README.md:30-31, docs/architecture.md:203-208, src/tensormsg/tensormsg/converter.py:1-262


系统角色与数据流

tensormsg 包作为 IB-Robot 架构中的核心协议枢纽,实现以下数据转换:

        graph TB
    subgraph "ROS 2 World"
        CAM["/camera/top/image_raw<br/>(sensor_msgs/Image)"]
        JS["/joint_states<br/>(sensor_msgs/JointState)"]
        CMD["/arm_position_controller/commands<br/>(Float64MultiArray)"]
    end

    subgraph "tensormsg Package"
        CONV["TensorMsgConverter"]
        DECODE["decode()<br/>ROS → NumPy"]
        ENCODE["encode()<br/>Tensor → ROS"]
        VARIANT_TO["to_variant()<br/>Dict[Tensor] → VariantsList"]
        VARIANT_FROM["from_variant()<br/>VariantsList → Dict[Tensor]"]

        CONV --> DECODE
        CONV --> ENCODE
        CONV --> VARIANT_TO
        CONV --> VARIANT_FROM
    end

    subgraph "ML/LeRobot World"
        OBS["observation.images.top<br/>(480, 640, 3) float32"]
        STATE["observation.state<br/>(6,) float32"]
        ACTION["action<br/>(6,) float32"]
    end

    CAM -->|"subscribe"| DECODE
    JS -->|"subscribe"| DECODE
    DECODE --> OBS
    DECODE --> STATE

    ACTION --> ENCODE
    ENCODE -->|"publish"| CMD

    OBS -.->|"distributed mode"| VARIANT_TO
    STATE -.->|"distributed mode"| VARIANT_TO
    VARIANT_FROM -.->|"distributed mode"| ACTION

    style CONV fill:#e8f5e9,stroke:#388e3c,stroke-width:3px
    style DECODE fill:#fff3e0,stroke:#ff9800,stroke-width:2px
    style ENCODE fill:#fff3e0,stroke:#ff9800,stroke-width:2px
    

关键转换路径:

  1. 观测路径 (ROS → 张量): 相机图像和关节状态从 ROS 消息解码为 NumPy 数组,按契约规范调整大小/归一化,并批处理为张量用于策略输入。

  2. 动作路径 (张量 → ROS): 策略输出张量被编码回 ROS 消息(如 Float64MultiArrayJointState)并发布到控制器话题。

  3. 分布式推理路径: 观测和动作被序列化为 ibrobot_msgs/msg/VariantsList,用于边缘节点和云端节点之间的网络传输。

来源: README.md:30-31, src/tensormsg/tensormsg/converter.py:11-72, src/robot_config/config/robots/so101_single_arm.yaml:203-301


TensorMsgConverter 类

TensorMsgConverter 类提供所有协议转换的核心 API。它作为无状态工具运行,使用静态方法委托给已注册的编码器/解码器函数。

        graph LR
    subgraph "TensorMsgConverter API"
        API["TensorMsgConverter"]

        subgraph "Core Methods"
            DEC["decode(msg, spec)<br/>→ np.ndarray"]
            ENC["encode(ros_type, data, names, clamp)<br/>→ ROS Message"]
            TO_VAR["to_variant(batch)<br/>→ VariantsList"]
            FROM_VAR["from_variant(msg, device)<br/>→ Dict[str, Tensor]"]
        end

        API --> DEC
        API --> ENC
        API --> TO_VAR
        API --> FROM_VAR
    end

    subgraph "Registry System"
        ENC_REG["ENCODER_REGISTRY"]
        DEC_REG["DECODER_REGISTRY"]
    end

    subgraph "Registered Handlers"
        IMG_DEC["_dec_image()<br/>sensor_msgs/Image"]
        JS_DEC["_dec_joint_state()<br/>sensor_msgs/JointState"]
        F32_DEC["_dec_f32()<br/>Float32MultiArray"]
        TWIST_ENC["_enc_twist()<br/>geometry_msgs/Twist"]
        JS_ENC["_enc_joint_state()"]
    end

    DEC -->|"lookup"| DEC_REG
    ENC -->|"lookup"| ENC_REG

    DEC_REG --> IMG_DEC
    DEC_REG --> JS_DEC
    DEC_REG --> F32_DEC
    ENC_REG --> TWIST_ENC
    ENC_REG --> JS_ENC
    

来源: src/tensormsg/tensormsg/converter.py:11-72


注册表系统

该包使用基于装饰器的注册表模式将 ROS 消息类型与转换函数关联。这允许轻松扩展自定义消息类型,无需修改核心逻辑。

编码器注册

编码器将 Python 数据(张量、数组、序列)转换为 ROS 消息:

# From converter.py:162-170
@register_encoder("geometry_msgs/msg/Twist")
def _enc_twist(names, data, clamp):
    if names: return _encode_via_dotted_paths("geometry_msgs/msg/Twist", names, data, clamp)
    msg = get_message("geometry_msgs/msg/Twist")()
    arr = np.asarray(data, dtype=np.float32).reshape(-1)
    if clamp: arr = np.clip(arr, clamp[0], clamp[1])
    if len(arr) >= 1: msg.linear.x = float(arr[0])
    if len(arr) >= 2: msg.angular.z = float(arr[1])
    return msg

解码器注册

解码器从 ROS 消息中提取 NumPy 数组,应用契约规范:

# From converter.py:172-232
@register_decoder("sensor_msgs/msg/Image")
def _dec_image(msg, spec):
    h, w = int(msg.height), int(msg.width)
    enc = getattr(msg, "encoding", "bgr8").lower()
    raw = np.frombuffer(msg.data, dtype=np.uint8)

    resize_hw = spec.image_resize if spec and hasattr(spec, 'image_resize') else None

    # Handle RGB/BGR/RGBA/BGRA
    if enc in ("rgb8", "bgr8"):
        ch = 3
        row = raw.reshape(h, step)[:, : w * ch]
        arr = row.reshape(h, w, ch)
        hwc_rgb = arr if enc == "rgb8" else arr[..., ::-1]

    if resize_hw:
        hwc_rgb = nearest_resize_rgb(hwc_rgb, int(resize_hw[0]), int(resize_hw[1]))

    return hwc_rgb.astype(np.float32) / 255.0

converter.py 中的主要注册:

ROS 类型

解码器

编码器

行范围

sensor_msgs/msg/Image

172-232

sensor_msgs/msg/JointState

234-249

std_msgs/msg/Float32MultiArray

251-253

std_msgs/msg/Float64MultiArray

255-257

geometry_msgs/msg/Twist

162-170

来源: src/tensormsg/tensormsg/converter.py:11-262, src/tensormsg/tensormsg/registry.py


图像解码管道

图像解码器支持多种 ROS 编码格式,并应用契约指定的转换:

        graph TB
    MSG["sensor_msgs/Image<br/>encoding, width, height, data"]

    subgraph "Encoding Detection"
        ENC_CHECK{encoding type?}
    end

    subgraph "RGB/BGR Path"
        RGB_DECODE["Extract HWC array<br/>reshape(h, w, 3)"]
        BGR_FLIP["BGR → RGB<br/>arr[..., ::-1]"]
    end

    subgraph "Depth Path"
        DEPTH_16["16UC1:<br/>uint16 → meters"]
        DEPTH_32["32FC1:<br/>float32 direct"]
        DEPTH_NORM["Normalize to [0,1]<br/>clip & divide"]
        DEPTH_REPEAT["Repeat to 3 channels"]
    end

    subgraph "Contract Transformations"
        RESIZE_CHECK{resize specified?}
        RESIZE_OP["nearest_resize_rgb()<br/>to target HW"]
        NORMALIZE["/ 255.0<br/>→ float32 [0,1]"]
    end

    OUT["np.ndarray<br/>(H, W, 3) float32"]

    MSG --> ENC_CHECK
    ENC_CHECK -->|"rgb8/bgr8"| RGB_DECODE
    ENC_CHECK -->|"16uc1/mono16"| DEPTH_16
    ENC_CHECK -->|"32fc1"| DEPTH_32

    RGB_DECODE --> BGR_FLIP
    BGR_FLIP --> RESIZE_CHECK

    DEPTH_16 --> DEPTH_NORM
    DEPTH_32 --> DEPTH_NORM
    DEPTH_NORM --> DEPTH_REPEAT
    DEPTH_REPEAT --> RESIZE_CHECK

    RESIZE_CHECK -->|"yes"| RESIZE_OP
    RESIZE_CHECK -->|"no"| NORMALIZE
    RESIZE_OP --> NORMALIZE
    NORMALIZE --> OUT
    

支持的编码格式: - RGB/BGR: rgb8bgr8rgba8bgra8 (移除 alpha 通道) - 灰度: mono88uc1 (复制到 3 通道) - 深度: 16uc1mono16 (转换为米, 归一化)、32fc1 (直接使用 float32)

契约驱动转换: - 调整大小: 如果 spec.image_resize[480, 640], 应用最近邻调整大小 - 归一化: 始终转换为 float32, 范围 [0, 1]

来源: src/tensormsg/tensormsg/converter.py:172-232, src/robot_config/config/robots/so101_single_arm.yaml:208-209


带名称选择的关节状态解码

JointState 解码器支持使用点分路径表示法的契约驱动字段选择:

        graph TB
    JS_MSG["sensor_msgs/JointState<br/>name: ['1','2','3','4','5','6']<br/>position: [0.1, 0.2, ..., 0.6]<br/>velocity: [...]<br/>effort: [...]"]

    SPEC["Contract Spec<br/>names: ['position.1', 'position.2', ..., 'position.6']"]

    subgraph "Decoding Logic"
        HAS_NAMES{spec.names<br/>provided?}

        DOTTED["_decode_via_names()"]
        DEFAULT["Return msg.position<br/>as np.array"]

        LOOP["For each name in spec.names"]
        DOT_GET["dot_get(msg, 'position.1')<br/>→ Extract by index"]
        BUILD["Accumulate values"]
    end

    OUT["np.ndarray<br/>(6,) float32"]

    JS_MSG --> HAS_NAMES
    SPEC --> HAS_NAMES

    HAS_NAMES -->|"yes"| DOTTED
    HAS_NAMES -->|"no"| DEFAULT

    DOTTED --> LOOP
    LOOP --> DOT_GET
    DOT_GET --> BUILD
    BUILD --> OUT
    DEFAULT --> OUT
    

契约规范示例:

# From so101_single_arm.yaml:249-260
- key: observation.state
  topic: /joint_states
  type: sensor_msgs/msg/JointState
  selector:
    names:
      - "position.1"
      - "position.2"
      - "position.3"
      - "position.4"
      - "position.5"
      - "position.6"

点分表示法 position.1dot_get() 辅助函数解析,该函数导航消息结构并从 position 数组中提取索引 1 处的值。

来源: src/tensormsg/tensormsg/converter.py:234-238, src/robot_config/config/robots/so101_single_arm.yaml:249-260


动作编码管道

动作从张量编码回 ROS 消息供控制器使用:

        graph TB
    TENSOR["Tensor or np.ndarray<br/>(6,) float32<br/>[j1, j2, j3, j4, j5, gripper]"]

    CONTRACT["Contract Action Spec<br/>topic: /arm_position_controller/commands<br/>type: std_msgs/msg/Float64MultiArray<br/>names: ['action.0', ..., 'action.4']<br/>clamp: [-3.14, 3.14]"]

    subgraph "Encoding Logic"
        LOOKUP["Lookup encoder for<br/>std_msgs/msg/Float64MultiArray"]

        FOUND{encoder<br/>registered?}

        FALLBACK["_encode_via_dotted_paths()<br/>Manual field assignment"]
        REGISTERED["Registered encoder function"]

        FLATTEN["Flatten to 1D array"]
        CLAMP["Apply joint limits<br/>np.clip(arr, min, max)"]
        ASSIGN["Assign to msg fields<br/>via names or direct"]
    end

    MSG["std_msgs/Float64MultiArray<br/>data: [0.1, 0.2, 0.3, 0.4, 0.5]"]

    TENSOR --> FLATTEN
    CONTRACT --> LOOKUP
    FLATTEN --> CLAMP

    LOOKUP --> FOUND
    FOUND -->|"no"| FALLBACK
    FOUND -->|"yes"| REGISTERED

    FALLBACK --> CLAMP
    REGISTERED --> CLAMP
    CLAMP --> ASSIGN
    ASSIGN --> MSG
    

使用点分路径编码:

# From converter.py:75-84
def _encode_via_dotted_paths(ros_type: str, names: List[str], data: Any, clamp: Optional[Tuple[float, float]] = None) -> Any:
    msg_cls = get_message(ros_type)
    msg = msg_cls()
    arr = np.asarray(data, dtype=np.float32).reshape(-1)
    if clamp:
        arr = np.clip(arr, clamp[0], clamp[1])
    for i, path in enumerate(names):
        if i < arr.size:
            dot_set(msg, path, float(arr[i]))
    return msg

来源: src/tensormsg/tensormsg/converter.py:14-22, src/tensormsg/tensormsg/converter.py:75-84, src/robot_config/config/robots/so101_single_arm.yaml:264-283


分布式推理的变体序列化

对于分布式推理模式,tensormsg 提供 to_variant()from_variant() 方法,通过 ROS 话题序列化批次数据:

        graph TB
    subgraph "Edge Node (Device)"
        BATCH_EDGE["Dict[str, Tensor]<br/>observation.images.top: (1,3,480,640)<br/>observation.state: (1,6)"]
        TO_VAR["to_variant(batch)"]
        PUB["/preprocessed/batch<br/>publisher"]
    end

    subgraph "VariantsList Message"
        VLIST["ibrobot_msgs/msg/VariantsList<br/>variants: [Variant, Variant, ...]"]

        V1["Variant<br/>key: 'observation.images.top'<br/>type: 'float_32_array'<br/>float_32_array.data: [...]<br/>float_32_array.layout.dim: [(1), (3), (480), (640)]"]

        V2["Variant<br/>key: 'observation.state'<br/>type: 'float_32_array'<br/>float_32_array.data: [...]<br/>float_32_array.layout.dim: [(1), (6)]"]

        VLIST --> V1
        VLIST --> V2
    end

    subgraph "Cloud Node (GPU)"
        SUB["/preprocessed/batch<br/>subscriber"]
        FROM_VAR["from_variant(msg, device)"]
        BATCH_CLOUD["Dict[str, Tensor]<br/>on GPU"]
    end

    BATCH_EDGE --> TO_VAR
    TO_VAR --> VLIST
    VLIST --> PUB
    PUB -.->|"ROS 2 DDS"| SUB
    SUB --> FROM_VAR
    FROM_VAR --> BATCH_CLOUD

    style VLIST fill:#fff3e0,stroke:#ff9800,stroke-width:2px
    

to_variant() 实现

将张量字典转换为 VariantsList 消息:

# From converter.py:42-63
@staticmethod
def to_variant(batch: Dict[str, Any]) -> Any:
    msg_cls = get_message("ibrobot_msgs/msg/VariantsList")
    msg = msg_cls()
    msg.variants = []

    for key, value in batch.items():
        if not any(key.startswith(p) for p in ['task', 'observation', 'action']):
            continue

        variant_msg = get_message("ibrobot_msgs/msg/Variant")()
        variant_msg.key = key

        if isinstance(value, Tensor):
            _fill_variant_from_tensor(variant_msg, value)
        elif isinstance(value, list) and all(isinstance(x, str) for x in value):
            variant_msg.type = "string_array"
            variant_msg.string_array = value
        else:
            continue
        msg.variants.append(variant_msg)
    return msg

支持的张量类型

Torch dtype

Variant 类型

MultiArray 消息

torch.bool

"bool_array"

直接列表

torch.int32

"int_32_array"

Int32MultiArray

torch.int64

"int_64_array"

Int64MultiArray

torch.float32

"float_32_array"

Float32MultiArray

torch.float64

"float_64_array"

Float64MultiArray

每个 MultiArray 消息包含: - data: 扁平化的 1D 值数组 - layout.dim: MultiArrayDimension 消息列表,保留张量形状

来源: src/tensormsg/tensormsg/converter.py:42-158, src/inference_service/README.en.md:46-79


契约驱动转换

契约系统确保观测和动作在录制、训练和推理期间以相同方式处理:

        graph TB
    subgraph "robot_config YAML (Single Source of Truth)"
        CONTRACT["contract:<br/>  rate_hz: 20<br/>  observations: [...]<br/>  actions: [...]"]
    end

    subgraph "Observation Specification"
        OBS_SPEC["- key: observation.images.top<br/>  topic: /camera/top/image_raw<br/>  type: sensor_msgs/msg/Image<br/>  peripheral: top<br/>  image:<br/>    resize: [480, 640]<br/>  align:<br/>    strategy: hold<br/>    stamp: header<br/>    tol_ms: 1500"]
    end

    subgraph "Action Specification"
        ACT_SPEC["- key: action<br/>  selector:<br/>    names: ['action.0', ..., 'action.4']<br/>  publish:<br/>    topic: /arm_position_controller/commands<br/>    type: std_msgs/msg/Float64MultiArray<br/>    strategy:<br/>      mode: nearest<br/>      tolerance_ms: 500"]
    end

    subgraph "tensormsg Usage"
        DECODE_CALL["decode(img_msg, spec)<br/>→ (480, 640, 3) float32"]
        ENCODE_CALL["encode('std_msgs/msg/Float64MultiArray',<br/>       action_tensor, names, clamp)"]
    end

    CONTRACT --> OBS_SPEC
    CONTRACT --> ACT_SPEC

    OBS_SPEC -.->|"spec.image_resize"| DECODE_CALL
    OBS_SPEC -.->|"spec.peripheral.width/height"| DECODE_CALL

    ACT_SPEC -.->|"spec.selector.names"| ENCODE_CALL
    ACT_SPEC -.->|"spec.publish.type"| ENCODE_CALL
    

契约确保: - 调整大小一致性: 训练数据和实时推理使用相同的图像尺寸 - 字段映射: 录制和推理中提取相同的关节索引/名称 - QoS 设置: 相同的可靠性/历史设置防止数据丢失 - 对齐策略: 相同的时间戳容差防止同步漂移

来源: src/robot_config/config/robots/so101_single_arm.yaml:198-301, src/tensormsg/tensormsg/converter.py:24-39


使用示例

推理节点中解码观测

# Simplified from lerobot_policy_node
from tensormsg.converter import TensorMsgConverter
from robot_config.loader import load_robot_config

config = load_robot_config("so101_single_arm.yaml")

# Find observation spec by key
obs_spec = next(o for o in config.contract.observations if o.key == "observation.images.top")

# Callback receives ROS message
def image_callback(msg):
    # Decode with contract spec
    img_array = TensorMsgConverter.decode(msg, obs_spec)
    # img_array is now (480, 640, 3) float32 in [0, 1]

    # Convert to torch tensor
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)  # CHW
    # Ready for policy input

编码动作用于分发

# Simplified from action_dispatcher_node
from tensormsg.converter import TensorMsgConverter

# Policy outputs action tensor
action_tensor = policy.predict(observations)  # (6,) float32

# Find action spec
action_spec = config.contract.actions[0]
ros_type = action_spec.publish['type']
names = action_spec.selector['names']

# Encode to ROS message
msg = TensorMsgConverter.encode(
    ros_type=ros_type,
    data=action_tensor,
    names=names,
    clamp=(-3.14, 3.14)  # Joint limits
)

# Publish to controller
publisher.publish(msg)

分布式模式的变体序列化

# Edge node: serialize batch for network transmission
from tensormsg.converter import TensorMsgConverter

batch = {
    "observation.images.top": torch.randn(1, 3, 480, 640),
    "observation.state": torch.randn(1, 6)
}

variants_msg = TensorMsgConverter.to_variant(batch)
preprocessed_pub.publish(variants_msg)

# Cloud node: deserialize and move to GPU
def batch_callback(variants_msg):
    batch = TensorMsgConverter.from_variant(variants_msg, device=torch.device("cuda"))
    # batch tensors are now on GPU, ready for inference
    output = model(batch)

来源: src/tensormsg/tensormsg/converter.py:11-158, src/inference_service/inference_service/nodes/lerobot_policy_node.py


关键设计原则

  1. 无状态转换: 所有方法都是静态的;没有内部状态允许跨线程并行使用。

  2. 契约驱动一致性: 来自 robot_config 的规范确保录制、训练和推理中的处理完全相同。

  3. 通过注册表扩展: 通过使用 @register_encoder@register_decoder 装饰函数来添加新消息类型。

  4. 回退机制: 当没有注册的处理程序时,点分路径表示法允许通用字段访问。

  5. 尽可能零拷贝: NumPy 数组直接从 ROS 消息缓冲区创建(如 np.frombuffer(msg.data)),以最小化分配开销。

来源: src/tensormsg/tensormsg/converter.py:11-262, README.md:30-31