#onnx
import onnx
from onnx import helper
def remove_node_and_connect(model, target_node_name):
# 获取模型的图
graph = model.graph
# 查找目标节点、前驱和后继节点
target_node = None
predecessor_nodes = []
successor_nodes = []
# 遍历图中的所有节点,找到目标节点及其前后节点
for node in graph.node:
if node.name == target_node_name:
target_node = node
# 记录目标节点的输入(前驱节点的输出)
input_name = node.input[0]
# 记录目标节点的输出(后继节点的输入)
output_name = node.output[0]
else:
if target_node_name in node.input:
successor_nodes.append(node)
if target_node_name in node.output:
predecessor_nodes.append(node)
if not target_node:
print(f"节点 {target_node_name} 未找到")
return model
# 更新前后节点以跳过目标节点
for successor in successor_nodes:
for i, input_name in enumerate(successor.input):
if input_name == target_node.output[0]: # 检查是否为目标节点的输出
successor.input[i] = target_node.input[0] # 将前驱节点的输出连接到后继节点
# 从图中移除目标节点
graph.node.remove(target_node)
return model
# 加载 ONNX 模型
model_path = "path/to/your_model.onnx"
model = onnx.load(model_path)
# 删除节点并连接前后节点
target_node_name = "name_of_target_node" # 替换为实际节点名称
modified_model = remove_node_and_connect(model, target_node_name)
# 保存修改后的模型
onnx.save(modified_model, "path/to/modified_model.onnx")
print("模型已保存")