Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions datagateway_api/src/common/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Annotated, Optional

from pydantic import (
AfterValidator,
BaseModel,
field_validator,
StrictBool,
StrictInt,
StrictStr,
ValidationError,
validator,
)
import yaml

Expand Down Expand Up @@ -38,6 +39,9 @@ def validate_extension(extension):
return extension


DataGatewayAPIExtension = Annotated[StrictStr, AfterValidator(validate_extension)]


class UseReaderForPerformance(BaseModel):
enabled: StrictBool
reader_mechanism: StrictStr
Expand All @@ -54,21 +58,14 @@ class DataGatewayAPI(BaseModel):
client_cache_size: StrictInt
client_pool_init_size: StrictInt
client_pool_max_size: StrictInt
extension: StrictStr
extension: DataGatewayAPIExtension
icat_check_cert: StrictBool
icat_url: StrictStr
use_reader_for_performance: Optional[UseReaderForPerformance]

_validate_extension = validator("extension", allow_reuse=True)(validate_extension)
use_reader_for_performance: Optional[UseReaderForPerformance] = None

def __getitem__(self, item):
return getattr(self, item)

class Config:
"""
The behaviour of the BaseModel class can be controlled via this class.
"""


class SearchScoring(BaseModel):
enabled: StrictBool
Expand All @@ -84,16 +81,14 @@ class SearchAPI(BaseModel):
validation of the SearchAPI config data using Python type annotations.
"""

extension: StrictStr
extension: DataGatewayAPIExtension
icat_check_cert: StrictBool
icat_url: StrictStr
mechanism: StrictStr
username: StrictStr
password: StrictStr
search_scoring: SearchScoring

_validate_extension = validator("extension", allow_reuse=True)(validate_extension)

def __getitem__(self, item):
return getattr(self, item)

Expand Down Expand Up @@ -124,20 +119,18 @@ class APIConfig(BaseModel):
API startup so any missing options will be caught quickly.
"""

datagateway_api: Optional[DataGatewayAPI]
debug_mode: Optional[StrictBool]
flask_reloader: Optional[StrictBool]
datagateway_api: Optional[DataGatewayAPI] = None
debug_mode: Optional[StrictBool] = None
flask_reloader: Optional[StrictBool] = None
generate_swagger: StrictBool
host: Optional[StrictStr]
host: Optional[StrictStr] = None
log_level: StrictStr
log_location: StrictStr
port: Optional[StrictStr]
search_api: Optional[SearchAPI]
test_mechanism: Optional[StrictStr]
test_user_credentials: Optional[TestUserCredentials]
url_prefix: StrictStr

_validate_extension = validator("url_prefix", allow_reuse=True)(validate_extension)
port: Optional[StrictStr] = None
search_api: Optional[SearchAPI] = None
test_mechanism: Optional[StrictStr] = None
url_prefix: DataGatewayAPIExtension
test_user_credentials: Optional[TestUserCredentials] = None

def __getitem__(self, item):
return getattr(self, item)
Expand Down Expand Up @@ -170,22 +163,23 @@ def load(cls, path=None):
except (IOError, ValidationError) as error:
sys.exit(f"An error occurred while trying to load the config data: {error}")

@validator("search_api")
def validate_api_extensions(cls, value, values): # noqa: B902, N805
@field_validator("search_api")
@classmethod
def validate_api_extensions(cls, value, info): # noqa: B902, N805
"""
Checks that the DataGateway API and Search API extensions are not the same. An
error is raised, at which point the application exits, if the extensions are the
same.

:param cls: :class:`APIConfig` pointer
:param value: The value of the given config field
:param values: The config field values loaded before the given config field
:param info: The config field values loaded before the given config field
"""
if (
"datagateway_api" in values
and values["datagateway_api"] is not None
"datagateway_api" in info.data
and info.data["datagateway_api"] is not None
and value is not None
and values["datagateway_api"].extension == value.extension
and info.data["datagateway_api"].extension == value.extension
):
raise ValueError(
"extension cannot be the same as datagateway_api extension",
Expand Down
Loading
Loading