diff --git a/backend/btrixcloud/db.py b/backend/btrixcloud/db.py index 7f0bfbd9b2..cb9637c254 100644 --- a/backend/btrixcloud/db.py +++ b/backend/btrixcloud/db.py @@ -4,11 +4,17 @@ import importlib.util import os -import urllib +import urllib.parse import asyncio from uuid import UUID, uuid4 -from typing import Optional, Union, TypeVar, Type, TYPE_CHECKING +from typing import ( + Optional, + Type, + TypeVar, + Union, + TYPE_CHECKING, +) from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from pydantic import BaseModel diff --git a/backend/btrixcloud/main.py b/backend/btrixcloud/main.py index 4fc8e5510e..78bd76be6b 100644 --- a/backend/btrixcloud/main.py +++ b/backend/btrixcloud/main.py @@ -176,6 +176,7 @@ def main() -> None: org_ops = init_orgs_api( app, + dbclient, mdb, user_manager, crawl_manager, diff --git a/backend/btrixcloud/ops.py b/backend/btrixcloud/ops.py index 869c13c755..148299e3a5 100644 --- a/backend/btrixcloud/ops.py +++ b/backend/btrixcloud/ops.py @@ -56,7 +56,7 @@ def init_ops() -> Tuple[ user_manager = UserManager(mdb, email, invite_ops) - org_ops = OrgOps(mdb, invite_ops, user_manager, crawl_manager) + org_ops = OrgOps(dbclient, mdb, invite_ops, user_manager, crawl_manager) event_webhook_ops = EventWebhookOps(mdb, org_ops) diff --git a/backend/btrixcloud/orgs.py b/backend/btrixcloud/orgs.py index 29c06aadb9..8323a6266d 100644 --- a/backend/btrixcloud/orgs.py +++ b/backend/btrixcloud/orgs.py @@ -24,7 +24,11 @@ Any, ) -from motor.motor_asyncio import AsyncIOMotorDatabase +from motor.motor_asyncio import ( + AsyncIOMotorClient, + AsyncIOMotorClientSession, + AsyncIOMotorDatabase, +) from pydantic import ValidationError from pymongo import ReturnDocument from pymongo.errors import AutoReconnect, DuplicateKeyError @@ -202,11 +206,14 @@ class OrgOps(BaseOrgs): def __init__( self, - mdb, + dbclient: AsyncIOMotorClient, + mdb: AsyncIOMotorDatabase, invites: InviteOps, user_manager: UserManager, crawl_manager: CrawlManager, ): + self.dbclient = dbclient + self.orgs = mdb["organizations"] self.crawls_db = mdb["crawls"] self.crawl_configs_db = mdb["crawl_configs"] @@ -384,9 +391,11 @@ async def get_users_for_org( users.append(User(**user_dict)) return users - async def get_org_by_id(self, oid: UUID) -> Organization: + async def get_org_by_id( + self, oid: UUID, session: AsyncIOMotorClientSession | None = None + ) -> Organization: """Get an org by id""" - res = await self.orgs.find_one({"_id": oid}) + res = await self.orgs.find_one({"_id": oid}, session=session) if not res: raise HTTPException(status_code=400, detail="invalid_org_id") @@ -665,79 +674,100 @@ async def update_quotas( quotas: OrgQuotasIn, mode: Literal["set", "add"], sub_event_id: str | None = None, + session: AsyncIOMotorClientSession | None = None, ) -> None: """update organization quotas""" - previous_extra_mins = ( - org.quotas.extraExecMinutes - if (org.quotas and org.quotas.extraExecMinutes) - else 0 - ) - previous_gifted_mins = ( - org.quotas.giftedExecMinutes - if (org.quotas and org.quotas.giftedExecMinutes) - else 0 - ) - - if mode == "add": - increment_update: dict[str, Any] = { - "$inc": {}, - } - - for field, value in quotas.model_dump( - exclude_unset=True, exclude_defaults=True, exclude_none=True - ).items(): - if field == "context" or value is None: - continue - inc = max(value, -org.quotas.model_dump().get(field, 0)) - increment_update["$inc"][f"quotas.{field}"] = inc - - updated_org = await self.orgs.find_one_and_update( - {"_id": org.id}, - increment_update, - projection={"quotas": True}, - return_document=ReturnDocument.AFTER, - ) - quotas = OrgQuotasIn(**updated_org["quotas"]) - - update: dict[str, dict[str, dict[str, Any] | int]] = { - "$push": { - "quotaUpdates": OrgQuotaUpdate( - modified=dt_now(), - update=OrgQuotas( - **quotas.model_dump( - exclude_unset=True, exclude_defaults=True, exclude_none=True - ) - ), - subEventId=sub_event_id, - ).model_dump() - }, - "$inc": {}, - "$set": {}, - } + async with await self.dbclient.start_session( + causal_consistency=True + ) as session: + try: + # Re-fetch the organization within the session + # so that the operation as a whole is atomic. + org = await self.get_org_by_id(org.id, session=session) + + previous_extra_mins = ( + org.quotas.extraExecMinutes + if (org.quotas and org.quotas.extraExecMinutes) + else 0 + ) + previous_gifted_mins = ( + org.quotas.giftedExecMinutes + if (org.quotas and org.quotas.giftedExecMinutes) + else 0 + ) - if mode == "set": - increment_update = quotas.model_dump( - exclude_unset=True, exclude_defaults=True, exclude_none=True - ) - update["$set"]["quotas"] = increment_update + if mode == "add": + increment_update: dict[str, Any] = { + "$inc": {}, + } - # Inc org available fields for extra/gifted execution time as needed - if quotas.extraExecMinutes is not None: - extra_secs_diff = (quotas.extraExecMinutes - previous_extra_mins) * 60 - if org.extraExecSecondsAvailable + extra_secs_diff <= 0: - update["$set"]["extraExecSecondsAvailable"] = 0 - else: - update["$inc"]["extraExecSecondsAvailable"] = extra_secs_diff + for field, value in quotas.model_dump( + exclude_unset=True, exclude_defaults=True, exclude_none=True + ).items(): + if value is None: + continue + inc = max(value, -org.quotas.model_dump().get(field, 0)) + increment_update["$inc"][f"quotas.{field}"] = inc + + updated_org = await self.orgs.find_one_and_update( + {"_id": org.id}, + increment_update, + projection={"quotas": True}, + return_document=ReturnDocument.AFTER, + session=session, + ) + quotas = OrgQuotasIn(**updated_org["quotas"]) + + update: dict[str, dict[str, dict[str, Any] | int]] = { + "$push": { + "quotaUpdates": OrgQuotaUpdate( + modified=dt_now(), + update=OrgQuotas( + **quotas.model_dump( + exclude_unset=True, + exclude_defaults=True, + exclude_none=True, + ) + ), + subEventId=sub_event_id, + ).model_dump() + }, + "$inc": {}, + "$set": {}, + } - if quotas.giftedExecMinutes is not None: - gifted_secs_diff = (quotas.giftedExecMinutes - previous_gifted_mins) * 60 - if org.giftedExecSecondsAvailable + gifted_secs_diff <= 0: - update["$set"]["giftedExecSecondsAvailable"] = 0 - else: - update["$inc"]["giftedExecSecondsAvailable"] = gifted_secs_diff + if mode == "set": + increment_update = quotas.model_dump( + exclude_unset=True, exclude_defaults=True, exclude_none=True + ) + update["$set"]["quotas"] = increment_update + + # Inc org available fields for extra/gifted execution time as needed + if quotas.extraExecMinutes is not None: + extra_secs_diff = ( + quotas.extraExecMinutes - previous_extra_mins + ) * 60 + if org.extraExecSecondsAvailable + extra_secs_diff <= 0: + update["$set"]["extraExecSecondsAvailable"] = 0 + else: + update["$inc"]["extraExecSecondsAvailable"] = extra_secs_diff + + if quotas.giftedExecMinutes is not None: + gifted_secs_diff = ( + quotas.giftedExecMinutes - previous_gifted_mins + ) * 60 + if org.giftedExecSecondsAvailable + gifted_secs_diff <= 0: + update["$set"]["giftedExecSecondsAvailable"] = 0 + else: + update["$inc"]["giftedExecSecondsAvailable"] = gifted_secs_diff - await self.orgs.find_one_and_update({"_id": org.id}, update) + await self.orgs.find_one_and_update( + {"_id": org.id}, update, session=session + ) + except Exception as e: + print(f"Error updating organization quotas: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e async def update_event_webhook_urls( self, org: Organization, urls: OrgWebhookUrls @@ -1543,6 +1573,7 @@ async def inc_org_bytes_stored_field(self, oid: UUID, field: str, size: int): # pylint: disable=too-many-statements, too-many-arguments def init_orgs_api( app: APIRouter, + dbclient: AsyncIOMotorClient, mdb: AsyncIOMotorDatabase[Any], user_manager: UserManager, crawl_manager: CrawlManager, @@ -1552,7 +1583,7 @@ def init_orgs_api( """Init organizations api router for /orgs""" # pylint: disable=too-many-locals,invalid-name - ops = OrgOps(mdb, invites, user_manager, crawl_manager) + ops = OrgOps(dbclient, mdb, invites, user_manager, crawl_manager) async def org_dep(oid: UUID, user: User = Depends(user_dep)): org = await ops.get_org_for_user_by_id(oid, user)