1055 lines
38 KiB
Python
1055 lines
38 KiB
Python
"""
|
||
Create and load Studio pipeline runs with frozen pipeline snapshots.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import copy
|
||
import json
|
||
import logging
|
||
import shutil
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||
|
||
from core.config import settings
|
||
from models.studio_models import (
|
||
PipelineDefinition,
|
||
StudioNode,
|
||
StudioNodeRunState,
|
||
StudioRun,
|
||
StudioRunStatus,
|
||
StudioRunSummary,
|
||
TurnSnapshot,
|
||
)
|
||
from services.character_service import CharacterService
|
||
from services.studio_context_service import assemble_prompt_blocks, store_context_on_run
|
||
from services.studio_project_service import studio_project_service
|
||
from services.studio_step_respond import (
|
||
resolve_api_config,
|
||
studio_step_respond,
|
||
studio_step_respond_stream,
|
||
)
|
||
from models.converters import WorldBookConverter
|
||
from services.worldbook_service import worldbook_service
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _read_json(path: Path) -> dict:
|
||
with path.open("r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
|
||
|
||
def _write_json(path: Path, data: dict) -> None:
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
with path.open("w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def _build_node_states(pipeline: PipelineDefinition) -> tuple[list[StudioNodeRunState], Optional[str]]:
|
||
"""First enabled node is active; other enabled nodes pending; disabled skipped."""
|
||
states: list[StudioNodeRunState] = []
|
||
first_active_id: Optional[str] = None
|
||
seen_active = False
|
||
|
||
for node in pipeline.nodes:
|
||
if not node.enabled:
|
||
status = "skipped"
|
||
elif not seen_active:
|
||
status = "active"
|
||
first_active_id = node.id
|
||
seen_active = True
|
||
else:
|
||
status = "pending"
|
||
|
||
states.append(
|
||
StudioNodeRunState(
|
||
nodeId=node.id,
|
||
displayName=node.displayName,
|
||
skillId=node.skillId,
|
||
status=status,
|
||
loopUntilSatisfied=node.loopUntilSatisfied,
|
||
)
|
||
)
|
||
|
||
return states, first_active_id
|
||
|
||
|
||
def _find_node(pipeline: PipelineDefinition, node_id: str) -> Optional[StudioNode]:
|
||
for node in pipeline.nodes:
|
||
if node.id == node_id:
|
||
return node
|
||
return None
|
||
|
||
|
||
def _next_enabled_node_id(pipeline: PipelineDefinition, after_node_id: str) -> Optional[str]:
|
||
seen = False
|
||
for node in pipeline.nodes:
|
||
if node.id == after_node_id:
|
||
seen = True
|
||
continue
|
||
if seen and node.enabled:
|
||
return node.id
|
||
return None
|
||
|
||
|
||
def _snapshot_node_state(state: StudioNodeRunState) -> TurnSnapshot:
|
||
return TurnSnapshot(
|
||
lastDraft=copy.deepcopy(state.lastDraft) if state.lastDraft else None,
|
||
lastToolResponse=(
|
||
state.lastToolResponse.model_copy()
|
||
if state.lastToolResponse
|
||
else None
|
||
),
|
||
stepMessages=[m.model_copy() for m in (state.stepMessages or [])],
|
||
timestamp=datetime.now().isoformat(),
|
||
)
|
||
|
||
|
||
def _apply_snapshot(
|
||
state: StudioNodeRunState, snapshot: TurnSnapshot
|
||
) -> StudioNodeRunState:
|
||
updated = state.model_copy()
|
||
updated.lastDraft = (
|
||
copy.deepcopy(snapshot.lastDraft) if snapshot.lastDraft else None
|
||
)
|
||
updated.lastToolResponse = (
|
||
snapshot.lastToolResponse.model_copy()
|
||
if snapshot.lastToolResponse
|
||
else None
|
||
)
|
||
updated.stepMessages = [
|
||
m.model_copy() for m in (snapshot.stepMessages or [])
|
||
]
|
||
return updated
|
||
|
||
|
||
def _find_last_user_message_index(step_messages: list) -> int:
|
||
for i in range(len(step_messages) - 1, -1, -1):
|
||
if step_messages[i].role == "user":
|
||
return i
|
||
return -1
|
||
|
||
|
||
class StudioRunService:
|
||
def __init__(self) -> None:
|
||
self._character_service = CharacterService()
|
||
|
||
@property
|
||
def runs_root(self) -> Path:
|
||
return settings.AGENT_STUDIO_RUNS_PATH
|
||
|
||
def _project_runs_dir(self, project_id: str) -> Path:
|
||
return self.runs_root / project_id
|
||
|
||
def _run_dir(self, project_id: str, run_id: str) -> Path:
|
||
return self._project_runs_dir(project_id) / run_id
|
||
|
||
def _run_path(self, project_id: str, run_id: str) -> Path:
|
||
return self._run_dir(project_id, run_id) / "run.json"
|
||
|
||
def _save_run(self, project_id: str, run_id: str, run: StudioRun) -> None:
|
||
_write_json(
|
||
self._run_path(project_id, run_id),
|
||
run.model_dump(mode="json"),
|
||
)
|
||
|
||
def create_run(self, project_id: str) -> StudioRun:
|
||
project = studio_project_service.get_project(project_id)
|
||
snapshot = PipelineDefinition(**copy.deepcopy(project.pipeline.model_dump()))
|
||
now = datetime.now().isoformat()
|
||
run_id = str(uuid.uuid4())
|
||
node_states, current_node_id = _build_node_states(snapshot)
|
||
|
||
has_active = current_node_id is not None
|
||
run = StudioRun(
|
||
id=run_id,
|
||
projectId=project_id,
|
||
status=StudioRunStatus.RUNNING if has_active else StudioRunStatus.PENDING,
|
||
pipelineSnapshot=snapshot,
|
||
pipelineVersionNote=now,
|
||
currentNodeId=current_node_id,
|
||
nodeStates=node_states,
|
||
workflowVariables={"workflow.goal": snapshot.workflowGoal},
|
||
createdAt=now,
|
||
updatedAt=now,
|
||
)
|
||
self._save_run(project_id, run_id, run)
|
||
return run
|
||
|
||
def list_runs(self, project_id: str) -> List[StudioRunSummary]:
|
||
root = self._project_runs_dir(project_id)
|
||
if not root.exists():
|
||
return []
|
||
|
||
summaries: List[StudioRunSummary] = []
|
||
for child in sorted(root.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True):
|
||
if not child.is_dir():
|
||
continue
|
||
run_path = child / "run.json"
|
||
if not run_path.exists():
|
||
continue
|
||
raw = _read_json(run_path)
|
||
summaries.append(
|
||
StudioRunSummary(
|
||
id=raw.get("id", child.name),
|
||
projectId=raw.get("projectId", project_id),
|
||
status=StudioRunStatus(raw.get("status", StudioRunStatus.PENDING.value)),
|
||
currentNodeId=raw.get("currentNodeId"),
|
||
title=raw.get("title", ""),
|
||
createdAt=raw.get("createdAt", ""),
|
||
updatedAt=raw.get("updatedAt", ""),
|
||
)
|
||
)
|
||
return summaries
|
||
|
||
def get_run(self, project_id: str, run_id: str) -> StudioRun:
|
||
run_path = self._run_path(project_id, run_id)
|
||
if not run_path.exists():
|
||
raise FileNotFoundError(f"Studio run not found: {project_id}/{run_id}")
|
||
run = StudioRun(**_read_json(run_path))
|
||
return run
|
||
|
||
def _validate_display_params(
|
||
self, node: StudioNode, display_params: Dict[str, str]
|
||
) -> Dict[str, str]:
|
||
normalized: Dict[str, str] = {}
|
||
for dp in node.displayParams:
|
||
raw = display_params.get(dp.key, "")
|
||
value = (raw or "").strip()
|
||
if dp.required and not value:
|
||
raise ValueError(f"缺少必填项:{dp.label}")
|
||
normalized[dp.key] = value
|
||
return normalized
|
||
|
||
def _execute_init_bind(
|
||
self,
|
||
project_id: str,
|
||
run: StudioRun,
|
||
node: StudioNode,
|
||
display_params: Dict[str, str],
|
||
) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
||
params = self._validate_display_params(node, display_params)
|
||
character_name = params.get("characterName", "")
|
||
worldbook_name = params.get("worldbookName", "")
|
||
|
||
if self._character_service.get_character_by_name(character_name):
|
||
raise ValueError(f"角色「{character_name}」已存在")
|
||
if worldbook_service._get_worldbook_path(worldbook_name).exists():
|
||
raise ValueError(f"世界书「{worldbook_name}」已存在")
|
||
|
||
worldbook = worldbook_service.create_worldbook(worldbook_name)
|
||
worldbook_id = worldbook["id"]
|
||
|
||
character = self._character_service.create_character(
|
||
{
|
||
"name": character_name,
|
||
"description": "",
|
||
"personality": "",
|
||
"scenario": "",
|
||
"first_mes": "",
|
||
"mes_example": "",
|
||
"categories": [],
|
||
"tags": [],
|
||
"worldInfoId": worldbook_id,
|
||
}
|
||
)
|
||
character_id = character.id
|
||
|
||
studio_project_service.update_project_bindings(
|
||
project_id, character_id, worldbook_id
|
||
)
|
||
|
||
workflow_variables = dict(run.workflowVariables or {})
|
||
workflow_variables["workflow.goal"] = run.pipelineSnapshot.workflowGoal
|
||
workflow_variables["workflow.boundCharacter"] = (
|
||
f"名称:{character_name}\nID:{character_id}"
|
||
)
|
||
workflow_variables["workflow.boundWorldbook"] = (
|
||
f"名称:{worldbook_name}\nID:{worldbook_id}"
|
||
)
|
||
|
||
last_draft = {
|
||
"displayParams": params,
|
||
"characterId": character_id,
|
||
"worldbookId": worldbook_id,
|
||
"characterName": character_name,
|
||
"worldbookName": worldbook_name,
|
||
}
|
||
return workflow_variables, last_draft
|
||
|
||
@staticmethod
|
||
def _parse_keyword_field(raw: Any) -> List[str]:
|
||
if raw is None:
|
||
return []
|
||
if isinstance(raw, list):
|
||
return [str(x).strip() for x in raw if str(x).strip()]
|
||
text = str(raw).strip()
|
||
if not text:
|
||
return []
|
||
return [part.strip() for part in text.replace(",", ",").split(",") if part.strip()]
|
||
|
||
def _resolve_bound_worldbook_name(
|
||
self, project_id: str, run: StudioRun
|
||
) -> str:
|
||
for state in run.nodeStates:
|
||
if state.skillId != "studio.init_bind" or state.status != "completed":
|
||
continue
|
||
draft = state.lastDraft or {}
|
||
name = (draft.get("worldbookName") or "").strip()
|
||
if name:
|
||
return name
|
||
|
||
try:
|
||
project = studio_project_service.get_project(project_id)
|
||
worldbook_id = project.meta.worldbookId
|
||
except FileNotFoundError:
|
||
worldbook_id = None
|
||
|
||
if worldbook_id:
|
||
for summary in worldbook_service.list_worldbooks():
|
||
wb_name = summary.get("name")
|
||
if not wb_name:
|
||
continue
|
||
try:
|
||
data = worldbook_service.get_worldbook(wb_name)
|
||
except FileNotFoundError:
|
||
continue
|
||
if data.get("id") == worldbook_id:
|
||
return wb_name
|
||
|
||
raise ValueError("未找到绑定的世界书,请先完成「创建并绑定」步骤")
|
||
|
||
def _build_worldbook_entry_payload(
|
||
self, node: StudioNode, draft: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
insertion = (node.config or {}).get("insertion") or {}
|
||
content = (draft.get("entryContent") or "").strip()
|
||
if not content:
|
||
raise ValueError("当前步骤产物为空,请先生成世界书条目内容后再进入下一步")
|
||
|
||
activation = (insertion.get("activationType") or "permanent").strip()
|
||
position = insertion.get("position", 0)
|
||
if isinstance(position, str):
|
||
position = WorldBookConverter.POSITION_MAP_ST_TO_INTERNAL.get(position, 0)
|
||
|
||
key = self._parse_keyword_field(
|
||
draft.get("insertionKey") or insertion.get("key")
|
||
)
|
||
keysecondary = self._parse_keyword_field(insertion.get("keysecondary"))
|
||
comment = (
|
||
(draft.get("insertionComment") or insertion.get("comment") or "").strip()
|
||
or f"Studio · {node.displayName}"
|
||
)
|
||
|
||
payload: Dict[str, Any] = {
|
||
"content": content,
|
||
"comment": comment,
|
||
"activationType": activation,
|
||
"position": position,
|
||
"key": key,
|
||
"keysecondary": keysecondary,
|
||
"order": insertion.get("order", 100),
|
||
"depth": insertion.get("depth", 4),
|
||
"probability": insertion.get("probability", 100),
|
||
"group": insertion.get("group") or [],
|
||
"disable": bool(insertion.get("disable", False)),
|
||
}
|
||
if activation == "rag" and insertion.get("ragConfig"):
|
||
payload["ragConfig"] = insertion["ragConfig"]
|
||
return payload
|
||
|
||
def _write_worldbook_entry_on_advance(
|
||
self,
|
||
worldbook_name: str,
|
||
node: StudioNode,
|
||
draft: Dict[str, Any],
|
||
) -> Dict[str, Any]:
|
||
entry_payload = self._build_worldbook_entry_payload(node, draft)
|
||
try:
|
||
return worldbook_service.append_entry(worldbook_name, entry_payload)
|
||
except FileNotFoundError:
|
||
raise ValueError(f"世界书「{worldbook_name}」不存在,无法写入条目")
|
||
except ValueError as exc:
|
||
raise ValueError(f"写入世界书失败:{exc}") from exc
|
||
except OSError as exc:
|
||
raise ValueError(f"写入世界书失败:{exc}") from exc
|
||
|
||
def _overwrite_worldbook_entry(
|
||
self,
|
||
worldbook_name: str,
|
||
node: StudioNode,
|
||
draft: Dict[str, Any],
|
||
entry_uid: str,
|
||
) -> Dict[str, Any]:
|
||
entry_payload = self._build_worldbook_entry_payload(node, draft)
|
||
try:
|
||
return worldbook_service.update_entry(
|
||
worldbook_name, entry_uid, entry_payload
|
||
)
|
||
except FileNotFoundError as exc:
|
||
raise ValueError(f"世界书条目不存在或无法更新:{exc}") from exc
|
||
except ValueError as exc:
|
||
raise ValueError(f"覆盖世界书失败:{exc}") from exc
|
||
except OSError as exc:
|
||
raise ValueError(f"覆盖世界书失败:{exc}") from exc
|
||
|
||
@staticmethod
|
||
def _node_has_written_entry(state: StudioNodeRunState) -> bool:
|
||
draft = state.lastDraft or {}
|
||
return bool(draft.get("writtenEntryUid")) or state.status == "completed"
|
||
|
||
def _save_worldbook_entry_only(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
run: StudioRun,
|
||
node: StudioNode,
|
||
state: StudioNodeRunState,
|
||
node_id: str,
|
||
save_mode: str,
|
||
) -> StudioRun:
|
||
if not state.lastDraft:
|
||
raise ValueError("请先生成并确认当前产物后再保存")
|
||
|
||
worldbook_name = self._resolve_bound_worldbook_name(project_id, run)
|
||
draft = state.lastDraft
|
||
|
||
if save_mode == "append":
|
||
entry_payload = self._build_worldbook_entry_payload(node, draft)
|
||
if (draft.get("writtenEntryUid") or "").strip():
|
||
try:
|
||
wb_data = worldbook_service.get_worldbook(worldbook_name)
|
||
entries = wb_data.get("entries") or []
|
||
max_order = max(
|
||
(int(e.get("order", 0)) for e in entries),
|
||
default=0,
|
||
)
|
||
entry_payload["order"] = max_order + 1
|
||
except (TypeError, ValueError):
|
||
entry_payload["order"] = int(entry_payload.get("order", 100)) + 1
|
||
written_entry = worldbook_service.append_entry(
|
||
worldbook_name, entry_payload
|
||
)
|
||
else:
|
||
written_entry = self._write_worldbook_entry_on_advance(
|
||
worldbook_name, node, draft
|
||
)
|
||
elif save_mode == "overwrite":
|
||
entry_uid = (draft.get("writtenEntryUid") or "").strip()
|
||
if not entry_uid:
|
||
raise ValueError(
|
||
"尚未写入过世界书条目,请使用「下一步」完成首次导出,或使用「增量保存」"
|
||
)
|
||
written_entry = self._overwrite_worldbook_entry(
|
||
worldbook_name, node, draft, entry_uid
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的保存模式:{save_mode}")
|
||
|
||
last_draft = copy.deepcopy(draft)
|
||
last_draft["writtenEntryUid"] = written_entry.get("uid")
|
||
last_draft["writtenWorldbookName"] = worldbook_name
|
||
|
||
now = datetime.now().isoformat()
|
||
new_node_states: list[StudioNodeRunState] = []
|
||
for ns in run.nodeStates:
|
||
updated = ns.model_copy()
|
||
if ns.nodeId == node_id:
|
||
updated.lastDraft = last_draft
|
||
new_node_states.append(updated)
|
||
|
||
updated_run = run.model_copy(
|
||
update={
|
||
"nodeStates": new_node_states,
|
||
"updatedAt": now,
|
||
}
|
||
)
|
||
self._save_run(project_id, run_id, updated_run)
|
||
return updated_run
|
||
|
||
def advance_run(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
display_params: Optional[Dict[str, str]] = None,
|
||
save_mode: str = "advance",
|
||
) -> StudioRun:
|
||
run = self.get_run(project_id, run_id)
|
||
if save_mode in ("append", "overwrite"):
|
||
if run.status not in (
|
||
StudioRunStatus.RUNNING,
|
||
StudioRunStatus.PENDING,
|
||
StudioRunStatus.COMPLETED,
|
||
):
|
||
raise ValueError("运行已结束,无法保存")
|
||
elif run.status not in (StudioRunStatus.RUNNING, StudioRunStatus.PENDING):
|
||
raise ValueError("运行已结束,无法继续推进")
|
||
|
||
current_node_id = run.currentNodeId
|
||
if not current_node_id:
|
||
raise ValueError("当前运行无活动节点")
|
||
|
||
current_node = _find_node(run.pipelineSnapshot, current_node_id)
|
||
if not current_node:
|
||
raise ValueError(f"节点不存在:{current_node_id}")
|
||
|
||
current_state = next(
|
||
(s for s in run.nodeStates if s.nodeId == current_node_id), None
|
||
)
|
||
if not current_state:
|
||
raise ValueError("当前节点状态不存在")
|
||
|
||
if save_mode in ("append", "overwrite"):
|
||
if current_node.skillId != "studio.worldbook_entry":
|
||
raise ValueError("当前步骤不支持增量/覆盖保存")
|
||
if current_state.status not in ("active", "completed"):
|
||
raise ValueError("当前节点不可保存")
|
||
return self._save_worldbook_entry_only(
|
||
project_id,
|
||
run_id,
|
||
run,
|
||
current_node,
|
||
current_state,
|
||
current_node_id,
|
||
save_mode,
|
||
)
|
||
|
||
if current_state.status != "active":
|
||
raise ValueError("当前节点不可执行")
|
||
|
||
workflow_variables = dict(run.workflowVariables or {})
|
||
last_draft: Optional[Dict[str, Any]] = None
|
||
|
||
if current_node.skillId == "studio.init_bind":
|
||
if not display_params:
|
||
raise ValueError("请填写引导表单后再提交")
|
||
workflow_variables, last_draft = self._execute_init_bind(
|
||
project_id, run, current_node, display_params
|
||
)
|
||
elif current_node.skillId == "studio.worldbook_entry":
|
||
if not current_state.lastDraft:
|
||
raise ValueError("请先生成并确认当前产物后再进入下一步")
|
||
worldbook_name = self._resolve_bound_worldbook_name(project_id, run)
|
||
written_entry = self._write_worldbook_entry_on_advance(
|
||
worldbook_name,
|
||
current_node,
|
||
current_state.lastDraft,
|
||
)
|
||
last_draft = copy.deepcopy(current_state.lastDraft)
|
||
last_draft["writtenEntryUid"] = written_entry.get("uid")
|
||
last_draft["writtenWorldbookName"] = worldbook_name
|
||
else:
|
||
raise NotImplementedError(
|
||
f"技能「{current_node.skillId}」的执行尚未实现(R2+)"
|
||
)
|
||
|
||
next_node_id = _next_enabled_node_id(run.pipelineSnapshot, current_node_id)
|
||
now = datetime.now().isoformat()
|
||
|
||
new_node_states: list[StudioNodeRunState] = []
|
||
for state in run.nodeStates:
|
||
updated = state.model_copy()
|
||
if state.nodeId == current_node_id:
|
||
updated.status = "completed"
|
||
if last_draft is not None:
|
||
updated.lastDraft = last_draft
|
||
elif next_node_id and state.nodeId == next_node_id:
|
||
updated.status = "active"
|
||
new_node_states.append(updated)
|
||
|
||
new_status = (
|
||
StudioRunStatus.COMPLETED if not next_node_id else StudioRunStatus.RUNNING
|
||
)
|
||
updated_run = run.model_copy(
|
||
update={
|
||
"currentNodeId": next_node_id,
|
||
"nodeStates": new_node_states,
|
||
"workflowVariables": workflow_variables,
|
||
"status": new_status,
|
||
"updatedAt": now,
|
||
}
|
||
)
|
||
updated_run = store_context_on_run(updated_run, next_node_id)
|
||
self._save_run(project_id, run_id, updated_run)
|
||
return updated_run
|
||
|
||
def save_run(
|
||
self, project_id: str, run_id: str, mode: str
|
||
) -> StudioRun:
|
||
"""Save worldbook entry without advancing pipeline (incremental or overwrite)."""
|
||
mode_map = {"incremental": "append", "overwrite": "overwrite"}
|
||
if mode not in mode_map:
|
||
raise ValueError(f"不支持的保存模式:{mode}")
|
||
return self.advance_run(
|
||
project_id, run_id, save_mode=mode_map[mode]
|
||
)
|
||
|
||
def switch_run_node(
|
||
self, project_id: str, run_id: str, node_id: str
|
||
) -> StudioRun:
|
||
"""Manually focus a pipeline node (e.g. revisit a completed worldbook step)."""
|
||
run = self.get_run(project_id, run_id)
|
||
if run.status not in (
|
||
StudioRunStatus.RUNNING,
|
||
StudioRunStatus.PENDING,
|
||
StudioRunStatus.COMPLETED,
|
||
):
|
||
raise ValueError("运行已结束,无法切换节点")
|
||
|
||
target_node = _find_node(run.pipelineSnapshot, node_id)
|
||
if not target_node:
|
||
raise ValueError(f"节点不存在:{node_id}")
|
||
if not target_node.enabled:
|
||
raise ValueError("该节点已禁用,无法切换")
|
||
if target_node.skillId != "studio.worldbook_entry":
|
||
raise ValueError("仅支持切换到世界书创作步骤")
|
||
|
||
target_state = next(
|
||
(s for s in run.nodeStates if s.nodeId == node_id), None
|
||
)
|
||
if not target_state:
|
||
raise ValueError("节点状态不存在")
|
||
if target_state.status not in ("active", "completed"):
|
||
raise ValueError("仅可切换到进行中或已完成的步骤")
|
||
|
||
prev_node_id = run.currentNodeId
|
||
now = datetime.now().isoformat()
|
||
new_node_states: list[StudioNodeRunState] = []
|
||
|
||
for state in run.nodeStates:
|
||
updated = state.model_copy()
|
||
if state.nodeId == node_id:
|
||
updated.status = "active"
|
||
elif prev_node_id and state.nodeId == prev_node_id and prev_node_id != node_id:
|
||
if state.skillId == "studio.worldbook_entry":
|
||
if self._node_has_written_entry(state):
|
||
updated.status = "completed"
|
||
else:
|
||
updated.status = "pending"
|
||
new_node_states.append(updated)
|
||
|
||
new_status = (
|
||
StudioRunStatus.COMPLETED
|
||
if run.status == StudioRunStatus.COMPLETED
|
||
and not _next_enabled_node_id(run.pipelineSnapshot, node_id)
|
||
else StudioRunStatus.RUNNING
|
||
)
|
||
|
||
updated_run = run.model_copy(
|
||
update={
|
||
"currentNodeId": node_id,
|
||
"nodeStates": new_node_states,
|
||
"status": new_status,
|
||
"updatedAt": now,
|
||
}
|
||
)
|
||
updated_run = store_context_on_run(updated_run, node_id)
|
||
self._save_run(project_id, run_id, updated_run)
|
||
return updated_run
|
||
|
||
async def send_run_message(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
content: str,
|
||
*,
|
||
stream: bool = False,
|
||
profile_id: Optional[str] = None,
|
||
api_config: Optional[Dict[str, str]] = None,
|
||
) -> StudioRun:
|
||
"""Send a user message to the active worldbook step and invoke LLM (R3)."""
|
||
trimmed = (content or "").strip()
|
||
if not trimmed:
|
||
raise ValueError("消息内容不能为空")
|
||
|
||
run, current_node, current_state, current_node_id = self._prepare_run_message(
|
||
project_id, run_id, trimmed
|
||
)
|
||
|
||
resolved_api = resolve_api_config(profile_id, api_config)
|
||
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
|
||
step_messages = list(current_state.stepMessages or [])
|
||
|
||
last_draft, last_tool_response, user_msg, assistant_msg = (
|
||
await studio_step_respond(
|
||
node=current_node,
|
||
prompt_blocks=prompt_blocks,
|
||
step_messages=step_messages,
|
||
user_message=trimmed,
|
||
existing_draft=current_state.lastDraft,
|
||
api_config=resolved_api,
|
||
stream=stream,
|
||
)
|
||
)
|
||
|
||
return self._persist_run_message_turn(
|
||
project_id,
|
||
run_id,
|
||
run,
|
||
current_node_id,
|
||
prompt_blocks,
|
||
step_messages,
|
||
last_draft,
|
||
last_tool_response,
|
||
user_msg,
|
||
assistant_msg,
|
||
)
|
||
|
||
async def send_run_message_stream(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
content: str,
|
||
*,
|
||
profile_id: Optional[str] = None,
|
||
api_config: Optional[Dict[str, str]] = None,
|
||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||
"""Stream thinking deltas, then persist and emit complete run (R4)."""
|
||
trimmed = (content or "").strip()
|
||
if not trimmed:
|
||
raise ValueError("消息内容不能为空")
|
||
|
||
run, current_node, current_state, current_node_id = self._prepare_run_message(
|
||
project_id, run_id, trimmed
|
||
)
|
||
|
||
resolved_api = resolve_api_config(profile_id, api_config)
|
||
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
|
||
step_messages = list(current_state.stepMessages or [])
|
||
|
||
async for event in studio_step_respond_stream(
|
||
node=current_node,
|
||
prompt_blocks=prompt_blocks,
|
||
step_messages=step_messages,
|
||
user_message=trimmed,
|
||
existing_draft=current_state.lastDraft,
|
||
api_config=resolved_api,
|
||
):
|
||
if event.get("type") == "thinking_delta":
|
||
yield event
|
||
continue
|
||
|
||
if event.get("type") == "complete":
|
||
from models.studio_models import LastToolResponse, StepMessage
|
||
|
||
last_draft = event["last_draft"]
|
||
last_tool_response = LastToolResponse(**event["last_tool_response"])
|
||
user_msg = StepMessage(**event["user_msg"])
|
||
assistant_msg = StepMessage(**event["assistant_msg"])
|
||
|
||
updated_run = self._persist_run_message_turn(
|
||
project_id,
|
||
run_id,
|
||
run,
|
||
current_node_id,
|
||
prompt_blocks,
|
||
step_messages,
|
||
last_draft,
|
||
last_tool_response,
|
||
user_msg,
|
||
assistant_msg,
|
||
)
|
||
yield {
|
||
"type": "complete",
|
||
"run": updated_run.model_dump(mode="json"),
|
||
}
|
||
|
||
def _prepare_active_worldbook_step(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
) -> tuple[StudioRun, StudioNode, StudioNodeRunState, str]:
|
||
run = self.get_run(project_id, run_id)
|
||
if run.status != StudioRunStatus.RUNNING:
|
||
raise ValueError("运行未处于进行中,无法操作")
|
||
|
||
current_node_id = run.currentNodeId
|
||
if not current_node_id:
|
||
raise ValueError("当前运行无活动节点")
|
||
|
||
current_node = _find_node(run.pipelineSnapshot, current_node_id)
|
||
if not current_node:
|
||
raise ValueError(f"节点不存在:{current_node_id}")
|
||
|
||
if current_node.skillId != "studio.worldbook_entry":
|
||
raise ValueError(
|
||
f"当前步骤「{current_node.displayName}」不支持对话消息"
|
||
)
|
||
|
||
current_state = next(
|
||
(s for s in run.nodeStates if s.nodeId == current_node_id), None
|
||
)
|
||
if not current_state or current_state.status != "active":
|
||
raise ValueError("当前节点不可执行")
|
||
|
||
return run, current_node, current_state, current_node_id
|
||
|
||
def _prepare_run_message(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
trimmed: str,
|
||
) -> tuple[StudioRun, StudioNode, StudioNodeRunState, str]:
|
||
if not trimmed:
|
||
raise ValueError("消息内容不能为空")
|
||
return self._prepare_active_worldbook_step(project_id, run_id)
|
||
|
||
def _persist_run_message_turn(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
run: StudioRun,
|
||
current_node_id: str,
|
||
prompt_blocks: list,
|
||
step_messages: list,
|
||
last_draft: Dict[str, Any],
|
||
last_tool_response,
|
||
user_msg,
|
||
assistant_msg,
|
||
*,
|
||
push_snapshot: bool = True,
|
||
) -> StudioRun:
|
||
step_messages = list(step_messages)
|
||
step_messages.extend([user_msg, assistant_msg])
|
||
now = datetime.now().isoformat()
|
||
|
||
new_node_states: list[StudioNodeRunState] = []
|
||
for state in run.nodeStates:
|
||
updated = state.model_copy()
|
||
if state.nodeId == current_node_id:
|
||
if push_snapshot:
|
||
turn_history = list(state.turnHistory or [])
|
||
turn_history.append(_snapshot_node_state(state))
|
||
updated.turnHistory = turn_history
|
||
updated.lastDraft = last_draft
|
||
updated.lastToolResponse = last_tool_response
|
||
updated.stepMessages = step_messages
|
||
new_node_states.append(updated)
|
||
|
||
updated_run = run.model_copy(
|
||
update={
|
||
"nodeStates": new_node_states,
|
||
"lastPromptBlocks": prompt_blocks,
|
||
"updatedAt": now,
|
||
}
|
||
)
|
||
updated_run = store_context_on_run(updated_run, current_node_id)
|
||
self._save_run(project_id, run_id, updated_run)
|
||
return updated_run
|
||
|
||
def undo_run(self, project_id: str, run_id: str) -> StudioRun:
|
||
"""Restore the active step to the state before the last LLM turn."""
|
||
run, _, current_state, current_node_id = self._prepare_active_worldbook_step(
|
||
project_id, run_id
|
||
)
|
||
|
||
turn_history = list(current_state.turnHistory or [])
|
||
if not turn_history:
|
||
raise ValueError("无可回退的回合")
|
||
|
||
snapshot = turn_history.pop()
|
||
restored_state = _apply_snapshot(current_state, snapshot)
|
||
restored_state.turnHistory = turn_history
|
||
|
||
now = datetime.now().isoformat()
|
||
new_node_states: list[StudioNodeRunState] = []
|
||
for state in run.nodeStates:
|
||
if state.nodeId == current_node_id:
|
||
new_node_states.append(restored_state)
|
||
else:
|
||
new_node_states.append(state.model_copy())
|
||
|
||
updated_run = run.model_copy(
|
||
update={
|
||
"nodeStates": new_node_states,
|
||
"updatedAt": now,
|
||
}
|
||
)
|
||
updated_run = store_context_on_run(updated_run, current_node_id)
|
||
self._save_run(project_id, run_id, updated_run)
|
||
return updated_run
|
||
|
||
async def reroll_run(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
*,
|
||
stream: bool = False,
|
||
profile_id: Optional[str] = None,
|
||
api_config: Optional[Dict[str, str]] = None,
|
||
) -> StudioRun:
|
||
"""Re-run the last user turn with identical inputs (no new user message)."""
|
||
run, current_node, current_state, current_node_id = (
|
||
self._prepare_active_worldbook_step(project_id, run_id)
|
||
)
|
||
|
||
step_messages_all = list(current_state.stepMessages or [])
|
||
last_user_idx = _find_last_user_message_index(step_messages_all)
|
||
if last_user_idx < 0:
|
||
raise ValueError("当前步骤尚无用户消息,无法重 roll")
|
||
|
||
user_content = step_messages_all[last_user_idx].content
|
||
step_messages = step_messages_all[:last_user_idx]
|
||
|
||
pre_turn_draft = current_state.lastDraft
|
||
if current_state.turnHistory:
|
||
pre_turn_draft = current_state.turnHistory[-1].lastDraft
|
||
|
||
resolved_api = resolve_api_config(profile_id, api_config)
|
||
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
|
||
|
||
last_draft, last_tool_response, user_msg, assistant_msg = (
|
||
await studio_step_respond(
|
||
node=current_node,
|
||
prompt_blocks=prompt_blocks,
|
||
step_messages=step_messages,
|
||
user_message=user_content,
|
||
existing_draft=pre_turn_draft,
|
||
api_config=resolved_api,
|
||
stream=stream,
|
||
)
|
||
)
|
||
|
||
return self._persist_reroll_turn(
|
||
project_id,
|
||
run_id,
|
||
run,
|
||
current_node_id,
|
||
current_state,
|
||
prompt_blocks,
|
||
step_messages,
|
||
last_draft,
|
||
last_tool_response,
|
||
user_msg,
|
||
assistant_msg,
|
||
)
|
||
|
||
async def reroll_run_stream(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
*,
|
||
profile_id: Optional[str] = None,
|
||
api_config: Optional[Dict[str, str]] = None,
|
||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||
"""Stream thinking during reroll, then persist replaced turn."""
|
||
run, current_node, current_state, current_node_id = (
|
||
self._prepare_active_worldbook_step(project_id, run_id)
|
||
)
|
||
|
||
step_messages_all = list(current_state.stepMessages or [])
|
||
last_user_idx = _find_last_user_message_index(step_messages_all)
|
||
if last_user_idx < 0:
|
||
raise ValueError("当前步骤尚无用户消息,无法重 roll")
|
||
|
||
user_content = step_messages_all[last_user_idx].content
|
||
step_messages = step_messages_all[:last_user_idx]
|
||
|
||
pre_turn_draft = current_state.lastDraft
|
||
if current_state.turnHistory:
|
||
pre_turn_draft = current_state.turnHistory[-1].lastDraft
|
||
|
||
resolved_api = resolve_api_config(profile_id, api_config)
|
||
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
|
||
|
||
async for event in studio_step_respond_stream(
|
||
node=current_node,
|
||
prompt_blocks=prompt_blocks,
|
||
step_messages=step_messages,
|
||
user_message=user_content,
|
||
existing_draft=pre_turn_draft,
|
||
api_config=resolved_api,
|
||
):
|
||
if event.get("type") == "thinking_delta":
|
||
yield event
|
||
continue
|
||
|
||
if event.get("type") == "complete":
|
||
from models.studio_models import LastToolResponse, StepMessage
|
||
|
||
last_draft = event["last_draft"]
|
||
last_tool_response = LastToolResponse(**event["last_tool_response"])
|
||
user_msg = StepMessage(**event["user_msg"])
|
||
assistant_msg = StepMessage(**event["assistant_msg"])
|
||
|
||
updated_run = self._persist_reroll_turn(
|
||
project_id,
|
||
run_id,
|
||
run,
|
||
current_node_id,
|
||
current_state,
|
||
prompt_blocks,
|
||
step_messages,
|
||
last_draft,
|
||
last_tool_response,
|
||
user_msg,
|
||
assistant_msg,
|
||
)
|
||
yield {
|
||
"type": "complete",
|
||
"run": updated_run.model_dump(mode="json"),
|
||
}
|
||
|
||
def _persist_reroll_turn(
|
||
self,
|
||
project_id: str,
|
||
run_id: str,
|
||
run: StudioRun,
|
||
current_node_id: str,
|
||
current_state: StudioNodeRunState,
|
||
prompt_blocks: list,
|
||
step_messages: list,
|
||
last_draft: Dict[str, Any],
|
||
last_tool_response,
|
||
user_msg,
|
||
assistant_msg,
|
||
) -> StudioRun:
|
||
"""Replace the last turn; snapshot current state so reroll is undoable."""
|
||
step_messages = list(step_messages)
|
||
step_messages.extend([user_msg, assistant_msg])
|
||
now = datetime.now().isoformat()
|
||
|
||
new_node_states: list[StudioNodeRunState] = []
|
||
for state in run.nodeStates:
|
||
updated = state.model_copy()
|
||
if state.nodeId == current_node_id:
|
||
turn_history = list(current_state.turnHistory or [])
|
||
turn_history.append(_snapshot_node_state(current_state))
|
||
updated.turnHistory = turn_history
|
||
updated.lastDraft = last_draft
|
||
updated.lastToolResponse = last_tool_response
|
||
updated.stepMessages = step_messages
|
||
new_node_states.append(updated)
|
||
|
||
updated_run = run.model_copy(
|
||
update={
|
||
"nodeStates": new_node_states,
|
||
"lastPromptBlocks": prompt_blocks,
|
||
"updatedAt": now,
|
||
}
|
||
)
|
||
updated_run = store_context_on_run(updated_run, current_node_id)
|
||
self._save_run(project_id, run_id, updated_run)
|
||
return updated_run
|
||
|
||
def delete_run(self, project_id: str, run_id: str) -> None:
|
||
run_dir = self._run_dir(project_id, run_id)
|
||
if not run_dir.exists():
|
||
raise FileNotFoundError(f"Studio run not found: {project_id}/{run_id}")
|
||
shutil.rmtree(run_dir)
|
||
|
||
def rename_run(self, project_id: str, run_id: str, title: str) -> StudioRun:
|
||
run = self.get_run(project_id, run_id)
|
||
trimmed = (title or "").strip()
|
||
if not trimmed:
|
||
raise ValueError("运行名称不能为空")
|
||
now = datetime.now().isoformat()
|
||
updated = run.model_copy(update={"title": trimmed, "updatedAt": now})
|
||
self._save_run(project_id, run_id, updated)
|
||
return updated
|
||
|
||
|
||
studio_run_service = StudioRunService()
|