Source code for sql.sqldb

# Copyright 2025 Softwell S.r.l. - SPDX-License-Identifier: Apache-2.0
"""Async database manager with adapter pattern and table registration."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from .adapters import DbAdapter, get_adapter

if TYPE_CHECKING:
    from .table import Table


[docs] class SqlDb: """Async database manager with adapter pattern. Supports multiple database types via adapters: - SQLite: "/path/to/db.sqlite" or "sqlite:/path/to/db" - PostgreSQL: "postgresql://user:pass@host/db" Features: - Table class registration via add_table() - Table access via table(name) - Schema creation and verification - CRUD operations via adapter - Encryption key access via parent.encryption_key Usage: db = SqlDb("/data/mail.db", parent=proxy) await db.connect() db.add_table(TenantsTable) db.add_table(AccountsTable) await db.check_structure() tenant = await db.table('tenants').select_one(where={"id": "acme"}) await db.close() """
[docs] def __init__(self, connection_string: str, parent: Any = None): """Initialize database manager. Args: connection_string: Database connection string. parent: Parent object (e.g., proxy) that provides encryption_key. """ self.connection_string = connection_string self.parent = parent self.adapter: DbAdapter = get_adapter(connection_string) self.tables: dict[str, Table] = {}
@property def encryption_key(self) -> bytes | None: """Get encryption key from parent. Returns None if not configured.""" if self.parent is None: return None return getattr(self.parent, "encryption_key", None)
[docs] async def connect(self) -> None: """Connect to database.""" await self.adapter.connect()
[docs] async def close(self) -> None: """Close database connection.""" await self.adapter.close()
[docs] def add_table(self, table_class: type[Table]) -> Table: """Register and instantiate a table class. Args: table_class: Table manager class (must have name attribute). Returns: The instantiated table. """ if not hasattr(table_class, "name") or not table_class.name: raise ValueError(f"Table class {table_class.__name__} must define 'name'") instance = table_class(self) self.tables[instance.name] = instance return instance
[docs] def table(self, name: str) -> Table: """Get table instance by name. Args: name: Table name. Returns: Table instance. Raises: ValueError: If table not registered. """ if name not in self.tables: raise ValueError(f"Table '{name}' not registered. Use add_table() first.") return self.tables[name]
[docs] async def check_structure(self) -> None: """Create all registered tables if they don't exist.""" for table in self.tables.values(): await table.create_schema()
# ------------------------------------------------------------------------- # Direct adapter access # -------------------------------------------------------------------------
[docs] async def execute(self, query: str, params: dict[str, Any] | None = None) -> int: """Execute raw query, return affected row count.""" return await self.adapter.execute(query, params)
[docs] async def fetch_one( self, query: str, params: dict[str, Any] | None = None ) -> dict[str, Any] | None: """Execute raw query, return single row.""" return await self.adapter.fetch_one(query, params)
[docs] async def fetch_all( self, query: str, params: dict[str, Any] | None = None ) -> list[dict[str, Any]]: """Execute raw query, return all rows.""" return await self.adapter.fetch_all(query, params)
[docs] async def commit(self) -> None: """Commit transaction.""" await self.adapter.commit()
[docs] async def rollback(self) -> None: """Rollback transaction.""" await self.adapter.rollback()
__all__ = ["SqlDb"]