找回密码
 立即注册
首页 业界区 安全 Spring AI Alibaba 项目源码学习(二)-Graph 定义与描 ...

Spring AI Alibaba 项目源码学习(二)-Graph 定义与描述分析

拴茅劾 2025-11-11 03:15:02
Graph 定义与描述分析

请关注公众号:阿呆-bot
概述

本文档分析 spring-ai-alibaba-graph-core 模块中 Graph 的定义和描述机制,包括入口类、关键类关系、核心实现代码和设计模式。
入口类说明

StateGraph - Graph 定义入口

StateGraph 是定义工作流图的主要入口类,用于构建包含节点(Node)和边(Edge)的状态图。
核心职责

  • 管理图中的节点集合(Nodes)
  • 管理图中的边集合(Edges)
  • 提供节点和边的添加方法
  • 支持条件路由和子图
  • 编译为可执行的 CompiledGraph
关键代码
  1. public class StateGraph {
  2.         /**
  3.          * Constant representing the END of the graph.
  4.          */
  5.         public static final String END = "__END__";
  6.         /**
  7.          * Constant representing the START of the graph.
  8.          */
  9.         public static final String START = "__START__";
  10.         /**
  11.          * Constant representing the ERROR of the graph.
  12.          */
  13.         public static final String ERROR = "__ERROR__";
  14.         /**
  15.          * Constant representing the NODE_BEFORE of the graph.
  16.          */
  17.         public static final String NODE_BEFORE = "__NODE_BEFORE__";
  18.         /**
  19.          * Constant representing the NODE_AFTER of the graph.
  20.          */
  21.         public static final String NODE_AFTER = "__NODE_AFTER__";
  22.         /**
  23.          * Collection of nodes in the graph.
  24.          */
  25.         final Nodes nodes = new Nodes();
  26.         /**
  27.          * Collection of edges in the graph.
  28.          */
  29.         final Edges edges = new Edges();
  30.         /**
  31.          * Factory for providing key strategies.
  32.          */
  33.         private KeyStrategyFactory keyStrategyFactory;
  34.         /**
  35.          * Name of the graph.
  36.          */
  37.         private String name;
  38.         /**
  39.          * Serializer for the state.
  40.          */
  41.         private final StateSerializer stateSerializer;
复制代码
构造函数
  1. public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
  2.                 this.name = name;
  3.                 this.keyStrategyFactory = keyStrategyFactory;
  4.                 this.stateSerializer = stateSerializer;
  5.         }
  6.         public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
  7.                 this.keyStrategyFactory = keyStrategyFactory;
  8.                 this.stateSerializer = stateSerializer;
  9.         }
  10.         /**
  11.          * Constructs a StateGraph with the specified name, key strategy factory, and SpringAI
  12.          * state serializer.
  13.          * @param name the name of the graph
  14.          * @param keyStrategyFactory the factory for providing key strategies
  15.          * @param stateSerializer the SpringAI state serializer to use
  16.          */
  17.         public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, SpringAIStateSerializer stateSerializer) {
  18.                 this.name = name;
  19.                 this.keyStrategyFactory = keyStrategyFactory;
  20.                 this.stateSerializer = stateSerializer;
  21.         }
  22.         /**
  23.          * Constructs a StateGraph with the specified key strategy factory and SpringAI state
  24.          * serializer.
  25.          * @param keyStrategyFactory the factory for providing key strategies
  26.          * @param stateSerializer the SpringAI state serializer to use
  27.          */
  28.         public StateGraph(KeyStrategyFactory keyStrategyFactory, SpringAIStateSerializer stateSerializer) {
  29.                 this.keyStrategyFactory = keyStrategyFactory;
  30.                 this.stateSerializer = stateSerializer;
  31.         }
  32.         public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
  33.                 this.name = name;
  34.                 this.keyStrategyFactory = keyStrategyFactory;
  35.                 this.stateSerializer = new JacksonSerializer();
  36.         }
  37.         /**
  38.          * Constructs a StateGraph with the provided key strategy factory.
  39.          * @param keyStrategyFactory the factory for providing key strategies
  40.          */
  41.         public StateGraph(KeyStrategyFactory keyStrategyFactory) {
  42.                 this.keyStrategyFactory = keyStrategyFactory;
  43.                 this.stateSerializer = new JacksonSerializer();
  44.         }
  45.         /**
  46.          * Default constructor that initializes a StateGraph with a Gson-based state
  47.          * serializer.
  48.          */
  49.         public StateGraph() {
  50.                 this.stateSerializer = new JacksonSerializer();
  51.                 this.keyStrategyFactory = HashMap::new;
  52.         }
