Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backend/btrixcloud/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@

import importlib.util
import os
import urllib
import urllib.parse
import asyncio
from uuid import UUID, uuid4

from typing import Optional, Union, TypeVar, Type, TYPE_CHECKING
from typing import (
Optional,
Type,
TypeVar,
Union,
TYPE_CHECKING,
)

from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from pydantic import BaseModel
Expand Down
1 change: 1 addition & 0 deletions backend/btrixcloud/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def main() -> None:

org_ops = init_orgs_api(
app,
dbclient,
mdb,
user_manager,
crawl_manager,
Expand Down
2 changes: 1 addition & 1 deletion backend/btrixcloud/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def init_ops() -> Tuple[

user_manager = UserManager(mdb, email, invite_ops)

org_ops = OrgOps(mdb, invite_ops, user_manager, crawl_manager)
org_ops = OrgOps(dbclient, mdb, invite_ops, user_manager, crawl_manager)

event_webhook_ops = EventWebhookOps(mdb, org_ops)

Expand Down
173 changes: 102 additions & 71 deletions backend/btrixcloud/orgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
Any,
)

from motor.motor_asyncio import AsyncIOMotorDatabase
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorClientSession,
AsyncIOMotorDatabase,
)
from pydantic import ValidationError
from pymongo import ReturnDocument
from pymongo.errors import AutoReconnect, DuplicateKeyError
Expand Down Expand Up @@ -202,11 +206,14 @@ class OrgOps(BaseOrgs):

def __init__(
self,
mdb,
dbclient: AsyncIOMotorClient,
mdb: AsyncIOMotorDatabase,
invites: InviteOps,
user_manager: UserManager,
crawl_manager: CrawlManager,
):
self.dbclient = dbclient

