协议转换 (tensormsg)
相关源文件
以下文件用作生成此 wiki 页面的上下文:
tensormsg 包作为 ROS 2 消息流与机器学习张量表示之间的**双向协议转换器**。它使 LeRobot 策略能够消费机器人观测并产生动作,无需手动编写序列化代码,使用契约驱动规范系统确保数据收集、训练和部署之间的一致性。
有关契约如何定义和加载的信息,请参阅 契约定义。有关使用 tensormsg 的推理管道详情,请参阅 推理管道。
来源: README.md:30-31, docs/architecture.md:203-208, src/tensormsg/tensormsg/converter.py:1-262
系统角色与数据流
tensormsg 包作为 IB-Robot 架构中的核心协议枢纽,实现以下数据转换:
graph TB
subgraph "ROS 2 World"
CAM["/camera/top/image_raw<br/>(sensor_msgs/Image)"]
JS["/joint_states<br/>(sensor_msgs/JointState)"]
CMD["/arm_position_controller/commands<br/>(Float64MultiArray)"]
end
subgraph "tensormsg Package"
CONV["TensorMsgConverter"]
DECODE["decode()<br/>ROS → NumPy"]
ENCODE["encode()<br/>Tensor → ROS"]
VARIANT_TO["to_variant()<br/>Dict[Tensor] → VariantsList"]
VARIANT_FROM["from_variant()<br/>VariantsList → Dict[Tensor]"]
CONV --> DECODE
CONV --> ENCODE
CONV --> VARIANT_TO
CONV --> VARIANT_FROM
end
subgraph "ML/LeRobot World"
OBS["observation.images.top<br/>(480, 640, 3) float32"]
STATE["observation.state<br/>(6,) float32"]
ACTION["action<br/>(6,) float32"]
end
CAM -->|"subscribe"| DECODE
JS -->|"subscribe"| DECODE
DECODE --> OBS
DECODE --> STATE
ACTION --> ENCODE
ENCODE -->|"publish"| CMD
OBS -.->|"distributed mode"| VARIANT_TO
STATE -.->|"distributed mode"| VARIANT_TO
VARIANT_FROM -.->|"distributed mode"| ACTION
style CONV fill:#e8f5e9,stroke:#388e3c,stroke-width:3px
style DECODE fill:#fff3e0,stroke:#ff9800,stroke-width:2px
style ENCODE fill:#fff3e0,stroke:#ff9800,stroke-width:2px
关键转换路径:
观测路径 (ROS → 张量): 相机图像和关节状态从 ROS 消息解码为 NumPy 数组,按契约规范调整大小/归一化,并批处理为张量用于策略输入。
动作路径 (张量 → ROS): 策略输出张量被编码回 ROS 消息(如
Float64MultiArray、JointState)并发布到控制器话题。分布式推理路径: 观测和动作被序列化为
ibrobot_msgs/msg/VariantsList,用于边缘节点和云端节点之间的网络传输。
来源: README.md:30-31, src/tensormsg/tensormsg/converter.py:11-72, src/robot_config/config/robots/so101_single_arm.yaml:203-301
TensorMsgConverter 类
TensorMsgConverter 类提供所有协议转换的核心 API。它作为无状态工具运行,使用静态方法委托给已注册的编码器/解码器函数。
graph LR
subgraph "TensorMsgConverter API"
API["TensorMsgConverter"]
subgraph "Core Methods"
DEC["decode(msg, spec)<br/>→ np.ndarray"]
ENC["encode(ros_type, data, names, clamp)<br/>→ ROS Message"]
TO_VAR["to_variant(batch)<br/>→ VariantsList"]
FROM_VAR["from_variant(msg, device)<br/>→ Dict[str, Tensor]"]
end
API --> DEC
API --> ENC
API --> TO_VAR
API --> FROM_VAR
end
subgraph "Registry System"
ENC_REG["ENCODER_REGISTRY"]
DEC_REG["DECODER_REGISTRY"]
end
subgraph "Registered Handlers"
IMG_DEC["_dec_image()<br/>sensor_msgs/Image"]
JS_DEC["_dec_joint_state()<br/>sensor_msgs/JointState"]
F32_DEC["_dec_f32()<br/>Float32MultiArray"]
TWIST_ENC["_enc_twist()<br/>geometry_msgs/Twist"]
JS_ENC["_enc_joint_state()"]
end
DEC -->|"lookup"| DEC_REG
ENC -->|"lookup"| ENC_REG
DEC_REG --> IMG_DEC
DEC_REG --> JS_DEC
DEC_REG --> F32_DEC
ENC_REG --> TWIST_ENC
ENC_REG --> JS_ENC
注册表系统
该包使用基于装饰器的注册表模式将 ROS 消息类型与转换函数关联。这允许轻松扩展自定义消息类型,无需修改核心逻辑。
编码器注册
编码器将 Python 数据(张量、数组、序列)转换为 ROS 消息:
# From converter.py:162-170
@register_encoder("geometry_msgs/msg/Twist")
def _enc_twist(names, data, clamp):
if names: return _encode_via_dotted_paths("geometry_msgs/msg/Twist", names, data, clamp)
msg = get_message("geometry_msgs/msg/Twist")()
arr = np.asarray(data, dtype=np.float32).reshape(-1)
if clamp: arr = np.clip(arr, clamp[0], clamp[1])
if len(arr) >= 1: msg.linear.x = float(arr[0])
if len(arr) >= 2: msg.angular.z = float(arr[1])
return msg
解码器注册
解码器从 ROS 消息中提取 NumPy 数组,应用契约规范:
# From converter.py:172-232
@register_decoder("sensor_msgs/msg/Image")
def _dec_image(msg, spec):
h, w = int(msg.height), int(msg.width)
enc = getattr(msg, "encoding", "bgr8").lower()
raw = np.frombuffer(msg.data, dtype=np.uint8)
resize_hw = spec.image_resize if spec and hasattr(spec, 'image_resize') else None
# Handle RGB/BGR/RGBA/BGRA
if enc in ("rgb8", "bgr8"):
ch = 3
row = raw.reshape(h, step)[:, : w * ch]
arr = row.reshape(h, w, ch)
hwc_rgb = arr if enc == "rgb8" else arr[..., ::-1]
if resize_hw:
hwc_rgb = nearest_resize_rgb(hwc_rgb, int(resize_hw[0]), int(resize_hw[1]))
return hwc_rgb.astype(np.float32) / 255.0
converter.py 中的主要注册:
ROS 类型 |
解码器 |
编码器 |
行范围 |
|---|---|---|---|
|
✅ |
❌ |
|
|
✅ |
✅ |
|
|
✅ |
❌ |
|
|
✅ |
❌ |
|
|
❌ |
✅ |
来源: src/tensormsg/tensormsg/converter.py:11-262, src/tensormsg/tensormsg/registry.py
图像解码管道
图像解码器支持多种 ROS 编码格式,并应用契约指定的转换:
graph TB
MSG["sensor_msgs/Image<br/>encoding, width, height, data"]
subgraph "Encoding Detection"
ENC_CHECK{encoding type?}
end
subgraph "RGB/BGR Path"
RGB_DECODE["Extract HWC array<br/>reshape(h, w, 3)"]
BGR_FLIP["BGR → RGB<br/>arr[..., ::-1]"]
end
subgraph "Depth Path"
DEPTH_16["16UC1:<br/>uint16 → meters"]
DEPTH_32["32FC1:<br/>float32 direct"]
DEPTH_NORM["Normalize to [0,1]<br/>clip & divide"]
DEPTH_REPEAT["Repeat to 3 channels"]
end
subgraph "Contract Transformations"
RESIZE_CHECK{resize specified?}
RESIZE_OP["nearest_resize_rgb()<br/>to target HW"]
NORMALIZE["/ 255.0<br/>→ float32 [0,1]"]
end
OUT["np.ndarray<br/>(H, W, 3) float32"]
MSG --> ENC_CHECK
ENC_CHECK -->|"rgb8/bgr8"| RGB_DECODE
ENC_CHECK -->|"16uc1/mono16"| DEPTH_16
ENC_CHECK -->|"32fc1"| DEPTH_32
RGB_DECODE --> BGR_FLIP
BGR_FLIP --> RESIZE_CHECK
DEPTH_16 --> DEPTH_NORM
DEPTH_32 --> DEPTH_NORM
DEPTH_NORM --> DEPTH_REPEAT
DEPTH_REPEAT --> RESIZE_CHECK
RESIZE_CHECK -->|"yes"| RESIZE_OP
RESIZE_CHECK -->|"no"| NORMALIZE
RESIZE_OP --> NORMALIZE
NORMALIZE --> OUT
支持的编码格式: - RGB/BGR: rgb8、bgr8、rgba8、bgra8 (移除 alpha 通道) - 灰度: mono8、8uc1 (复制到 3 通道) - 深度: 16uc1、mono16 (转换为米, 归一化)、32fc1 (直接使用 float32)
契约驱动转换: - 调整大小: 如果 spec.image_resize 为 [480, 640], 应用最近邻调整大小 - 归一化: 始终转换为 float32, 范围 [0, 1]
来源: src/tensormsg/tensormsg/converter.py:172-232, src/robot_config/config/robots/so101_single_arm.yaml:208-209
带名称选择的关节状态解码
JointState 解码器支持使用点分路径表示法的契约驱动字段选择:
graph TB
JS_MSG["sensor_msgs/JointState<br/>name: ['1','2','3','4','5','6']<br/>position: [0.1, 0.2, ..., 0.6]<br/>velocity: [...]<br/>effort: [...]"]
SPEC["Contract Spec<br/>names: ['position.1', 'position.2', ..., 'position.6']"]
subgraph "Decoding Logic"
HAS_NAMES{spec.names<br/>provided?}
DOTTED["_decode_via_names()"]
DEFAULT["Return msg.position<br/>as np.array"]
LOOP["For each name in spec.names"]
DOT_GET["dot_get(msg, 'position.1')<br/>→ Extract by index"]
BUILD["Accumulate values"]
end
OUT["np.ndarray<br/>(6,) float32"]
JS_MSG --> HAS_NAMES
SPEC --> HAS_NAMES
HAS_NAMES -->|"yes"| DOTTED
HAS_NAMES -->|"no"| DEFAULT
DOTTED --> LOOP
LOOP --> DOT_GET
DOT_GET --> BUILD
BUILD --> OUT
DEFAULT --> OUT
契约规范示例:
# From so101_single_arm.yaml:249-260
- key: observation.state
topic: /joint_states
type: sensor_msgs/msg/JointState
selector:
names:
- "position.1"
- "position.2"
- "position.3"
- "position.4"
- "position.5"
- "position.6"
点分表示法 position.1 由 dot_get() 辅助函数解析,该函数导航消息结构并从 position 数组中提取索引 1 处的值。
来源: src/tensormsg/tensormsg/converter.py:234-238, src/robot_config/config/robots/so101_single_arm.yaml:249-260
动作编码管道
动作从张量编码回 ROS 消息供控制器使用:
graph TB
TENSOR["Tensor or np.ndarray<br/>(6,) float32<br/>[j1, j2, j3, j4, j5, gripper]"]
CONTRACT["Contract Action Spec<br/>topic: /arm_position_controller/commands<br/>type: std_msgs/msg/Float64MultiArray<br/>names: ['action.0', ..., 'action.4']<br/>clamp: [-3.14, 3.14]"]
subgraph "Encoding Logic"
LOOKUP["Lookup encoder for<br/>std_msgs/msg/Float64MultiArray"]
FOUND{encoder<br/>registered?}
FALLBACK["_encode_via_dotted_paths()<br/>Manual field assignment"]
REGISTERED["Registered encoder function"]
FLATTEN["Flatten to 1D array"]
CLAMP["Apply joint limits<br/>np.clip(arr, min, max)"]
ASSIGN["Assign to msg fields<br/>via names or direct"]
end
MSG["std_msgs/Float64MultiArray<br/>data: [0.1, 0.2, 0.3, 0.4, 0.5]"]
TENSOR --> FLATTEN
CONTRACT --> LOOKUP
FLATTEN --> CLAMP
LOOKUP --> FOUND
FOUND -->|"no"| FALLBACK
FOUND -->|"yes"| REGISTERED
FALLBACK --> CLAMP
REGISTERED --> CLAMP
CLAMP --> ASSIGN
ASSIGN --> MSG
使用点分路径编码:
# From converter.py:75-84
def _encode_via_dotted_paths(ros_type: str, names: List[str], data: Any, clamp: Optional[Tuple[float, float]] = None) -> Any:
msg_cls = get_message(ros_type)
msg = msg_cls()
arr = np.asarray(data, dtype=np.float32).reshape(-1)
if clamp:
arr = np.clip(arr, clamp[0], clamp[1])
for i, path in enumerate(names):
if i < arr.size:
dot_set(msg, path, float(arr[i]))
return msg
来源: src/tensormsg/tensormsg/converter.py:14-22, src/tensormsg/tensormsg/converter.py:75-84, src/robot_config/config/robots/so101_single_arm.yaml:264-283
分布式推理的变体序列化
对于分布式推理模式,tensormsg 提供 to_variant() 和 from_variant() 方法,通过 ROS 话题序列化批次数据:
graph TB
subgraph "Edge Node (Device)"
BATCH_EDGE["Dict[str, Tensor]<br/>observation.images.top: (1,3,480,640)<br/>observation.state: (1,6)"]
TO_VAR["to_variant(batch)"]
PUB["/preprocessed/batch<br/>publisher"]
end
subgraph "VariantsList Message"
VLIST["ibrobot_msgs/msg/VariantsList<br/>variants: [Variant, Variant, ...]"]
V1["Variant<br/>key: 'observation.images.top'<br/>type: 'float_32_array'<br/>float_32_array.data: [...]<br/>float_32_array.layout.dim: [(1), (3), (480), (640)]"]
V2["Variant<br/>key: 'observation.state'<br/>type: 'float_32_array'<br/>float_32_array.data: [...]<br/>float_32_array.layout.dim: [(1), (6)]"]
VLIST --> V1
VLIST --> V2
end
subgraph "Cloud Node (GPU)"
SUB["/preprocessed/batch<br/>subscriber"]
FROM_VAR["from_variant(msg, device)"]
BATCH_CLOUD["Dict[str, Tensor]<br/>on GPU"]
end
BATCH_EDGE --> TO_VAR
TO_VAR --> VLIST
VLIST --> PUB
PUB -.->|"ROS 2 DDS"| SUB
SUB --> FROM_VAR
FROM_VAR --> BATCH_CLOUD
style VLIST fill:#fff3e0,stroke:#ff9800,stroke-width:2px
to_variant() 实现
将张量字典转换为 VariantsList 消息:
# From converter.py:42-63
@staticmethod
def to_variant(batch: Dict[str, Any]) -> Any:
msg_cls = get_message("ibrobot_msgs/msg/VariantsList")
msg = msg_cls()
msg.variants = []
for key, value in batch.items():
if not any(key.startswith(p) for p in ['task', 'observation', 'action']):
continue
variant_msg = get_message("ibrobot_msgs/msg/Variant")()
variant_msg.key = key
if isinstance(value, Tensor):
_fill_variant_from_tensor(variant_msg, value)
elif isinstance(value, list) and all(isinstance(x, str) for x in value):
variant_msg.type = "string_array"
variant_msg.string_array = value
else:
continue
msg.variants.append(variant_msg)
return msg
支持的张量类型
Torch dtype |
Variant 类型 |
MultiArray 消息 |
|---|---|---|
|
|
直接列表 |
|
|
|
|
|
|
|
|
|
|
|
|
每个 MultiArray 消息包含: - data: 扁平化的 1D 值数组 - layout.dim: MultiArrayDimension 消息列表,保留张量形状
来源: src/tensormsg/tensormsg/converter.py:42-158, src/inference_service/README.en.md:46-79
契约驱动转换
契约系统确保观测和动作在录制、训练和推理期间以相同方式处理:
graph TB
subgraph "robot_config YAML (Single Source of Truth)"
CONTRACT["contract:<br/> rate_hz: 20<br/> observations: [...]<br/> actions: [...]"]
end
subgraph "Observation Specification"
OBS_SPEC["- key: observation.images.top<br/> topic: /camera/top/image_raw<br/> type: sensor_msgs/msg/Image<br/> peripheral: top<br/> image:<br/> resize: [480, 640]<br/> align:<br/> strategy: hold<br/> stamp: header<br/> tol_ms: 1500"]
end
subgraph "Action Specification"
ACT_SPEC["- key: action<br/> selector:<br/> names: ['action.0', ..., 'action.4']<br/> publish:<br/> topic: /arm_position_controller/commands<br/> type: std_msgs/msg/Float64MultiArray<br/> strategy:<br/> mode: nearest<br/> tolerance_ms: 500"]
end
subgraph "tensormsg Usage"
DECODE_CALL["decode(img_msg, spec)<br/>→ (480, 640, 3) float32"]
ENCODE_CALL["encode('std_msgs/msg/Float64MultiArray',<br/> action_tensor, names, clamp)"]
end
CONTRACT --> OBS_SPEC
CONTRACT --> ACT_SPEC
OBS_SPEC -.->|"spec.image_resize"| DECODE_CALL
OBS_SPEC -.->|"spec.peripheral.width/height"| DECODE_CALL
ACT_SPEC -.->|"spec.selector.names"| ENCODE_CALL
ACT_SPEC -.->|"spec.publish.type"| ENCODE_CALL
契约确保: - 调整大小一致性: 训练数据和实时推理使用相同的图像尺寸 - 字段映射: 录制和推理中提取相同的关节索引/名称 - QoS 设置: 相同的可靠性/历史设置防止数据丢失 - 对齐策略: 相同的时间戳容差防止同步漂移
来源: src/robot_config/config/robots/so101_single_arm.yaml:198-301, src/tensormsg/tensormsg/converter.py:24-39
使用示例
推理节点中解码观测
# Simplified from lerobot_policy_node
from tensormsg.converter import TensorMsgConverter
from robot_config.loader import load_robot_config
config = load_robot_config("so101_single_arm.yaml")
# Find observation spec by key
obs_spec = next(o for o in config.contract.observations if o.key == "observation.images.top")
# Callback receives ROS message
def image_callback(msg):
# Decode with contract spec
img_array = TensorMsgConverter.decode(msg, obs_spec)
# img_array is now (480, 640, 3) float32 in [0, 1]
# Convert to torch tensor
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1) # CHW
# Ready for policy input
编码动作用于分发
# Simplified from action_dispatcher_node
from tensormsg.converter import TensorMsgConverter
# Policy outputs action tensor
action_tensor = policy.predict(observations) # (6,) float32
# Find action spec
action_spec = config.contract.actions[0]
ros_type = action_spec.publish['type']
names = action_spec.selector['names']
# Encode to ROS message
msg = TensorMsgConverter.encode(
ros_type=ros_type,
data=action_tensor,
names=names,
clamp=(-3.14, 3.14) # Joint limits
)
# Publish to controller
publisher.publish(msg)
分布式模式的变体序列化
# Edge node: serialize batch for network transmission
from tensormsg.converter import TensorMsgConverter
batch = {
"observation.images.top": torch.randn(1, 3, 480, 640),
"observation.state": torch.randn(1, 6)
}
variants_msg = TensorMsgConverter.to_variant(batch)
preprocessed_pub.publish(variants_msg)
# Cloud node: deserialize and move to GPU
def batch_callback(variants_msg):
batch = TensorMsgConverter.from_variant(variants_msg, device=torch.device("cuda"))
# batch tensors are now on GPU, ready for inference
output = model(batch)
来源: src/tensormsg/tensormsg/converter.py:11-158, src/inference_service/inference_service/nodes/lerobot_policy_node.py
关键设计原则
无状态转换: 所有方法都是静态的;没有内部状态允许跨线程并行使用。
契约驱动一致性: 来自
robot_config的规范确保录制、训练和推理中的处理完全相同。通过注册表扩展: 通过使用
@register_encoder或@register_decoder装饰函数来添加新消息类型。回退机制: 当没有注册的处理程序时,点分路径表示法允许通用字段访问。
尽可能零拷贝: NumPy 数组直接从 ROS 消息缓冲区创建(如
np.frombuffer(msg.data)),以最小化分配开销。
来源: src/tensormsg/tensormsg/converter.py:11-262, README.md:30-31