Files
SillyTavern_replica/backend/services/studio_run_service.py

1055 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Create and load Studio pipeline runs with frozen pipeline snapshots.
"""
from __future__ import annotations
import copy
import json
import logging
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional
from core.config import settings
from models.studio_models import (
PipelineDefinition,
StudioNode,
StudioNodeRunState,
StudioRun,
StudioRunStatus,
StudioRunSummary,
TurnSnapshot,
)
from services.character_service import CharacterService
from services.studio_context_service import assemble_prompt_blocks, store_context_on_run
from services.studio_project_service import studio_project_service
from services.studio_step_respond import (
resolve_api_config,
studio_step_respond,
studio_step_respond_stream,
)
from models.converters import WorldBookConverter
from services.worldbook_service import worldbook_service
logger = logging.getLogger(__name__)
def _read_json(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def _write_json(path: Path, data: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def _build_node_states(pipeline: PipelineDefinition) -> tuple[list[StudioNodeRunState], Optional[str]]:
"""First enabled node is active; other enabled nodes pending; disabled skipped."""
states: list[StudioNodeRunState] = []
first_active_id: Optional[str] = None
seen_active = False
for node in pipeline.nodes:
if not node.enabled:
status = "skipped"
elif not seen_active:
status = "active"
first_active_id = node.id
seen_active = True
else:
status = "pending"
states.append(
StudioNodeRunState(
nodeId=node.id,
displayName=node.displayName,
skillId=node.skillId,
status=status,
loopUntilSatisfied=node.loopUntilSatisfied,
)
)
return states, first_active_id
def _find_node(pipeline: PipelineDefinition, node_id: str) -> Optional[StudioNode]:
for node in pipeline.nodes:
if node.id == node_id:
return node
return None
def _next_enabled_node_id(pipeline: PipelineDefinition, after_node_id: str) -> Optional[str]:
seen = False
for node in pipeline.nodes:
if node.id == after_node_id:
seen = True
continue
if seen and node.enabled:
return node.id
return None
def _snapshot_node_state(state: StudioNodeRunState) -> TurnSnapshot:
return TurnSnapshot(
lastDraft=copy.deepcopy(state.lastDraft) if state.lastDraft else None,
lastToolResponse=(
state.lastToolResponse.model_copy()
if state.lastToolResponse
else None
),
stepMessages=[m.model_copy() for m in (state.stepMessages or [])],
timestamp=datetime.now().isoformat(),
)
def _apply_snapshot(
state: StudioNodeRunState, snapshot: TurnSnapshot
) -> StudioNodeRunState:
updated = state.model_copy()
updated.lastDraft = (
copy.deepcopy(snapshot.lastDraft) if snapshot.lastDraft else None
)
updated.lastToolResponse = (
snapshot.lastToolResponse.model_copy()
if snapshot.lastToolResponse
else None
)
updated.stepMessages = [
m.model_copy() for m in (snapshot.stepMessages or [])
]
return updated
def _find_last_user_message_index(step_messages: list) -> int:
for i in range(len(step_messages) - 1, -1, -1):
if step_messages[i].role == "user":
return i
return -1
class StudioRunService:
def __init__(self) -> None:
self._character_service = CharacterService()
@property
def runs_root(self) -> Path:
return settings.AGENT_STUDIO_RUNS_PATH
def _project_runs_dir(self, project_id: str) -> Path:
return self.runs_root / project_id
def _run_dir(self, project_id: str, run_id: str) -> Path:
return self._project_runs_dir(project_id) / run_id
def _run_path(self, project_id: str, run_id: str) -> Path:
return self._run_dir(project_id, run_id) / "run.json"
def _save_run(self, project_id: str, run_id: str, run: StudioRun) -> None:
_write_json(
self._run_path(project_id, run_id),
run.model_dump(mode="json"),
)
def create_run(self, project_id: str) -> StudioRun:
project = studio_project_service.get_project(project_id)
snapshot = PipelineDefinition(**copy.deepcopy(project.pipeline.model_dump()))
now = datetime.now().isoformat()
run_id = str(uuid.uuid4())
node_states, current_node_id = _build_node_states(snapshot)
has_active = current_node_id is not None
run = StudioRun(
id=run_id,
projectId=project_id,
status=StudioRunStatus.RUNNING if has_active else StudioRunStatus.PENDING,
pipelineSnapshot=snapshot,
pipelineVersionNote=now,
currentNodeId=current_node_id,
nodeStates=node_states,
workflowVariables={"workflow.goal": snapshot.workflowGoal},
createdAt=now,
updatedAt=now,
)
self._save_run(project_id, run_id, run)
return run
def list_runs(self, project_id: str) -> List[StudioRunSummary]:
root = self._project_runs_dir(project_id)
if not root.exists():
return []
summaries: List[StudioRunSummary] = []
for child in sorted(root.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True):
if not child.is_dir():
continue
run_path = child / "run.json"
if not run_path.exists():
continue
raw = _read_json(run_path)
summaries.append(
StudioRunSummary(
id=raw.get("id", child.name),
projectId=raw.get("projectId", project_id),
status=StudioRunStatus(raw.get("status", StudioRunStatus.PENDING.value)),
currentNodeId=raw.get("currentNodeId"),
title=raw.get("title", ""),
createdAt=raw.get("createdAt", ""),
updatedAt=raw.get("updatedAt", ""),
)
)
return summaries
def get_run(self, project_id: str, run_id: str) -> StudioRun:
run_path = self._run_path(project_id, run_id)
if not run_path.exists():
raise FileNotFoundError(f"Studio run not found: {project_id}/{run_id}")
run = StudioRun(**_read_json(run_path))
return run
def _validate_display_params(
self, node: StudioNode, display_params: Dict[str, str]
) -> Dict[str, str]:
normalized: Dict[str, str] = {}
for dp in node.displayParams:
raw = display_params.get(dp.key, "")
value = (raw or "").strip()
if dp.required and not value:
raise ValueError(f"缺少必填项:{dp.label}")
normalized[dp.key] = value
return normalized
def _execute_init_bind(
self,
project_id: str,
run: StudioRun,
node: StudioNode,
display_params: Dict[str, str],
) -> tuple[Dict[str, Any], Dict[str, Any]]:
params = self._validate_display_params(node, display_params)
character_name = params.get("characterName", "")
worldbook_name = params.get("worldbookName", "")
if self._character_service.get_character_by_name(character_name):
raise ValueError(f"角色「{character_name}」已存在")
if worldbook_service._get_worldbook_path(worldbook_name).exists():
raise ValueError(f"世界书「{worldbook_name}」已存在")
worldbook = worldbook_service.create_worldbook(worldbook_name)
worldbook_id = worldbook["id"]
character = self._character_service.create_character(
{
"name": character_name,
"description": "",
"personality": "",
"scenario": "",
"first_mes": "",
"mes_example": "",
"categories": [],
"tags": [],
"worldInfoId": worldbook_id,
}
)
character_id = character.id
studio_project_service.update_project_bindings(
project_id, character_id, worldbook_id
)
workflow_variables = dict(run.workflowVariables or {})
workflow_variables["workflow.goal"] = run.pipelineSnapshot.workflowGoal
workflow_variables["workflow.boundCharacter"] = (
f"名称:{character_name}\nID{character_id}"
)
workflow_variables["workflow.boundWorldbook"] = (
f"名称:{worldbook_name}\nID{worldbook_id}"
)
last_draft = {
"displayParams": params,
"characterId": character_id,
"worldbookId": worldbook_id,
"characterName": character_name,
"worldbookName": worldbook_name,
}
return workflow_variables, last_draft
@staticmethod
def _parse_keyword_field(raw: Any) -> List[str]:
if raw is None:
return []
if isinstance(raw, list):
return [str(x).strip() for x in raw if str(x).strip()]
text = str(raw).strip()
if not text:
return []
return [part.strip() for part in text.replace("", ",").split(",") if part.strip()]
def _resolve_bound_worldbook_name(
self, project_id: str, run: StudioRun
) -> str:
for state in run.nodeStates:
if state.skillId != "studio.init_bind" or state.status != "completed":
continue
draft = state.lastDraft or {}
name = (draft.get("worldbookName") or "").strip()
if name:
return name
try:
project = studio_project_service.get_project(project_id)
worldbook_id = project.meta.worldbookId
except FileNotFoundError:
worldbook_id = None
if worldbook_id:
for summary in worldbook_service.list_worldbooks():
wb_name = summary.get("name")
if not wb_name:
continue
try:
data = worldbook_service.get_worldbook(wb_name)
except FileNotFoundError:
continue
if data.get("id") == worldbook_id:
return wb_name
raise ValueError("未找到绑定的世界书,请先完成「创建并绑定」步骤")
def _build_worldbook_entry_payload(
self, node: StudioNode, draft: Dict[str, Any]
) -> Dict[str, Any]:
insertion = (node.config or {}).get("insertion") or {}
content = (draft.get("entryContent") or "").strip()
if not content:
raise ValueError("当前步骤产物为空,请先生成世界书条目内容后再进入下一步")
activation = (insertion.get("activationType") or "permanent").strip()
position = insertion.get("position", 0)
if isinstance(position, str):
position = WorldBookConverter.POSITION_MAP_ST_TO_INTERNAL.get(position, 0)
key = self._parse_keyword_field(
draft.get("insertionKey") or insertion.get("key")
)
keysecondary = self._parse_keyword_field(insertion.get("keysecondary"))
comment = (
(draft.get("insertionComment") or insertion.get("comment") or "").strip()
or f"Studio · {node.displayName}"
)
payload: Dict[str, Any] = {
"content": content,
"comment": comment,
"activationType": activation,
"position": position,
"key": key,
"keysecondary": keysecondary,
"order": insertion.get("order", 100),
"depth": insertion.get("depth", 4),
"probability": insertion.get("probability", 100),
"group": insertion.get("group") or [],
"disable": bool(insertion.get("disable", False)),
}
if activation == "rag" and insertion.get("ragConfig"):
payload["ragConfig"] = insertion["ragConfig"]
return payload
def _write_worldbook_entry_on_advance(
self,
worldbook_name: str,
node: StudioNode,
draft: Dict[str, Any],
) -> Dict[str, Any]:
entry_payload = self._build_worldbook_entry_payload(node, draft)
try:
return worldbook_service.append_entry(worldbook_name, entry_payload)
except FileNotFoundError:
raise ValueError(f"世界书「{worldbook_name}」不存在,无法写入条目")
except ValueError as exc:
raise ValueError(f"写入世界书失败:{exc}") from exc
except OSError as exc:
raise ValueError(f"写入世界书失败:{exc}") from exc
def _overwrite_worldbook_entry(
self,
worldbook_name: str,
node: StudioNode,
draft: Dict[str, Any],
entry_uid: str,
) -> Dict[str, Any]:
entry_payload = self._build_worldbook_entry_payload(node, draft)
try:
return worldbook_service.update_entry(
worldbook_name, entry_uid, entry_payload
)
except FileNotFoundError as exc:
raise ValueError(f"世界书条目不存在或无法更新:{exc}") from exc
except ValueError as exc:
raise ValueError(f"覆盖世界书失败:{exc}") from exc
except OSError as exc:
raise ValueError(f"覆盖世界书失败:{exc}") from exc
@staticmethod
def _node_has_written_entry(state: StudioNodeRunState) -> bool:
draft = state.lastDraft or {}
return bool(draft.get("writtenEntryUid")) or state.status == "completed"
def _save_worldbook_entry_only(
self,
project_id: str,
run_id: str,
run: StudioRun,
node: StudioNode,
state: StudioNodeRunState,
node_id: str,
save_mode: str,
) -> StudioRun:
if not state.lastDraft:
raise ValueError("请先生成并确认当前产物后再保存")
worldbook_name = self._resolve_bound_worldbook_name(project_id, run)
draft = state.lastDraft
if save_mode == "append":
entry_payload = self._build_worldbook_entry_payload(node, draft)
if (draft.get("writtenEntryUid") or "").strip():
try:
wb_data = worldbook_service.get_worldbook(worldbook_name)
entries = wb_data.get("entries") or []
max_order = max(
(int(e.get("order", 0)) for e in entries),
default=0,
)
entry_payload["order"] = max_order + 1
except (TypeError, ValueError):
entry_payload["order"] = int(entry_payload.get("order", 100)) + 1
written_entry = worldbook_service.append_entry(
worldbook_name, entry_payload
)
else:
written_entry = self._write_worldbook_entry_on_advance(
worldbook_name, node, draft
)
elif save_mode == "overwrite":
entry_uid = (draft.get("writtenEntryUid") or "").strip()
if not entry_uid:
raise ValueError(
"尚未写入过世界书条目,请使用「下一步」完成首次导出,或使用「增量保存」"
)
written_entry = self._overwrite_worldbook_entry(
worldbook_name, node, draft, entry_uid
)
else:
raise ValueError(f"不支持的保存模式:{save_mode}")
last_draft = copy.deepcopy(draft)
last_draft["writtenEntryUid"] = written_entry.get("uid")
last_draft["writtenWorldbookName"] = worldbook_name
now = datetime.now().isoformat()
new_node_states: list[StudioNodeRunState] = []
for ns in run.nodeStates:
updated = ns.model_copy()
if ns.nodeId == node_id:
updated.lastDraft = last_draft
new_node_states.append(updated)
updated_run = run.model_copy(
update={
"nodeStates": new_node_states,
"updatedAt": now,
}
)
self._save_run(project_id, run_id, updated_run)
return updated_run
def advance_run(
self,
project_id: str,
run_id: str,
display_params: Optional[Dict[str, str]] = None,
save_mode: str = "advance",
) -> StudioRun:
run = self.get_run(project_id, run_id)
if save_mode in ("append", "overwrite"):
if run.status not in (
StudioRunStatus.RUNNING,
StudioRunStatus.PENDING,
StudioRunStatus.COMPLETED,
):
raise ValueError("运行已结束,无法保存")
elif run.status not in (StudioRunStatus.RUNNING, StudioRunStatus.PENDING):
raise ValueError("运行已结束,无法继续推进")
current_node_id = run.currentNodeId
if not current_node_id:
raise ValueError("当前运行无活动节点")
current_node = _find_node(run.pipelineSnapshot, current_node_id)
if not current_node:
raise ValueError(f"节点不存在:{current_node_id}")
current_state = next(
(s for s in run.nodeStates if s.nodeId == current_node_id), None
)
if not current_state:
raise ValueError("当前节点状态不存在")
if save_mode in ("append", "overwrite"):
if current_node.skillId != "studio.worldbook_entry":
raise ValueError("当前步骤不支持增量/覆盖保存")
if current_state.status not in ("active", "completed"):
raise ValueError("当前节点不可保存")
return self._save_worldbook_entry_only(
project_id,
run_id,
run,
current_node,
current_state,
current_node_id,
save_mode,
)
if current_state.status != "active":
raise ValueError("当前节点不可执行")
workflow_variables = dict(run.workflowVariables or {})
last_draft: Optional[Dict[str, Any]] = None
if current_node.skillId == "studio.init_bind":
if not display_params:
raise ValueError("请填写引导表单后再提交")
workflow_variables, last_draft = self._execute_init_bind(
project_id, run, current_node, display_params
)
elif current_node.skillId == "studio.worldbook_entry":
if not current_state.lastDraft:
raise ValueError("请先生成并确认当前产物后再进入下一步")
worldbook_name = self._resolve_bound_worldbook_name(project_id, run)
written_entry = self._write_worldbook_entry_on_advance(
worldbook_name,
current_node,
current_state.lastDraft,
)
last_draft = copy.deepcopy(current_state.lastDraft)
last_draft["writtenEntryUid"] = written_entry.get("uid")
last_draft["writtenWorldbookName"] = worldbook_name
else:
raise NotImplementedError(
f"技能「{current_node.skillId}」的执行尚未实现R2+"
)
next_node_id = _next_enabled_node_id(run.pipelineSnapshot, current_node_id)
now = datetime.now().isoformat()
new_node_states: list[StudioNodeRunState] = []
for state in run.nodeStates:
updated = state.model_copy()
if state.nodeId == current_node_id:
updated.status = "completed"
if last_draft is not None:
updated.lastDraft = last_draft
elif next_node_id and state.nodeId == next_node_id:
updated.status = "active"
new_node_states.append(updated)
new_status = (
StudioRunStatus.COMPLETED if not next_node_id else StudioRunStatus.RUNNING
)
updated_run = run.model_copy(
update={
"currentNodeId": next_node_id,
"nodeStates": new_node_states,
"workflowVariables": workflow_variables,
"status": new_status,
"updatedAt": now,
}
)
updated_run = store_context_on_run(updated_run, next_node_id)
self._save_run(project_id, run_id, updated_run)
return updated_run
def save_run(
self, project_id: str, run_id: str, mode: str
) -> StudioRun:
"""Save worldbook entry without advancing pipeline (incremental or overwrite)."""
mode_map = {"incremental": "append", "overwrite": "overwrite"}
if mode not in mode_map:
raise ValueError(f"不支持的保存模式:{mode}")
return self.advance_run(
project_id, run_id, save_mode=mode_map[mode]
)
def switch_run_node(
self, project_id: str, run_id: str, node_id: str
) -> StudioRun:
"""Manually focus a pipeline node (e.g. revisit a completed worldbook step)."""
run = self.get_run(project_id, run_id)
if run.status not in (
StudioRunStatus.RUNNING,
StudioRunStatus.PENDING,
StudioRunStatus.COMPLETED,
):
raise ValueError("运行已结束,无法切换节点")
target_node = _find_node(run.pipelineSnapshot, node_id)
if not target_node:
raise ValueError(f"节点不存在:{node_id}")
if not target_node.enabled:
raise ValueError("该节点已禁用,无法切换")
if target_node.skillId != "studio.worldbook_entry":
raise ValueError("仅支持切换到世界书创作步骤")
target_state = next(
(s for s in run.nodeStates if s.nodeId == node_id), None
)
if not target_state:
raise ValueError("节点状态不存在")
if target_state.status not in ("active", "completed"):
raise ValueError("仅可切换到进行中或已完成的步骤")
prev_node_id = run.currentNodeId
now = datetime.now().isoformat()
new_node_states: list[StudioNodeRunState] = []
for state in run.nodeStates:
updated = state.model_copy()
if state.nodeId == node_id:
updated.status = "active"
elif prev_node_id and state.nodeId == prev_node_id and prev_node_id != node_id:
if state.skillId == "studio.worldbook_entry":
if self._node_has_written_entry(state):
updated.status = "completed"
else:
updated.status = "pending"
new_node_states.append(updated)
new_status = (
StudioRunStatus.COMPLETED
if run.status == StudioRunStatus.COMPLETED
and not _next_enabled_node_id(run.pipelineSnapshot, node_id)
else StudioRunStatus.RUNNING
)
updated_run = run.model_copy(
update={
"currentNodeId": node_id,
"nodeStates": new_node_states,
"status": new_status,
"updatedAt": now,
}
)
updated_run = store_context_on_run(updated_run, node_id)
self._save_run(project_id, run_id, updated_run)
return updated_run
async def send_run_message(
self,
project_id: str,
run_id: str,
content: str,
*,
stream: bool = False,
profile_id: Optional[str] = None,
api_config: Optional[Dict[str, str]] = None,
) -> StudioRun:
"""Send a user message to the active worldbook step and invoke LLM (R3)."""
trimmed = (content or "").strip()
if not trimmed:
raise ValueError("消息内容不能为空")
run, current_node, current_state, current_node_id = self._prepare_run_message(
project_id, run_id, trimmed
)
resolved_api = resolve_api_config(profile_id, api_config)
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
step_messages = list(current_state.stepMessages or [])
last_draft, last_tool_response, user_msg, assistant_msg = (
await studio_step_respond(
node=current_node,
prompt_blocks=prompt_blocks,
step_messages=step_messages,
user_message=trimmed,
existing_draft=current_state.lastDraft,
api_config=resolved_api,
stream=stream,
)
)
return self._persist_run_message_turn(
project_id,
run_id,
run,
current_node_id,
prompt_blocks,
step_messages,
last_draft,
last_tool_response,
user_msg,
assistant_msg,
)
async def send_run_message_stream(
self,
project_id: str,
run_id: str,
content: str,
*,
profile_id: Optional[str] = None,
api_config: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""Stream thinking deltas, then persist and emit complete run (R4)."""
trimmed = (content or "").strip()
if not trimmed:
raise ValueError("消息内容不能为空")
run, current_node, current_state, current_node_id = self._prepare_run_message(
project_id, run_id, trimmed
)
resolved_api = resolve_api_config(profile_id, api_config)
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
step_messages = list(current_state.stepMessages or [])
async for event in studio_step_respond_stream(
node=current_node,
prompt_blocks=prompt_blocks,
step_messages=step_messages,
user_message=trimmed,
existing_draft=current_state.lastDraft,
api_config=resolved_api,
):
if event.get("type") == "thinking_delta":
yield event
continue
if event.get("type") == "complete":
from models.studio_models import LastToolResponse, StepMessage
last_draft = event["last_draft"]
last_tool_response = LastToolResponse(**event["last_tool_response"])
user_msg = StepMessage(**event["user_msg"])
assistant_msg = StepMessage(**event["assistant_msg"])
updated_run = self._persist_run_message_turn(
project_id,
run_id,
run,
current_node_id,
prompt_blocks,
step_messages,
last_draft,
last_tool_response,
user_msg,
assistant_msg,
)
yield {
"type": "complete",
"run": updated_run.model_dump(mode="json"),
}
def _prepare_active_worldbook_step(
self,
project_id: str,
run_id: str,
) -> tuple[StudioRun, StudioNode, StudioNodeRunState, str]:
run = self.get_run(project_id, run_id)
if run.status != StudioRunStatus.RUNNING:
raise ValueError("运行未处于进行中,无法操作")
current_node_id = run.currentNodeId
if not current_node_id:
raise ValueError("当前运行无活动节点")
current_node = _find_node(run.pipelineSnapshot, current_node_id)
if not current_node:
raise ValueError(f"节点不存在:{current_node_id}")
if current_node.skillId != "studio.worldbook_entry":
raise ValueError(
f"当前步骤「{current_node.displayName}」不支持对话消息"
)
current_state = next(
(s for s in run.nodeStates if s.nodeId == current_node_id), None
)
if not current_state or current_state.status != "active":
raise ValueError("当前节点不可执行")
return run, current_node, current_state, current_node_id
def _prepare_run_message(
self,
project_id: str,
run_id: str,
trimmed: str,
) -> tuple[StudioRun, StudioNode, StudioNodeRunState, str]:
if not trimmed:
raise ValueError("消息内容不能为空")
return self._prepare_active_worldbook_step(project_id, run_id)
def _persist_run_message_turn(
self,
project_id: str,
run_id: str,
run: StudioRun,
current_node_id: str,
prompt_blocks: list,
step_messages: list,
last_draft: Dict[str, Any],
last_tool_response,
user_msg,
assistant_msg,
*,
push_snapshot: bool = True,
) -> StudioRun:
step_messages = list(step_messages)
step_messages.extend([user_msg, assistant_msg])
now = datetime.now().isoformat()
new_node_states: list[StudioNodeRunState] = []
for state in run.nodeStates:
updated = state.model_copy()
if state.nodeId == current_node_id:
if push_snapshot:
turn_history = list(state.turnHistory or [])
turn_history.append(_snapshot_node_state(state))
updated.turnHistory = turn_history
updated.lastDraft = last_draft
updated.lastToolResponse = last_tool_response
updated.stepMessages = step_messages
new_node_states.append(updated)
updated_run = run.model_copy(
update={
"nodeStates": new_node_states,
"lastPromptBlocks": prompt_blocks,
"updatedAt": now,
}
)
updated_run = store_context_on_run(updated_run, current_node_id)
self._save_run(project_id, run_id, updated_run)
return updated_run
def undo_run(self, project_id: str, run_id: str) -> StudioRun:
"""Restore the active step to the state before the last LLM turn."""
run, _, current_state, current_node_id = self._prepare_active_worldbook_step(
project_id, run_id
)
turn_history = list(current_state.turnHistory or [])
if not turn_history:
raise ValueError("无可回退的回合")
snapshot = turn_history.pop()
restored_state = _apply_snapshot(current_state, snapshot)
restored_state.turnHistory = turn_history
now = datetime.now().isoformat()
new_node_states: list[StudioNodeRunState] = []
for state in run.nodeStates:
if state.nodeId == current_node_id:
new_node_states.append(restored_state)
else:
new_node_states.append(state.model_copy())
updated_run = run.model_copy(
update={
"nodeStates": new_node_states,
"updatedAt": now,
}
)
updated_run = store_context_on_run(updated_run, current_node_id)
self._save_run(project_id, run_id, updated_run)
return updated_run
async def reroll_run(
self,
project_id: str,
run_id: str,
*,
stream: bool = False,
profile_id: Optional[str] = None,
api_config: Optional[Dict[str, str]] = None,
) -> StudioRun:
"""Re-run the last user turn with identical inputs (no new user message)."""
run, current_node, current_state, current_node_id = (
self._prepare_active_worldbook_step(project_id, run_id)
)
step_messages_all = list(current_state.stepMessages or [])
last_user_idx = _find_last_user_message_index(step_messages_all)
if last_user_idx < 0:
raise ValueError("当前步骤尚无用户消息,无法重 roll")
user_content = step_messages_all[last_user_idx].content
step_messages = step_messages_all[:last_user_idx]
pre_turn_draft = current_state.lastDraft
if current_state.turnHistory:
pre_turn_draft = current_state.turnHistory[-1].lastDraft
resolved_api = resolve_api_config(profile_id, api_config)
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
last_draft, last_tool_response, user_msg, assistant_msg = (
await studio_step_respond(
node=current_node,
prompt_blocks=prompt_blocks,
step_messages=step_messages,
user_message=user_content,
existing_draft=pre_turn_draft,
api_config=resolved_api,
stream=stream,
)
)
return self._persist_reroll_turn(
project_id,
run_id,
run,
current_node_id,
current_state,
prompt_blocks,
step_messages,
last_draft,
last_tool_response,
user_msg,
assistant_msg,
)
async def reroll_run_stream(
self,
project_id: str,
run_id: str,
*,
profile_id: Optional[str] = None,
api_config: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""Stream thinking during reroll, then persist replaced turn."""
run, current_node, current_state, current_node_id = (
self._prepare_active_worldbook_step(project_id, run_id)
)
step_messages_all = list(current_state.stepMessages or [])
last_user_idx = _find_last_user_message_index(step_messages_all)
if last_user_idx < 0:
raise ValueError("当前步骤尚无用户消息,无法重 roll")
user_content = step_messages_all[last_user_idx].content
step_messages = step_messages_all[:last_user_idx]
pre_turn_draft = current_state.lastDraft
if current_state.turnHistory:
pre_turn_draft = current_state.turnHistory[-1].lastDraft
resolved_api = resolve_api_config(profile_id, api_config)
prompt_blocks = assemble_prompt_blocks(run, current_node_id)
async for event in studio_step_respond_stream(
node=current_node,
prompt_blocks=prompt_blocks,
step_messages=step_messages,
user_message=user_content,
existing_draft=pre_turn_draft,
api_config=resolved_api,
):
if event.get("type") == "thinking_delta":
yield event
continue
if event.get("type") == "complete":
from models.studio_models import LastToolResponse, StepMessage
last_draft = event["last_draft"]
last_tool_response = LastToolResponse(**event["last_tool_response"])
user_msg = StepMessage(**event["user_msg"])
assistant_msg = StepMessage(**event["assistant_msg"])
updated_run = self._persist_reroll_turn(
project_id,
run_id,
run,
current_node_id,
current_state,
prompt_blocks,
step_messages,
last_draft,
last_tool_response,
user_msg,
assistant_msg,
)
yield {
"type": "complete",
"run": updated_run.model_dump(mode="json"),
}
def _persist_reroll_turn(
self,
project_id: str,
run_id: str,
run: StudioRun,
current_node_id: str,
current_state: StudioNodeRunState,
prompt_blocks: list,
step_messages: list,
last_draft: Dict[str, Any],
last_tool_response,
user_msg,
assistant_msg,
) -> StudioRun:
"""Replace the last turn; snapshot current state so reroll is undoable."""
step_messages = list(step_messages)
step_messages.extend([user_msg, assistant_msg])
now = datetime.now().isoformat()
new_node_states: list[StudioNodeRunState] = []
for state in run.nodeStates:
updated = state.model_copy()
if state.nodeId == current_node_id:
turn_history = list(current_state.turnHistory or [])
turn_history.append(_snapshot_node_state(current_state))
updated.turnHistory = turn_history
updated.lastDraft = last_draft
updated.lastToolResponse = last_tool_response
updated.stepMessages = step_messages
new_node_states.append(updated)
updated_run = run.model_copy(
update={
"nodeStates": new_node_states,
"lastPromptBlocks": prompt_blocks,
"updatedAt": now,
}
)
updated_run = store_context_on_run(updated_run, current_node_id)
self._save_run(project_id, run_id, updated_run)
return updated_run
def delete_run(self, project_id: str, run_id: str) -> None:
run_dir = self._run_dir(project_id, run_id)
if not run_dir.exists():
raise FileNotFoundError(f"Studio run not found: {project_id}/{run_id}")
shutil.rmtree(run_dir)
def rename_run(self, project_id: str, run_id: str, title: str) -> StudioRun:
run = self.get_run(project_id, run_id)
trimmed = (title or "").strip()
if not trimmed:
raise ValueError("运行名称不能为空")
now = datetime.now().isoformat()
updated = run.model_copy(update={"title": trimmed, "updatedAt": now})
self._save_run(project_id, run_id, updated)
return updated
studio_run_service = StudioRunService()