refactor(injection): centralize legacy injected parameter filtering

This commit is contained in:
Lishiling
2026-03-19 11:06:41 +08:00
parent d078e51051
commit 090724a7d4
6 changed files with 279 additions and 62 deletions

View File

@@ -1,7 +1,13 @@
from __future__ import annotations
import inspect
from typing import Any
try:
from typing import get_type_hints
except ImportError: # pragma: no cover
get_type_hints = None
from ._typing_utils import unwrap_optional
_INJECTED_PARAMETER_NAMES = {
@@ -32,6 +38,34 @@ def is_framework_injected_parameter(name: str, annotation: Any) -> bool:
return False
def legacy_arg_parameter_names(handler: Any) -> list[str]:
try:
signature = inspect.signature(handler)
except (TypeError, ValueError):
return []
try:
if get_type_hints is None:
type_hints = {}
else:
type_hints = get_type_hints(handler)
except Exception:
type_hints = {}
names: list[str] = []
for parameter in signature.parameters.values():
if parameter.kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
continue
if is_framework_injected_parameter(
parameter.name, type_hints.get(parameter.name)
):
continue
names.append(parameter.name)
return names
def _framework_injected_types() -> tuple[type[Any], ...]:
from .clients.llm import LLMResponse
from .context import Context
@@ -52,4 +86,4 @@ def _framework_injected_types() -> tuple[type[Any], ...]:
)
__all__ = ["is_framework_injected_parameter"]
__all__ = ["is_framework_injected_parameter", "legacy_arg_parameter_names"]

View File

