Source code for core.mail_proxy.interface.endpoint_base

# Copyright 2025 Softwell S.r.l. - SPDX-License-Identifier: Apache-2.0
"""Base class for endpoint introspection and command dispatch.

This module provides the foundation for automatic API/CLI generation
from endpoint classes via method introspection.

Components:
    POST: Decorator to mark methods as HTTP POST.
    BaseEndpoint: Base class with introspection capabilities.
    EndpointDispatcher: Routes commands to endpoint methods.

Example:
    Define an endpoint::

        from core.mail_proxy.interface.endpoint_base import BaseEndpoint, POST

        class MyEndpoint(BaseEndpoint):
            name = "items"

            async def list(self, active_only: bool = False) -> list[dict]:
                \"\"\"List all items.\"\"\"
                return await self.table.list_all(active_only=active_only)

            @POST
            async def add(self, id: str, name: str) -> dict:
                \"\"\"Add a new item.\"\"\"
                return await self.table.add({"id": id, "name": name})

    Use with dispatcher::

        dispatcher = EndpointDispatcher(db)
        result = await dispatcher.dispatch("addMessages", {"messages": [...]})

Note:
    BaseEndpoint.discover() scans CE and EE packages for endpoint classes
    and composes them when both exist for an entity.
"""

from __future__ import annotations

import importlib
import inspect
import pkgutil
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, get_origin, get_type_hints

from pydantic import create_model

if TYPE_CHECKING:
    from sql import SqlDb

# Packages to scan for entity endpoints
_CE_ENTITIES_PACKAGE = "core.mail_proxy.entities"
_EE_ENTITIES_PACKAGE = "enterprise.mail_proxy.entities"


