feat: refactor injected parameter handling and introduce is_framework_injected_parameter utility

This commit is contained in:
whatevertogo
2026-03-19 09:33:47 +08:00
parent 461f72764a
commit d078e51051
9 changed files with 65 additions and 91 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any
from pydantic import BaseModel
from ._injected_params import is_framework_injected_parameter
from ._typing_utils import unwrap_optional
from .errors import AstrBotError
from .runtime._command_matching import split_command_remainder
@@ -211,26 +212,7 @@ def _command_parse_error(message: str) -> AstrBotError:
def _is_injected_parameter(name: str, annotation: Any) -> bool:
if name in {"event", "ctx", "context", "sched", "schedule", "conversation", "conv"}:
return True
normalized, _is_optional = unwrap_optional(annotation)
if normalized is None:
return False
try:
from .context import Context
from .conversation import ConversationSession
from .events import MessageEvent
from .schedule import ScheduleContext
except Exception:
return False
if normalized in {Context, MessageEvent, ScheduleContext, ConversationSession}:
return True
if isinstance(normalized, type):
return issubclass(
normalized,
(Context, MessageEvent, ScheduleContext, ConversationSession),
)
return False
return is_framework_injected_parameter(name, annotation)
__all__ = [

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
from typing import Any
from ._typing_utils import unwrap_optional
_INJECTED_PARAMETER_NAMES = {
"event",
"ctx",
"context",
"sched",
"schedule",
"conversation",
"conv",
}
def is_framework_injected_parameter(name: str, annotation: Any) -> bool:
if name in _INJECTED_PARAMETER_NAMES:
return True
normalized, _is_optional = unwrap_optional(annotation)
if normalized is None:
return False
try:
injected_types = _framework_injected_types()
except Exception:
return False
if normalized in injected_types:
return True
if isinstance(normalized, type):
return issubclass(normalized, injected_types)
return False
def _framework_injected_types() -> tuple[type[Any], ...]:
from .clients.llm import LLMResponse
from .context import Context
from .conversation import ConversationSession
from .events import MessageEvent
from .llm.entities import ProviderRequest
from .message_result import MessageEventResult
from .schedule import ScheduleContext
return (
Context,
MessageEvent,
ScheduleContext,
ConversationSession,
ProviderRequest,
LLMResponse,
MessageEventResult,
)
__all__ = ["is_framework_injected_parameter"]

View File

@@ -62,4 +62,3 @@ class LLMCapabilityMixin(CapabilityRouterBridgeBase):
"text": "".join(item.get("text", "") for item in chunks)
},
)

View File

@@ -228,4 +228,3 @@ class PlatformCapabilityMixin(CapabilityRouterBridgeBase):
),
call_handler=self._platform_manager_get_stats,
)

View File

@@ -451,4 +451,3 @@ class SystemCapabilityMixin(CapabilityRouterBridgeBase):
}
)
return {"supported": True}

View File

@@ -18,10 +18,10 @@ import inspect
import typing
from typing import Any, Literal, TypeAlias, cast
from .._injected_params import is_framework_injected_parameter
from .._typing_utils import unwrap_optional
from ..decorators import get_capability_meta, get_handler_meta
from ..protocol.descriptors import ParamSpec
from ..schedule import ScheduleContext
from ..types import GreedyStr
ParamTypeName: TypeAlias = Literal[
@@ -31,19 +31,7 @@ OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None
def is_injected_parameter(annotation: Any, parameter_name: str) -> bool:
if parameter_name in {"event", "ctx", "context", "sched", "schedule"}:
return True
normalized, _is_optional = unwrap_optional(annotation)
if normalized is None:
return False
if normalized in {ScheduleContext}:
return True
if isinstance(normalized, type):
from ..context import Context
from ..events import MessageEvent
return issubclass(normalized, (Context, MessageEvent, ScheduleContext))
return False
return is_framework_injected_parameter(parameter_name, annotation)
def param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]:

View File

@@ -35,6 +35,7 @@ from .._command_model import (
parse_command_model_remainder,
resolve_command_model_param,
)
from .._injected_params import is_framework_injected_parameter
from .._invocation_context import caller_plugin_scope
from .._plugin_logger import PluginLogger
from .._star_runtime import bind_star_runtime
@@ -947,19 +948,7 @@ class HandlerDispatcher:
@classmethod
def _is_injected_parameter(cls, name: str, annotation: Any) -> bool:
if name in {"event", "ctx", "context", "conversation", "conv"}:
return True
normalized, _is_optional = unwrap_optional(annotation)
if normalized is None:
return False
if normalized in {Context, MessageEvent, ConversationSession}:
return True
if isinstance(normalized, type) and issubclass(
normalized,
(Context, MessageEvent, ConversationSession),
):
return True
return False
return is_framework_injected_parameter(name, annotation)
async def _handle_error(
self,

View File

@@ -69,6 +69,7 @@ from typing import Any, Literal, TypeAlias, cast
import yaml
from .._command_model import resolve_command_model_param
from .._injected_params import is_framework_injected_parameter
from .._typing_utils import unwrap_optional
from ..decorators import (
ConversationMeta,
@@ -86,7 +87,6 @@ from ..protocol.descriptors import (
ParamSpec,
ScheduleTrigger,
)
from ..schedule import ScheduleContext
from ..types import GreedyStr
from .environment_groups import (
EnvironmentGroup,
@@ -223,31 +223,7 @@ def _iter_discoverable_names(instance: Any) -> list[str]:
def _is_injected_parameter(annotation: Any, parameter_name: str) -> bool:
if parameter_name in {
"event",
"ctx",
"context",
"sched",
"schedule",
"conversation",
"conv",
}:
return True
normalized, _is_optional = unwrap_optional(annotation)
if normalized is None:
return False
if normalized in {ScheduleContext}:
return True
if isinstance(normalized, type):
from ..context import Context
from ..conversation import ConversationSession
from ..events import MessageEvent
return issubclass(
normalized,
(Context, MessageEvent, ScheduleContext, ConversationSession),
)
return False
return is_framework_injected_parameter(parameter_name, annotation)
def _param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]:

View File

@@ -20,6 +20,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, get_type_hints
from ._injected_params import is_framework_injected_parameter
from ._star_runtime import bind_star_runtime
from ._testing_support import (
InMemoryDB,
@@ -33,7 +34,6 @@ from ._testing_support import (
RecordedSend,
StdoutPlatformSink,
)
from ._typing_utils import unwrap_optional
from .context import CancelToken
from .context import Context as RuntimeContext
from .errors import AstrBotError
@@ -754,20 +754,7 @@ class PluginHarness:
return names
def _is_injected_parameter(self, name: str, annotation: Any) -> bool:
if name in {"event", "ctx", "context"}:
return True
normalized, _is_optional = unwrap_optional(annotation)
if normalized is None:
return False
if normalized is RuntimeContext:
return True
if normalized is MessageEvent:
return True
if isinstance(normalized, type) and issubclass(
normalized, (RuntimeContext, MessageEvent)
):
return True
return False
return is_framework_injected_parameter(name, annotation)
def _next_request_id(self, prefix: str) -> str:
self._request_counter += 1