mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-02 10:40:15 +08:00
feat: refactor injected parameter handling and introduce is_framework_injected_parameter utility
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._injected_params import is_framework_injected_parameter
|
||||
from ._typing_utils import unwrap_optional
|
||||
from .errors import AstrBotError
|
||||
from .runtime._command_matching import split_command_remainder
|
||||
@@ -211,26 +212,7 @@ def _command_parse_error(message: str) -> AstrBotError:
|
||||
|
||||
|
||||
def _is_injected_parameter(name: str, annotation: Any) -> bool:
|
||||
if name in {"event", "ctx", "context", "sched", "schedule", "conversation", "conv"}:
|
||||
return True
|
||||
normalized, _is_optional = unwrap_optional(annotation)
|
||||
if normalized is None:
|
||||
return False
|
||||
try:
|
||||
from .context import Context
|
||||
from .conversation import ConversationSession
|
||||
from .events import MessageEvent
|
||||
from .schedule import ScheduleContext
|
||||
except Exception:
|
||||
return False
|
||||
if normalized in {Context, MessageEvent, ScheduleContext, ConversationSession}:
|
||||
return True
|
||||
if isinstance(normalized, type):
|
||||
return issubclass(
|
||||
normalized,
|
||||
(Context, MessageEvent, ScheduleContext, ConversationSession),
|
||||
)
|
||||
return False
|
||||
return is_framework_injected_parameter(name, annotation)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
55
src/astrbot_sdk/_injected_params.py
Normal file
55
src/astrbot_sdk/_injected_params.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._typing_utils import unwrap_optional
|
||||
|
||||
_INJECTED_PARAMETER_NAMES = {
|
||||
"event",
|
||||
"ctx",
|
||||
"context",
|
||||
"sched",
|
||||
"schedule",
|
||||
"conversation",
|
||||
"conv",
|
||||
}
|
||||
|
||||
|
||||
def is_framework_injected_parameter(name: str, annotation: Any) -> bool:
|
||||
if name in _INJECTED_PARAMETER_NAMES:
|
||||
return True
|
||||
normalized, _is_optional = unwrap_optional(annotation)
|
||||
if normalized is None:
|
||||
return False
|
||||
try:
|
||||
injected_types = _framework_injected_types()
|
||||
except Exception:
|
||||
return False
|
||||
if normalized in injected_types:
|
||||
return True
|
||||
if isinstance(normalized, type):
|
||||
return issubclass(normalized, injected_types)
|
||||
return False
|
||||
|
||||
|
||||
def _framework_injected_types() -> tuple[type[Any], ...]:
|
||||
from .clients.llm import LLMResponse
|
||||
from .context import Context
|
||||
from .conversation import ConversationSession
|
||||
from .events import MessageEvent
|
||||
from .llm.entities import ProviderRequest
|
||||
from .message_result import MessageEventResult
|
||||
from .schedule import ScheduleContext
|
||||
|
||||
return (
|
||||
Context,
|
||||
MessageEvent,
|
||||
ScheduleContext,
|
||||
ConversationSession,
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["is_framework_injected_parameter"]
|
||||
@@ -62,4 +62,3 @@ class LLMCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
"text": "".join(item.get("text", "") for item in chunks)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -228,4 +228,3 @@ class PlatformCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
),
|
||||
call_handler=self._platform_manager_get_stats,
|
||||
)
|
||||
|
||||
|
||||
@@ -451,4 +451,3 @@ class SystemCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
}
|
||||
)
|
||||
return {"supported": True}
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ import inspect
|
||||
import typing
|
||||
from typing import Any, Literal, TypeAlias, cast
|
||||
|
||||
from .._injected_params import is_framework_injected_parameter
|
||||
from .._typing_utils import unwrap_optional
|
||||
from ..decorators import get_capability_meta, get_handler_meta
|
||||
from ..protocol.descriptors import ParamSpec
|
||||
from ..schedule import ScheduleContext
|
||||
from ..types import GreedyStr
|
||||
|
||||
ParamTypeName: TypeAlias = Literal[
|
||||
@@ -31,19 +31,7 @@ OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None
|
||||
|
||||
|
||||
def is_injected_parameter(annotation: Any, parameter_name: str) -> bool:
|
||||
if parameter_name in {"event", "ctx", "context", "sched", "schedule"}:
|
||||
return True
|
||||
normalized, _is_optional = unwrap_optional(annotation)
|
||||
if normalized is None:
|
||||
return False
|
||||
if normalized in {ScheduleContext}:
|
||||
return True
|
||||
if isinstance(normalized, type):
|
||||
from ..context import Context
|
||||
from ..events import MessageEvent
|
||||
|
||||
return issubclass(normalized, (Context, MessageEvent, ScheduleContext))
|
||||
return False
|
||||
return is_framework_injected_parameter(parameter_name, annotation)
|
||||
|
||||
|
||||
def param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]:
|
||||
|
||||
@@ -35,6 +35,7 @@ from .._command_model import (
|
||||
parse_command_model_remainder,
|
||||
resolve_command_model_param,
|
||||
)
|
||||
from .._injected_params import is_framework_injected_parameter
|
||||
from .._invocation_context import caller_plugin_scope
|
||||
from .._plugin_logger import PluginLogger
|
||||
from .._star_runtime import bind_star_runtime
|
||||
@@ -947,19 +948,7 @@ class HandlerDispatcher:
|
||||
|
||||
@classmethod
|
||||
def _is_injected_parameter(cls, name: str, annotation: Any) -> bool:
|
||||
if name in {"event", "ctx", "context", "conversation", "conv"}:
|
||||
return True
|
||||
normalized, _is_optional = unwrap_optional(annotation)
|
||||
if normalized is None:
|
||||
return False
|
||||
if normalized in {Context, MessageEvent, ConversationSession}:
|
||||
return True
|
||||
if isinstance(normalized, type) and issubclass(
|
||||
normalized,
|
||||
(Context, MessageEvent, ConversationSession),
|
||||
):
|
||||
return True
|
||||
return False
|
||||
return is_framework_injected_parameter(name, annotation)
|
||||
|
||||
async def _handle_error(
|
||||
self,
|
||||
|
||||
@@ -69,6 +69,7 @@ from typing import Any, Literal, TypeAlias, cast
|
||||
import yaml
|
||||
|
||||
from .._command_model import resolve_command_model_param
|
||||
from .._injected_params import is_framework_injected_parameter
|
||||
from .._typing_utils import unwrap_optional
|
||||
from ..decorators import (
|
||||
ConversationMeta,
|
||||
@@ -86,7 +87,6 @@ from ..protocol.descriptors import (
|
||||
ParamSpec,
|
||||
ScheduleTrigger,
|
||||
)
|
||||
from ..schedule import ScheduleContext
|
||||
from ..types import GreedyStr
|
||||
from .environment_groups import (
|
||||
EnvironmentGroup,
|
||||
@@ -223,31 +223,7 @@ def _iter_discoverable_names(instance: Any) -> list[str]:
|
||||
|
||||
|
||||
def _is_injected_parameter(annotation: Any, parameter_name: str) -> bool:
|
||||
if parameter_name in {
|
||||
"event",
|
||||
"ctx",
|
||||
"context",
|
||||
"sched",
|
||||
"schedule",
|
||||
"conversation",
|
||||
"conv",
|
||||
}:
|
||||
return True
|
||||
normalized, _is_optional = unwrap_optional(annotation)
|
||||
if normalized is None:
|
||||
return False
|
||||
if normalized in {ScheduleContext}:
|
||||
return True
|
||||
if isinstance(normalized, type):
|
||||
from ..context import Context
|
||||
from ..conversation import ConversationSession
|
||||
from ..events import MessageEvent
|
||||
|
||||
return issubclass(
|
||||
normalized,
|
||||
(Context, MessageEvent, ScheduleContext, ConversationSession),
|
||||
)
|
||||
return False
|
||||
return is_framework_injected_parameter(parameter_name, annotation)
|
||||
|
||||
|
||||
def _param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]:
|
||||
|
||||
@@ -20,6 +20,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
from ._injected_params import is_framework_injected_parameter
|
||||
from ._star_runtime import bind_star_runtime
|
||||
from ._testing_support import (
|
||||
InMemoryDB,
|
||||
@@ -33,7 +34,6 @@ from ._testing_support import (
|
||||
RecordedSend,
|
||||
StdoutPlatformSink,
|
||||
)
|
||||
from ._typing_utils import unwrap_optional
|
||||
from .context import CancelToken
|
||||
from .context import Context as RuntimeContext
|
||||
from .errors import AstrBotError
|
||||
@@ -754,20 +754,7 @@ class PluginHarness:
|
||||
return names
|
||||
|
||||
def _is_injected_parameter(self, name: str, annotation: Any) -> bool:
|
||||
if name in {"event", "ctx", "context"}:
|
||||
return True
|
||||
normalized, _is_optional = unwrap_optional(annotation)
|
||||
if normalized is None:
|
||||
return False
|
||||
if normalized is RuntimeContext:
|
||||
return True
|
||||
if normalized is MessageEvent:
|
||||
return True
|
||||
if isinstance(normalized, type) and issubclass(
|
||||
normalized, (RuntimeContext, MessageEvent)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
return is_framework_injected_parameter(name, annotation)
|
||||
|
||||
def _next_request_id(self, prefix: str) -> str:
|
||||
self._request_counter += 1
|
||||
|
||||
Reference in New Issue
Block a user