Graph 定义与描述分析
请关注公众号:阿呆-bot
概述
本文档分析 spring-ai-alibaba-graph-core 模块中 Graph 的定义和描述机制,包括入口类、关键类关系、核心实现代码和设计模式。
入口类说明
StateGraph - Graph 定义入口
StateGraph 是定义工作流图的主要入口类,用于构建包含节点(Node)和边(Edge)的状态图。
核心职责:
- 管理图中的节点集合(Nodes)
- 管理图中的边集合(Edges)
- 提供节点和边的添加方法
- 支持条件路由和子图
- 编译为可执行的 CompiledGraph
关键代码:- public class StateGraph {
- /**
- * Constant representing the END of the graph.
- */
- public static final String END = "__END__";
- /**
- * Constant representing the START of the graph.
- */
- public static final String START = "__START__";
- /**
- * Constant representing the ERROR of the graph.
- */
- public static final String ERROR = "__ERROR__";
- /**
- * Constant representing the NODE_BEFORE of the graph.
- */
- public static final String NODE_BEFORE = "__NODE_BEFORE__";
- /**
- * Constant representing the NODE_AFTER of the graph.
- */
- public static final String NODE_AFTER = "__NODE_AFTER__";
- /**
- * Collection of nodes in the graph.
- */
- final Nodes nodes = new Nodes();
- /**
- * Collection of edges in the graph.
- */
- final Edges edges = new Edges();
- /**
- * Factory for providing key strategies.
- */
- private KeyStrategyFactory keyStrategyFactory;
- /**
- * Name of the graph.
- */
- private String name;
- /**
- * Serializer for the state.
- */
- private final StateSerializer stateSerializer;
复制代码 构造函数:- public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
- this.name = name;
- this.keyStrategyFactory = keyStrategyFactory;
- this.stateSerializer = stateSerializer;
- }
- public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
- this.keyStrategyFactory = keyStrategyFactory;
- this.stateSerializer = stateSerializer;
- }
- /**
- * Constructs a StateGraph with the specified name, key strategy factory, and SpringAI
- * state serializer.
- * @param name the name of the graph
- * @param keyStrategyFactory the factory for providing key strategies
- * @param stateSerializer the SpringAI state serializer to use
- */
- public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, SpringAIStateSerializer stateSerializer) {
- this.name = name;
- this.keyStrategyFactory = keyStrategyFactory;
- this.stateSerializer = stateSerializer;
- }
- /**
- * Constructs a StateGraph with the specified key strategy factory and SpringAI state
- * serializer.
- * @param keyStrategyFactory the factory for providing key strategies
- * @param stateSerializer the SpringAI state serializer to use
- */
- public StateGraph(KeyStrategyFactory keyStrategyFactory, SpringAIStateSerializer stateSerializer) {
- this.keyStrategyFactory = keyStrategyFactory;
- this.stateSerializer = stateSerializer;
- }
- public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
- this.name = name;
- this.keyStrategyFactory = keyStrategyFactory;
- this.stateSerializer = new JacksonSerializer();
- }
- /**
- * Constructs a StateGraph with the provided key strategy factory.
- * @param keyStrategyFactory the factory for providing key strategies
- */
- public StateGraph(KeyStrategyFactory keyStrategyFactory) {
- this.keyStrategyFactory = keyStrategyFactory;
- this.stateSerializer = new JacksonSerializer();
- }
- /**
- * Default constructor that initializes a StateGraph with a Gson-based state
- * serializer.
- */
- public StateGraph() {
- this.stateSerializer = new JacksonSerializer();
- this.keyStrategyFactory = HashMap::new;
- }
复制代码 CompiledGraph - 编译后的可执行图
CompiledGraph 是 StateGraph 编译后的可执行形式,包含了优化后的节点工厂和边映射。
核心职责:
- 存储编译后的节点工厂映射
- 管理边映射关系
- 提供节点执行入口
- 支持中断和恢复
关键代码:- public class CompiledGraph {
- private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
- private static String INTERRUPT_AFTER = "__INTERRUPTED__";
- /**
- * The State graph.
- */
- public final StateGraph stateGraph;
- /**
- * The Compile config.
- */
- public final CompileConfig compileConfig;
- /**
- * The Node Factories - stores factory functions instead of instances to ensure thread safety.
- */
- final Map<String, Node.ActionFactory> nodeFactories = new LinkedHashMap<>();
- /**
- * The Edges.
- */
- final Map<String, EdgeValue> edges = new LinkedHashMap<>();
- private final Map<String, KeyStrategy> keyStrategyMap;
- private final ProcessedNodesEdgesAndConfig processedData;
- private int maxIterations = 25;
- /**
- * Constructs a CompiledGraph with the given StateGraph.
- * @param stateGraph the StateGraph to be used in this CompiledGraph
- * @param compileConfig the compile config
- * @throws GraphStateException the graph state exception
- */
- protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
- maxIterations = compileConfig.recursionLimit();
复制代码 关键类关系
Node - 节点抽象
Node 表示图中的节点,包含唯一标识符和动作工厂。
关键代码:- public class Node {
- public static final String PRIVATE_PREFIX = "__";
- public interface ActionFactory {
- AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
- }
- private final String id;
- private final ActionFactory actionFactory;
- public Node(String id, ActionFactory actionFactory) {
- this.id = id;
- this.actionFactory = actionFactory;
- }
- /**
- * Constructor that accepts only the `id` and sets `actionFactory` to null.
- * @param id the unique identifier for the node
- */
- public Node(String id) {
- this(id, null);
- }
- public void validate() throws GraphStateException {
- if (Objects.equals(id, StateGraph.END) || Objects.equals(id, StateGraph.START)) {
- return;
- }
- if (id.isBlank()) {
- throw Errors.invalidNodeIdentifier.exception("blank node id");
- }
- if (id.startsWith(PRIVATE_PREFIX)) {
- throw Errors.invalidNodeIdentifier.exception("id that start with %s", PRIVATE_PREFIX);
- }
- }
- /**
- * id
- * @return the unique identifier for the node.
- */
- public String id() {
- return id;
- }
- /**
- * actionFactory
- * @return a factory function that takes a {@link CompileConfig} and returns an
- * {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
- */
- public ActionFactory actionFactory() {
- return actionFactory;
- }
复制代码 Edge - 边抽象
Edge 表示图中的边,连接源节点和目标节点,支持条件路由。
关键代码:- public record Edge(String sourceId, List<EdgeValue> targets) {
- public Edge(String sourceId, EdgeValue target) {
- this(sourceId, List.of(target));
- }
- public Edge(String id) {
- this(id, List.of());
- }
- public boolean isParallel() {
- return targets.size() > 1;
- }
- public EdgeValue target() {
- if (isParallel()) {
- throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
- }
- return targets.get(0);
- }
- public boolean anyMatchByTargetId(String targetId) {
- return targets().stream()
- .anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
- : v.value().mappings().containsValue(targetId)
- );
- }
- public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
- Function<String, EdgeValue> newTarget) {
- var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
- return new Edge(newSourceId.apply(sourceId), newTargets);
- }
复制代码 OverAllState - 全局状态
OverAllState 是贯穿整个图执行过程的全局状态对象,用于在节点间传递数据。
关键代码:- public final class OverAllState implements Serializable {
- public static final Object MARK_FOR_REMOVAL = new Object();
- /**
- * Internal map storing the actual state data. All get/set operations on state values
- * go through this map.
- */
- private final Map<String, Object> data;
- /**
- * Mapping of keys to their respective update strategies. Determines how values for
- * each key should be merged or updated.
- */
- private final Map<String, KeyStrategy> keyStrategies;
- /**
- * Store instance for long-term memory storage across different executions.
- */
- private Store store;
- /**
- * The default key used for standard input injection into the state. Typically used
- * when initializing the state with user or external input.
- */
复制代码 关键类关系图
以下 PlantUML 类图展示了 Graph 定义相关的关键类及其关系:- @startuml
- !theme plain
- skinparam classAttributeIconSize 0
- package "Graph Definition" {
- class StateGraph {
- -Nodes nodes
- -Edges edges
- -KeyStrategyFactory keyStrategyFactory
- -String name
- -StateSerializer stateSerializer
- +addNode(String, NodeAction)
- +addEdge(String, String)
- +addConditionalEdges(...)
- +compile(CompileConfig): CompiledGraph
- }
- class CompiledGraph {
- +StateGraph stateGraph
- +CompileConfig compileConfig
- -Map<String, Node.ActionFactory> nodeFactories
- -Map<String, EdgeValue> edges
- -Map<String, KeyStrategy> keyStrategyMap
- }
- class Node {
- -String id
- -ActionFactory actionFactory
- +id(): String
- +actionFactory(): ActionFactory
- +validate(): void
- }
- class Edge {
- -String sourceId
- -List<EdgeValue> targets
- +isParallel(): boolean
- +target(): EdgeValue
- +validate(Nodes): void
- }
- class OverAllState {
- -Map<String, Object> data
- -Map<String, KeyStrategy> keyStrategies
- -Store store
- +get(String): Object
- +put(String, Object): void
- +registerKeyAndStrategy(String, KeyStrategy): void
- }
- interface NodeAction {
- +apply(OverAllState): Map<String, Object>
- }
- interface EdgeAction {
- +apply(OverAllState): String
- }
- class KeyStrategy {
- +merge(Object, Object): Object
- }
- class StateSerializer {
- +serialize(OverAllState): String
- +deserialize(String): OverAllState
- }
- }
- StateGraph "1" --> "1" CompiledGraph : compiles to
- StateGraph "1" --> "*" Node : contains
- StateGraph "1" --> "*" Edge : contains
- StateGraph --> KeyStrategyFactory : uses
- StateGraph --> StateSerializer : uses
- Node --> NodeAction : creates via ActionFactory
- Edge --> EdgeAction : uses for conditional routing
- CompiledGraph --> Node : stores factories
- CompiledGraph --> Edge : stores mappings
- OverAllState --> KeyStrategy : uses
- OverAllState --> Store : uses
- note right of StateGraph
- 入口类:用于定义工作流图
- 支持添加节点和边
- 支持条件路由和子图
- end note
- note right of CompiledGraph
- 编译后的可执行图
- 包含优化的节点工厂
- 线程安全的工厂函数
- end note
- note right of OverAllState
- 全局状态对象
- 在节点间传递数据
- 支持键策略管理
- end note
- @enduml
复制代码 实现关键点说明
1. Builder 模式
StateGraph 使用链式调用模式构建图,支持流畅的 API:- StateGraph graph = new StateGraph("MyGraph", keyStrategyFactory)
- .addNode("node1", nodeAction1)
- .addNode("node2", nodeAction2)
- .addEdge(START, "node1")
- .addConditionalEdges("node1", edgeAction, Map.of("yes", "node2", "no", END))
- .addEdge("node2", END);
复制代码 2. 工厂模式
Node 使用 ActionFactory 接口延迟创建节点动作,确保线程安全:- public interface ActionFactory {
- AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
- }
复制代码 这种设计允许在编译时创建工厂函数,在执行时根据配置创建实际的动作实例。
3. 策略模式
KeyStrategy 用于控制状态键的更新策略,支持不同的合并逻辑(Replace、Append、Reduce 等)。
4. 序列化支持
StateGraph 支持多种状态序列化器:
- PlainTextStateSerializer:纯文本序列化
- SpringAIStateSerializer:Spring AI 标准序列化
- JacksonSerializer:Jackson JSON 序列化(默认)
5. 验证机制
Node 和 Edge 都实现了 validate() 方法,确保图的完整性:
- 节点 ID 不能为空或使用保留前缀
- 边引用的节点必须存在
- 并行边不能有重复目标
总结说明
核心设计理念
- 声明式 API:通过 StateGraph 提供声明式的图定义方式,隐藏底层复杂性
- 编译时优化:StateGraph 编译为 CompiledGraph,将定义转换为可执行形式
- 状态管理:OverAllState 作为全局状态容器,支持键策略和序列化
- 类型安全:使用泛型和接口确保类型安全
- 可扩展性:通过接口和工厂模式支持自定义节点和边动作
关键优势
- 灵活性:支持同步和异步节点、条件路由、并行执行
- 可维护性:清晰的类层次结构和职责分离
- 可测试性:接口抽象便于单元测试
- 性能:编译时优化和工厂模式减少运行时开销
使用流程
- 定义图:使用 StateGraph 添加节点和边
- 编译图:调用 compile() 方法生成 CompiledGraph
- 执行图:通过 GraphRunner 执行编译后的图
- 状态传递:使用 OverAllState 在节点间传递数据
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |