# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Base support for tasks that run on server-side Celery workers."""

from abc import ABCMeta
from collections.abc import Callable
from typing import Any, override

from debusine.artifacts.models import TaskTypes
from debusine.db.context import context
from debusine.db.models import WorkRequest
from debusine.db.models.permissions import (
    PermissionUser,
    format_permission_check_error,
)
from debusine.db.models.tasks import DBTask
from debusine.tasks.models import BaseDynamicTaskData, BaseTaskData


class ServerTaskPermissionDenied(Exception):
    """Permission predicate checks failed on server tasks."""


def analyze_celery_worker_tasks() -> dict[str, Any]:
    """
    Return dictionary with metadata for each server task in DBTask._sub_tasks.

    Subclasses of DBTask get registered in DBTask._sub_tasks. Return
    a dictionary with the metadata of each of the subtasks.

    This method is executed in the worker when submitting the dynamic
    metadata.
    """
    metadata = {}
    registry = DBTask._sub_tasks.get(TaskTypes.SERVER, {})
    for task_class in registry.values():
        assert issubclass(task_class, BaseServerTask)
        metadata.update(task_class.analyze_worker())
    return metadata


class BaseServerTask[TD: BaseTaskData, DTD: BaseDynamicTaskData](
    DBTask[TD, DTD], metaclass=ABCMeta
):
    """Base class for tasks that run on server-side Celery workers."""

    TASK_TYPE = TaskTypes.SERVER

    # If True, the task manages its own transactions.  If False, the task is
    # automatically run within a single transaction.
    TASK_MANAGES_TRANSACTIONS = False

    work_request: WorkRequest

    @override
    def execute(self) -> WorkRequest.Results:
        """Execute task, setting up a suitable context."""
        context.reset()
        context.set_scope(self.workspace.scope)
        context.set_user(self.work_request.created_by)
        self.workspace.set_current()
        try:
            return super().execute()
        finally:
            context.reset()

    def enforce(self, predicate: Callable[[PermissionUser], bool]) -> None:
        """Enforce a permission predicate."""
        if predicate(self.work_request.created_by):
            return

        raise ServerTaskPermissionDenied(
            format_permission_check_error(predicate, context.user)
        )

    @classmethod
    def analyze_worker(cls) -> dict[str, Any]:
        """
        Return dynamic metadata about the current worker.

        This method is called on the worker to collect information about the
        worker. The information is stored as a set of key-value pairs in a
        dictionary.

        That information is then reused on the scheduler to be fed to
        :py:meth:`can_run_on` and determine if a task is suitable to be
        executed on the worker.

        Derived objects can extend the behaviour by overriding
        the method, calling ``metadata = super().analyze_worker()``,
        and then adding supplementary data in the dictionary.

        To avoid conflicts on the names of the keys used by different tasks
        you should use key names obtained with
        ``self.prefix_with_task_name(...)``.

        :return: a dictionary describing the worker.
        :rtype: dict.
        """
        version_key_name = cls.prefix_with_task_name("version")
        return {
            version_key_name: cls.TASK_VERSION,
        }