self.orgs = mdb["organizations"]
self.crawls_db = mdb["crawls"]
self.crawl_configs_db = mdb["crawl_configs"]
Expand Down Expand Up @@ -384,9 +391,11 @@ async def get_users_for_org(
users.append(User(**user_dict))
return users

async def get_org_by_id(self, oid: UUID) -> Organization:
async def get_org_by_id(
self, oid: UUID, session: AsyncIOMotorClientSession | None = None
) -> Organization:
"""Get an org by id"""
res = await self.orgs.find_one({"_id": oid})
res = await self.orgs.find_one({"_id": oid}, session=session)
if not res:
raise HTTPException(status_code=400, detail="invalid_org_id")

Expand Down Expand Up @@ -665,79 +674,100 @@ async def update_quotas(
quotas: OrgQuotasIn,
mode: Literal["set", "add"],
sub_event_id: str | None = None,
session: AsyncIOMotorClientSession | None = None,
) -> None:
"""update organization quotas"""

previous_extra_mins = (
org.quotas.extraExecMinutes
if (org.quotas and org.quotas.extraExecMinutes)
else 0
)
previous_gifted_mins = (
org.quotas.giftedExecMinutes
if (org.quotas and org.quotas.giftedExecMinutes)
else 0
)

if mode == "add":
increment_update: dict[str, Any] = {
"$inc": {},
}

for field, value in quotas.model_dump(
exclude_unset=True, exclude_defaults=True, exclude_none=True
).items():
if field == "context" or value is None:
continue
inc = max(value, -org.quotas.model_dump().get(field, 0))
increment_update["$inc"][f"quotas.{field}"] = inc

updated_org = await self.orgs.find_one_and_update(
{"_id": org.id},
increment_update,
projection={"quotas": True},
return_document=ReturnDocument.AFTER,
)
quotas = OrgQuotasIn(**updated_org["quotas"])

update: dict[str, dict[str, dict[str, Any] | int]] = {
"$push": {
"quotaUpdates": OrgQuotaUpdate(
modified=dt_now(),
update=OrgQuotas(
**quotas.model_dump(
exclude_unset=True, exclude_defaults=True, exclude_none=True
)
),
subEventId=sub_event_id,
).model_dump()
},
"$inc": {},
"$set": {},
}
async with await self.dbclient.start_session(
causal_consistency=True
) as session:
try:
# Re-fetch the organization within the session
# so that the operation as a whole is atomic.
org = await self.get_org_by_id(org.id, session=session)

previous_extra_mins = (
org.quotas.extraExecMinutes
if (org.quotas and org.quotas.extraExecMinutes)
else 0
)
previous_gifted_mins = (
org.quotas.giftedExecMinutes
if (org.quotas and org.quotas.giftedExecMinutes)
else 0
)

if mode == "set":
increment_update = quotas.model_dump(
exclude_unset=True, exclude_defaults=True, exclude_none=True
)
update["$set"]["quotas"] = increment_update
if mode == "add":
increment_update: dict[str, Any] = {
"$inc": {},
}

# Inc org available fields for extra/gifted execution time as needed
if quotas.extraExecMinutes is not None:
extra_secs_diff = (quotas.extraExecMinutes - previous_extra_mins) * 60
if org.extraExecSecondsAvailable + extra_secs_diff <= 0:
update["$set"]["extraExecSecondsAvailable"] = 0
else:
update["$inc"]["extraExecSecondsAvailable"] = extra_secs_diff
for field, value in quotas.model_dump(
exclude_unset=True, exclude_defaults=True, exclude_none=True
).items():
if value is None:
continue
inc = max(value, -org.quotas.model_dump().get(field, 0))
increment_update["$inc"][f"quotas.{field}"] = inc

updated_org = await self.orgs.find_one_and_update(
{"_id": org.id},
increment_update,
projection={"quotas": True},
return_document=ReturnDocument.AFTER,
session=session,
)
quotas = OrgQuotasIn(**updated_org["quotas"])

update: dict[str, dict[str, dict[str, Any] | int]] = {
"$push": {
"quotaUpdates": OrgQuotaUpdate(
modified=dt_now(),
update=OrgQuotas(
**quotas.model_dump(
exclude_unset=True,
exclude_defaults=True,
exclude_none=True,
)
),
subEventId=sub_event_id,
).model_dump()
},
"$inc": {},
"$set": {},
}

if quotas.giftedExecMinutes is not None:
gifted_secs_diff = (quotas.giftedExecMinutes - previous_gifted_mins) * 60
if org.giftedExecSecondsAvailable + gifted_secs_diff <= 0:
update["$set"]["giftedExecSecondsAvailable"] = 0
else:
update["$inc"]["giftedExecSecondsAvailable"] = gifted_secs_diff
if mode == "set":
increment_update = quotas.model_dump(
exclude_unset=True, exclude_defaults=True, exclude_none=True
)
update["$set"]["quotas"] = increment_update

# Inc org available fields for extra/gifted execution time as needed
if quotas.extraExecMinutes is not None:
extra_secs_diff = (
quotas.extraExecMinutes - previous_extra_mins
) * 60
if org.extraExecSecondsAvailable + extra_secs_diff <= 0:
update["$set"]["extraExecSecondsAvailable"] = 0
else:
update["$inc"]["extraExecSecondsAvailable"] = extra_secs_diff

if quotas.giftedExecMinutes is not None:
gifted_secs_diff = (
quotas.giftedExecMinutes - previous_gifted_mins
) * 60
if org.giftedExecSecondsAvailable + gifted_secs_diff <= 0:
update["$set"]["giftedExecSecondsAvailable"] = 0
else:
update["$inc"]["giftedExecSecondsAvailable"] = gifted_secs_diff

await self.orgs.find_one_and_update({"_id": org.id}, update)
await self.orgs.find_one_and_update(
{"_id": org.id}, update, session=session
)
except Exception as e:
print(f"Error updating organization quotas: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

async def update_event_webhook_urls(
self, org: Organization, urls: OrgWebhookUrls
Expand Down Expand Up @@ -1543,6 +1573,7 @@ async def inc_org_bytes_stored_field(self, oid: UUID, field: str, size: int):
# pylint: disable=too-many-statements, too-many-arguments
def init_orgs_api(
app: APIRouter,
dbclient: AsyncIOMotorClient,
mdb: AsyncIOMotorDatabase[Any],
user_manager: UserManager,
crawl_manager: CrawlManager,
Expand All @@ -1552,7 +1583,7 @@ def init_orgs_api(
"""Init organizations api router for /orgs"""
# pylint: disable=too-many-locals,invalid-name

ops = OrgOps(mdb, invites, user_manager, crawl_manager)
ops = OrgOps(dbclient, mdb, invites, user_manager, crawl_manager)

async def org_dep(oid: UUID, user: User = Depends(user_dep)):
org = await ops.get_org_for_user_by_id(oid, user)
Expand Down
Loading