复制代码
CompiledGraph - 编译后的可执行图

CompiledGraph 是 StateGraph 编译后的可执行形式,包含了优化后的节点工厂和边映射。
核心职责

  • 存储编译后的节点工厂映射
  • 管理边映射关系
  • 提供节点执行入口
  • 支持中断和恢复
关键代码
  1. public class CompiledGraph {
  2.         private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
  3.         private static String INTERRUPT_AFTER = "__INTERRUPTED__";
  4.         /**
  5.          * The State graph.
  6.          */
  7.         public final StateGraph stateGraph;
  8.         /**
  9.          * The Compile config.
  10.          */
  11.         public final CompileConfig compileConfig;
  12.         /**
  13.          * The Node Factories - stores factory functions instead of instances to ensure thread safety.
  14.          */
  15.         final Map<String, Node.ActionFactory> nodeFactories = new LinkedHashMap<>();
  16.         /**
  17.          * The Edges.
  18.          */
  19.         final Map<String, EdgeValue> edges = new LinkedHashMap<>();
  20.         private final Map<String, KeyStrategy> keyStrategyMap;
  21.         private final ProcessedNodesEdgesAndConfig processedData;
  22.         private int maxIterations = 25;
  23.         /**
  24.          * Constructs a CompiledGraph with the given StateGraph.
  25.          * @param stateGraph the StateGraph to be used in this CompiledGraph
  26.          * @param compileConfig the compile config
  27.          * @throws GraphStateException the graph state exception
  28.          */
  29.         protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
  30.                 maxIterations = compileConfig.recursionLimit();
复制代码
关键类关系

Node - 节点抽象

Node 表示图中的节点,包含唯一标识符和动作工厂。
关键代码
  1. public class Node {
  2.         public static final String PRIVATE_PREFIX = "__";
  3.         public interface ActionFactory {
  4.                 AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
  5.         }
  6.         private final String id;
  7.         private final ActionFactory actionFactory;
  8.         public Node(String id, ActionFactory actionFactory) {
  9.                 this.id = id;
  10.                 this.actionFactory = actionFactory;
  11.         }
  12.         /**
  13.          * Constructor that accepts only the `id` and sets `actionFactory` to null.
  14.          * @param id the unique identifier for the node
  15.          */
  16.         public Node(String id) {
  17.                 this(id, null);
  18.         }
  19.         public void validate() throws GraphStateException {
  20.                 if (Objects.equals(id, StateGraph.END) || Objects.equals(id, StateGraph.START)) {
  21.                         return;
  22.                 }
  23.                 if (id.isBlank()) {
  24.                         throw Errors.invalidNodeIdentifier.exception("blank node id");
  25.                 }
  26.                 if (id.startsWith(PRIVATE_PREFIX)) {
  27.                         throw Errors.invalidNodeIdentifier.exception("id that start with %s", PRIVATE_PREFIX);
  28.                 }
  29.         }
  30.         /**
  31.          * id
  32.          * @return the unique identifier for the node.
  33.          */
  34.         public String id() {
  35.                 return id;
  36.         }
  37.         /**
  38.          * actionFactory
  39.          * @return a factory function that takes a {@link CompileConfig} and returns an
  40.          * {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
  41.          */
  42.         public ActionFactory actionFactory() {
  43.                 return actionFactory;
  44.         }
复制代码
Edge - 边抽象

Edge 表示图中的边,连接源节点和目标节点,支持条件路由。
关键代码
  1. public record Edge(String sourceId, List<EdgeValue> targets) {
  2.         public Edge(String sourceId, EdgeValue target) {
  3.                 this(sourceId, List.of(target));
  4.         }
  5.         public Edge(String id) {
  6.                 this(id, List.of());
  7.         }
  8.         public boolean isParallel() {
  9.                 return targets.size() > 1;
  10.         }
  11.         public EdgeValue target() {
  12.                 if (isParallel()) {
  13.                         throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
  14.                 }
  15.                 return targets.get(0);
  16.         }
  17.         public boolean anyMatchByTargetId(String targetId) {
  18.                 return targets().stream()
  19.                         .anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
  20.                                         : v.value().mappings().containsValue(targetId)
  21.                         );
  22.         }
  23.         public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
  24.                         Function<String, EdgeValue> newTarget) {
  25.                 var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
  26.                 return new Edge(newSourceId.apply(sourceId), newTargets);
  27.         }
复制代码
OverAllState - 全局状态

OverAllState 是贯穿整个图执行过程的全局状态对象,用于在节点间传递数据。
关键代码
  1. public final class OverAllState implements Serializable {
  2.         public static final Object MARK_FOR_REMOVAL = new Object();
  3.         /**
  4.          * Internal map storing the actual state data. All get/set operations on state values
  5.          * go through this map.
  6.          */
  7.         private final Map<String, Object> data;
  8.         /**
  9.          * Mapping of keys to their respective update strategies. Determines how values for
  10.          * each key should be merged or updated.
  11.          */
  12.         private final Map<String, KeyStrategy> keyStrategies;
  13.         /**
  14.          * Store instance for long-term memory storage across different executions.
  15.          */
  16.         private Store store;
  17.         /**
  18.          * The default key used for standard input injection into the state. Typically used
  19.          * when initializing the state with user or external input.
  20.          */
复制代码
关键类关系图

以下 PlantUML 类图展示了 Graph 定义相关的关键类及其关系:
  1. @startuml
  2. !theme plain
  3. skinparam classAttributeIconSize 0
  4. package "Graph Definition" {
  5.     class StateGraph {
  6.         -Nodes nodes
  7.         -Edges edges
  8.         -KeyStrategyFactory keyStrategyFactory
  9.         -String name
  10.         -StateSerializer stateSerializer
  11.         +addNode(String, NodeAction)
  12.         +addEdge(String, String)
  13.         +addConditionalEdges(...)
  14.         +compile(CompileConfig): CompiledGraph
  15.     }
  16.     class CompiledGraph {
  17.         +StateGraph stateGraph
  18.         +CompileConfig compileConfig
  19.         -Map<String, Node.ActionFactory> nodeFactories
  20.         -Map<String, EdgeValue> edges
  21.         -Map<String, KeyStrategy> keyStrategyMap
  22.     }
  23.     class Node {
  24.         -String id
  25.         -ActionFactory actionFactory
  26.         +id(): String
  27.         +actionFactory(): ActionFactory
  28.         +validate(): void
  29.     }
  30.     class Edge {
  31.         -String sourceId
  32.         -List<EdgeValue> targets
  33.         +isParallel(): boolean
  34.         +target(): EdgeValue
  35.         +validate(Nodes): void
  36.     }
  37.     class OverAllState {
  38.         -Map<String, Object> data
  39.         -Map<String, KeyStrategy> keyStrategies
  40.         -Store store
  41.         +get(String): Object
  42.         +put(String, Object): void
  43.         +registerKeyAndStrategy(String, KeyStrategy): void
  44.     }
  45.     interface NodeAction {
  46.         +apply(OverAllState): Map<String, Object>
  47.     }
  48.     interface EdgeAction {
  49.         +apply(OverAllState): String
  50.     }
  51.     class KeyStrategy {
  52.         +merge(Object, Object): Object
  53.     }
  54.     class StateSerializer {
  55.         +serialize(OverAllState): String
  56.         +deserialize(String): OverAllState
  57.     }
  58. }
  59. StateGraph "1" --> "1" CompiledGraph : compiles to
  60. StateGraph "1" --> "*" Node : contains
  61. StateGraph "1" --> "*" Edge : contains
  62. StateGraph --> KeyStrategyFactory : uses
  63. StateGraph --> StateSerializer : uses
  64. Node --> NodeAction : creates via ActionFactory
  65. Edge --> EdgeAction : uses for conditional routing
  66. CompiledGraph --> Node : stores factories
  67. CompiledGraph --> Edge : stores mappings
  68. OverAllState --> KeyStrategy : uses
  69. OverAllState --> Store : uses
  70. note right of StateGraph
  71.   入口类:用于定义工作流图
  72.   支持添加节点和边
  73.   支持条件路由和子图
  74. end note
  75. note right of CompiledGraph
  76.   编译后的可执行图
  77.   包含优化的节点工厂
  78.   线程安全的工厂函数
  79. end note
  80. note right of OverAllState
  81.   全局状态对象
  82.   在节点间传递数据
  83.   支持键策略管理
  84. end note
  85. @enduml
复制代码
实现关键点说明

1. Builder 模式

StateGraph 使用链式调用模式构建图,支持流畅的 API:
  1. StateGraph graph = new StateGraph("MyGraph", keyStrategyFactory)
  2.     .addNode("node1", nodeAction1)
  3.     .addNode("node2", nodeAction2)
  4.     .addEdge(START, "node1")
  5.     .addConditionalEdges("node1", edgeAction, Map.of("yes", "node2", "no", END))
  6.     .addEdge("node2", END);
复制代码
2. 工厂模式

Node 使用 ActionFactory 接口延迟创建节点动作,确保线程安全:
  1. public interface ActionFactory {
  2.     AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
  3. }
复制代码
这种设计允许在编译时创建工厂函数,在执行时根据配置创建实际的动作实例。
3. 策略模式

KeyStrategy 用于控制状态键的更新策略,支持不同的合并逻辑(Replace、Append、Reduce 等)。
4. 序列化支持

StateGraph 支持多种状态序列化器:

  • PlainTextStateSerializer:纯文本序列化
  • SpringAIStateSerializer:Spring AI 标准序列化
  • JacksonSerializer:Jackson JSON 序列化(默认)
5. 验证机制

Node 和 Edge 都实现了 validate() 方法,确保图的完整性:

  • 节点 ID 不能为空或使用保留前缀
  • 边引用的节点必须存在
  • 并行边不能有重复目标
总结说明

核心设计理念


  • 声明式 API:通过 StateGraph 提供声明式的图定义方式,隐藏底层复杂性
  • 编译时优化:StateGraph 编译为 CompiledGraph,将定义转换为可执行形式
  • 状态管理:OverAllState 作为全局状态容器,支持键策略和序列化
  • 类型安全:使用泛型和接口确保类型安全
  • 可扩展性:通过接口和工厂模式支持自定义节点和边动作
关键优势


  • 灵活性:支持同步和异步节点、条件路由、并行执行
  • 可维护性:清晰的类层次结构和职责分离
  • 可测试性:接口抽象便于单元测试
  • 性能:编译时优化和工厂模式减少运行时开销
使用流程


  • 定义图:使用 StateGraph 添加节点和边
  • 编译图:调用 compile() 方法生成 CompiledGraph
  • 执行图:通过 GraphRunner 执行编译后的图
  • 状态传递:使用 OverAllState 在节点间传递数据

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册