[docs] def POST(method: Callable) -> Callable: """Decorator to mark an endpoint method as POST. POST methods receive parameters via JSON request body instead of query parameters. Args: method: The async method to decorate. Returns: The decorated method with _http_post attribute set. Example: :: @POST async def add(self, id: str, data: dict) -> dict: \"\"\"Add item with complex data.\"\"\" ... """ method._http_post = True # type: ignore[attr-defined] return method
[docs] class BaseEndpoint: """Base class for all endpoints with introspection capabilities. Provides method discovery, HTTP method inference, and Pydantic model generation from signatures for automatic API/CLI generation. Attributes: name: Endpoint name used in URL paths and CLI groups. table: Database table instance for operations. Example: Create a custom endpoint:: class ItemEndpoint(BaseEndpoint): name = "items" async def get(self, item_id: str) -> dict: item = await self.table.get(item_id) if not item: raise ValueError(f"Item '{item_id}' not found") return item @POST async def add(self, id: str, name: str) -> dict: return await self.table.add({"id": id, "name": name}) # Register with FastAPI endpoint = ItemEndpoint(db.table("items")) register_endpoint(app, endpoint) """ name: str = ""
[docs] def __init__(self, table: Any): """Initialize endpoint with table reference. Args: table: Database table instance for operations. """ self.table = table
[docs] def get_methods(self) -> list[tuple[str, Callable]]: """Return all public async methods for API/CLI generation. Returns: List of (method_name, method) tuples for all public async methods (excluding those starting with underscore). """ methods = [] for method_name in dir(self): if method_name.startswith("_"): continue method = getattr(self, method_name) if callable(method) and inspect.iscoroutinefunction(method): methods.append((method_name, method)) return methods
[docs] def get_http_method(self, method_name: str) -> str: """Determine HTTP method for an endpoint method. Args: method_name: Name of the endpoint method. Returns: "POST" if decorated with @POST, otherwise "GET". """ method = getattr(self, method_name) if getattr(method, "_http_post", False): return "POST" return "GET"
[docs] def create_request_model(self, method_name: str) -> type: """Create Pydantic model from method signature. Used by API layer to validate and parse request bodies. Args: method_name: Name of the method to introspect. Returns: Dynamically created Pydantic model class. """ method = getattr(self, method_name) sig = inspect.signature(method) try: hints = get_type_hints(method) except Exception: hints = {} fields = {} for param_name, param in sig.parameters.items(): if param_name == "self": continue annotation = hints.get(param_name, param.annotation) if annotation is inspect.Parameter.empty: annotation = Any fields[param_name] = self._annotation_to_field(annotation, param.default) model_name = f"{method_name.title().replace('_', '')}Request" return create_model(model_name, **fields)
[docs] def is_simple_params(self, method_name: str) -> bool: """Check if method has only simple params suitable for query string. Args: method_name: Name of the method to check. Returns: False if any parameter is list or dict (including Optional[list]). """ method = getattr(self, method_name) try: hints = get_type_hints(method) except Exception: hints = {} sig = inspect.signature(method) for param_name, param in sig.parameters.items(): if param_name == "self": continue ann = hints.get(param_name, param.annotation) if self._is_complex_type(ann): return False return True
def _is_complex_type(self, ann: Any) -> bool: """Check if annotation is a complex type (list, dict, or contains them).""" import types from typing import Union, get_args if ann in (list, dict): return True origin = get_origin(ann) if origin in (list, dict): return True if origin is Union or isinstance(origin, type) and origin is types.UnionType: for arg in get_args(ann): if arg is type(None): continue if self._is_complex_type(arg): return True if type(ann).__name__ == "UnionType": for arg in get_args(ann): if arg is type(None): continue if self._is_complex_type(arg): return True return False
[docs] def count_params(self, method_name: str) -> int: """Count non-self parameters for a method. Args: method_name: Name of the method. Returns: Number of parameters excluding 'self'. """ method = getattr(self, method_name) sig = inspect.signature(method) return sum(1 for p in sig.parameters if p != "self")
def _annotation_to_field(self, annotation: Any, default: Any) -> tuple[Any, Any]: """Convert Python annotation to Pydantic field tuple (type, default).""" if default is inspect.Parameter.empty: return (annotation, ...) # Required field return (annotation, default)
[docs] @classmethod def discover(cls) -> list[type[BaseEndpoint]]: """Autodiscover all endpoint classes from entities/ directories. Scans CE and EE packages for endpoint.py and endpoint_ee.py modules. When both exist for an entity, composes them with EE mixin first. Returns: List of endpoint classes ready for instantiation. Example: :: for endpoint_class in BaseEndpoint.discover(): table = db.table(endpoint_class.name) endpoint = endpoint_class(table) register_endpoint(app, endpoint) """ ce_modules = cls._find_entity_modules(_CE_ENTITIES_PACKAGE, "endpoint") ee_modules = cls._find_entity_modules(_EE_ENTITIES_PACKAGE, "endpoint_ee") endpoints: list[type[BaseEndpoint]] = [] for entity_name, ce_module in ce_modules.items(): ce_class = cls._get_class_from_module(ce_module, "Endpoint") if not ce_class: continue ee_module = ee_modules.get(entity_name) if ee_module: ee_mixin = cls._get_ee_mixin_from_module(ee_module, "_EE") if ee_mixin: composed_class = type( ce_class.__name__, (ee_mixin, ce_class), {"__module__": ce_class.__module__} ) endpoints.append(composed_class) continue endpoints.append(ce_class) return endpoints
@classmethod def _find_entity_modules(cls, base_package: str, module_name: str) -> dict[str, Any]: """Find entity modules in a package.""" result: dict[str, Any] = {} try: package = importlib.import_module(base_package) except ImportError: return result package_path = getattr(package, "__path__", None) if not package_path: return result for _, name, is_pkg in pkgutil.iter_modules(package_path): if not is_pkg: continue full_module_name = f"{base_package}.{name}.{module_name}" try: module = importlib.import_module(full_module_name) result[name] = module except ImportError: pass return result @classmethod def _get_class_from_module(cls, module: Any, class_suffix: str) -> type | None: """Extract a class from module by suffix pattern.""" for attr_name in dir(module): if attr_name.startswith("_"): continue obj = getattr(module, attr_name) if isinstance(obj, type) and attr_name.endswith(class_suffix): if "_EE" in attr_name or "Mixin" in attr_name: continue if attr_name in ("BaseEndpoint", "Endpoint"): continue if not hasattr(obj, "name"): continue return obj return None @classmethod def _get_ee_mixin_from_module(cls, module: Any, class_suffix: str) -> type | None: """Extract an EE mixin class from module.""" for name in dir(module): if name.startswith("_"): continue obj = getattr(module, name) if isinstance(obj, type) and name.endswith(class_suffix): return obj return None
[docs] class EndpointDispatcher: """Dispatches commands to appropriate endpoint methods. Centralizes command routing, mapping legacy camelCase commands to endpoint.method pairs for backward compatibility. Attributes: COMMAND_MAP: Maps command names to (endpoint_name, method_name). db: Database instance for table access. proxy: Optional MailProxy for operations needing runtime state. Example: Use dispatcher for legacy API compatibility:: dispatcher = EndpointDispatcher(db, proxy=proxy) # Dispatch legacy command result = await dispatcher.dispatch( "addMessages", {"messages": [{"to": "user@example.com"}]} ) # Returns: {"ok": True, "count": 1} # Direct endpoint access messages_endpoint = dispatcher.get_endpoint("messages") await messages_endpoint.add_batch(messages=[...]) """ COMMAND_MAP: dict[str, tuple[str, str]] = { # Messages "addMessages": ("messages", "add_batch"), "deleteMessages": ("messages", "delete_batch"), "listMessages": ("messages", "list"), "cleanupMessages": ("messages", "cleanup"), # Accounts "addAccount": ("accounts", "add"), "listAccounts": ("accounts", "list"), "deleteAccount": ("accounts", "delete"), # Tenants "addTenant": ("tenants", "add"), "getTenant": ("tenants", "get"), "listTenants": ("tenants", "list"), "updateTenant": ("tenants", "update"), "deleteTenant": ("tenants", "delete"), "suspend": ("tenants", "suspend_batch"), "activate": ("tenants", "activate_batch"), # Instance "getInstance": ("instance", "get"), "updateInstance": ("instance", "update"), "listTenantsSyncStatus": ("instance", "get_sync_status"), } # Result wrapping rules for legacy API compatibility _RESULT_WRAP_KEYS: dict[str, str] = { "listTenants": "tenants", "listAccounts": "accounts", "listMessages": "messages", }
[docs] def __init__(self, db: SqlDb, proxy: Any = None): """Initialize dispatcher with database and optional proxy. Args: db: MailProxyDb instance for table access. proxy: Optional MailProxy for operations needing runtime state. """ self.db = db self.proxy = proxy self._endpoints: dict[str, BaseEndpoint] = {}
def _get_endpoint(self, endpoint_name: str) -> BaseEndpoint: """Get or create endpoint instance by name.""" if endpoint_name not in self._endpoints: self._endpoints[endpoint_name] = self._create_endpoint(endpoint_name) return self._endpoints[endpoint_name] def _create_endpoint(self, endpoint_name: str) -> BaseEndpoint: """Create endpoint instance for the given name.""" from ..entities.account import AccountEndpoint from ..entities.instance import InstanceEndpoint from ..entities.message import MessageEndpoint from ..entities.tenant import TenantEndpoint table = self.db.table(endpoint_name) match endpoint_name: case "messages": return MessageEndpoint(table) case "accounts": return AccountEndpoint(table) case "tenants": return TenantEndpoint(table) case "instance": return InstanceEndpoint(table, proxy=self.proxy) case _: raise ValueError(f"Unknown endpoint: {endpoint_name}")
[docs] async def dispatch(self, cmd: str, payload: dict[str, Any]) -> dict[str, Any]: """Dispatch a command to the appropriate endpoint method. Args: cmd: Command name (e.g., "addMessages", "listTenants"). payload: Command parameters as dict. Returns: Result dict in legacy format {"ok": True/False, ...}. Example: :: result = await dispatcher.dispatch( "addTenant", {"id": "acme", "name": "Acme Corp"} ) if result["ok"]: print(f"Created tenant: {result['id']}") """ if cmd not in self.COMMAND_MAP: return {"ok": False, "error": f"unknown command: {cmd}"} validation_error = self._validate_payload(cmd, payload) if validation_error: return {"ok": False, "error": validation_error} endpoint_name, method_name = self.COMMAND_MAP[cmd] endpoint = self._get_endpoint(endpoint_name) method = getattr(endpoint, method_name) mapped_payload = self._map_payload(cmd, payload) try: result = await method(**mapped_payload) return self._wrap_result(cmd, result) except ValueError as e: return {"ok": False, "error": str(e)} except Exception as e: return {"ok": False, "error": f"Internal error: {e}"}
def _validate_payload(self, cmd: str, payload: dict[str, Any]) -> str | None: """Validate payload before dispatch. Returns error message or None.""" if cmd == "updateTenant": if "id" not in payload: return "tenant id required" return None def _wrap_result(self, cmd: str, result: Any) -> dict[str, Any]: """Wrap endpoint result in legacy API format.""" if isinstance(result, list): key = self._RESULT_WRAP_KEYS.get(cmd, "items") return {"ok": True, key: result} if isinstance(result, bool): if result: return {"ok": True} return {"ok": False, "error": "not found"} if result is None: return {"ok": False, "error": "not found"} if isinstance(result, dict): if "ok" not in result: result["ok"] = True return result return {"ok": True, "value": result} def _map_payload(self, cmd: str, payload: dict[str, Any]) -> dict[str, Any]: """Map legacy payload keys to endpoint method parameters.""" result = dict(payload) if cmd in ("getTenant", "deleteTenant", "updateTenant"): if "id" in result: result["tenant_id"] = result.pop("id") elif cmd == "deleteAccount": if "id" in result: result["account_id"] = result.pop("id") if cmd == "listMessages": result.setdefault("active_only", False) result.setdefault("include_history", False) return result
[docs] def get_endpoint(self, name: str) -> BaseEndpoint: """Get endpoint by name for direct access. Args: name: Endpoint name (e.g., "messages", "accounts"). Returns: BaseEndpoint instance for direct method calls. """ return self._get_endpoint(name)
__all__ = ["BaseEndpoint", "EndpointDispatcher", "POST"]