Source code for sql.adapters.base

# Copyright 2025 Softwell S.r.l. - SPDX-License-Identifier: Apache-2.0
"""Base adapter class for async database backends with CRUD helpers."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from collections.abc import Sequence


[docs] class DbAdapter(ABC): """Abstract base class for async database adapters with CRUD helpers. Provides a unified interface for SQLite and PostgreSQL with: - Connection management (connect, close) - Raw query execution (execute, fetch_one, fetch_all) - CRUD helpers (insert, select, update, delete) Subclasses must implement the abstract methods and set the placeholder attribute for parameter binding (`:name` for SQLite, `%(name)s` for PostgreSQL). """ placeholder: str = ":name" # Override in subclass
[docs] def pk_column(self, name: str) -> str: """Return SQL definition for autoincrement primary key column.""" return f'"{name}" INTEGER PRIMARY KEY'
[docs] def for_update_clause(self) -> str: """Return FOR UPDATE clause if supported, empty string otherwise.""" return ""
[docs] @abstractmethod async def connect(self) -> None: """Establish database connection.""" ...
[docs] @abstractmethod async def close(self) -> None: """Close database connection.""" ...
[docs] @abstractmethod async def execute(self, query: str, params: dict[str, Any] | None = None) -> int: """Execute query, return affected row count.""" ...
[docs] @abstractmethod async def execute_many(self, query: str, params_list: Sequence[dict[str, Any]]) -> int: """Execute query multiple times with different params (batch insert).""" ...
[docs] @abstractmethod async def fetch_one( self, query: str, params: dict[str, Any] | None = None ) -> dict[str, Any] | None: """Execute query, return single row as dict or None.""" ...
[docs] @abstractmethod async def fetch_all( self, query: str, params: dict[str, Any] | None = None ) -> list[dict[str, Any]]: """Execute query, return all rows as list of dicts.""" ...
[docs] @abstractmethod async def execute_script(self, script: str) -> None: """Execute multiple statements (for schema creation).""" ...
[docs] @abstractmethod async def commit(self) -> None: """Commit current transaction.""" ...
[docs] @abstractmethod async def rollback(self) -> None: """Rollback current transaction.""" ...
# ------------------------------------------------------------------------- # CRUD Helpers # ------------------------------------------------------------------------- def _sql_name(self, name: str) -> str: """Return quoted SQL identifier for column/table name. Quotes with double quotes to handle reserved words like 'user'. Works for both SQLite and PostgreSQL. """ return f'"{name}"' def _placeholder(self, name: str) -> str: """Return placeholder for named parameter.""" return self.placeholder.replace("name", name)
[docs] async def insert(self, table: str, values: dict[str, Any]) -> int: """Insert a row, return rowcount. Args: table: Table name. values: Column-value pairs. Returns: Number of affected rows (typically 1). """ cols = list(values.keys()) placeholders = ", ".join(self._placeholder(c) for c in cols) col_list = ", ".join(self._sql_name(c) for c in cols) query = f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})" return await self.execute(query, values)
[docs] async def insert_returning_id( self, table: str, values: dict[str, Any], pk_col: str = "id" ) -> Any: """Insert a row and return the generated primary key. Override in subclasses for database-specific implementation (e.g., RETURNING for PostgreSQL, lastrowid for SQLite). Args: table: Table name. values: Column-value pairs. pk_col: Primary key column name. Returns: The generated primary key value, or None if not supported. """ # Default: just insert and return None (subclasses override) await self.insert(table, values) return None
[docs] async def select( self, table: str, columns: list[str] | None = None, where: dict[str, Any] | None = None, order_by: str | None = None, limit: int | None = None, ) -> list[dict[str, Any]]: """Select rows, return list of dicts. Args: table: Table name. columns: Columns to select (None = all). where: WHERE conditions (AND). order_by: ORDER BY clause. limit: LIMIT clause. Returns: List of row dicts. """ cols_sql = ", ".join(self._sql_name(c) for c in columns) if columns else "*" query = f"SELECT {cols_sql} FROM {table}" params: dict[str, Any] = {} if where: conditions = [f"{self._sql_name(k)} = {self._placeholder(k)}" for k in where] query += " WHERE " + " AND ".join(conditions) params.update(where) if order_by: query += f" ORDER BY {order_by}" if limit: query += f" LIMIT {limit}" return await self.fetch_all(query, params)
[docs] async def select_one( self, table: str, columns: list[str] | None = None, where: dict[str, Any] | None = None, ) -> dict[str, Any] | None: """Select single row, return dict or None.""" results = await self.select(table, columns, where, limit=1) return results[0] if results else None
[docs] async def update(self, table: str, values: dict[str, Any], where: dict[str, Any]) -> int: """Update rows, return rowcount. Args: table: Table name. values: Column-value pairs to update. where: WHERE conditions. Returns: Number of affected rows. """ # Prefix value params to avoid collision with where params set_parts = [f"{self._sql_name(k)} = {self._placeholder('val_' + k)}" for k in values] where_parts = [f"{self._sql_name(k)} = {self._placeholder('whr_' + k)}" for k in where] query = f"UPDATE {table} SET {', '.join(set_parts)} WHERE {' AND '.join(where_parts)}" params = {f"val_{k}": v for k, v in values.items()} params.update({f"whr_{k}": v for k, v in where.items()}) return await self.execute(query, params)
[docs] async def delete(self, table: str, where: dict[str, Any]) -> int: """Delete rows, return rowcount. Args: table: Table name. where: WHERE conditions. Returns: Number of deleted rows. """ where_parts = [f"{self._sql_name(k)} = {self._placeholder(k)}" for k in where] query = f"DELETE FROM {table} WHERE {' AND '.join(where_parts)}" return await self.execute(query, where)
[docs] async def exists(self, table: str, where: dict[str, Any]) -> bool: """Check if row exists.""" conditions = [f"{self._sql_name(k)} = {self._placeholder(k)}" for k in where] query = f"SELECT 1 FROM {table} WHERE {' AND '.join(conditions)} LIMIT 1" result = await self.fetch_one(query, where) return result is not None
[docs] async def count(self, table: str, where: dict[str, Any] | None = None) -> int: """Count rows in table. Args: table: Table name. where: Optional WHERE conditions. Returns: Row count. """ query = f"SELECT COUNT(*) as cnt FROM {table}" params: dict[str, Any] = {} if where: conditions = [f"{self._sql_name(k)} = {self._placeholder(k)}" for k in where] query += " WHERE " + " AND ".join(conditions) params.update(where) result = await self.fetch_one(query, params) return result["cnt"] if result else 0