@@ -35,7 +35,7 @@ from .._command_model import (
parse_command_model_remainder,
resolve_command_model_param,
)
from .._injected_params import is_framework_injected_parameter
from .._injected_params import legacy_arg_parameter_names
from .._invocation_context import caller_plugin_scope
from .._plugin_logger import PluginLogger
from .._star_runtime import bind_star_runtime
@@ -333,9 +333,7 @@ class HandlerDispatcher:
return build_command_args(
[
ParamSpec(name=name, type="str")
for name in self._legacy_arg_parameter_names(
loaded.callable
)
for name in legacy_arg_parameter_names(loaded.callable)
],
remainder,
)
@@ -349,7 +347,7 @@ class HandlerDispatcher:
return build_regex_args(
[
ParamSpec(name=name, type="str")
for name in self._legacy_arg_parameter_names(loaded.callable)
for name in legacy_arg_parameter_names(loaded.callable)
],
match,
)
@@ -922,34 +920,6 @@ class HandlerDispatcher:
except Exception:
return None
@classmethod
def _legacy_arg_parameter_names(cls, handler) -> list[str]:
try:
signature = inspect.signature(handler)
except (TypeError, ValueError):
return []
try:
type_hints = get_type_hints(handler)
except Exception:
type_hints = {}
names: list[str] = []
for parameter in signature.parameters.values():
if parameter.kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
continue
if cls._is_injected_parameter(
parameter.name, type_hints.get(parameter.name)
):
continue
names.append(parameter.name)
return names
@classmethod
def _is_injected_parameter(cls, name: str, annotation: Any) -> bool:
return is_framework_injected_parameter(name, annotation)
async def _handle_error(
self,
owner: Any,

View File

@@ -18,9 +18,8 @@ import inspect
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, get_type_hints
from typing import Any
from ._injected_params import is_framework_injected_parameter
from ._star_runtime import bind_star_runtime
from ._testing_support import (
InMemoryDB,
@@ -730,32 +729,6 @@ class PluginHarness:
return hook
return None
def _legacy_arg_parameter_names(self, handler) -> list[str]:
try:
signature = inspect.signature(handler)
except (TypeError, ValueError):
return []
try:
type_hints = get_type_hints(handler)
except Exception:
type_hints = {}
names: list[str] = []
for parameter in signature.parameters.values():
if parameter.kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
continue
if self._is_injected_parameter(
parameter.name, type_hints.get(parameter.name)
):
continue
names.append(parameter.name)
return names
def _is_injected_parameter(self, name: str, annotation: Any) -> bool:
return is_framework_injected_parameter(name, annotation)
def _next_request_id(self, prefix: str) -> str:
self._request_counter += 1
return f"{prefix}_{self._request_counter:04d}"

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import re
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from astrbot_sdk.protocol.descriptors import ParamSpec
from astrbot_sdk.runtime._command_matching import (
build_command_args,
build_regex_args,
match_command_name,
split_command_remainder,
)
def test_match_command_name_trims_input_consistently() -> None:
assert match_command_name(" ping ", "ping") == ""
assert match_command_name(" ping hello world ", "ping") == "hello world"
assert match_command_name("pingpong", "ping") is None
def test_build_command_args_supports_quotes_and_greedy_tail() -> None:
param_specs = [
ParamSpec(name="name", type="str"),
ParamSpec(name="message", type="greedy_str"),
]
args = build_command_args(param_specs, '"alpha beta" "hello world" tail')
assert args == {"name": "alpha beta", "message": "hello world tail"}
def test_split_command_remainder_falls_back_on_invalid_quotes() -> None:
assert split_command_remainder('"unterminated quote test') == [
'"unterminated',
"quote",
"test",
]
def test_build_regex_args_preserves_named_and_positional_mapping() -> None:
param_specs = [
ParamSpec(name="first", type="str"),
ParamSpec(name="second", type="str"),
ParamSpec(name="third", type="str"),
]
match = re.search(r"(?P<second>\w+)-(\w+)-(\w+)", "named-positional-tail")
assert match is not None
assert build_regex_args(param_specs, match) == {
"second": "named",
"first": "named",
"third": "positional",
}

View File

@@ -0,0 +1,83 @@
from __future__ import annotations
import sys
from pathlib import Path
from types import SimpleNamespace
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from pydantic import BaseModel
from astrbot_sdk._command_model import resolve_command_model_param
from astrbot_sdk._injected_params import (
is_framework_injected_parameter,
legacy_arg_parameter_names,
)
from astrbot_sdk.conversation import ConversationSession
from astrbot_sdk.schedule import ScheduleContext
from astrbot_sdk.protocol.descriptors import CommandTrigger, HandlerDescriptor
from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher
from astrbot_sdk.runtime.loader import LoadedHandler, _build_param_specs
class _Payload(BaseModel):
name: str
def test_legacy_arg_parameter_names_excludes_injected_aliases() -> None:
def handler(
ctx,
conversation,
conv,
sched,
schedule,
name,
extra="fallback",
) -> None: ...
assert legacy_arg_parameter_names(handler) == ["name", "extra"]
def test_resolve_command_model_param_ignores_injected_aliases() -> None:
def handler(conversation, sched, payload: _Payload) -> None: ...
resolved = resolve_command_model_param(handler)
assert resolved is not None
assert resolved.name == "payload"
assert resolved.model_cls is _Payload
def test_is_framework_injected_parameter_supports_type_based_injection() -> None:
assert is_framework_injected_parameter("custom_conv", ConversationSession)
assert is_framework_injected_parameter("custom_schedule", ScheduleContext)
def test_loader_build_param_specs_excludes_injected_aliases() -> None:
def handler(conversation, schedule, name: str, count: int = 0) -> None: ...
specs = _build_param_specs(handler)
assert [spec.name for spec in specs] == ["name", "count"]
def test_handler_dispatcher_derive_args_skips_injected_aliases() -> None:
def handler(conversation, name, sched) -> None: ...
loaded = LoadedHandler(
descriptor=HandlerDescriptor(
id="plugin.handler",
trigger=CommandTrigger(command="ping"),
),
callable=handler,
owner=object(),
)
dispatcher = HandlerDispatcher(
plugin_id="plugin",
peer=SimpleNamespace(),
handlers=[loaded],
)
args = dispatcher._derive_args(loaded, SimpleNamespace(text="ping alice"))
assert args == {"name": "alice"}

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
import sys
from pathlib import Path
from types import SimpleNamespace
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from astrbot_sdk.errors import AstrBotError
from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher
from astrbot_sdk.star import Star
class _DummyEvent:
def __init__(self) -> None:
self.replies: list[str] = []
async def reply(self, message: str) -> None:
self.replies.append(message)
@pytest.mark.asyncio
async def test_handle_error_fallback_does_not_instantiate_star(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_default_on_error(error: Exception, event, ctx) -> None:
del ctx
await event.reply(str(error))
def _fail_init(self) -> None:
raise AssertionError("Star should not be instantiated for fallback on_error")
monkeypatch.setattr(Star, "default_on_error", staticmethod(_fake_default_on_error))
monkeypatch.setattr(Star, "__init__", _fail_init)
dispatcher = HandlerDispatcher(
plugin_id="plugin", peer=SimpleNamespace(), handlers=[]
)
event = _DummyEvent()
await dispatcher._handle_error(
object(),
RuntimeError("boom"),
event,
SimpleNamespace(),
)
assert event.replies == ["boom"]
@pytest.mark.asyncio
async def test_default_on_error_formats_astrbot_error_reply() -> None:
event = _DummyEvent()
error = AstrBotError.invalid_input(
"bad payload",
hint="check payload",
docs_url="https://example.com/docs",
details={"b": 2, "a": 1},
)
await Star.default_on_error(error, event, SimpleNamespace())
assert len(event.replies) == 1
assert "check payload" in event.replies[0]
assert "https://example.com/docs" in event.replies[0]
assert '"a": 1' in event.replies[0]
assert '"b": 2' in event.replies[0]
@pytest.mark.asyncio
async def test_default_on_error_replies_generic_message_for_unknown_errors() -> None:
event = _DummyEvent()
await Star.default_on_error(RuntimeError("boom"), event, SimpleNamespace())
assert len(event.replies) == 1
assert event.replies[0]
@pytest.mark.asyncio
async def test_on_error_does_not_dispatch_via_subclass_default_on_error() -> None:
class PluginWithShadowedDefault(Star):
async def default_on_error(self, error: Exception, event, ctx) -> None:
del error, event, ctx
raise AssertionError(
"Star.on_error should not virtual-dispatch default_on_error"
)
expected_event = _DummyEvent()
actual_event = _DummyEvent()
await Star.default_on_error(RuntimeError("boom"), expected_event, SimpleNamespace())
await PluginWithShadowedDefault().on_error(
RuntimeError("boom"),
actual_event,
SimpleNamespace(),
)
assert actual_event.replies == expected_event.replies