mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 11:10:14 +08:00
refactor(injection): centralize legacy injected parameter filtering
This commit is contained in:
@@ -1,7 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
except ImportError: # pragma: no cover
|
||||
get_type_hints = None
|
||||
|
||||
from ._typing_utils import unwrap_optional
|
||||
|
||||
_INJECTED_PARAMETER_NAMES = {
|
||||
@@ -32,6 +38,34 @@ def is_framework_injected_parameter(name: str, annotation: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def legacy_arg_parameter_names(handler: Any) -> list[str]:
|
||||
try:
|
||||
signature = inspect.signature(handler)
|
||||
except (TypeError, ValueError):
|
||||
return []
|
||||
try:
|
||||
if get_type_hints is None:
|
||||
type_hints = {}
|
||||
else:
|
||||
type_hints = get_type_hints(handler)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
|
||||
names: list[str] = []
|
||||
for parameter in signature.parameters.values():
|
||||
if parameter.kind not in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
):
|
||||
continue
|
||||
if is_framework_injected_parameter(
|
||||
parameter.name, type_hints.get(parameter.name)
|
||||
):
|
||||
continue
|
||||
names.append(parameter.name)
|
||||
return names
|
||||
|
||||
|
||||
def _framework_injected_types() -> tuple[type[Any], ...]:
|
||||
from .clients.llm import LLMResponse
|
||||
from .context import Context
|
||||
@@ -52,4 +86,4 @@ def _framework_injected_types() -> tuple[type[Any], ...]:
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["is_framework_injected_parameter"]
|
||||
__all__ = ["is_framework_injected_parameter", "legacy_arg_parameter_names"]
|
||||
|
||||
@@ -35,7 +35,7 @@ from .._command_model import (
|
||||
parse_command_model_remainder,
|
||||
resolve_command_model_param,
|
||||
)
|
||||
from .._injected_params import is_framework_injected_parameter
|
||||
from .._injected_params import legacy_arg_parameter_names
|
||||
from .._invocation_context import caller_plugin_scope
|
||||
from .._plugin_logger import PluginLogger
|
||||
from .._star_runtime import bind_star_runtime
|
||||
@@ -333,9 +333,7 @@ class HandlerDispatcher:
|
||||
return build_command_args(
|
||||
[
|
||||
ParamSpec(name=name, type="str")
|
||||
for name in self._legacy_arg_parameter_names(
|
||||
loaded.callable
|
||||
)
|
||||
for name in legacy_arg_parameter_names(loaded.callable)
|
||||
],
|
||||
remainder,
|
||||
)
|
||||
@@ -349,7 +347,7 @@ class HandlerDispatcher:
|
||||
return build_regex_args(
|
||||
[
|
||||
ParamSpec(name=name, type="str")
|
||||
for name in self._legacy_arg_parameter_names(loaded.callable)
|
||||
for name in legacy_arg_parameter_names(loaded.callable)
|
||||
],
|
||||
match,
|
||||
)
|
||||
@@ -922,34 +920,6 @@ class HandlerDispatcher:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _legacy_arg_parameter_names(cls, handler) -> list[str]:
|
||||
try:
|
||||
signature = inspect.signature(handler)
|
||||
except (TypeError, ValueError):
|
||||
return []
|
||||
try:
|
||||
type_hints = get_type_hints(handler)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
names: list[str] = []
|
||||
for parameter in signature.parameters.values():
|
||||
if parameter.kind not in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
):
|
||||
continue
|
||||
if cls._is_injected_parameter(
|
||||
parameter.name, type_hints.get(parameter.name)
|
||||
):
|
||||
continue
|
||||
names.append(parameter.name)
|
||||
return names
|
||||
|
||||
@classmethod
|
||||
def _is_injected_parameter(cls, name: str, annotation: Any) -> bool:
|
||||
return is_framework_injected_parameter(name, annotation)
|
||||
|
||||
async def _handle_error(
|
||||
self,
|
||||
owner: Any,
|
||||
|
||||
@@ -18,9 +18,8 @@ import inspect
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, get_type_hints
|
||||
from typing import Any
|
||||
|
||||
from ._injected_params import is_framework_injected_parameter
|
||||
from ._star_runtime import bind_star_runtime
|
||||
from ._testing_support import (
|
||||
InMemoryDB,
|
||||
@@ -730,32 +729,6 @@ class PluginHarness:
|
||||
return hook
|
||||
return None
|
||||
|
||||
def _legacy_arg_parameter_names(self, handler) -> list[str]:
|
||||
try:
|
||||
signature = inspect.signature(handler)
|
||||
except (TypeError, ValueError):
|
||||
return []
|
||||
try:
|
||||
type_hints = get_type_hints(handler)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
names: list[str] = []
|
||||
for parameter in signature.parameters.values():
|
||||
if parameter.kind not in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
):
|
||||
continue
|
||||
if self._is_injected_parameter(
|
||||
parameter.name, type_hints.get(parameter.name)
|
||||
):
|
||||
continue
|
||||
names.append(parameter.name)
|
||||
return names
|
||||
|
||||
def _is_injected_parameter(self, name: str, annotation: Any) -> bool:
|
||||
return is_framework_injected_parameter(name, annotation)
|
||||
|
||||
def _next_request_id(self, prefix: str) -> str:
|
||||
self._request_counter += 1
|
||||
return f"{prefix}_{self._request_counter:04d}"
|
||||
|
||||
56
tests/test_command_matching.py
Normal file
56
tests/test_command_matching.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from astrbot_sdk.protocol.descriptors import ParamSpec
|
||||
from astrbot_sdk.runtime._command_matching import (
|
||||
build_command_args,
|
||||
build_regex_args,
|
||||
match_command_name,
|
||||
split_command_remainder,
|
||||
)
|
||||
|
||||
|
||||
def test_match_command_name_trims_input_consistently() -> None:
|
||||
assert match_command_name(" ping ", "ping") == ""
|
||||
assert match_command_name(" ping hello world ", "ping") == "hello world"
|
||||
assert match_command_name("pingpong", "ping") is None
|
||||
|
||||
|
||||
def test_build_command_args_supports_quotes_and_greedy_tail() -> None:
|
||||
param_specs = [
|
||||
ParamSpec(name="name", type="str"),
|
||||
ParamSpec(name="message", type="greedy_str"),
|
||||
]
|
||||
|
||||
args = build_command_args(param_specs, '"alpha beta" "hello world" tail')
|
||||
|
||||
assert args == {"name": "alpha beta", "message": "hello world tail"}
|
||||
|
||||
|
||||
def test_split_command_remainder_falls_back_on_invalid_quotes() -> None:
|
||||
assert split_command_remainder('"unterminated quote test') == [
|
||||
'"unterminated',
|
||||
"quote",
|
||||
"test",
|
||||
]
|
||||
|
||||
|
||||
def test_build_regex_args_preserves_named_and_positional_mapping() -> None:
|
||||
param_specs = [
|
||||
ParamSpec(name="first", type="str"),
|
||||
ParamSpec(name="second", type="str"),
|
||||
ParamSpec(name="third", type="str"),
|
||||
]
|
||||
match = re.search(r"(?P<second>\w+)-(\w+)-(\w+)", "named-positional-tail")
|
||||
|
||||
assert match is not None
|
||||
assert build_regex_args(param_specs, match) == {
|
||||
"second": "named",
|
||||
"first": "named",
|
||||
"third": "positional",
|
||||
}
|
||||
83
tests/test_injected_params.py
Normal file
83
tests/test_injected_params.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrbot_sdk._command_model import resolve_command_model_param
|
||||
from astrbot_sdk._injected_params import (
|
||||
is_framework_injected_parameter,
|
||||
legacy_arg_parameter_names,
|
||||
)
|
||||
from astrbot_sdk.conversation import ConversationSession
|
||||
from astrbot_sdk.schedule import ScheduleContext
|
||||
from astrbot_sdk.protocol.descriptors import CommandTrigger, HandlerDescriptor
|
||||
from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher
|
||||
from astrbot_sdk.runtime.loader import LoadedHandler, _build_param_specs
|
||||
|
||||
|
||||
class _Payload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
def test_legacy_arg_parameter_names_excludes_injected_aliases() -> None:
|
||||
def handler(
|
||||
ctx,
|
||||
conversation,
|
||||
conv,
|
||||
sched,
|
||||
schedule,
|
||||
name,
|
||||
extra="fallback",
|
||||
) -> None: ...
|
||||
|
||||
assert legacy_arg_parameter_names(handler) == ["name", "extra"]
|
||||
|
||||
|
||||
def test_resolve_command_model_param_ignores_injected_aliases() -> None:
|
||||
def handler(conversation, sched, payload: _Payload) -> None: ...
|
||||
|
||||
resolved = resolve_command_model_param(handler)
|
||||
|
||||
assert resolved is not None
|
||||
assert resolved.name == "payload"
|
||||
assert resolved.model_cls is _Payload
|
||||
|
||||
|
||||
def test_is_framework_injected_parameter_supports_type_based_injection() -> None:
|
||||
assert is_framework_injected_parameter("custom_conv", ConversationSession)
|
||||
assert is_framework_injected_parameter("custom_schedule", ScheduleContext)
|
||||
|
||||
|
||||
def test_loader_build_param_specs_excludes_injected_aliases() -> None:
|
||||
def handler(conversation, schedule, name: str, count: int = 0) -> None: ...
|
||||
|
||||
specs = _build_param_specs(handler)
|
||||
|
||||
assert [spec.name for spec in specs] == ["name", "count"]
|
||||
|
||||
|
||||
def test_handler_dispatcher_derive_args_skips_injected_aliases() -> None:
|
||||
def handler(conversation, name, sched) -> None: ...
|
||||
|
||||
loaded = LoadedHandler(
|
||||
descriptor=HandlerDescriptor(
|
||||
id="plugin.handler",
|
||||
trigger=CommandTrigger(command="ping"),
|
||||
),
|
||||
callable=handler,
|
||||
owner=object(),
|
||||
)
|
||||
dispatcher = HandlerDispatcher(
|
||||
plugin_id="plugin",
|
||||
peer=SimpleNamespace(),
|
||||
handlers=[loaded],
|
||||
)
|
||||
|
||||
args = dispatcher._derive_args(loaded, SimpleNamespace(text="ping alice"))
|
||||
|
||||
assert args == {"name": "alice"}
|
||||
101
tests/test_star_on_error_fallback.py
Normal file
101
tests/test_star_on_error_fallback.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from astrbot_sdk.errors import AstrBotError
|
||||
from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher
|
||||
from astrbot_sdk.star import Star
|
||||
|
||||
|
||||
class _DummyEvent:
|
||||
def __init__(self) -> None:
|
||||
self.replies: list[str] = []
|
||||
|
||||
async def reply(self, message: str) -> None:
|
||||
self.replies.append(message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_error_fallback_does_not_instantiate_star(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_default_on_error(error: Exception, event, ctx) -> None:
|
||||
del ctx
|
||||
await event.reply(str(error))
|
||||
|
||||
def _fail_init(self) -> None:
|
||||
raise AssertionError("Star should not be instantiated for fallback on_error")
|
||||
|
||||
monkeypatch.setattr(Star, "default_on_error", staticmethod(_fake_default_on_error))
|
||||
monkeypatch.setattr(Star, "__init__", _fail_init)
|
||||
|
||||
dispatcher = HandlerDispatcher(
|
||||
plugin_id="plugin", peer=SimpleNamespace(), handlers=[]
|
||||
)
|
||||
event = _DummyEvent()
|
||||
|
||||
await dispatcher._handle_error(
|
||||
object(),
|
||||
RuntimeError("boom"),
|
||||
event,
|
||||
SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert event.replies == ["boom"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_on_error_formats_astrbot_error_reply() -> None:
|
||||
event = _DummyEvent()
|
||||
error = AstrBotError.invalid_input(
|
||||
"bad payload",
|
||||
hint="check payload",
|
||||
docs_url="https://example.com/docs",
|
||||
details={"b": 2, "a": 1},
|
||||
)
|
||||
|
||||
await Star.default_on_error(error, event, SimpleNamespace())
|
||||
|
||||
assert len(event.replies) == 1
|
||||
assert "check payload" in event.replies[0]
|
||||
assert "https://example.com/docs" in event.replies[0]
|
||||
assert '"a": 1' in event.replies[0]
|
||||
assert '"b": 2' in event.replies[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_on_error_replies_generic_message_for_unknown_errors() -> None:
|
||||
event = _DummyEvent()
|
||||
|
||||
await Star.default_on_error(RuntimeError("boom"), event, SimpleNamespace())
|
||||
|
||||
assert len(event.replies) == 1
|
||||
assert event.replies[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_error_does_not_dispatch_via_subclass_default_on_error() -> None:
|
||||
class PluginWithShadowedDefault(Star):
|
||||
async def default_on_error(self, error: Exception, event, ctx) -> None:
|
||||
del error, event, ctx
|
||||
raise AssertionError(
|
||||
"Star.on_error should not virtual-dispatch default_on_error"
|
||||
)
|
||||
|
||||
expected_event = _DummyEvent()
|
||||
actual_event = _DummyEvent()
|
||||
|
||||
await Star.default_on_error(RuntimeError("boom"), expected_event, SimpleNamespace())
|
||||
await PluginWithShadowedDefault().on_error(
|
||||
RuntimeError("boom"),
|
||||
actual_event,
|
||||
SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert actual_event.replies == expected_event.replies
|
||||
Reference in New Issue
Block a user