diff --git a/bson/json_util.py b/bson/json_util.py index ecae103b55..1a3b0bd833 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -844,7 +844,7 @@ def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any: return {"$binary": {"base64": base64.b64encode(data).decode(), "subType": "%02x" % subtype}} -def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: +def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg] if ( json_options.datetime_representation == DatetimeRepresentation.ISO8601 and 0 <= int(obj) <= _MAX_UTC_MS @@ -855,7 +855,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: return {"$date": {"$numberLong": str(int(obj))}} -def _encode_code(obj: Code, json_options: JSONOptions) -> dict: +def _encode_code(obj: Code, json_options: JSONOptions) -> dict: # type: ignore[type-arg] if obj.scope is None: return {"$code": str(obj)} else: @@ -873,7 +873,7 @@ def _encode_noop(obj: Any, dummy0: Any) -> Any: return obj -def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: +def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg] flags = "" if obj.flags & re.IGNORECASE: flags += "i" @@ -918,7 +918,7 @@ def _encode_float(obj: float, json_options: JSONOptions) -> Any: return obj -def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: +def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: # type: ignore[type-arg] if json_options.datetime_representation == DatetimeRepresentation.ISO8601: if not obj.tzinfo: obj = obj.replace(tzinfo=utc) @@ -941,15 +941,15 @@ def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: return {"$date": {"$numberLong": str(millis)}} -def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: +def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: # type: ignore[type-arg] return _encode_binary(obj, 0, json_options) -def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: +def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: # type: ignore[type-arg] return _encode_binary(obj, obj.subtype, json_options) -def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: +def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: # type: ignore[type-arg] if json_options.strict_uuid: binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation) return _encode_binary(binval, binval.subtype, json_options) @@ -957,27 +957,27 @@ def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: return {"$uuid": obj.hex} -def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: +def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: # type: ignore[type-arg] return {"$oid": str(obj)} -def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: +def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg] return {"$timestamp": {"t": obj.time, "i": obj.inc}} -def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: +def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg] return {"$numberDecimal": str(obj)} -def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: +def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: # type: ignore[type-arg] return _json_convert(obj.as_doc(), json_options=json_options) -def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: +def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg] return {"$minKey": 1} -def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: +def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg] return {"$maxKey": 1} @@ -985,7 +985,7 @@ def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: # Each encoder function's signature is: # - obj: a Python data type, e.g. a Python int for _encode_int # - json_options: a JSONOptions -_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { +_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { # type: ignore[type-arg] bool: _encode_noop, bytes: _encode_bytes, datetime.datetime: _encode_datetime, @@ -1056,7 +1056,7 @@ def _get_datetime_size(obj: datetime.datetime) -> int: return 5 + len(str(obj.time())) -def _get_regex_size(obj: Regex) -> int: +def _get_regex_size(obj: Regex) -> int: # type: ignore[type-arg] return 18 + len(obj.pattern) diff --git a/bson/typings.py b/bson/typings.py index b80c661454..55e90b19a5 100644 --- a/bson/typings.py +++ b/bson/typings.py @@ -28,4 +28,4 @@ _DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"] _DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any]) _DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any]) -_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] +_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg] diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 3c7d4ef0e9..e512f796a8 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[AsyncClientSession]) -> None: class AsyncGridFS: """An instance of GridFS on top of a single Database.""" - def __init__(self, database: AsyncDatabase, collection: str = "fs"): + def __init__(self, database: AsyncDatabase[Any], collection: str = "fs"): """Create a new instance of :class:`GridFS`. Raises :class:`TypeError` if `database` is not an instance of @@ -463,7 +463,7 @@ class AsyncGridFSBucket: def __init__( self, - db: AsyncDatabase, + db: AsyncDatabase[Any], bucket_name: str = "fs", chunk_size_bytes: int = DEFAULT_CHUNK_SIZE, write_concern: Optional[WriteConcern] = None, @@ -513,11 +513,11 @@ def __init__( self._bucket_name = bucket_name self._collection = db[bucket_name] - self._chunks: AsyncCollection = self._collection.chunks.with_options( + self._chunks: AsyncCollection[Any] = self._collection.chunks.with_options( write_concern=write_concern, read_preference=read_preference ) - self._files: AsyncCollection = self._collection.files.with_options( + self._files: AsyncCollection[Any] = self._collection.files.with_options( write_concern=write_concern, read_preference=read_preference ) @@ -1085,7 +1085,7 @@ class AsyncGridIn: def __init__( self, - root_collection: AsyncCollection, + root_collection: AsyncCollection[Any], session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> None: @@ -1172,7 +1172,7 @@ def __init__( object.__setattr__(self, "_buffered_docs_size", 0) async def _create_index( - self, collection: AsyncCollection, index_key: Any, unique: bool + self, collection: AsyncCollection[Any], index_key: Any, unique: bool ) -> None: doc = await collection.find_one(projection={"_id": 1}, session=self._session) if doc is None: @@ -1456,7 +1456,7 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore def __init__( self, - root_collection: AsyncCollection, + root_collection: AsyncCollection[Any], file_id: Optional[int] = None, file_document: Optional[Any] = None, session: Optional[AsyncClientSession] = None, @@ -1829,7 +1829,7 @@ class _AsyncGridOutChunkIterator: def __init__( self, grid_out: AsyncGridOut, - chunks: AsyncCollection, + chunks: AsyncCollection[Any], session: Optional[AsyncClientSession], next_chunk: Any, ) -> None: @@ -1842,7 +1842,7 @@ def __init__( self._num_chunks = math.ceil(float(self._length) / self._chunk_size) self._cursor = None - _cursor: Optional[AsyncCursor] + _cursor: Optional[AsyncCursor[Any]] def expected_chunk_length(self, chunk_n: int) -> int: if chunk_n < self._num_chunks - 1: @@ -1921,7 +1921,7 @@ async def close(self) -> None: class AsyncGridOutIterator: def __init__( - self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession + self, grid_out: AsyncGridOut, chunks: AsyncCollection[Any], session: AsyncClientSession ): self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0) @@ -1935,14 +1935,14 @@ async def next(self) -> bytes: __anext__ = next -class AsyncGridOutCursor(AsyncCursor): +class AsyncGridOutCursor(AsyncCursor): # type: ignore[type-arg] """A cursor / iterator for returning GridOut objects as the result of an arbitrary query against the GridFS files collection. """ def __init__( self, - collection: AsyncCollection, + collection: AsyncCollection[Any], filter: Optional[Mapping[str, Any]] = None, skip: int = 0, limit: int = 0, diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index d0a4c7fc7f..70a4f80774 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None: class GridFS: """An instance of GridFS on top of a single Database.""" - def __init__(self, database: Database, collection: str = "fs"): + def __init__(self, database: Database[Any], collection: str = "fs"): """Create a new instance of :class:`GridFS`. Raises :class:`TypeError` if `database` is not an instance of @@ -461,7 +461,7 @@ class GridFSBucket: def __init__( self, - db: Database, + db: Database[Any], bucket_name: str = "fs", chunk_size_bytes: int = DEFAULT_CHUNK_SIZE, write_concern: Optional[WriteConcern] = None, @@ -511,11 +511,11 @@ def __init__( self._bucket_name = bucket_name self._collection = db[bucket_name] - self._chunks: Collection = self._collection.chunks.with_options( + self._chunks: Collection[Any] = self._collection.chunks.with_options( write_concern=write_concern, read_preference=read_preference ) - self._files: Collection = self._collection.files.with_options( + self._files: Collection[Any] = self._collection.files.with_options( write_concern=write_concern, read_preference=read_preference ) @@ -1077,7 +1077,7 @@ class GridIn: def __init__( self, - root_collection: Collection, + root_collection: Collection[Any], session: Optional[ClientSession] = None, **kwargs: Any, ) -> None: @@ -1163,7 +1163,7 @@ def __init__( object.__setattr__(self, "_buffered_docs", []) object.__setattr__(self, "_buffered_docs_size", 0) - def _create_index(self, collection: Collection, index_key: Any, unique: bool) -> None: + def _create_index(self, collection: Collection[Any], index_key: Any, unique: bool) -> None: doc = collection.find_one(projection={"_id": 1}, session=self._session) if doc is None: try: @@ -1444,7 +1444,7 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore def __init__( self, - root_collection: Collection, + root_collection: Collection[Any], file_id: Optional[int] = None, file_document: Optional[Any] = None, session: Optional[ClientSession] = None, @@ -1817,7 +1817,7 @@ class GridOutChunkIterator: def __init__( self, grid_out: GridOut, - chunks: Collection, + chunks: Collection[Any], session: Optional[ClientSession], next_chunk: Any, ) -> None: @@ -1830,7 +1830,7 @@ def __init__( self._num_chunks = math.ceil(float(self._length) / self._chunk_size) self._cursor = None - _cursor: Optional[Cursor] + _cursor: Optional[Cursor[Any]] def expected_chunk_length(self, chunk_n: int) -> int: if chunk_n < self._num_chunks - 1: @@ -1908,7 +1908,7 @@ def close(self) -> None: class GridOutIterator: - def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession): + def __init__(self, grid_out: GridOut, chunks: Collection[Any], session: ClientSession): self._chunk_iter = GridOutChunkIterator(grid_out, chunks, session, 0) def __iter__(self) -> GridOutIterator: @@ -1921,14 +1921,14 @@ def next(self) -> bytes: __next__ = next -class GridOutCursor(Cursor): +class GridOutCursor(Cursor): # type: ignore[type-arg] """A cursor / iterator for returning GridOut objects as the result of an arbitrary query against the GridFS files collection. """ def __init__( self, - collection: Collection, + collection: Collection[Any], filter: Optional[Mapping[str, Any]] = None, skip: int = 0, limit: int = 0, diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py index a9c409d486..5ca09982fa 100644 --- a/pymongo/_asyncio_lock.py +++ b/pymongo/_asyncio_lock.py @@ -93,7 +93,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin): """ def __init__(self) -> None: - self._waiters: Optional[collections.deque] = None + self._waiters: Optional[collections.deque[Any]] = None self._locked = False def __repr__(self) -> str: @@ -196,7 +196,7 @@ def __init__(self, lock: Optional[Lock] = None) -> None: self.acquire = lock.acquire self.release = lock.release - self._waiters: collections.deque = collections.deque() + self._waiters: collections.deque[Any] = collections.deque() def __repr__(self) -> str: res = super().__repr__() @@ -260,7 +260,7 @@ async def wait(self) -> bool: self._notify(1) raise - async def wait_for(self, predicate: Any) -> Coroutine: + async def wait_for(self, predicate: Any) -> Coroutine[Any, Any, Any]: """Wait until a predicate becomes true. The predicate should be a callable whose result will be diff --git a/pymongo/_asyncio_task.py b/pymongo/_asyncio_task.py index 7a528f027d..118471963a 100644 --- a/pymongo/_asyncio_task.py +++ b/pymongo/_asyncio_task.py @@ -24,7 +24,7 @@ # TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered -class _Task(asyncio.Task): +class _Task(asyncio.Task[Any]): def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None: super().__init__(coro, name=name) self._cancel_requests = 0 @@ -43,7 +43,7 @@ def cancelling(self) -> int: return self._cancel_requests -def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task: +def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task[Any]: if sys.version_info >= (3, 11): return asyncio.create_task(coro, name=name) return _Task(coro, name=name) diff --git a/pymongo/_csot.py b/pymongo/_csot.py index c5681e345a..ce72a66486 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -68,7 +68,7 @@ def clamp_remaining(max_timeout: float) -> float: return min(timeout, max_timeout) -class _TimeoutContext(AbstractContextManager): +class _TimeoutContext(AbstractContextManager[Any]): """Internal timeout context manager. Use :func:`pymongo.timeout` instead:: diff --git a/pymongo/asynchronous/aggregation.py b/pymongo/asynchronous/aggregation.py index daccd1bcb0..059d698772 100644 --- a/pymongo/asynchronous/aggregation.py +++ b/pymongo/asynchronous/aggregation.py @@ -46,8 +46,8 @@ class _AggregationCommand: def __init__( self, - target: Union[AsyncDatabase, AsyncCollection], - cursor_class: type[AsyncCommandCursor], + target: Union[AsyncDatabase[Any], AsyncCollection[Any]], + cursor_class: type[AsyncCommandCursor[Any]], pipeline: _Pipeline, options: MutableMapping[str, Any], explicit_session: bool, @@ -111,12 +111,12 @@ def _cursor_namespace(self) -> str: """The namespace in which the aggregate command is run.""" raise NotImplementedError - def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection: + def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection[Any]: """The AsyncCollection used for the aggregate command cursor.""" raise NotImplementedError @property - def _database(self) -> AsyncDatabase: + def _database(self) -> AsyncDatabase[Any]: """The database against which the aggregation command is run.""" raise NotImplementedError @@ -205,7 +205,7 @@ async def get_cursor( class _CollectionAggregationCommand(_AggregationCommand): - _target: AsyncCollection + _target: AsyncCollection[Any] @property def _aggregation_target(self) -> str: @@ -215,12 +215,12 @@ def _aggregation_target(self) -> str: def _cursor_namespace(self) -> str: return self._target.full_name - def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection: + def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]: """The AsyncCollection used for the aggregate command cursor.""" return self._target @property - def _database(self) -> AsyncDatabase: + def _database(self) -> AsyncDatabase[Any]: return self._target.database @@ -234,7 +234,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class _DatabaseAggregationCommand(_AggregationCommand): - _target: AsyncDatabase + _target: AsyncDatabase[Any] @property def _aggregation_target(self) -> int: @@ -245,10 +245,10 @@ def _cursor_namespace(self) -> str: return f"{self._target.name}.$cmd.aggregate" @property - def _database(self) -> AsyncDatabase: + def _database(self) -> AsyncDatabase[Any]: return self._target - def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection: + def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]: """The AsyncCollection used for the aggregate command cursor.""" # AsyncCollection level aggregate may not always return the "ns" field # according to our MockupDB tests. Let's handle that case for db level diff --git a/pymongo/asynchronous/auth_oidc.py b/pymongo/asynchronous/auth_oidc.py index 20b8340060..f8f046bd94 100644 --- a/pymongo/asynchronous/auth_oidc.py +++ b/pymongo/asynchronous/auth_oidc.py @@ -259,7 +259,7 @@ async def _sasl_continue_jwt( ) -> Mapping[str, Any]: self.access_token = None self.refresh_token = None - start_payload: dict = bson.decode(start_resp["payload"]) + start_payload: dict[str, Any] = bson.decode(start_resp["payload"]) if "issuer" in start_payload: self.idp_info = OIDCIdPInfo(**start_payload) access_token = await self._get_access_token() diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ac514db98f..4a54f9eb3f 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -248,7 +248,7 @@ async def write_command( request_id: int, msg: bytes, docs: list[Mapping[str, Any]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[bwc.field] = docs @@ -334,7 +334,7 @@ async def unack_write( msg: bytes, max_doc_size: int, docs: list[Mapping[str, Any]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): @@ -419,7 +419,7 @@ async def _execute_batch_unack( bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], cmd: dict[str, Any], ops: list[Mapping[str, Any]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> list[Mapping[str, Any]]: if self.is_encrypted: _, batched_cmd, to_send = bwc.batch_command(cmd, ops) @@ -446,7 +446,7 @@ async def _execute_batch( bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], cmd: dict[str, Any], ops: list[Mapping[str, Any]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: if self.is_encrypted: _, batched_cmd, to_send = bwc.batch_command(cmd, ops) diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index 6c37f9d05f..3940111df2 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -164,7 +164,7 @@ def _aggregation_command_class(self) -> Type[_AggregationCommand]: raise NotImplementedError @property - def _client(self) -> AsyncMongoClient: + def _client(self) -> AsyncMongoClient: # type: ignore[type-arg] """The client against which the aggregation commands for this AsyncChangeStream will be run. """ @@ -206,7 +206,7 @@ def _command_options(self) -> dict[str, Any]: def _aggregation_pipeline(self) -> list[dict[str, Any]]: """Return the full aggregation pipeline for this AsyncChangeStream.""" options = self._change_stream_options() - full_pipeline: list = [{"$changeStream": options}] + full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}] full_pipeline.extend(self._pipeline) return full_pipeline @@ -237,7 +237,7 @@ def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> N async def _run_aggregation_cmd( self, session: Optional[AsyncClientSession], explicit_session: bool - ) -> AsyncCommandCursor: + ) -> AsyncCommandCursor: # type: ignore[type-arg] """Run the full aggregation pipeline for this AsyncChangeStream and return the corresponding AsyncCommandCursor. """ @@ -257,7 +257,7 @@ async def _run_aggregation_cmd( operation=_Op.AGGREGATE, ) - async def _create_cursor(self) -> AsyncCommandCursor: + async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg] async with self._client._tmp_session(self._session, close=False) as s: return await self._run_aggregation_cmd( session=s, explicit_session=self._session is not None diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 5f7ac013e9..45812b3400 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -88,7 +88,7 @@ class _AsyncClientBulk: def __init__( self, - client: AsyncMongoClient, + client: AsyncMongoClient[Any], write_concern: WriteConcern, ordered: bool = True, bypass_document_validation: Optional[bool] = None, @@ -233,7 +233,7 @@ async def write_command( msg: Union[bytes, dict[str, Any]], op_docs: list[Mapping[str, Any]], ns_docs: list[Mapping[str, Any]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> dict[str, Any]: """A proxy for AsyncConnection.write_command that handles event publishing.""" cmd["ops"] = op_docs @@ -324,7 +324,7 @@ async def unack_write( msg: bytes, op_docs: list[Mapping[str, Any]], ns_docs: list[Mapping[str, Any]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 1225445710..c30fc6679f 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -396,7 +396,7 @@ class _TxnState: class _Transaction: """Internal class to hold transaction information in a AsyncClientSession.""" - def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient): + def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[Any]): self.opts = opts self.state = _TxnState.NONE self.sharded = False @@ -459,7 +459,7 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # From the transactions spec, all the retryable writes errors plus # WriteConcernTimeout. -_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( +_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # type: ignore[type-arg] [ 64, # WriteConcernTimeout 50, # MaxTimeMSExpired @@ -499,13 +499,13 @@ class AsyncClientSession: def __init__( self, - client: AsyncMongoClient, + client: AsyncMongoClient[Any], server_session: Any, options: SessionOptions, implicit: bool, ) -> None: # An AsyncMongoClient, a _ServerSession, a SessionOptions, and a set. - self._client: AsyncMongoClient = client + self._client: AsyncMongoClient[Any] = client self._server_session = server_session self._options = options self._cluster_time: Optional[Mapping[str, Any]] = None @@ -551,7 +551,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self._end_session(lock=True) @property - def client(self) -> AsyncMongoClient: + def client(self) -> AsyncMongoClient[Any]: """The :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` this session was created from. """ @@ -751,7 +751,7 @@ async def start_transaction( write_concern: Optional[WriteConcern] = None, read_preference: Optional[_ServerMode] = None, max_commit_time_ms: Optional[int] = None, - ) -> AsyncContextManager: + ) -> AsyncContextManager[Any]: """Start a multi-statement transaction. Takes the same arguments as :class:`TransactionOptions`. @@ -1123,7 +1123,7 @@ def inc_transaction_id(self) -> None: self._transaction_id += 1 -class _ServerSessionPool(collections.deque): +class _ServerSessionPool(collections.deque): # type: ignore[type-arg] """Pool of _ServerSession objects. This class is thread-safe. diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 7fb20b7ab3..313c8c7c04 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -581,7 +581,7 @@ async def _command( conn: AsyncConnection, command: MutableMapping[str, Any], read_preference: Optional[_ServerMode] = None, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_concern: Optional[ReadConcern] = None, @@ -704,7 +704,7 @@ async def bulk_write( bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, - let: Optional[Mapping] = None, + let: Optional[Mapping[str, Any]] = None, ) -> BulkWriteResult: """Send a batch of write operations to the server. @@ -2525,7 +2525,7 @@ async def _list_indexes( session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: - codec_options: CodecOptions = CodecOptions(SON) + codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON) coll = cast( AsyncCollection[MutableMapping[str, Any]], self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY), @@ -2871,7 +2871,7 @@ async def _aggregate( self, aggregation_command: Type[_AggregationCommand], pipeline: _Pipeline, - cursor_class: Type[AsyncCommandCursor], + cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg] session: Optional[AsyncClientSession], explicit_session: bool, let: Optional[Mapping[str, Any]] = None, @@ -3114,7 +3114,7 @@ async def distinct( comment: Optional[Any] = None, hint: Optional[_IndexKeyHint] = None, **kwargs: Any, - ) -> list: + ) -> list[str]: """Get a list of distinct values for `key` among all documents in this collection. @@ -3177,7 +3177,7 @@ async def _cmd( _server: Server, conn: AsyncConnection, read_preference: Optional[_ServerMode], - ) -> list: + ) -> list: # type: ignore[type-arg] return ( await self._command( conn, @@ -3202,7 +3202,7 @@ async def _find_and_modify( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - let: Optional[Mapping] = None, + let: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> Any: """Internal findAndModify helper.""" diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 353c5e91c2..db7c2b6638 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -350,7 +350,7 @@ async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]: else: return None - async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg] """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: await self._refresh() @@ -457,7 +457,7 @@ def _unpack_response( # type: ignore[override] self, response: Union[_OpReply, _OpMsg], cursor_id: Optional[int], - codec_options: CodecOptions, + codec_options: CodecOptions[dict[str, Any]], user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> list[Mapping[str, Any]]: diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 51efab4f43..ab2d0e873c 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -216,7 +216,7 @@ def __init__( # it anytime we change __limit. self._empty = False - self._data: deque = deque() + self._data: deque = deque() # type: ignore[type-arg] self._address: Optional[_Address] = None self._retrieved = 0 @@ -280,7 +280,7 @@ def clone(self) -> AsyncCursor[_DocumentType]: """ return self._clone(True) - def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: + def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg] """Internal clone helper.""" if not base: if self._explicit_session: @@ -322,7 +322,7 @@ def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> A base.__dict__.update(data) return base - def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor: + def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor: # type: ignore[type-arg] """Creates an empty AsyncCursor object for information to be copied into.""" return self.__class__(self._collection, session=session) @@ -864,7 +864,7 @@ def where(self, code: Union[str, Code]) -> AsyncCursor[_DocumentType]: if self._has_filter: spec = dict(self._spec) else: - spec = cast(dict, self._spec) + spec = cast(dict, self._spec) # type: ignore[type-arg] spec["$where"] = code self._spec = spec return self @@ -888,7 +888,7 @@ def _unpack_response( self, response: Union[_OpReply, _OpMsg], cursor_id: Optional[int], - codec_options: CodecOptions, + codec_options: CodecOptions, # type: ignore[type-arg] user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> Sequence[_DocumentOut]: @@ -964,29 +964,33 @@ def __deepcopy__(self, memo: Any) -> Any: return self._clone(deepcopy=True) @overload - def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: + def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: # type: ignore[type-arg] ... @overload def _deepcopy( - self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None - ) -> dict: + self, + x: SupportsItems, # type: ignore[type-arg] + memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg] + ) -> dict: # type: ignore[type-arg] ... def _deepcopy( - self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None - ) -> Union[list, dict]: + self, + x: Union[Iterable, SupportsItems], # type: ignore[type-arg] + memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg] + ) -> Union[list[Any], dict[str, Any]]: """Deepcopy helper for the data dictionary or list. Regular expressions cannot be deep copied but as they are immutable we don't have to copy them when cloning. """ - y: Union[list, dict] + y: Union[list[Any], dict[str, Any]] iterator: Iterable[tuple[Any, Any]] if not hasattr(x, "items"): y, is_list, iterator = [], True, enumerate(x) else: - y, is_list, iterator = {}, False, cast("SupportsItems", x).items() + y, is_list, iterator = {}, False, cast("SupportsItems", x).items() # type: ignore[type-arg] if memo is None: memo = {} val_id = id(x) @@ -1060,7 +1064,7 @@ async def close(self) -> None: """Explicitly close / kill this cursor.""" await self._die_lock() - async def distinct(self, key: str) -> list: + async def distinct(self, key: str) -> list[str]: """Get a list of distinct values for `key` among all documents in the result set of this query. @@ -1265,7 +1269,7 @@ async def next(self) -> _DocumentType: else: raise StopAsyncIteration - async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg] """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True @@ -1325,7 +1329,7 @@ async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: return res -class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]): +class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]): # type: ignore[type-arg] """An asynchronous cursor / iterator over raw batches of BSON data from a query result.""" _query_class = _RawBatchQuery diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index d0089eb4ee..09713c37ec 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -771,7 +771,7 @@ async def _command( self._name, command, read_preference, - codec_options, + codec_options, # type: ignore[arg-type] check, allowable_errors, write_concern=write_concern, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 3488030166..b988120d7c 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -161,10 +161,10 @@ _IS_SYNC = False _WriteOp = Union[ - InsertOne, + InsertOne, # type: ignore[type-arg] DeleteOne, DeleteMany, - ReplaceOne, + ReplaceOne, # type: ignore[type-arg] UpdateOne, UpdateMany, ] @@ -176,7 +176,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # Define order to retrieve options from ClientOptions for __repr__. # No host/port; these are retrieved from TopologySettings. _constructor_args = ("document_class", "tz_aware", "connect") - _clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + _clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg] def __init__( self, @@ -847,7 +847,7 @@ def __init__( self._default_database_name = dbase self._lock = _async_create_lock() - self._kill_cursors_queue: list = [] + self._kill_cursors_queue: list = [] # type: ignore[type-arg] self._encrypter: Optional[_Encrypter] = None @@ -1064,7 +1064,7 @@ def _after_fork(self) -> None: # Reset the session pool to avoid duplicate sessions in the child process. self._topology._session_pool.reset() - def _duplicate(self, **kwargs: Any) -> AsyncMongoClient: + def _duplicate(self, **kwargs: Any) -> AsyncMongoClient: # type: ignore[type-arg] args = self._init_kwargs.copy() args.update(kwargs) return AsyncMongoClient(**args) @@ -1548,7 +1548,7 @@ def get_database( self, name, codec_options, read_preference, write_concern, read_concern ) - def _database_default_options(self, name: str) -> database.AsyncDatabase: + def _database_default_options(self, name: str) -> database.AsyncDatabase: # type: ignore[type-arg] """Get a AsyncDatabase instance with the default settings.""" return self.get_database( name, @@ -1887,7 +1887,7 @@ async def _conn_for_reads( async def _run_operation( self, operation: Union[_Query, _GetMore], - unpack_res: Callable, + unpack_res: Callable, # type: ignore[type-arg] address: Optional[_Address] = None, ) -> Response: """Run a _Query/_GetMore operation and return a Response. @@ -2261,7 +2261,7 @@ def _return_server_session( @contextlib.asynccontextmanager async def _tmp_session( self, session: Optional[client_session.AsyncClientSession], close: bool = True - ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]: + ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]: """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.AsyncClientSession): @@ -2308,8 +2308,8 @@ async def server_info( .. versionchanged:: 3.6 Added ``session`` parameter. """ - return cast( - dict, + return cast( # type: ignore[redundant-cast] + dict[str, Any], await self.admin.command( "buildinfo", read_preference=ReadPreference.PRIMARY, session=session ), @@ -2438,13 +2438,13 @@ async def drop_database( @_csot.apply async def bulk_write( self, - models: Sequence[_WriteOp[_DocumentType]], + models: Sequence[_WriteOp], session: Optional[AsyncClientSession] = None, ordered: bool = True, verbose_results: bool = False, bypass_document_validation: Optional[bool] = None, comment: Optional[Any] = None, - let: Optional[Mapping] = None, + let: Optional[Mapping[str, Any]] = None, write_concern: Optional[WriteConcern] = None, ) -> ClientBulkWriteResult: """Send a batch of write operations, potentially across multiple namespaces, to the server. @@ -2631,7 +2631,10 @@ class _MongoClientErrorHandler: ) def __init__( - self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] + self, + client: AsyncMongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[AsyncClientSession], ): if not isinstance(client, AsyncMongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -2704,7 +2707,7 @@ class _ClientConnectionRetryable(Generic[T]): def __init__( self, - mongo_client: AsyncMongoClient, + mongo_client: AsyncMongoClient, # type: ignore[type-arg] func: _WriteCall[T] | _ReadCall[T], bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]], operation: str, diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 32b545380a..e067bd8c54 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -351,7 +351,7 @@ async def _check_once(self) -> ServerDescription: ) return sd - async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]: + async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]: # type: ignore[type-arg] """Return (Hello, round_trip_time). Can raise ConnectionFailure or OperationFailure. diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 1605efe92d..5a5dc7fa2c 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -66,7 +66,7 @@ async def command( read_preference: Optional[_ServerMode], codec_options: CodecOptions[_DocumentType], session: Optional[AsyncClientSession], - client: Optional[AsyncMongoClient], + client: Optional[AsyncMongoClient[Any]], check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, address: Optional[_Address] = None, diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 9a39883fc2..e215cafdc1 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -201,7 +201,7 @@ def set_conn_timeout(self, timeout: Optional[float]) -> None: self.conn.get_conn.settimeout(timeout) def apply_timeout( - self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]] + self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]] ) -> Optional[float]: # CSOT: use remaining timeout when set. timeout = _csot.remaining() @@ -255,7 +255,7 @@ def hello_cmd(self) -> dict[str, Any]: else: return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} - async def hello(self) -> Hello: + async def hello(self) -> Hello[dict[str, Any]]: return await self._hello(None, None) async def _hello( @@ -357,7 +357,7 @@ async def command( dbname: str, spec: MutableMapping[str, Any], read_preference: _ServerMode = ReadPreference.PRIMARY, - codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + codec_options: CodecOptions[Mapping[str, Any]] = DEFAULT_CODEC_OPTIONS, # type: ignore[assignment] check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_concern: Optional[ReadConcern] = None, @@ -365,7 +365,7 @@ async def command( parse_write_concern_error: bool = False, collation: Optional[_CollationIn] = None, session: Optional[AsyncClientSession] = None, - client: Optional[AsyncMongoClient] = None, + client: Optional[AsyncMongoClient[Any]] = None, retryable_write: bool = False, publish_events: bool = True, user_fields: Optional[Mapping[str, Any]] = None, @@ -417,7 +417,7 @@ async def command( spec, self.is_mongos, read_preference, - codec_options, + codec_options, # type: ignore[arg-type] session, client, check, @@ -489,7 +489,7 @@ async def unack_write(self, msg: bytes, max_doc_size: int) -> None: await self.send_message(msg, max_doc_size) async def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions + self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] ) -> dict[str, Any]: """Send "insert" etc. command, returning response as a dict. @@ -541,7 +541,7 @@ async def authenticate(self, reauthenticate: bool = False) -> None: ) def validate_session( - self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession] + self, client: Optional[AsyncMongoClient[Any]], session: Optional[AsyncClientSession] ) -> None: """Validate this session before use with client. @@ -598,7 +598,7 @@ def send_cluster_time( self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession], - client: Optional[AsyncMongoClient], + client: Optional[AsyncMongoClient[Any]], ) -> None: """Add $clusterTime.""" if client: @@ -732,7 +732,7 @@ def __init__( # LIFO pool. Sockets are ordered on idle time. Sockets claimed # and returned to pool from the left side. Stale sockets removed # from the right side. - self.conns: collections.deque = collections.deque() + self.conns: collections.deque[AsyncConnection] = collections.deque() self.active_contexts: set[_CancellationContext] = set() self.lock = _async_create_lock() self._max_connecting_cond = _async_create_condition(self.lock) @@ -839,8 +839,8 @@ async def _reset( if service_id is None: sockets, self.conns = self.conns, collections.deque() else: - discard: collections.deque = collections.deque() - keep: collections.deque = collections.deque() + discard: collections.deque = collections.deque() # type: ignore[type-arg] + keep: collections.deque = collections.deque() # type: ignore[type-arg] for conn in self.conns: if conn.service_id == service_id: discard.append(conn) @@ -866,7 +866,7 @@ async def _reset( if close: if not _IS_SYNC: await asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value] return_exceptions=True, ) else: @@ -903,7 +903,7 @@ async def _reset( ) if not _IS_SYNC: await asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value] return_exceptions=True, ) else: @@ -917,7 +917,7 @@ async def update_is_writable(self, is_writable: Optional[bool]) -> None: self.is_writable = is_writable async with self.lock: for _socket in self.conns: - _socket.update_is_writable(self.is_writable) + _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] async def reset( self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False @@ -956,7 +956,7 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: close_conns.append(self.conns.pop()) if not _IS_SYNC: await asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], + *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value] return_exceptions=True, ) else: @@ -1477,4 +1477,4 @@ def __del__(self) -> None: # not safe to acquire a lock in __del__. if _IS_SYNC: for conn in self.conns: - conn.close_conn(None) + conn.close_conn(None) # type: ignore[unused-coroutine] diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 0e0d53b96f..0f8565f6cc 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -66,7 +66,7 @@ def __init__( monitor: Monitor, topology_id: Optional[ObjectId] = None, listeners: Optional[_EventListeners] = None, - events: Optional[ReferenceType[Queue]] = None, + events: Optional[ReferenceType[Queue[Any]]] = None, ) -> None: """Represent one MongoDB server.""" self._description = server_description @@ -142,7 +142,7 @@ async def run_operation( read_preference: _ServerMode, listeners: Optional[_EventListeners], unpack_res: Callable[..., list[_DocumentOut]], - client: AsyncMongoClient, + client: AsyncMongoClient[Any], ) -> Response: """Run a _Query or _GetMore operation and return a Response object. diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 052f91afee..283aabc690 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -84,7 +84,7 @@ _pymongo_dir = str(Path(__file__).parent) -def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: +def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: # type: ignore[type-arg] q = queue_ref() if not q: return False # Cancel PeriodicExecutor. @@ -186,7 +186,7 @@ def __init__(self, topology_settings: TopologySettings): if self._publish_server or self._publish_tp: assert self._events is not None - weak: weakref.ReferenceType[queue.Queue] + weak: weakref.ReferenceType[queue.Queue[Any]] async def target() -> bool: return process_events_queue(weak) diff --git a/pymongo/client_options.py b/pymongo/client_options.py index bd27dd4eb0..8b4eea7e65 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -247,7 +247,7 @@ def connect(self) -> Optional[bool]: return self.__connect @property - def codec_options(self) -> CodecOptions: + def codec_options(self) -> CodecOptions[Any]: """A :class:`~bson.codec_options.CodecOptions` instance.""" return self.__codec_options diff --git a/pymongo/common.py b/pymongo/common.py index 96f9f87459..5210e72189 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -56,7 +56,7 @@ from pymongo.typings import _AgnosticClientSession -ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) +ORDERED_TYPES: Sequence[Type[Any]] = (SON, OrderedDict) # Defaults until we connect to a server and get updated limits. MAX_BSON_SIZE = 16 * (1024**2) @@ -166,7 +166,7 @@ def clean_node(node: str) -> tuple[str, int]: return host.lower(), port -def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn: +def raise_config_error(key: str, suggestions: Optional[list[str]] = None) -> NoReturn: """Raise ConfigurationError with the given key name.""" msg = f"Unknown option: {key}." if suggestions: @@ -411,7 +411,7 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]] if not isinstance(value, list): value = [value] - tag_sets: list = [] + tag_sets: list[dict[str, Any]] = [] for tag_set in value: if tag_set == "": tag_sets.append({}) @@ -497,7 +497,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni def validate_document_class( option: str, value: Any -) -> Union[Type[MutableMapping], Type[RawBSONDocument]]: +) -> Union[Type[MutableMapping[str, Any]], Type[RawBSONDocument]]: """Validate the document_class option.""" # issubclass can raise TypeError for generic aliases like SON[str, Any]. # In that case we can use the base class for the comparison. @@ -523,14 +523,14 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: return value -def validate_list(option: str, value: Any) -> list: +def validate_list(option: str, value: Any) -> list[Any]: """Validates that 'value' is a list.""" if not isinstance(value, list): raise TypeError(f"{option} must be a list, not {type(value)}") return value -def validate_list_or_none(option: Any, value: Any) -> Optional[list]: +def validate_list_or_none(option: Any, value: Any) -> Optional[list[Any]]: """Validates that 'value' is a list or None.""" if value is None: return value @@ -597,7 +597,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: return value -def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: +def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable[..., Any]]: """Validates that 'value' is a callable.""" if value is None: return value @@ -829,7 +829,7 @@ def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: def _get_validator( key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None -) -> Callable: +) -> Callable[[Any, Any], Any]: normed_key = normed_key or key try: return validators[normed_key] @@ -917,7 +917,7 @@ class BaseObject: def __init__( self, - codec_options: CodecOptions, + codec_options: CodecOptions[Any], read_preference: _ServerMode, write_concern: WriteConcern, read_concern: ReadConcern, @@ -947,7 +947,7 @@ def __init__( self._read_concern = read_concern @property - def codec_options(self) -> CodecOptions: + def codec_options(self) -> CodecOptions[Any]: """Read only access to the :class:`~bson.codec_options.CodecOptions` of this instance. """ diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index e9ad1c1e01..cf686f6ab5 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: from pymongo.pyopenssl_context import SSLContext - from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg + from pymongo.typings import _AgnosticMongoClient class AutoEncryptionOpts: @@ -47,7 +47,7 @@ def __init__( self, kms_providers: Mapping[str, Any], key_vault_namespace: str, - key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None, + key_vault_client: Optional[_AgnosticMongoClient] = None, schema_map: Optional[Mapping[str, Any]] = None, bypass_auto_encryption: bool = False, mongocryptd_uri: str = "mongodb://localhost:27020", diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py index a664e87a69..9646c0691a 100644 --- a/pymongo/helpers_shared.py +++ b/pymongo/helpers_shared.py @@ -52,7 +52,7 @@ # From the SDAM spec, the "node is shutting down" codes. -_SHUTDOWN_CODES: frozenset = frozenset( +_SHUTDOWN_CODES: frozenset[int] = frozenset( [ 11600, # InterruptedAtShutdown 91, # ShutdownInProgress @@ -61,7 +61,7 @@ # From the SDAM spec, the "not primary" error codes are combined with the # "node is recovering" error codes (of which the "node is shutting down" # errors are a subset). -_NOT_PRIMARY_CODES: frozenset = ( +_NOT_PRIMARY_CODES: frozenset[int] = ( frozenset( [ 10058, # LegacyNotPrimary <=3.2 "not primary" error code @@ -75,7 +75,7 @@ | _SHUTDOWN_CODES ) # From the retryable writes spec. -_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( +_RETRYABLE_ERROR_CODES: frozenset[int] = _NOT_PRIMARY_CODES | frozenset( [ 7, # HostNotFound 6, # HostUnreachable @@ -95,7 +95,7 @@ # Note - to avoid bugs from forgetting which if these is all lowercase and # which are camelCase, and at the same time avoid having to add a test for # every command, use all lowercase here and test against command_name.lower(). -_SENSITIVE_COMMANDS: set = { +_SENSITIVE_COMMANDS: set[str] = { "authenticate", "saslstart", "saslcontinue", diff --git a/pymongo/message.py b/pymongo/message.py index d51c77a174..b2e5a685af 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -333,7 +333,7 @@ def _op_msg_no_header( command: Mapping[str, Any], identifier: str, docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, + opts: CodecOptions[Any], ) -> tuple[bytes, int, int]: """Get a OP_MSG message. @@ -365,7 +365,7 @@ def _op_msg_compressed( command: Mapping[str, Any], identifier: str, docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: Union[SnappyContext, ZlibContext, ZstdContext], ) -> tuple[int, bytes, int, int]: """Internal OP_MSG message helper.""" @@ -379,7 +379,7 @@ def _op_msg_uncompressed( command: Mapping[str, Any], identifier: str, docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, + opts: CodecOptions[Any], ) -> tuple[int, bytes, int, int]: """Internal compressed OP_MSG message helper.""" data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) @@ -396,7 +396,7 @@ def _op_msg( command: MutableMapping[str, Any], dbname: str, read_preference: Optional[_ServerMode], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, ) -> tuple[int, bytes, int, int]: """Get a OP_MSG message.""" @@ -430,7 +430,7 @@ def _query_impl( num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ) -> tuple[bytes, int]: """Get an OP_QUERY message.""" encoded = _dict_to_bson(query, False, opts) @@ -461,7 +461,7 @@ def _query_compressed( num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: Union[SnappyContext, ZlibContext, ZstdContext], ) -> tuple[int, bytes, int]: """Internal compressed query message helper.""" @@ -479,7 +479,7 @@ def _query_uncompressed( num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ) -> tuple[int, bytes, int]: """Internal query message helper.""" op_query, max_bson_size = _query_impl( @@ -500,7 +500,7 @@ def _query( num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, ) -> tuple[int, bytes, int]: """Get a **query** message.""" @@ -598,7 +598,7 @@ def __init__( listeners: _EventListeners, session: Optional[_AgnosticClientSession], op_type: int, - codec: CodecOptions, + codec: CodecOptions[Any], ): self.db_name = database_name self.conn = conn @@ -679,7 +679,7 @@ def __init__( listeners: _EventListeners, session: Optional[_AgnosticClientSession], op_type: int, - codec: CodecOptions, + codec: CodecOptions[Any], ): super().__init__( database_name, @@ -771,7 +771,7 @@ def _batched_op_msg_impl( command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, buf: _BytesIO, ) -> tuple[list[Mapping[str, Any]], int]: @@ -839,7 +839,7 @@ def _encode_batched_op_msg( command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, ) -> tuple[bytes, list[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete operation @@ -860,7 +860,7 @@ def _batched_op_msg_compressed( command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation @@ -878,7 +878,7 @@ def _batched_op_msg( command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]]]: """OP_MSG implementation entry point.""" @@ -910,7 +910,7 @@ def _do_batched_op_msg( operation: int, command: MutableMapping[str, Any], docs: list[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation @@ -939,7 +939,7 @@ def __init__( operation_id: int, listeners: _EventListeners, session: Optional[_AgnosticClientSession], - codec: CodecOptions, + codec: CodecOptions[Any], ): super().__init__( database_name, @@ -1043,7 +1043,7 @@ def _client_batched_op_msg_impl( operations: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _ClientBulkWriteContext, buf: _BytesIO, ) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]], int]: @@ -1161,7 +1161,7 @@ def _client_encode_batched_op_msg( operations: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _ClientBulkWriteContext, ) -> tuple[bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Encode the next batched client-level bulkWrite @@ -1180,7 +1180,7 @@ def _client_batched_op_msg_compressed( operations: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _ClientBulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Create the next batched client-level bulkWrite operation @@ -1200,7 +1200,7 @@ def _client_batched_op_msg( operations: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], ack: bool, - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _ClientBulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]: """OP_MSG implementation entry point for client-level bulkWrite.""" @@ -1229,7 +1229,7 @@ def _client_do_batched_op_msg( command: MutableMapping[str, Any], operations: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _ClientBulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Create the next batched client-level bulkWrite @@ -1253,7 +1253,7 @@ def _encode_batched_write_command( operation: int, command: MutableMapping[str, Any], docs: list[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, ) -> tuple[bytes, list[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete command.""" @@ -1272,7 +1272,7 @@ def _batched_write_command_impl( operation: int, command: MutableMapping[str, Any], docs: list[Mapping[str, Any]], - opts: CodecOptions, + opts: CodecOptions[Any], ctx: _BulkWriteContext, buf: _BytesIO, ) -> tuple[list[Mapping[str, Any]], int]: @@ -1383,7 +1383,7 @@ def raw_response( errobj = {"ok": 0, "errmsg": msg, "code": 43} raise CursorNotFound(msg, 43, errobj) elif self.flags & 2: - error_object: dict = bson.BSON(self.documents).decode() + error_object: dict[str, Any] = bson.BSON(self.documents).decode() # Fake the ok field if it doesn't exist. error_object.setdefault("ok", 0) if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): @@ -1405,7 +1405,7 @@ def raw_response( def unpack_response( self, cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + codec_options: CodecOptions[Any] = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> list[dict[str, Any]]: @@ -1431,7 +1431,7 @@ def unpack_response( return bson.decode_all(self.documents, codec_options) return bson._decode_all_selective(self.documents, codec_options, user_fields) - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]: """Unpack a command response.""" docs = self.unpack_response(codec_options=codec_options) assert self.number_returned == 1 @@ -1491,7 +1491,7 @@ def raw_response( def unpack_response( self, cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + codec_options: CodecOptions[Any] = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> list[dict[str, Any]]: @@ -1508,7 +1508,7 @@ def unpack_response( assert not legacy_response return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]: """Unpack a command response.""" return self.unpack_response(codec_options=codec_options)[0] @@ -1583,7 +1583,7 @@ def __init__( ntoskip: int, spec: Mapping[str, Any], fields: Optional[Mapping[str, Any]], - codec_options: CodecOptions, + codec_options: CodecOptions[Any], read_preference: _ServerMode, limit: int, batch_size: int, @@ -1757,7 +1757,7 @@ def __init__( coll: str, ntoreturn: int, cursor_id: int, - codec_options: CodecOptions, + codec_options: CodecOptions[Any], read_preference: _ServerMode, session: Optional[_AgnosticClientSession], client: _AgnosticMongoClient, @@ -1871,7 +1871,7 @@ def use_command(self, conn: _AgnosticConnection) -> bool: return False -class _CursorAddress(tuple): +class _CursorAddress(tuple[Any, ...]): """The server address (host, port) of a cursor, with namespace property.""" __namespace: Any diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 101a8fbc37..46a78aea0b 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -1347,7 +1347,11 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): __slots__ = ("__duration", "__reply") def __init__( - self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False + self, + duration: float, + reply: Hello[dict[str, Any]], + connection_id: _Address, + awaited: bool = False, ) -> None: super().__init__(connection_id, awaited) self.__duration = duration @@ -1359,7 +1363,7 @@ def duration(self) -> float: return self.__duration @property - def reply(self) -> Hello: + def reply(self) -> Hello[dict[str, Any]]: """An instance of :class:`~pymongo.hello.Hello`.""" return self.__reply @@ -1647,7 +1651,7 @@ def publish_server_heartbeat_started(self, connection_id: _Address, awaited: boo _handle_exception() def publish_server_heartbeat_succeeded( - self, connection_id: _Address, duration: float, reply: Hello, awaited: bool + self, connection_id: _Address, duration: float, reply: Hello[dict[str, Any]], awaited: bool ) -> None: """Publish a ServerHeartbeatSucceededEvent to all server heartbeat listeners. diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 78eefc7177..2f7f9c320f 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -96,7 +96,7 @@ async def _async_socket_sendall_ssl( view = memoryview(buf) sent = 0 - def _is_ready(fut: Future) -> None: + def _is_ready(fut: Future[Any]) -> None: if fut.done(): return fut.set_result(None) @@ -139,7 +139,7 @@ async def _async_socket_receive_ssl( mv = memoryview(bytearray(length)) total_read = 0 - def _is_ready(fut: Future) -> None: + def _is_ready(fut: Future[Any]) -> None: if fut.done(): return fut.set_result(None) @@ -486,15 +486,15 @@ def __init__(self, timeout: Optional[float] = None): self._message_size = 0 self._op_code = 0 self._connection_lost = False - self._read_waiter: Optional[Future] = None + self._read_waiter: Optional[Future[Any]] = None self._timeout = timeout self._is_compressed = False self._compressor_id: Optional[int] = None self._max_message_size = MAX_MESSAGE_SIZE self._response_to: Optional[int] = None self._closed = asyncio.get_running_loop().create_future() - self._pending_messages: collections.deque[Future] = collections.deque() - self._done_messages: collections.deque[Future] = collections.deque() + self._pending_messages: collections.deque[Future[Any]] = collections.deque() + self._done_messages: collections.deque[Future[Any]] = collections.deque() def settimeout(self, timeout: float | None) -> None: self._timeout = timeout diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index ed369a2b21..82f506f039 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -53,7 +53,7 @@ def __init__( self._min_interval = min_interval self._target = target self._stopped = False - self._task: Optional[asyncio.Task] = None + self._task: Optional[asyncio.Task[Any]] = None self._name = name self._skip_sleep = False diff --git a/pymongo/server_description.py b/pymongo/server_description.py index afc5346bb7..d038c04b1c 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -69,7 +69,7 @@ class ServerDescription: def __init__( self, address: _Address, - hello: Optional[Hello] = None, + hello: Optional[Hello[dict[str, Any]]] = None, round_trip_time: Optional[float] = None, error: Optional[Exception] = None, min_round_trip_time: float = 0.0, @@ -299,4 +299,4 @@ def __repr__(self) -> str: ) # For unittesting only. Use under no circumstances! - _host_to_round_trip_time: dict = {} + _host_to_round_trip_time: dict = {} # type: ignore[type-arg] diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index beafc717eb..7dbd0f2148 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -56,17 +56,22 @@ if HAVE_PYSSL: PYSSLError: Any = _pyssl.SSLError - BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS - BLOCKING_IO_READ_ERROR: tuple = (_pyssl.BLOCKING_IO_READ_ERROR, _ssl.BLOCKING_IO_READ_ERROR) - BLOCKING_IO_WRITE_ERROR: tuple = ( + BLOCKING_IO_ERRORS: tuple = ( # type: ignore[type-arg] + _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS + ) + BLOCKING_IO_READ_ERROR: tuple = ( # type: ignore[type-arg] + _pyssl.BLOCKING_IO_READ_ERROR, + _ssl.BLOCKING_IO_READ_ERROR, + ) + BLOCKING_IO_WRITE_ERROR: tuple = ( # type: ignore[type-arg] _pyssl.BLOCKING_IO_WRITE_ERROR, _ssl.BLOCKING_IO_WRITE_ERROR, ) else: PYSSLError = _ssl.SSLError - BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS - BLOCKING_IO_READ_ERROR = (_ssl.BLOCKING_IO_READ_ERROR,) - BLOCKING_IO_WRITE_ERROR = (_ssl.BLOCKING_IO_WRITE_ERROR,) + BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS # type: ignore[type-arg, no-redef] + BLOCKING_IO_READ_ERROR: tuple = (_ssl.BLOCKING_IO_READ_ERROR,) # type: ignore[type-arg, no-redef] + BLOCKING_IO_WRITE_ERROR: tuple = (_ssl.BLOCKING_IO_WRITE_ERROR,) # type: ignore[type-arg, no-redef] SSLError = _ssl.SSLError BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR @@ -131,7 +136,7 @@ class SSLError(Exception): # type: ignore pass IPADDR_SAFE = False - BLOCKING_IO_ERRORS = () + BLOCKING_IO_ERRORS: tuple = () # type: ignore[type-arg, no-redef] def _has_sni(is_sync: bool) -> bool: # noqa: ARG001 return False diff --git a/pymongo/synchronous/aggregation.py b/pymongo/synchronous/aggregation.py index 3eb0c8bf54..9845f28b08 100644 --- a/pymongo/synchronous/aggregation.py +++ b/pymongo/synchronous/aggregation.py @@ -46,8 +46,8 @@ class _AggregationCommand: def __init__( self, - target: Union[Database, Collection], - cursor_class: type[CommandCursor], + target: Union[Database[Any], Collection[Any]], + cursor_class: type[CommandCursor[Any]], pipeline: _Pipeline, options: MutableMapping[str, Any], explicit_session: bool, @@ -111,12 +111,12 @@ def _cursor_namespace(self) -> str: """The namespace in which the aggregate command is run.""" raise NotImplementedError - def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection: + def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection[Any]: """The Collection used for the aggregate command cursor.""" raise NotImplementedError @property - def _database(self) -> Database: + def _database(self) -> Database[Any]: """The database against which the aggregation command is run.""" raise NotImplementedError @@ -205,7 +205,7 @@ def get_cursor( class _CollectionAggregationCommand(_AggregationCommand): - _target: Collection + _target: Collection[Any] @property def _aggregation_target(self) -> str: @@ -215,12 +215,12 @@ def _aggregation_target(self) -> str: def _cursor_namespace(self) -> str: return self._target.full_name - def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection: + def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]: """The Collection used for the aggregate command cursor.""" return self._target @property - def _database(self) -> Database: + def _database(self) -> Database[Any]: return self._target.database @@ -234,7 +234,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class _DatabaseAggregationCommand(_AggregationCommand): - _target: Database + _target: Database[Any] @property def _aggregation_target(self) -> int: @@ -245,10 +245,10 @@ def _cursor_namespace(self) -> str: return f"{self._target.name}.$cmd.aggregate" @property - def _database(self) -> Database: + def _database(self) -> Database[Any]: return self._target - def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection: + def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]: """The Collection used for the aggregate command cursor.""" # Collection level aggregate may not always return the "ns" field # according to our MockupDB tests. Let's handle that case for db level diff --git a/pymongo/synchronous/auth_oidc.py b/pymongo/synchronous/auth_oidc.py index f4d754687d..583ee39f67 100644 --- a/pymongo/synchronous/auth_oidc.py +++ b/pymongo/synchronous/auth_oidc.py @@ -257,7 +257,7 @@ def _sasl_continue_jwt( ) -> Mapping[str, Any]: self.access_token = None self.refresh_token = None - start_payload: dict = bson.decode(start_resp["payload"]) + start_payload: dict[str, Any] = bson.decode(start_resp["payload"]) if "issuer" in start_payload: self.idp_info = OIDCIdPInfo(**start_payload) access_token = self._get_access_token() diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index a528b09add..22d6a7a76a 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -248,7 +248,7 @@ def write_command( request_id: int, msg: bytes, docs: list[Mapping[str, Any]], - client: MongoClient, + client: MongoClient[Any], ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[bwc.field] = docs @@ -334,7 +334,7 @@ def unack_write( msg: bytes, max_doc_size: int, docs: list[Mapping[str, Any]], - client: MongoClient, + client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): @@ -419,7 +419,7 @@ def _execute_batch_unack( bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], cmd: dict[str, Any], ops: list[Mapping[str, Any]], - client: MongoClient, + client: MongoClient[Any], ) -> list[Mapping[str, Any]]: if self.is_encrypted: _, batched_cmd, to_send = bwc.batch_command(cmd, ops) @@ -446,7 +446,7 @@ def _execute_batch( bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], cmd: dict[str, Any], ops: list[Mapping[str, Any]], - client: MongoClient, + client: MongoClient[Any], ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: if self.is_encrypted: _, batched_cmd, to_send = bwc.batch_command(cmd, ops) diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index 304427b89b..f5f6352186 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -164,7 +164,7 @@ def _aggregation_command_class(self) -> Type[_AggregationCommand]: raise NotImplementedError @property - def _client(self) -> MongoClient: + def _client(self) -> MongoClient: # type: ignore[type-arg] """The client against which the aggregation commands for this ChangeStream will be run. """ @@ -206,7 +206,7 @@ def _command_options(self) -> dict[str, Any]: def _aggregation_pipeline(self) -> list[dict[str, Any]]: """Return the full aggregation pipeline for this ChangeStream.""" options = self._change_stream_options() - full_pipeline: list = [{"$changeStream": options}] + full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}] full_pipeline.extend(self._pipeline) return full_pipeline @@ -237,7 +237,7 @@ def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: def _run_aggregation_cmd( self, session: Optional[ClientSession], explicit_session: bool - ) -> CommandCursor: + ) -> CommandCursor: # type: ignore[type-arg] """Run the full aggregation pipeline for this ChangeStream and return the corresponding CommandCursor. """ @@ -257,7 +257,7 @@ def _run_aggregation_cmd( operation=_Op.AGGREGATE, ) - def _create_cursor(self) -> CommandCursor: + def _create_cursor(self) -> CommandCursor: # type: ignore[type-arg] with self._client._tmp_session(self._session, close=False) as s: return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None) diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index d73bfb2a2b..1076ceba99 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -88,7 +88,7 @@ class _ClientBulk: def __init__( self, - client: MongoClient, + client: MongoClient[Any], write_concern: WriteConcern, ordered: bool = True, bypass_document_validation: Optional[bool] = None, @@ -233,7 +233,7 @@ def write_command( msg: Union[bytes, dict[str, Any]], op_docs: list[Mapping[str, Any]], ns_docs: list[Mapping[str, Any]], - client: MongoClient, + client: MongoClient[Any], ) -> dict[str, Any]: """A proxy for Connection.write_command that handles event publishing.""" cmd["ops"] = op_docs @@ -324,7 +324,7 @@ def unack_write( msg: bytes, op_docs: list[Mapping[str, Any]], ns_docs: list[Mapping[str, Any]], - client: MongoClient, + client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 8d5bf7697b..68a01dd7e7 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -395,7 +395,7 @@ class _TxnState: class _Transaction: """Internal class to hold transaction information in a ClientSession.""" - def __init__(self, opts: Optional[TransactionOptions], client: MongoClient): + def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any]): self.opts = opts self.state = _TxnState.NONE self.sharded = False @@ -458,7 +458,7 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # From the transactions spec, all the retryable writes errors plus # WriteConcernTimeout. -_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( +_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # type: ignore[type-arg] [ 64, # WriteConcernTimeout 50, # MaxTimeMSExpired @@ -498,13 +498,13 @@ class ClientSession: def __init__( self, - client: MongoClient, + client: MongoClient[Any], server_session: Any, options: SessionOptions, implicit: bool, ) -> None: # A MongoClient, a _ServerSession, a SessionOptions, and a set. - self._client: MongoClient = client + self._client: MongoClient[Any] = client self._server_session = server_session self._options = options self._cluster_time: Optional[Mapping[str, Any]] = None @@ -550,7 +550,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self._end_session(lock=True) @property - def client(self) -> MongoClient: + def client(self) -> MongoClient[Any]: """The :class:`~pymongo.mongo_client.MongoClient` this session was created from. """ @@ -748,7 +748,7 @@ def start_transaction( write_concern: Optional[WriteConcern] = None, read_preference: Optional[_ServerMode] = None, max_commit_time_ms: Optional[int] = None, - ) -> ContextManager: + ) -> ContextManager[Any]: """Start a multi-statement transaction. Takes the same arguments as :class:`TransactionOptions`. @@ -1118,7 +1118,7 @@ def inc_transaction_id(self) -> None: self._transaction_id += 1 -class _ServerSessionPool(collections.deque): +class _ServerSessionPool(collections.deque): # type: ignore[type-arg] """Pool of _ServerSession objects. This class is thread-safe. diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 8a71768318..32da83b0c2 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -582,7 +582,7 @@ def _command( conn: Connection, command: MutableMapping[str, Any], read_preference: Optional[_ServerMode] = None, - codec_options: Optional[CodecOptions] = None, + codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_concern: Optional[ReadConcern] = None, @@ -703,7 +703,7 @@ def bulk_write( bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, comment: Optional[Any] = None, - let: Optional[Mapping] = None, + let: Optional[Mapping[str, Any]] = None, ) -> BulkWriteResult: """Send a batch of write operations to the server. @@ -2522,7 +2522,7 @@ def _list_indexes( session: Optional[ClientSession] = None, comment: Optional[Any] = None, ) -> CommandCursor[MutableMapping[str, Any]]: - codec_options: CodecOptions = CodecOptions(SON) + codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON) coll = cast( Collection[MutableMapping[str, Any]], self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY), @@ -2864,7 +2864,7 @@ def _aggregate( self, aggregation_command: Type[_AggregationCommand], pipeline: _Pipeline, - cursor_class: Type[CommandCursor], + cursor_class: Type[CommandCursor], # type: ignore[type-arg] session: Optional[ClientSession], explicit_session: bool, let: Optional[Mapping[str, Any]] = None, @@ -3107,7 +3107,7 @@ def distinct( comment: Optional[Any] = None, hint: Optional[_IndexKeyHint] = None, **kwargs: Any, - ) -> list: + ) -> list[str]: """Get a list of distinct values for `key` among all documents in this collection. @@ -3170,7 +3170,7 @@ def _cmd( _server: Server, conn: Connection, read_preference: Optional[_ServerMode], - ) -> list: + ) -> list: # type: ignore[type-arg] return ( self._command( conn, @@ -3195,7 +3195,7 @@ def _find_and_modify( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - let: Optional[Mapping] = None, + let: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> Any: """Internal findAndModify helper.""" diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index e23519d740..bcdeed5f94 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -350,7 +350,7 @@ def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]: else: return None - def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg] """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: self._refresh() @@ -457,7 +457,7 @@ def _unpack_response( # type: ignore[override] self, response: Union[_OpReply, _OpMsg], cursor_id: Optional[int], - codec_options: CodecOptions, + codec_options: CodecOptions[dict[str, Any]], user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> list[Mapping[str, Any]]: diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index e49141e811..eb45d9c5d1 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -216,7 +216,7 @@ def __init__( # it anytime we change __limit. self._empty = False - self._data: deque = deque() + self._data: deque = deque() # type: ignore[type-arg] self._address: Optional[_Address] = None self._retrieved = 0 @@ -280,7 +280,7 @@ def clone(self) -> Cursor[_DocumentType]: """ return self._clone(True) - def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: + def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: # type: ignore[type-arg] """Internal clone helper.""" if not base: if self._explicit_session: @@ -322,7 +322,7 @@ def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor base.__dict__.update(data) return base - def _clone_base(self, session: Optional[ClientSession]) -> Cursor: + def _clone_base(self, session: Optional[ClientSession]) -> Cursor: # type: ignore[type-arg] """Creates an empty Cursor object for information to be copied into.""" return self.__class__(self._collection, session=session) @@ -862,7 +862,7 @@ def where(self, code: Union[str, Code]) -> Cursor[_DocumentType]: if self._has_filter: spec = dict(self._spec) else: - spec = cast(dict, self._spec) + spec = cast(dict, self._spec) # type: ignore[type-arg] spec["$where"] = code self._spec = spec return self @@ -886,7 +886,7 @@ def _unpack_response( self, response: Union[_OpReply, _OpMsg], cursor_id: Optional[int], - codec_options: CodecOptions, + codec_options: CodecOptions, # type: ignore[type-arg] user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> Sequence[_DocumentOut]: @@ -962,29 +962,33 @@ def __deepcopy__(self, memo: Any) -> Any: return self._clone(deepcopy=True) @overload - def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: + def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: # type: ignore[type-arg] ... @overload def _deepcopy( - self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None - ) -> dict: + self, + x: SupportsItems, # type: ignore[type-arg] + memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg] + ) -> dict: # type: ignore[type-arg] ... def _deepcopy( - self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None - ) -> Union[list, dict]: + self, + x: Union[Iterable, SupportsItems], # type: ignore[type-arg] + memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg] + ) -> Union[list[Any], dict[str, Any]]: """Deepcopy helper for the data dictionary or list. Regular expressions cannot be deep copied but as they are immutable we don't have to copy them when cloning. """ - y: Union[list, dict] + y: Union[list[Any], dict[str, Any]] iterator: Iterable[tuple[Any, Any]] if not hasattr(x, "items"): y, is_list, iterator = [], True, enumerate(x) else: - y, is_list, iterator = {}, False, cast("SupportsItems", x).items() + y, is_list, iterator = {}, False, cast("SupportsItems", x).items() # type: ignore[type-arg] if memo is None: memo = {} val_id = id(x) @@ -1058,7 +1062,7 @@ def close(self) -> None: """Explicitly close / kill this cursor.""" self._die_lock() - def distinct(self, key: str) -> list: + def distinct(self, key: str) -> list[str]: """Get a list of distinct values for `key` among all documents in the result set of this query. @@ -1263,7 +1267,7 @@ def next(self) -> _DocumentType: else: raise StopIteration - def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg] """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True @@ -1323,7 +1327,7 @@ def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: return res -class RawBatchCursor(Cursor, Generic[_DocumentType]): +class RawBatchCursor(Cursor, Generic[_DocumentType]): # type: ignore[type-arg] """A cursor / iterator over raw batches of BSON data from a query result.""" _query_class = _RawBatchQuery diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index a11674b9aa..dd9ea01558 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -771,7 +771,7 @@ def _command( self._name, command, read_preference, - codec_options, + codec_options, # type: ignore[arg-type] check, allowable_errors, write_concern=write_concern, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1fd506e052..5d95e9c9d5 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -158,10 +158,10 @@ _IS_SYNC = True _WriteOp = Union[ - InsertOne, + InsertOne, # type: ignore[type-arg] DeleteOne, DeleteMany, - ReplaceOne, + ReplaceOne, # type: ignore[type-arg] UpdateOne, UpdateMany, ] @@ -173,7 +173,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # Define order to retrieve options from ClientOptions for __repr__. # No host/port; these are retrieved from TopologySettings. _constructor_args = ("document_class", "tz_aware", "connect") - _clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + _clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg] def __init__( self, @@ -847,7 +847,7 @@ def __init__( self._default_database_name = dbase self._lock = _create_lock() - self._kill_cursors_queue: list = [] + self._kill_cursors_queue: list = [] # type: ignore[type-arg] self._encrypter: Optional[_Encrypter] = None @@ -1064,7 +1064,7 @@ def _after_fork(self) -> None: # Reset the session pool to avoid duplicate sessions in the child process. self._topology._session_pool.reset() - def _duplicate(self, **kwargs: Any) -> MongoClient: + def _duplicate(self, **kwargs: Any) -> MongoClient: # type: ignore[type-arg] args = self._init_kwargs.copy() args.update(kwargs) return MongoClient(**args) @@ -1546,7 +1546,7 @@ def get_database( self, name, codec_options, read_preference, write_concern, read_concern ) - def _database_default_options(self, name: str) -> database.Database: + def _database_default_options(self, name: str) -> database.Database: # type: ignore[type-arg] """Get a Database instance with the default settings.""" return self.get_database( name, @@ -1883,7 +1883,7 @@ def _conn_for_reads( def _run_operation( self, operation: Union[_Query, _GetMore], - unpack_res: Callable, + unpack_res: Callable, # type: ignore[type-arg] address: Optional[_Address] = None, ) -> Response: """Run a _Query/_GetMore operation and return a Response. @@ -2257,7 +2257,7 @@ def _return_server_session( @contextlib.contextmanager def _tmp_session( self, session: Optional[client_session.ClientSession], close: bool = True - ) -> Generator[Optional[client_session.ClientSession], None, None]: + ) -> Generator[Optional[client_session.ClientSession], None]: """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.ClientSession): @@ -2300,8 +2300,8 @@ def server_info(self, session: Optional[client_session.ClientSession] = None) -> .. versionchanged:: 3.6 Added ``session`` parameter. """ - return cast( - dict, + return cast( # type: ignore[redundant-cast] + dict[str, Any], self.admin.command( "buildinfo", read_preference=ReadPreference.PRIMARY, session=session ), @@ -2428,13 +2428,13 @@ def drop_database( @_csot.apply def bulk_write( self, - models: Sequence[_WriteOp[_DocumentType]], + models: Sequence[_WriteOp], session: Optional[ClientSession] = None, ordered: bool = True, verbose_results: bool = False, bypass_document_validation: Optional[bool] = None, comment: Optional[Any] = None, - let: Optional[Mapping] = None, + let: Optional[Mapping[str, Any]] = None, write_concern: Optional[WriteConcern] = None, ) -> ClientBulkWriteResult: """Send a batch of write operations, potentially across multiple namespaces, to the server. @@ -2620,7 +2620,12 @@ class _MongoClientErrorHandler: "handled", ) - def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): + def __init__( + self, + client: MongoClient, # type: ignore[type-arg] + server: Server, + session: Optional[ClientSession], + ): if not isinstance(client, MongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): @@ -2692,7 +2697,7 @@ class _ClientConnectionRetryable(Generic[T]): def __init__( self, - mongo_client: MongoClient, + mongo_client: MongoClient, # type: ignore[type-arg] func: _WriteCall[T] | _ReadCall[T], bulk: Optional[Union[_Bulk, _ClientBulk]], operation: str, diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index f41040801f..d5dd5caf82 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -349,7 +349,7 @@ def _check_once(self) -> ServerDescription: ) return sd - def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: + def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: # type: ignore[type-arg] """Return (Hello, round_trip_time). Can raise ConnectionFailure or OperationFailure. diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 9559a5a542..7d9bca4d58 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -66,7 +66,7 @@ def command( read_preference: Optional[_ServerMode], codec_options: CodecOptions[_DocumentType], session: Optional[ClientSession], - client: Optional[MongoClient], + client: Optional[MongoClient[Any]], check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, address: Optional[_Address] = None, diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 505f58c60f..4ea5cb1c1e 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -201,7 +201,7 @@ def set_conn_timeout(self, timeout: Optional[float]) -> None: self.conn.get_conn.settimeout(timeout) def apply_timeout( - self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]] + self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]] ) -> Optional[float]: # CSOT: use remaining timeout when set. timeout = _csot.remaining() @@ -255,7 +255,7 @@ def hello_cmd(self) -> dict[str, Any]: else: return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} - def hello(self) -> Hello: + def hello(self) -> Hello[dict[str, Any]]: return self._hello(None, None) def _hello( @@ -357,7 +357,7 @@ def command( dbname: str, spec: MutableMapping[str, Any], read_preference: _ServerMode = ReadPreference.PRIMARY, - codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + codec_options: CodecOptions[Mapping[str, Any]] = DEFAULT_CODEC_OPTIONS, # type: ignore[assignment] check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_concern: Optional[ReadConcern] = None, @@ -365,7 +365,7 @@ def command( parse_write_concern_error: bool = False, collation: Optional[_CollationIn] = None, session: Optional[ClientSession] = None, - client: Optional[MongoClient] = None, + client: Optional[MongoClient[Any]] = None, retryable_write: bool = False, publish_events: bool = True, user_fields: Optional[Mapping[str, Any]] = None, @@ -417,7 +417,7 @@ def command( spec, self.is_mongos, read_preference, - codec_options, + codec_options, # type: ignore[arg-type] session, client, check, @@ -489,7 +489,7 @@ def unack_write(self, msg: bytes, max_doc_size: int) -> None: self.send_message(msg, max_doc_size) def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions + self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] ) -> dict[str, Any]: """Send "insert" etc. command, returning response as a dict. @@ -541,7 +541,7 @@ def authenticate(self, reauthenticate: bool = False) -> None: ) def validate_session( - self, client: Optional[MongoClient], session: Optional[ClientSession] + self, client: Optional[MongoClient[Any]], session: Optional[ClientSession] ) -> None: """Validate this session before use with client. @@ -596,7 +596,7 @@ def send_cluster_time( self, command: MutableMapping[str, Any], session: Optional[ClientSession], - client: Optional[MongoClient], + client: Optional[MongoClient[Any]], ) -> None: """Add $clusterTime.""" if client: @@ -730,7 +730,7 @@ def __init__( # LIFO pool. Sockets are ordered on idle time. Sockets claimed # and returned to pool from the left side. Stale sockets removed # from the right side. - self.conns: collections.deque = collections.deque() + self.conns: collections.deque[Connection] = collections.deque() self.active_contexts: set[_CancellationContext] = set() self.lock = _create_lock() self._max_connecting_cond = _create_condition(self.lock) @@ -837,8 +837,8 @@ def _reset( if service_id is None: sockets, self.conns = self.conns, collections.deque() else: - discard: collections.deque = collections.deque() - keep: collections.deque = collections.deque() + discard: collections.deque = collections.deque() # type: ignore[type-arg] + keep: collections.deque = collections.deque() # type: ignore[type-arg] for conn in self.conns: if conn.service_id == service_id: discard.append(conn) @@ -864,7 +864,7 @@ def _reset( if close: if not _IS_SYNC: asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value] return_exceptions=True, ) else: @@ -901,7 +901,7 @@ def _reset( ) if not _IS_SYNC: asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value] return_exceptions=True, ) else: @@ -915,7 +915,7 @@ def update_is_writable(self, is_writable: Optional[bool]) -> None: self.is_writable = is_writable with self.lock: for _socket in self.conns: - _socket.update_is_writable(self.is_writable) + _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] def reset( self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False @@ -952,7 +952,7 @@ def remove_stale_sockets(self, reference_generation: int) -> None: close_conns.append(self.conns.pop()) if not _IS_SYNC: asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], + *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value] return_exceptions=True, ) else: @@ -1473,4 +1473,4 @@ def __del__(self) -> None: # not safe to acquire a lock in __del__. if _IS_SYNC: for conn in self.conns: - conn.close_conn(None) + conn.close_conn(None) # type: ignore[unused-coroutine] diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index c3643ba815..a85f1b0db7 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -66,7 +66,7 @@ def __init__( monitor: Monitor, topology_id: Optional[ObjectId] = None, listeners: Optional[_EventListeners] = None, - events: Optional[ReferenceType[Queue]] = None, + events: Optional[ReferenceType[Queue[Any]]] = None, ) -> None: """Represent one MongoDB server.""" self._description = server_description @@ -142,7 +142,7 @@ def run_operation( read_preference: _ServerMode, listeners: Optional[_EventListeners], unpack_res: Callable[..., list[_DocumentOut]], - client: MongoClient, + client: MongoClient[Any], ) -> Response: """Run a _Query or _GetMore operation and return a Response object. diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 28370d4adc..a4ca0e6e0f 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -84,7 +84,7 @@ _pymongo_dir = str(Path(__file__).parent) -def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: +def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: # type: ignore[type-arg] q = queue_ref() if not q: return False # Cancel PeriodicExecutor. @@ -186,7 +186,7 @@ def __init__(self, topology_settings: TopologySettings): if self._publish_server or self._publish_tp: assert self._events is not None - weak: weakref.ReferenceType[queue.Queue] + weak: weakref.ReferenceType[queue.Queue[Any]] def target() -> bool: return process_events_queue(weak) diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index e226992b45..de67a8f94a 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -569,8 +569,8 @@ def _update_rs_from_primary( return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id if server_description.max_wire_version is None or server_description.max_wire_version < 17: - new_election_tuple: tuple = (server_description.set_version, server_description.election_id) - max_election_tuple: tuple = (max_set_version, max_election_id) + new_election_tuple: tuple = (server_description.set_version, server_description.election_id) # type: ignore[type-arg] + max_election_tuple: tuple = (max_set_version, max_election_id) # type: ignore[type-arg] if None not in new_election_tuple: if None not in max_election_tuple and new_election_tuple < max_election_tuple: # Stale primary, set to type Unknown. diff --git a/pymongo/typings.py b/pymongo/typings.py index ce6f369d1f..e678720db9 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -51,7 +51,7 @@ _T = TypeVar("_T") # Type hinting types for compatibility between async and sync classes -_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] +_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] # type: ignore[type-arg] _AgnosticConnection = Union["AsyncConnection", "Connection"] _AgnosticClientSession = Union["AsyncClientSession", "ClientSession"] _AgnosticBulk = Union["_AsyncBulk", "_Bulk"] diff --git a/pyproject.toml b/pyproject.toml index e7e3161906..9740fe8e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,11 +148,12 @@ markers = [ strict = true show_error_codes = true pretty = true -disable_error_code = ["type-arg", "no-any-return"] +disable_error_code = ["no-any-return"] +disallow_any_generics = true [[tool.mypy.overrides]] module = ["test.*"] -disable_error_code = ["no-untyped-def", "no-untyped-call"] +disable_error_code = ["type-arg", "no-untyped-def", "no-untyped-call"] [[tool.mypy.overrides]] module = ["service_identity.*"]