diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index 68fc79539..603474d8d 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -344,10 +344,20 @@ "tiled_service_account_check": { "title": "Tiled Service Account Check", "type": "string" + }, + "submit_task_check": { + "title": "Submit Task Check", + "type": "string" + }, + "admin_check": { + "title": "Admin Check", + "type": "string" } }, "required": [ - "tiled_service_account_check" + "tiled_service_account_check", + "submit_task_check", + "admin_check" ], "title": "OpaConfig", "type": "object", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 7472c37a2..825ad21d0 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -756,9 +756,15 @@ "title": "OpaConfig", "type": "object", "required": [ - "tiled_service_account_check" + "tiled_service_account_check", + "submit_task_check", + "admin_check" ], "properties": { + "admin_check": { + "title": "Admin Check", + "type": "string" + }, "root": { "title": "Root", "default": "http://localhost:8181/", @@ -767,6 +773,10 @@ "maxLength": 2083, "minLength": 1 }, + "submit_task_check": { + "title": "Submit Task Check", + "type": "string" + }, "tiled_service_account_check": { "title": "Tiled Service Account Check", "type": "string" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 06c149955..d0eb4cf55 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -299,6 +299,8 @@ class Tag(StrEnum): class OpaConfig(BlueapiBaseModel): root: HttpUrl = HttpUrl("http://localhost:8181") tiled_service_account_check: str + submit_task_check: str + admin_check: str class ApplicationConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index 3be0814d4..cefd99c7e 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -1,15 +1,20 @@ import logging +import re from collections.abc import Mapping from contextlib import AbstractAsyncContextManager, aclosing, nullcontext -from typing import Any, Self +from typing import Annotated, Any, Self, cast import aiohttp from aiohttp import ClientSession +from fastapi import Depends, HTTPException, Request +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount -from blueapi.service.authentication import TiledAuth +from blueapi.service.authentication import TiledAuth, unchecked_bearer_token +from blueapi.service.model import TaskRequest LOGGER = logging.getLogger(__name__) +INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") class OpaClient: @@ -62,6 +67,23 @@ async def require_tiled_service_account(self, token: str): f"Tiled service account is not valid for '{self._instrument}'" ) + async def require_submit_task(self, instrument_session: str, token: str): + if not (match := INSTRUMENT_SESSION_RE.match(instrument_session)): + raise ValueError("Invalid instrument session") + + if not await self._call_opa( + self._conf.submit_task_check, + { + "token": token, + "proposal": int(match["proposal"]), + "visit": int(match["visit"]), + }, + ): + raise HTTPException(status_code=HTTP_403_FORBIDDEN) + + async def is_admin(self, token: str) -> bool: + return await self._call_opa(self._conf.admin_check, {"token": token}) + class OpaUserClient: client: OpaClient @@ -71,6 +93,13 @@ def __init__(self, client: OpaClient, token: str): self.client = client self.token = token + async def can_submit_task(self, task: TaskRequest): + LOGGER.info("Checking permissions to run task") + await self.client.require_submit_task(task.instrument_session, self.token) + + async def admin(self) -> bool: + return await self.client.is_admin(self.token) + async def validate_tiled_config( tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None @@ -87,3 +116,22 @@ async def validate_tiled_config( tiled.token_url = oidc.token_endpoint auth = TiledAuth(tiled) await opa.require_tiled_service_account(auth.get_access_token()) + + +async def opa( + request: Request, token: str | None = Depends(unchecked_bearer_token) +) -> OpaUserClient | None: + + if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): + if not token: + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED) + return OpaUserClient(opa, token) + return None + + +async def submit_permission( + opa: Annotated[OpaUserClient | None, Depends(opa)], + task_request: TaskRequest, +): + if opa: + await opa.can_submit_task(task_request) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 3114fa73f..06c8765c6 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -40,7 +40,13 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum -from .authorization import OpaClient, validate_tiled_config +from .authorization import ( + OpaClient, + OpaUserClient, + opa, + submit_permission, + validate_tiled_config, +) from .model import ( DeviceModel, DeviceResponse, @@ -146,6 +152,33 @@ def get_app(config: ApplicationConfig): return app +def access_task_permission( + opa: Annotated[OpaUserClient | None, Depends(opa)], + task_id: str, + fedid: Fedid, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + task = runner.run(interface.get_task_by_id, task_id) + + if opa and not opa.admin() and (task and fedid != task.task.metadata.get("user")): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + + +# start_task_permission is used when there is WorkerTask +def start_task_permission( + task: WorkerTask, + opa: Annotated[OpaUserClient, Depends(opa)], + fedid: Fedid, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + if not task.task_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="No task id provided", + ) + access_task_permission(opa, task.task_id, fedid, runner) + + async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -271,13 +304,13 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], - user: Fedid, + fedid: Fedid, ) -> TaskResponse: """Submit a task to the worker.""" try: - user = user or "Unknown" - task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + task_id: str = runner.run(interface.submit_task, task_request, {"user": fedid}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: @@ -309,6 +342,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -326,6 +360,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) def get_tasks( + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: @@ -346,6 +381,9 @@ def get_tasks( tasks = runner.run(interface.get_tasks_by_status, desired_status) else: tasks = runner.run(interface.get_tasks) + + tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] + return TasksListResponse(tasks=tasks) @@ -363,6 +401,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, + _: Annotated[None, Depends(start_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -393,6 +432,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -470,6 +510,9 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, + fedid: Fedid, + opa: Annotated[OpaUserClient, Depends(opa)], + # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ @@ -496,6 +539,16 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): + active = runner.run(interface.get_active_task) + + if ( + opa + and not opa.admin() + and active + and active.task.metadata.get("user") != fedid + ): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) elif new_state == WorkerState.RUNNING: diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index 249198580..65f5c44a6 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -2,13 +2,18 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from fastapi import HTTPException from pydantic import HttpUrl from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authorization import ( OpaClient, + OpaUserClient, + opa, + submit_permission, validate_tiled_config, ) +from blueapi.service.model import TaskRequest # Reusable client patch decorator patch_client_session = patch( @@ -22,6 +27,8 @@ def opa_config() -> OpaConfig: return OpaConfig( root=HttpUrl("http://auth.example.com"), + submit_task_check="/auth/submit", + admin_check="/auth/admin", tiled_service_account_check="/auth/tiled", ) @@ -105,6 +112,105 @@ async def test_opa_adds_input_fields(session: MagicMock, opa_config: OpaConfig): ) +@pytest.mark.parametrize( + "result,context", + [(True, nullcontext()), (False, pytest.raises(HTTPException, match="403"))], +) +@patch_client_session +async def test_require_submit_task( + session: MagicMock, + opa_config: OpaConfig, + result: bool, + context: AbstractContextManager, +): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + + client = OpaClient(instrument="p99", config=opa_config) + + session.assert_called_once_with(base_url="http://auth.example.com/") + with context: + await client.require_submit_task( + instrument_session="cm12345-1", token="foo_bar" + ) + + session().post.assert_called_once_with( + "/auth/submit", + json={ + "input": { + "token": "foo_bar", + "beamline": "p99", + "audience": "account", + "visit": 1, + "proposal": 12345, + } + }, + ) + + +@patch_client_session +async def test_opa_require_submit_task_invalid_session( + session: MagicMock, opa_config: OpaConfig +): + client = OpaClient(instrument="p45", config=opa_config) + + with pytest.raises(ValueError): + await client.require_submit_task( + instrument_session="not a session", token="foo_bar" + ) + + +@pytest.mark.parametrize("result", [True, False]) +@patch_client_session +async def test_opa_is_admin(session: MagicMock, opa_config: OpaConfig, result: bool): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + client = OpaClient(instrument="p45", config=opa_config) + + admin = await client.is_admin("foo_bar") + + assert admin == result + + session().post.assert_called_once_with( + "/auth/admin", + json={"input": {"token": "foo_bar", "beamline": "p45", "audience": "account"}}, + ) + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_user_client_can_submit_task(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.require_submit_task = AsyncMock(side_effect=result) + + user_client = OpaUserClient(opa, "foo_bar") + + with context: + await user_client.can_submit_task( + TaskRequest(name="foo", params={}, instrument_session="cm12345-1") + ) + opa.require_submit_task.assert_called_once_with("cm12345-1", "foo_bar") + + +@pytest.mark.parametrize("result", [True, False]) +async def test_user_client_admin(result: bool): + opa = MagicMock(spec=OpaUserClient) + opa.is_admin = AsyncMock(return_value=result) + + user_client = OpaUserClient(opa, "foo_bar") + + admin = await user_client.admin() + + assert admin == result + + async def test_validate_tiled_config(): opa = MagicMock(spec=OpaClient) tiled = ServiceAccount() @@ -149,3 +255,46 @@ async def test_validate_tiled_config_with_missing_config( assert await validate_tiled_config(tiled_auth, oidc, opa_client) is None if opa_client is not None: opa_client.require_tiled_service_account.assert_not_called() + + +async def test_opa_dependency_method(): + request = MagicMock() + + user_client = await opa(request, "foo_bar") + + assert user_client is not None + assert user_client.client == request.app.state.authz + assert user_client.token == "foo_bar" + + +async def test_opa_dependency_without_token(): + request = MagicMock() + + with pytest.raises(HTTPException, match="401"): + await opa(request, None) + + +@pytest.mark.parametrize("token", ["foo_bar", None]) +async def test_opa_dependency_without_authz(token): + request = MagicMock() + del request.app.state.authz + user_client = await opa(request, token) + assert user_client is None + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_submit_permission_dependency(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.can_submit_task.side_effect = result + with context: + await submit_permission(opa, Mock()) + + +async def test_submit_permission_dependency_without_opa(): + assert await submit_permission(None, Mock()) is None diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index c1d3b6a95..a2248e798 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -251,7 +251,7 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": None}) assert response.json() == {"task_id": task_id} @@ -574,7 +574,12 @@ def test_get_state(mock_runner: Mock, client: TestClient): def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -588,7 +593,12 @@ def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -602,7 +612,12 @@ def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -619,7 +634,12 @@ def test_set_state_running_to_stopping_including_reason( current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", @@ -635,7 +655,12 @@ def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - mock_runner.run.side_effect = [current_state, TransitionError(), final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + TransitionError(), + final_state, + ] response = client.put( "/worker/state", diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index b6e52c393..abc2851e6 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -340,6 +340,8 @@ def test_config_yaml_parsed(temp_yaml_config_file): "opa": { "root": "http://opa.example.com/", "tiled_service_account_check": "v1/tiled_service_account", + "submit_task_check": "v1/submit_task", + "admin_check": "v1/admin_check", }, }, {