diff --git a/manim/camera/camera.py b/manim/camera/camera.py index ed27907cd7..dc30cc06c9 100644 --- a/manim/camera/camera.py +++ b/manim/camera/camera.py @@ -10,7 +10,7 @@ import pathlib from collections.abc import Iterable from functools import reduce -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import cairo import numpy as np @@ -18,12 +18,11 @@ from scipy.spatial.distance import pdist from typing_extensions import Self -from manim.typing import Point3D_Array +from manim.typing import MatrixMN, PixelArray, Point3D, Point3D_Array from .. import config, logger from ..constants import * from ..mobject.mobject import Mobject -from ..mobject.types.image_mobject import AbstractImageMobject from ..mobject.types.point_cloud_mobject import PMobject from ..mobject.types.vectorized_mobject import VMobject from ..utils.color import ManimColor, ParsableManimColor, color_to_int_rgba @@ -32,6 +31,10 @@ from ..utils.iterables import list_difference_update from ..utils.space_ops import angle_of_vector +if TYPE_CHECKING: + from ..mobject.types.image_mobject import AbstractImageMobject + + LINE_JOIN_MAP = { LineJointType.AUTO: None, # TODO: this could be improved LineJointType.ROUND: cairo.LineJoin.ROUND, @@ -73,13 +76,13 @@ class Camera: def __init__( self, background_image: str | None = None, - frame_center: np.ndarray = ORIGIN, + frame_center: Point3D = ORIGIN, image_mode: str = "RGBA", n_channels: int = 4, pixel_array_dtype: str = "uint8", cairo_line_width_multiple: float = 0.01, use_z_index: bool = True, - background: np.ndarray | None = None, + background: PixelArray | None = None, pixel_height: int | None = None, pixel_width: int | None = None, frame_height: float | None = None, @@ -87,8 +90,8 @@ def __init__( frame_rate: float | None = None, background_color: ParsableManimColor | None = None, background_opacity: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.background_image = background_image self.frame_center = frame_center self.image_mode = image_mode @@ -97,6 +100,9 @@ def __init__( self.cairo_line_width_multiple = cairo_line_width_multiple self.use_z_index = use_z_index self.background = background + self.background_colored_vmobject_displayer: ( + BackgroundColoredVMobjectDisplayer | None + ) = None if pixel_height is None: pixel_height = config["pixel_height"] @@ -119,11 +125,13 @@ def __init__( self.frame_rate = frame_rate if background_color is None: - self._background_color = ManimColor.parse(config["background_color"]) + self._background_color: ManimColor = ManimColor.parse( + config["background_color"] + ) else: self._background_color = ManimColor.parse(background_color) if background_opacity is None: - self._background_opacity = config["background_opacity"] + self._background_opacity: float = config["background_opacity"] else: self._background_opacity = background_opacity @@ -132,7 +140,7 @@ def __init__( self.max_allowable_norm = config["frame_width"] self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max - self.pixel_array_to_cairo_context = {} + self.pixel_array_to_cairo_context: dict[int, cairo.Context] = {} # Contains the correct method to process a list of Mobjects of the # corresponding class. If a Mobject is not an instance of a class in @@ -143,7 +151,7 @@ def __init__( self.resize_frame_shape() self.reset() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> Camera: # This is to address a strange bug where deepcopying # will result in a segfault, which is somehow related # to the aggdraw library @@ -151,24 +159,26 @@ def __deepcopy__(self, memo): return copy.copy(self) @property - def background_color(self): + def background_color(self) -> ManimColor: return self._background_color @background_color.setter - def background_color(self, color): + def background_color(self, color: ManimColor) -> None: self._background_color = color self.init_background() @property - def background_opacity(self): + def background_opacity(self) -> float: return self._background_opacity @background_opacity.setter - def background_opacity(self, alpha): + def background_opacity(self, alpha: float) -> None: self._background_opacity = alpha self.init_background() - def type_or_raise(self, mobject: Mobject): + def type_or_raise( + self, mobject: Mobject + ) -> type[VMobject] | type[PMobject] | type[AbstractImageMobject] | type[Mobject]: """Return the type of mobject, if it is a type that can be rendered. If `mobject` is an instance of a class that inherits from a class that @@ -195,8 +205,12 @@ def type_or_raise(self, mobject: Mobject): :exc:`TypeError` When mobject is not an instance of a class that can be rendered. """ - self.display_funcs = { - VMobject: self.display_multiple_vectorized_mobjects, + from ..mobject.types.image_mobject import AbstractImageMobject + + self.display_funcs: dict[ + type[Mobject], Callable[[list[Mobject], PixelArray], Any] + ] = { + VMobject: self.display_multiple_vectorized_mobjects, # type: ignore[dict-item] PMobject: self.display_multiple_point_cloud_mobjects, AbstractImageMobject: self.display_multiple_image_mobjects, Mobject: lambda batch, pa: batch, # Do nothing @@ -209,7 +223,7 @@ def type_or_raise(self, mobject: Mobject): return _type raise TypeError(f"Displaying an object of class {_type} is not supported") - def reset_pixel_shape(self, new_height: float, new_width: float): + def reset_pixel_shape(self, new_height: float, new_width: float) -> None: """This method resets the height and width of a single pixel to the passed new_height and new_width. @@ -226,7 +240,7 @@ def reset_pixel_shape(self, new_height: float, new_width: float): self.resize_frame_shape() self.reset() - def resize_frame_shape(self, fixed_dimension: int = 0): + def resize_frame_shape(self, fixed_dimension: int = 0) -> None: """ Changes frame_shape to match the aspect ratio of the pixels, where fixed_dimension determines @@ -251,7 +265,7 @@ def resize_frame_shape(self, fixed_dimension: int = 0): self.frame_height = frame_height self.frame_width = frame_width - def init_background(self): + def init_background(self) -> None: """Initialize the background. If self.background_image is the path of an image the image is set as background; else, the default @@ -277,7 +291,9 @@ def init_background(self): ) self.background[:, :] = background_rgba - def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): + def get_image( + self, pixel_array: PixelArray | list | tuple | None = None + ) -> Image.Image: """Returns an image from the passed pixel array, or from the current frame if the passed pixel array is none. @@ -289,7 +305,7 @@ def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): Returns ------- - PIL.Image + PIL.Image.Image The PIL image of the array. """ if pixel_array is None: @@ -297,8 +313,8 @@ def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): return Image.fromarray(pixel_array, mode=self.image_mode) def convert_pixel_array( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> PixelArray: """Converts a pixel array from values that have floats in then to proper RGB values. @@ -324,8 +340,8 @@ def convert_pixel_array( return retval def set_pixel_array( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> None: """Sets the pixel array of the camera to the passed pixel array. Parameters @@ -335,19 +351,21 @@ def set_pixel_array( convert_from_floats Whether or not to convert float values to proper RGB values, by default False """ - converted_array = self.convert_pixel_array(pixel_array, convert_from_floats) + converted_array: PixelArray = self.convert_pixel_array( + pixel_array, convert_from_floats + ) if not ( hasattr(self, "pixel_array") and self.pixel_array.shape == converted_array.shape ): - self.pixel_array = converted_array + self.pixel_array: PixelArray = converted_array else: # Set in place self.pixel_array[:, :, :] = converted_array[:, :, :] def set_background( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> None: """Sets the background to the passed pixel_array after converting to valid RGB values. @@ -363,7 +381,7 @@ def set_background( # TODO, this should live in utils, not as a method of Camera def make_background_from_func( self, coords_to_colors_func: Callable[[np.ndarray], np.ndarray] - ): + ) -> PixelArray: """ Makes a pixel array for the background by using coords_to_colors_func to determine each pixel's color. Each input pixel's color. Each input to coords_to_colors_func is an (x, y) pair in space (in ordinary space coordinates; not @@ -415,7 +433,7 @@ def reset(self) -> Self: self.set_pixel_array(self.background) return self - def set_frame_to_background(self, background): + def set_frame_to_background(self, background: PixelArray) -> None: self.set_pixel_array(background) #### @@ -425,7 +443,7 @@ def get_mobjects_to_display( mobjects: Iterable[Mobject], include_submobjects: bool = True, excluded_mobjects: list | None = None, - ): + ) -> list[Mobject]: """Used to get the list of mobjects to display with the camera. @@ -457,7 +475,7 @@ def get_mobjects_to_display( mobjects = list_difference_update(mobjects, all_excluded) return list(mobjects) - def is_in_frame(self, mobject: Mobject): + def is_in_frame(self, mobject: Mobject) -> bool: """Checks whether the passed mobject is in frame or not. @@ -484,7 +502,7 @@ def is_in_frame(self, mobject: Mobject): ], ) - def capture_mobject(self, mobject: Mobject, **kwargs: Any): + def capture_mobject(self, mobject: Mobject, **kwargs: Any) -> None: """Capture mobjects by storing it in :attr:`pixel_array`. This is a single-mobject version of :meth:`capture_mobjects`. @@ -500,7 +518,7 @@ def capture_mobject(self, mobject: Mobject, **kwargs: Any): """ return self.capture_mobjects([mobject], **kwargs) - def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: """Capture mobjects by printing them on :attr:`pixel_array`. This is the essential function that converts the contents of a Scene @@ -535,7 +553,7 @@ def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs): # NOTE: None of the methods below have been mentioned outside of their definitions. Their DocStrings are not as # detailed as possible. - def get_cached_cairo_context(self, pixel_array: np.ndarray): + def get_cached_cairo_context(self, pixel_array: PixelArray) -> cairo.Context | None: """Returns the cached cairo context of the passed pixel array if it exists, and None if it doesn't. @@ -551,7 +569,7 @@ def get_cached_cairo_context(self, pixel_array: np.ndarray): """ return self.pixel_array_to_cairo_context.get(id(pixel_array), None) - def cache_cairo_context(self, pixel_array: np.ndarray, ctx: cairo.Context): + def cache_cairo_context(self, pixel_array: PixelArray, ctx: cairo.Context) -> None: """Caches the passed Pixel array into a Cairo Context Parameters @@ -563,7 +581,7 @@ def cache_cairo_context(self, pixel_array: np.ndarray, ctx: cairo.Context): """ self.pixel_array_to_cairo_context[id(pixel_array)] = ctx - def get_cairo_context(self, pixel_array: np.ndarray): + def get_cairo_context(self, pixel_array: PixelArray) -> cairo.Context: """Returns the cairo context for a pixel array after caching it to self.pixel_array_to_cairo_context If that array has already been cached, it returns the @@ -588,7 +606,7 @@ def get_cairo_context(self, pixel_array: np.ndarray): fh = self.frame_height fc = self.frame_center surface = cairo.ImageSurface.create_for_data( - pixel_array, + pixel_array.data, cairo.FORMAT_ARGB32, pw, ph, @@ -609,8 +627,8 @@ def get_cairo_context(self, pixel_array: np.ndarray): return ctx def display_multiple_vectorized_mobjects( - self, vmobjects: list, pixel_array: np.ndarray - ): + self, vmobjects: list[VMobject], pixel_array: PixelArray + ) -> None: """Displays multiple VMobjects in the pixel_array Parameters @@ -633,8 +651,8 @@ def display_multiple_vectorized_mobjects( ) def display_multiple_non_background_colored_vmobjects( - self, vmobjects: list, pixel_array: np.ndarray - ): + self, vmobjects: Iterable[VMobject], pixel_array: PixelArray + ) -> None: """Displays multiple VMobjects in the cairo context, as long as they don't have background colors. @@ -649,7 +667,7 @@ def display_multiple_non_background_colored_vmobjects( for vmobject in vmobjects: self.display_vectorized(vmobject, ctx) - def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context): + def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context) -> Self: """Displays a VMobject in the cairo context Parameters @@ -670,7 +688,7 @@ def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context): self.apply_stroke(ctx, vmobject) return self - def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): + def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject) -> Self: """Sets a path for the cairo context with the vmobject passed Parameters @@ -689,7 +707,7 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): # TODO, shouldn't this be handled in transform_points_pre_display? # points = points - self.get_frame_center() if len(points) == 0: - return + return self ctx.new_path() subpaths = vmobject.gen_subpaths_from_points_2d(points) @@ -705,8 +723,8 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): return self def set_cairo_context_color( - self, ctx: cairo.Context, rgbas: np.ndarray, vmobject: VMobject - ): + self, ctx: cairo.Context, rgbas: MatrixMN, vmobject: VMobject + ) -> Self: """Sets the color of the cairo context Parameters @@ -738,7 +756,7 @@ def set_cairo_context_color( ctx.set_source(pat) return self - def apply_fill(self, ctx: cairo.Context, vmobject: VMobject): + def apply_fill(self, ctx: cairo.Context, vmobject: VMobject) -> Self: """Fills the cairo context Parameters @@ -759,7 +777,7 @@ def apply_fill(self, ctx: cairo.Context, vmobject: VMobject): def apply_stroke( self, ctx: cairo.Context, vmobject: VMobject, background: bool = False - ): + ) -> Self: """Applies a stroke to the VMobject in the cairo context. Parameters @@ -798,7 +816,9 @@ def apply_stroke( ctx.stroke_preserve() return self - def get_stroke_rgbas(self, vmobject: VMobject, background: bool = False): + def get_stroke_rgbas( + self, vmobject: VMobject, background: bool = False + ) -> PixelArray: """Gets the RGBA array for the stroke of the passed VMobject. @@ -817,7 +837,7 @@ def get_stroke_rgbas(self, vmobject: VMobject, background: bool = False): """ return vmobject.get_stroke_rgbas(background) - def get_fill_rgbas(self, vmobject: VMobject): + def get_fill_rgbas(self, vmobject: VMobject) -> PixelArray: """Returns the RGBA array of the fill of the passed VMobject Parameters @@ -832,25 +852,27 @@ def get_fill_rgbas(self, vmobject: VMobject): """ return vmobject.get_fill_rgbas() - def get_background_colored_vmobject_displayer(self): + def get_background_colored_vmobject_displayer( + self, + ) -> BackgroundColoredVMobjectDisplayer: """Returns the background_colored_vmobject_displayer if it exists or makes one and returns it if not. Returns ------- - BackGroundColoredVMobjectDisplayer + BackgroundColoredVMobjectDisplayer Object that displays VMobjects that have the same color as the background. """ - # Quite wordy to type out a bunch - bcvd = "background_colored_vmobject_displayer" - if not hasattr(self, bcvd): - setattr(self, bcvd, BackgroundColoredVMobjectDisplayer(self)) - return getattr(self, bcvd) + if self.background_colored_vmobject_displayer is None: + self.background_colored_vmobject_displayer = ( + BackgroundColoredVMobjectDisplayer(self) + ) + return self.background_colored_vmobject_displayer def display_multiple_background_colored_vmobjects( - self, cvmobjects: list, pixel_array: np.ndarray - ): + self, cvmobjects: Iterable[VMobject], pixel_array: PixelArray + ) -> Self: """Displays multiple vmobjects that have the same color as the background. Parameters @@ -876,8 +898,8 @@ def display_multiple_background_colored_vmobjects( # As a result, the other methods do not have as detailed docstrings as would be preferred. def display_multiple_point_cloud_mobjects( - self, pmobjects: list, pixel_array: np.ndarray - ): + self, pmobjects: list, pixel_array: PixelArray + ) -> None: """Displays multiple PMobjects by modifying the passed pixel array. Parameters @@ -902,8 +924,8 @@ def display_point_cloud( points: list, rgbas: np.ndarray, thickness: float, - pixel_array: np.ndarray, - ): + pixel_array: PixelArray, + ) -> None: """Displays a PMobject by modifying the pixel array suitably. TODO: Write a description for the rgbas argument. @@ -951,7 +973,7 @@ def display_point_cloud( def display_multiple_image_mobjects( self, image_mobjects: list, pixel_array: np.ndarray - ): + ) -> None: """Displays multiple image mobjects by modifying the passed pixel_array. Parameters @@ -966,7 +988,7 @@ def display_multiple_image_mobjects( def display_image_mobject( self, image_mobject: AbstractImageMobject, pixel_array: np.ndarray - ): + ) -> None: """Displays an ImageMobject by changing the pixel_array suitably. Parameters @@ -1023,7 +1045,9 @@ def display_image_mobject( # Paint on top of existing pixel array self.overlay_PIL_image(pixel_array, full_image) - def overlay_rgba_array(self, pixel_array: np.ndarray, new_array: np.ndarray): + def overlay_rgba_array( + self, pixel_array: np.ndarray, new_array: np.ndarray + ) -> None: """Overlays an RGBA array on top of the given Pixel array. Parameters @@ -1035,7 +1059,7 @@ def overlay_rgba_array(self, pixel_array: np.ndarray, new_array: np.ndarray): """ self.overlay_PIL_image(pixel_array, self.get_image(new_array)) - def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image): + def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image) -> None: """Overlays a PIL image on the passed pixel array. Parameters @@ -1050,7 +1074,7 @@ def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image): dtype="uint8", ) - def adjust_out_of_range_points(self, points: np.ndarray): + def adjust_out_of_range_points(self, points: np.ndarray) -> np.ndarray: """If any of the points in the passed array are out of the viable range, they are adjusted suitably. @@ -1083,7 +1107,7 @@ def transform_points_pre_display( self, mobject: Mobject, points: Point3D_Array, - ): # TODO: Write more detailed docstrings for this method. + ) -> Point3D_Array: # TODO: Write more detailed docstrings for this method. # NOTE: There seems to be an unused argument `mobject`. # Subclasses (like ThreeDCamera) may want to @@ -1096,9 +1120,9 @@ def transform_points_pre_display( def points_to_pixel_coords( self, - mobject, - points, - ): # TODO: Write more detailed docstrings for this method. + mobject: Mobject, + points: np.ndarray, + ) -> np.ndarray: # TODO: Write more detailed docstrings for this method. points = self.transform_points_pre_display(mobject, points) shifted_points = points - self.frame_center @@ -1118,7 +1142,7 @@ def points_to_pixel_coords( result[:, 1] = shifted_points[:, 1] * height_mult + height_add return result.astype("int") - def on_screen_pixels(self, pixel_coords: np.ndarray): + def on_screen_pixels(self, pixel_coords: np.ndarray) -> PixelArray: """Returns array of pixels that are on the screen from a given array of pixel_coordinates @@ -1157,12 +1181,12 @@ def adjusted_thickness(self, thickness: float) -> float: the camera. """ # TODO: This seems...unsystematic - big_sum = op.add(config["pixel_height"], config["pixel_width"]) - this_sum = op.add(self.pixel_height, self.pixel_width) + big_sum: float = op.add(config["pixel_height"], config["pixel_width"]) + this_sum: float = op.add(self.pixel_height, self.pixel_width) factor = big_sum / this_sum return 1 + (thickness - 1) * factor - def get_thickening_nudges(self, thickness: float): + def get_thickening_nudges(self, thickness: float) -> PixelArray: """Determine a list of vectors used to nudge two-dimensional pixel coordinates. @@ -1179,7 +1203,9 @@ def get_thickening_nudges(self, thickness: float): _range = list(range(-thickness // 2 + 1, thickness // 2 + 1)) return np.array(list(it.product(_range, _range))) - def thickened_coordinates(self, pixel_coords: np.ndarray, thickness: float): + def thickened_coordinates( + self, pixel_coords: np.ndarray, thickness: float + ) -> PixelArray: """Returns thickened coordinates for a passed array of pixel coords and a thickness to thicken by. @@ -1201,7 +1227,7 @@ def thickened_coordinates(self, pixel_coords: np.ndarray, thickness: float): return pixel_coords.reshape((size // 2, 2)) # TODO, reimplement using cairo matrix - def get_coords_of_all_pixels(self): + def get_coords_of_all_pixels(self) -> PixelArray: """Returns the cartesian coordinates of each pixel. Returns @@ -1249,20 +1275,20 @@ class BackgroundColoredVMobjectDisplayer: def __init__(self, camera: Camera): self.camera = camera - self.file_name_to_pixel_array_map = {} + self.file_name_to_pixel_array_map: dict[str, PixelArray] = {} self.pixel_array = np.array(camera.pixel_array) self.reset_pixel_array() - def reset_pixel_array(self): + def reset_pixel_array(self) -> None: self.pixel_array[:, :] = 0 def resize_background_array( self, - background_array: np.ndarray, + background_array: PixelArray, new_width: float, new_height: float, mode: str = "RGBA", - ): + ) -> PixelArray: """Resizes the pixel array representing the background. Parameters @@ -1287,8 +1313,8 @@ def resize_background_array( return np.array(resized_image) def resize_background_array_to_match( - self, background_array: np.ndarray, pixel_array: np.ndarray - ): + self, background_array: PixelArray, pixel_array: PixelArray + ) -> PixelArray: """Resizes the background array to match the passed pixel array. Parameters @@ -1307,7 +1333,9 @@ def resize_background_array_to_match( mode = "RGBA" if pixel_array.shape[2] == 4 else "RGB" return self.resize_background_array(background_array, width, height, mode) - def get_background_array(self, image: Image.Image | pathlib.Path | str): + def get_background_array( + self, image: Image.Image | pathlib.Path | str + ) -> PixelArray: """Gets the background array that has the passed file_name. Parameters @@ -1336,7 +1364,7 @@ def get_background_array(self, image: Image.Image | pathlib.Path | str): self.file_name_to_pixel_array_map[image_key] = back_array return back_array - def display(self, *cvmobjects: VMobject): + def display(self, *cvmobjects: VMobject) -> PixelArray | None: """Displays the colored VMobjects. Parameters diff --git a/manim/camera/moving_camera.py b/manim/camera/moving_camera.py index f171477656..deff555b85 100644 --- a/manim/camera/moving_camera.py +++ b/manim/camera/moving_camera.py @@ -9,6 +9,9 @@ __all__ = ["MovingCamera"] +from collections.abc import Iterable +from typing import Any + import numpy as np from .. import config @@ -16,7 +19,7 @@ from ..constants import DOWN, LEFT, RIGHT, UP from ..mobject.frame import ScreenRectangle from ..mobject.mobject import Mobject -from ..utils.color import WHITE +from ..utils.color import WHITE, ManimColor class MovingCamera(Camera): @@ -32,10 +35,10 @@ class MovingCamera(Camera): def __init__( self, frame=None, - fixed_dimension=0, # width - default_frame_stroke_color=WHITE, - default_frame_stroke_width=0, - **kwargs, + fixed_dimension: int = 0, # width + default_frame_stroke_color: ManimColor = WHITE, + default_frame_stroke_width: int = 0, + **kwargs: Any, ) -> None: """Frame is a Mobject, (should almost certainly be a rectangle) determining which region of space the camera displays @@ -121,7 +124,7 @@ def frame_center(self, frame_center: np.ndarray | list | tuple | Mobject): """ self.frame.move_to(frame_center) - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: # self.reset_frame_center() # self.realign_frame_shape() super().capture_mobjects(mobjects, **kwargs) diff --git a/manim/camera/multi_camera.py b/manim/camera/multi_camera.py index a5202135e9..f4bd18a47c 100644 --- a/manim/camera/multi_camera.py +++ b/manim/camera/multi_camera.py @@ -5,7 +5,13 @@ __all__ = ["MultiCamera"] -from manim.mobject.types.image_mobject import ImageMobject +from collections.abc import Iterable +from typing import Any + +from typing_extensions import Self + +from manim.mobject.mobject import Mobject +from manim.mobject.types.image_mobject import ImageMobjectFromCamera from ..camera.moving_camera import MovingCamera from ..utils.iterables import list_difference_update @@ -16,10 +22,10 @@ class MultiCamera(MovingCamera): def __init__( self, - image_mobjects_from_cameras: ImageMobject | None = None, - allow_cameras_to_capture_their_own_display=False, - **kwargs, - ): + image_mobjects_from_cameras: Iterable[ImageMobjectFromCamera] | None = None, + allow_cameras_to_capture_their_own_display: bool = False, + **kwargs: Any, + ) -> None: """Initialises the MultiCamera Parameters @@ -29,7 +35,7 @@ def __init__( kwargs Any valid keyword arguments of MovingCamera. """ - self.image_mobjects_from_cameras = [] + self.image_mobjects_from_cameras: list[ImageMobjectFromCamera] = [] if image_mobjects_from_cameras is not None: for imfc in image_mobjects_from_cameras: self.add_image_mobject_from_camera(imfc) @@ -38,7 +44,9 @@ def __init__( ) super().__init__(**kwargs) - def add_image_mobject_from_camera(self, image_mobject_from_camera: ImageMobject): + def add_image_mobject_from_camera( + self, image_mobject_from_camera: ImageMobjectFromCamera + ) -> None: """Adds an ImageMobject that's been obtained from the camera into the list ``self.image_mobject_from_cameras`` @@ -53,20 +61,20 @@ def add_image_mobject_from_camera(self, image_mobject_from_camera: ImageMobject) assert isinstance(imfc.camera, MovingCamera) self.image_mobjects_from_cameras.append(imfc) - def update_sub_cameras(self): + def update_sub_cameras(self) -> None: """Reshape sub_camera pixel_arrays""" for imfc in self.image_mobjects_from_cameras: pixel_height, pixel_width = self.pixel_array.shape[:2] - imfc.camera.frame_shape = ( - imfc.camera.frame.height, - imfc.camera.frame.width, - ) + # imfc.camera.frame_shape = ( + # imfc.camera.frame.height, + # imfc.camera.frame.width, + # ) imfc.camera.reset_pixel_shape( int(pixel_height * imfc.height / self.frame_height), int(pixel_width * imfc.width / self.frame_width), ) - def reset(self): + def reset(self) -> Self: """Resets the MultiCamera. Returns @@ -79,7 +87,7 @@ def reset(self): super().reset() return self - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: self.update_sub_cameras() for imfc in self.image_mobjects_from_cameras: to_add = list(mobjects) @@ -88,7 +96,7 @@ def capture_mobjects(self, mobjects, **kwargs): imfc.camera.capture_mobjects(to_add, **kwargs) super().capture_mobjects(mobjects, **kwargs) - def get_mobjects_indicating_movement(self): + def get_mobjects_indicating_movement(self) -> list[Mobject]: """Returns all mobjects whose movement implies that the camera should think of all other mobjects on the screen as moving diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index b21879b90b..811581f8fb 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -126,7 +126,7 @@ def __init__( x_length: float | None = None, y_length: float | None = None, dimension: int = 2, - ) -> None: + ): self.dimension = dimension default_step = 1 @@ -153,11 +153,14 @@ def __init__( self.x_length = x_length self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 + self.x_axis: NumberLine - def coords_to_point(self, *coords: ManimFloat): + def coords_to_point(self, *coords: ManimFloat) -> Point3D: + # TODO: I think the method should be able to return more than just a single point. + # E.g. see the implementation of it on line 2065. raise NotImplementedError() - def point_to_coords(self, point: Point3DLike): + def point_to_coords(self, point: Point3DLike) -> list[ManimFloat]: raise NotImplementedError() def polar_to_point(self, radius: float, azimuth: float) -> Point2D: @@ -201,7 +204,7 @@ def point_to_polar(self, point: Point2DLike) -> Point2D: Returns ------- - Tuple[:class:`float`, :class:`float`] + Point2D The coordinate radius (:math:`r`) and the coordinate azimuth (:math:`\theta`). """ x, y = self.point_to_coords(point) @@ -213,7 +216,7 @@ def c2p( """Abbreviation for :meth:`coords_to_point`""" return self.coords_to_point(*coords) - def p2c(self, point: Point3DLike): + def p2c(self, point: Point3DLike) -> list[ManimFloat]: """Abbreviation for :meth:`point_to_coords`""" return self.point_to_coords(point) @@ -221,17 +224,18 @@ def pr2pt(self, radius: float, azimuth: float) -> np.ndarray: """Abbreviation for :meth:`polar_to_point`""" return self.polar_to_point(radius, azimuth) - def pt2pr(self, point: np.ndarray) -> tuple[float, float]: + def pt2pr(self, point: np.ndarray) -> Point2D: """Abbreviation for :meth:`point_to_polar`""" return self.point_to_polar(point) - def get_axes(self): + def get_axes(self) -> VGroup: raise NotImplementedError() - def get_axis(self, index: int) -> Mobject: - return self.get_axes()[index] + def get_axis(self, index: int) -> NumberLine: + val: NumberLine = self.get_axes()[index] + return val - def get_origin(self) -> np.ndarray: + def get_origin(self) -> Point3D: """Gets the origin of :class:`~.Axes`. Returns @@ -241,13 +245,13 @@ def get_origin(self) -> np.ndarray: """ return self.coords_to_point(0, 0) - def get_x_axis(self) -> Mobject: + def get_x_axis(self) -> NumberLine: return self.get_axis(0) - def get_y_axis(self) -> Mobject: + def get_y_axis(self) -> NumberLine: return self.get_axis(1) - def get_z_axis(self) -> Mobject: + def get_z_axis(self) -> NumberLine: return self.get_axis(2) def get_x_unit_size(self) -> float: @@ -258,11 +262,11 @@ def get_y_unit_size(self) -> float: def get_x_axis_label( self, - label: float | str | Mobject, - edge: Sequence[float] = UR, - direction: Sequence[float] = UR, + label: float | str | VMobject, + edge: Vector3D = UR, + direction: Vector3D = UR, buff: float = SMALL_BUFF, - **kwargs, + **kwargs: Any, ) -> Mobject: """Generate an x-axis label. @@ -301,11 +305,11 @@ def construct(self): def get_y_axis_label( self, - label: float | str | Mobject, - edge: Sequence[float] = UR, - direction: Sequence[float] = UP * 0.5 + RIGHT, + label: float | str | VMobject, + edge: Vector3D = UR, + direction: Vector3D = UP * 0.5 + RIGHT, buff: float = SMALL_BUFF, - **kwargs, + **kwargs: Any, ) -> Mobject: """Generate a y-axis label. @@ -347,10 +351,10 @@ def construct(self): def _get_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, axis: Mobject, - edge: Sequence[float], - direction: Sequence[float], + edge: Vector3D, + direction: Vector3D, buff: float = SMALL_BUFF, ) -> Mobject: """Gets the label for an axis. @@ -373,12 +377,14 @@ def _get_axis_label( :class:`~.Mobject` The positioned label along the given axis. """ - label = self.x_axis._create_label_tex(label) - label.next_to(axis.get_edge_center(edge), direction=direction, buff=buff) - label.shift_onto_screen(buff=MED_SMALL_BUFF) - return label + label_mobject: Mobject = self.x_axis._create_label_tex(label) + label_mobject.next_to( + axis.get_edge_center(edge), direction=direction, buff=buff + ) + label_mobject.shift_onto_screen(buff=MED_SMALL_BUFF) + return label_mobject - def get_axis_labels(self): + def get_axis_labels(self) -> VGroup: raise NotImplementedError() def add_coordinates( @@ -453,7 +459,7 @@ def add_coordinates( def get_line_from_axis_to_point( self, index: int, - point: Sequence[float], + point: Point3DLike, line_config: dict | None = ..., color: ParsableManimColor | None = ..., stroke_width: float = ..., @@ -463,7 +469,7 @@ def get_line_from_axis_to_point( def get_line_from_axis_to_point( self, index: int, - point: Sequence[float], + point: Point3DLike, line_func: type[LineType], line_config: dict | None = ..., color: ParsableManimColor | None = ..., @@ -518,7 +524,7 @@ def get_line_from_axis_to_point( # type: ignore[no-untyped-def] line = line_func(axis.get_projection(point), point, **line_config) return line - def get_vertical_line(self, point: Sequence[float], **kwargs: Any) -> Line: + def get_vertical_line(self, point: Point3DLike, **kwargs: Any) -> Line: """A vertical line from the x-axis to a given point in the scene. Parameters @@ -552,7 +558,7 @@ def construct(self): """ return self.get_line_from_axis_to_point(0, point, **kwargs) - def get_horizontal_line(self, point: Sequence[float], **kwargs) -> Line: + def get_horizontal_line(self, point: Point3DLike, **kwargs: Any) -> Line: """A horizontal line from the y-axis to a given point in the scene. Parameters @@ -584,7 +590,7 @@ def construct(self): """ return self.get_line_from_axis_to_point(1, point, **kwargs) - def get_lines_to_point(self, point: Sequence[float], **kwargs) -> VGroup: + def get_lines_to_point(self, point: Point3DLike, **kwargs: Any) -> VGroup: """Generate both horizontal and vertical lines from the axis to a point. Parameters @@ -630,7 +636,9 @@ def plot( function: Callable[[float], float], x_range: Sequence[float] | None = None, use_vectorized: bool = False, - colorscale: Union[Iterable[Color], Iterable[Color, float]] | None = None, + colorscale: Iterable[ParsableManimColor] + | Iterable[ParsableManimColor, float] + | None = None, colorscale_axis: int = 1, **kwargs: Any, ) -> ParametricFunction: @@ -1093,7 +1101,7 @@ def i2gp(self, x: float, graph: ParametricFunction) -> np.ndarray: def get_graph_label( self, graph: ParametricFunction, - label: float | str | Mobject = "f(x)", + label: float | str | VMobject = "f(x)", x_val: float | None = None, direction: Sequence[float] = RIGHT, buff: float = MED_SMALL_BUFF, @@ -1150,7 +1158,7 @@ def construct(self): dot_config = {} if color is None: color = graph.get_color() - label = self.x_axis._create_label_tex(label).set_color(color) + label_object: Mobject = self.x_axis._create_label_tex(label).set_color(color) if x_val is None: # Search from right to left @@ -1161,14 +1169,14 @@ def construct(self): else: point = self.input_to_graph_point(x_val, graph) - label.next_to(point, direction, buff=buff) - label.shift_onto_screen() + label_object.next_to(point, direction, buff=buff) + label_object.shift_onto_screen() if dot: dot = Dot(point=point, **dot_config) - label.add(dot) - label.dot = dot - return label + label_object.add(dot) + label_object.dot = dot + return label_object # calculus @@ -1176,14 +1184,14 @@ def get_riemann_rectangles( self, graph: ParametricFunction, x_range: Sequence[float] | None = None, - dx: float | None = 0.1, + dx: float = 0.1, input_sample_type: str = "left", stroke_width: float = 1, stroke_color: ParsableManimColor = BLACK, fill_opacity: float = 1, color: Iterable[ParsableManimColor] | ParsableManimColor = (BLUE, GREEN), show_signed_area: bool = True, - bounded_graph: ParametricFunction = None, + bounded_graph: ParametricFunction | None = None, blend: bool = False, width_scale_factor: float = 1.001, ) -> VGroup: @@ -1277,16 +1285,16 @@ def construct(self): x_range = [*x_range[:2], dx] rectangles = VGroup() - x_range = np.arange(*x_range) + x_range_array = np.arange(*x_range) if isinstance(color, (list, tuple)): color = [ManimColor(c) for c in color] else: color = [ManimColor(color)] - colors = color_gradient(color, len(x_range)) + colors = color_gradient(color, len(x_range_array)) - for x, color in zip(x_range, colors): + for x, color in zip(x_range_array, colors): if input_sample_type == "left": sample_input = x elif input_sample_type == "right": @@ -1341,7 +1349,7 @@ def get_area( x_range: tuple[float, float] | None = None, color: ParsableManimColor | Iterable[ParsableManimColor] = (BLUE, GREEN), opacity: float = 0.3, - bounded_graph: ParametricFunction = None, + bounded_graph: ParametricFunction | None = None, **kwargs: Any, ) -> Polygon: """Returns a :class:`~.Polygon` representing the area under the graph passed. @@ -1485,10 +1493,14 @@ def slope_of_tangent( ax.slope_of_tangent(x=-2, graph=curve) # -3.5000000259052038 """ - return np.tan(self.angle_of_tangent(x, graph, **kwargs)) + val: float = np.tan(self.angle_of_tangent(x, graph, **kwargs)) + return val def plot_derivative_graph( - self, graph: ParametricFunction, color: ParsableManimColor = GREEN, **kwargs + self, + graph: ParametricFunction, + color: ParsableManimColor = GREEN, + **kwargs: Any, ) -> ParametricFunction: """Returns the curve of the derivative of the passed graph. @@ -1526,7 +1538,7 @@ def construct(self): self.add(ax, curves, labels) """ - def deriv(x): + def deriv(x: float) -> float: return self.slope_of_tangent(x, graph) return self.plot(deriv, color=color, **kwargs) @@ -1587,7 +1599,7 @@ def antideriv(x): x_vals = np.linspace(0, x, samples, axis=1 if use_vectorized else 0) f_vec = np.vectorize(graph.underlying_function) y_vals = f_vec(x_vals) - return np.trapz(y_vals, x_vals) + y_intercept + return np.trapezoid(y_vals, x_vals) + y_intercept return self.plot(antideriv, use_vectorized=use_vectorized, **kwargs) @@ -1843,14 +1855,17 @@ def construct(self): return T_label_group - def __matmul__(self, coord: Point3DLike | Mobject): + def __matmul__(self, coord: Point3DLike | Mobject) -> Point3DLike: if isinstance(coord, Mobject): coord = coord.get_center() return self.coords_to_point(*coord) - def __rmatmul__(self, point: Point3DLike): + def __rmatmul__(self, point: Point3DLike) -> Point3DLike: return self.point_to_coords(point) + @staticmethod + def _origin_shift(axis_range: Sequence[float]) -> float: ... + class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): """Creates a set of axes. @@ -1918,7 +1933,7 @@ def __init__( y_axis_config: dict | None = None, tips: bool = True, **kwargs: Any, - ) -> None: + ): VGroup.__init__(self, **kwargs) CoordinateSystem.__init__(self, x_range, y_range, x_length, y_length) @@ -1926,8 +1941,11 @@ def __init__( "include_tip": tips, "numbers_to_exclude": [0], } - self.x_axis_config = {} - self.y_axis_config = {"rotation": 90 * DEGREES, "label_direction": LEFT} + self.x_axis_config: dict[str, Any] = {} + self.y_axis_config: dict[str, Any] = { + "rotation": 90 * DEGREES, + "label_direction": LEFT, + } self._update_default_configs( (self.axis_config, self.x_axis_config, self.y_axis_config), @@ -2416,12 +2434,12 @@ def __init__( z_axis_config: dict[str, Any] | None = None, z_normal: Vector3D = DOWN, num_axis_pieces: int = 20, - light_source: Sequence[float] = 9 * DOWN + 7 * LEFT + 10 * OUT, + light_source: Point3DLike = 9 * DOWN + 7 * LEFT + 10 * OUT, # opengl stuff (?) - depth=None, - gloss=0.5, + depth: Any = None, + gloss: float = 0.5, **kwargs: dict[str, Any], - ) -> None: + ): super().__init__( x_range=x_range, x_length=x_length, @@ -2433,7 +2451,7 @@ def __init__( self.z_range = z_range self.z_length = z_length - self.z_axis_config = {} + self.z_axis_config: dict[str, Any] = {} self._update_default_configs((self.z_axis_config,), (z_axis_config,)) self.z_axis_config = merge_dicts_recursively( self.axis_config, @@ -2443,7 +2461,7 @@ def __init__( self.z_normal = z_normal self.num_axis_pieces = num_axis_pieces - self.light_source = light_source + self.light_source = np.array(light_source) self.dimension = 3 @@ -2500,13 +2518,13 @@ def make_func(axis): def get_y_axis_label( self, - label: float | str | Mobject, - edge: Sequence[float] = UR, - direction: Sequence[float] = UR, + label: float | str | VMobject, + edge: Vector3D = UR, + direction: Vector3D = UR, buff: float = SMALL_BUFF, rotation: float = PI / 2, rotation_axis: Vector3D = OUT, - **kwargs, + **kwargs: dict[str, Any], ) -> Mobject: """Generate a y-axis label. @@ -2550,7 +2568,7 @@ def construct(self): def get_z_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, edge: Vector3D = OUT, direction: Vector3D = RIGHT, buff: float = SMALL_BUFF, @@ -2600,9 +2618,9 @@ def construct(self): def get_axis_labels( self, - x_label: float | str | Mobject = "x", - y_label: float | str | Mobject = "y", - z_label: float | str | Mobject = "z", + x_label: float | str | VMobject = "x", + y_label: float | str | VMobject = "y", + z_label: float | str | VMobject = "z", ) -> VGroup: """Defines labels for the x_axis and y_axis of the graph. @@ -2741,7 +2759,7 @@ def __init__( **kwargs: dict[str, Any], ): # configs - self.axis_config = { + self.axis_config: dict[str, Any] = { "stroke_width": 2, "include_ticks": False, "include_tip": False, @@ -2749,8 +2767,8 @@ def __init__( "label_direction": DR, "font_size": 24, } - self.y_axis_config = {"label_direction": DR} - self.background_line_style = { + self.y_axis_config: dict[str, Any] = {"label_direction": DR} + self.background_line_style: dict[str, Any] = { "stroke_color": BLUE_D, "stroke_width": 2, "stroke_opacity": 1, @@ -2997,7 +3015,7 @@ def __init__( size: float | None = None, radius_step: float = 1, azimuth_step: float | None = None, - azimuth_units: str | None = "PI radians", + azimuth_units: str = "PI radians", azimuth_compact_fraction: bool = True, azimuth_offset: float = 0, azimuth_direction: str = "CCW", @@ -3009,7 +3027,7 @@ def __init__( faded_line_ratio: int = 1, make_smooth_after_applying_functions: bool = True, **kwargs: Any, - ) -> None: + ): # error catching if azimuth_units in ["PI radians", "TAU radians", "degrees", "gradians", None]: self.azimuth_units = azimuth_units @@ -3130,11 +3148,11 @@ def _get_lines(self) -> tuple[VGroup, VGroup]: unit_vector = self.x_axis.get_unit_vector()[0] for k, x in enumerate(rinput): - new_line = Circle(radius=x * unit_vector) + new_circle = Circle(radius=x * unit_vector) if k % ratio_faded_lines == 0: - alines1.add(new_line) + alines1.add(new_circle) else: - alines2.add(new_line) + alines2.add(new_circle) line = Line(center, self.get_x_axis().get_end()) @@ -3292,7 +3310,9 @@ def add_coordinates( self.add(self.get_coordinate_labels(r_values, a_values)) return self - def get_radian_label(self, number, font_size: float = 24, **kwargs: Any) -> MathTex: + def get_radian_label( + self, number: float, font_size: float = 24, **kwargs: Any + ) -> MathTex: constant_label = {"PI radians": r"\pi", "TAU radians": r"\tau"}[ self.azimuth_units ] @@ -3361,7 +3381,7 @@ def construct(self): """ - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **kwargs: Any): super().__init__( **kwargs, ) diff --git a/manim/mobject/graphing/functions.py b/manim/mobject/graphing/functions.py index 83c48b1092..d125f45b6b 100644 --- a/manim/mobject/graphing/functions.py +++ b/manim/mobject/graphing/functions.py @@ -17,9 +17,12 @@ from manim.mobject.types.vectorized_mobject import VMobject if TYPE_CHECKING: + from typing import Any + from typing_extensions import Self from manim.typing import Point3D, Point3DLike + from manim.utils.color import ParsableManimColor from manim.utils.color import YELLOW @@ -111,7 +114,7 @@ def __init__( discontinuities: Iterable[float] | None = None, use_smoothing: bool = True, use_vectorized: bool = False, - **kwargs, + **kwargs: Any, ): def internal_parametric_function(t: float) -> Point3D: """Wrap ``function``'s output inside a NumPy array.""" @@ -143,13 +146,13 @@ def generate_points(self) -> Self: lambda t: self.t_min <= t <= self.t_max, self.discontinuities, ) - discontinuities = np.array(list(discontinuities)) + discontinuities_array = np.array(list(discontinuities)) boundary_times = np.array( [ self.t_min, self.t_max, - *(discontinuities - self.dt), - *(discontinuities + self.dt), + *(discontinuities_array - self.dt), + *(discontinuities_array + self.dt), ], ) boundary_times.sort() @@ -211,19 +214,27 @@ def construct(self): self.add(cos_func, sin_func_1, sin_func_2) """ - def __init__(self, function, x_range=None, color=YELLOW, **kwargs): + def __init__( + self, + function: Callable[[float], Any], + x_range: tuple[float, float] | tuple[float, float, float] | None = None, + color: ParsableManimColor = YELLOW, + **kwargs: Any, + ) -> None: if x_range is None: - x_range = np.array([-config["frame_x_radius"], config["frame_x_radius"]]) + x_range = (-config["frame_x_radius"], config["frame_x_radius"]) self.x_range = x_range - self.parametric_function = lambda t: np.array([t, function(t), 0]) - self.function = function + self.parametric_function: Callable[[float], Point3D] = lambda t: np.array( + [t, function(t), 0] + ) + self.function: Callable[[float], Any] = function super().__init__(self.parametric_function, self.x_range, color=color, **kwargs) - def get_function(self): + def get_function(self) -> Callable[[float], Any]: return self.function - def get_point_from_function(self, x): + def get_point_from_function(self, x: float) -> Point3D: return self.parametric_function(x) @@ -236,7 +247,7 @@ def __init__( min_depth: int = 5, max_quads: int = 1500, use_smoothing: bool = True, - **kwargs, + **kwargs: Any, ): """An implicit function. @@ -295,7 +306,7 @@ def construct(self): super().__init__(**kwargs) - def generate_points(self): + def generate_points(self) -> Self: p_min, p_max = ( np.array([self.x_range[0], self.y_range[0]]), np.array([self.x_range[1], self.y_range[1]]), diff --git a/manim/mobject/graphing/number_line.py b/manim/mobject/graphing/number_line.py index 017fac5bcb..14964bffc3 100644 --- a/manim/mobject/graphing/number_line.py +++ b/manim/mobject/graphing/number_line.py @@ -12,8 +12,12 @@ from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: + from typing import Any + + from typing_extensions import Self + from manim.mobject.geometry.tips import ArrowTip - from manim.typing import Point3DLike + from manim.typing import Point3D, Point3DLike, Vector3D import numpy as np @@ -21,8 +25,9 @@ from manim.constants import * from manim.mobject.geometry.line import Line from manim.mobject.graphing.scale import LinearBase, _ScaleBase -from manim.mobject.text.numbers import DecimalNumber +from manim.mobject.text.numbers import DecimalNumber, Integer from manim.mobject.text.tex_mobject import MathTex, Tex +from manim.mobject.text.text_mobject import Text from manim.mobject.types.vectorized_mobject import VGroup, VMobject from manim.utils.bezier import interpolate from manim.utils.config_ops import merge_dicts_recursively @@ -157,14 +162,14 @@ def __init__( # numbers/labels include_numbers: bool = False, font_size: float = 36, - label_direction: Sequence[float] = DOWN, - label_constructor: VMobject = MathTex, + label_direction: Point3DLike = DOWN, + label_constructor: type[MathTex] = MathTex, scaling: _ScaleBase = LinearBase(), line_to_number_buff: float = MED_SMALL_BUFF, decimal_number_config: dict | None = None, numbers_to_exclude: Iterable[float] | None = None, numbers_to_include: Iterable[float] | None = None, - **kwargs, + **kwargs: Any, ): # avoid mutable arguments in defaults if numbers_to_exclude is None: @@ -189,6 +194,9 @@ def __init__( # turn into a NumPy array to scale by just applying the function self.x_range = np.array(x_range, dtype=float) + self.x_min: float + self.x_max: float + self.x_step: float self.x_min, self.x_max, self.x_step = scaling.function(self.x_range) self.length = length self.unit_size = unit_size @@ -246,16 +254,16 @@ def __init__( if self.scaling.custom_labels: tick_range = self.get_tick_range() + custom_labels = self.scaling.get_custom_labels( + tick_range, + unit_decimal_places=decimal_number_config["num_decimal_places"], + ) + self.add_labels( dict( zip( tick_range, - self.scaling.get_custom_labels( - tick_range, - unit_decimal_places=decimal_number_config[ - "num_decimal_places" - ], - ), + custom_labels, ) ), ) @@ -267,21 +275,25 @@ def __init__( font_size=self.font_size, ) - def rotate_about_zero(self, angle: float, axis: Sequence[float] = OUT, **kwargs): + def rotate_about_zero( + self, angle: float, axis: Vector3D = OUT, **kwargs: Any + ) -> Self: return self.rotate_about_number(0, angle, axis, **kwargs) def rotate_about_number( - self, number: float, angle: float, axis: Sequence[float] = OUT, **kwargs - ): + self, number: float, angle: float, axis: Vector3D = OUT, **kwargs: Any + ) -> Self: return self.rotate(angle, axis, about_point=self.n2p(number), **kwargs) - def add_ticks(self): + def add_ticks(self) -> None: """Adds ticks to the number line. Ticks can be accessed after creation via ``self.ticks``. """ ticks = VGroup() elongated_tick_size = self.tick_size * self.longer_tick_multiple - elongated_tick_offsets = self.numbers_with_elongated_ticks - self.x_min + elongated_tick_offsets = ( + np.array(self.numbers_with_elongated_ticks) - self.x_min + ) for x in self.get_tick_range(): size = self.tick_size if np.any(np.isclose(x - self.x_min, elongated_tick_offsets)): @@ -413,31 +425,34 @@ def point_to_number(self, point: Sequence[float]) -> float: point = np.asarray(point) start, end = self.get_start_and_end() unit_vect = normalize(end - start) - proportion = np.dot(point - start, unit_vect) / np.dot(end - start, unit_vect) + proportion: float = np.dot(point - start, unit_vect) / np.dot( + end - start, unit_vect + ) return interpolate(self.x_min, self.x_max, proportion) - def n2p(self, number: float | np.ndarray) -> np.ndarray: + def n2p(self, number: float | np.ndarray) -> Point3D: """Abbreviation for :meth:`~.NumberLine.number_to_point`.""" return self.number_to_point(number) - def p2n(self, point: Sequence[float]) -> float: + def p2n(self, point: Point3DLike) -> float: """Abbreviation for :meth:`~.NumberLine.point_to_number`.""" return self.point_to_number(point) def get_unit_size(self) -> float: - return self.get_length() / (self.x_range[1] - self.x_range[0]) + val: float = self.get_length() / (self.x_range[1] - self.x_range[0]) + return val - def get_unit_vector(self) -> np.ndarray: + def get_unit_vector(self) -> Vector3D: return super().get_unit_vector() * self.unit_size def get_number_mobject( self, x: float, - direction: Sequence[float] | None = None, + direction: Vector3D | None = None, buff: float | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - **number_config, + label_constructor: type[MathTex] | None = None, + **number_config: dict[str, Any], ) -> VMobject: """Generates a positioned :class:`~.DecimalNumber` mobject generated according to ``label_constructor``. @@ -462,7 +477,7 @@ def get_number_mobject( :class:`~.DecimalNumber` The positioned mobject. """ - number_config = merge_dicts_recursively( + number_config_merged = merge_dicts_recursively( self.decimal_number_config, number_config, ) @@ -476,7 +491,10 @@ def get_number_mobject( label_constructor = self.label_constructor num_mob = DecimalNumber( - x, font_size=font_size, mob_class=label_constructor, **number_config + x, + font_size=font_size, + mob_class=label_constructor, + **number_config_merged, ) num_mob.next_to(self.number_to_point(x), direction=direction, buff=buff) @@ -485,7 +503,7 @@ def get_number_mobject( num_mob.shift(num_mob[0].width * LEFT / 2) return num_mob - def get_number_mobjects(self, *numbers, **kwargs) -> VGroup: + def get_number_mobjects(self, *numbers: float, **kwargs: Any) -> VGroup: if len(numbers) == 0: numbers = self.default_numbers_to_display() return VGroup([self.get_number_mobject(number, **kwargs) for number in numbers]) @@ -498,9 +516,9 @@ def add_numbers( x_values: Iterable[float] | None = None, excluding: Iterable[float] | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - **kwargs, - ): + label_constructor: type[MathTex] | None = None, + **kwargs: Any, + ) -> Self: """Adds :class:`~.DecimalNumber` mobjects representing their position at each tick of the number line. The numbers can be accessed after creation via ``self.numbers``. @@ -551,11 +569,11 @@ def add_numbers( def add_labels( self, dict_values: dict[float, str | float | VMobject], - direction: Sequence[float] = None, + direction: Point3DLike | None = None, buff: float | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - ): + label_constructor: type[MathTex] | None = None, + ) -> Self: """Adds specifically positioned labels to the :class:`~.NumberLine` using a ``dict``. The labels can be accessed after creation via ``self.labels``. @@ -598,6 +616,7 @@ def add_labels( label = self._create_label_tex(label, label_constructor) if hasattr(label, "font_size"): + assert isinstance(label, (MathTex, Tex, Text, Integer)), label label.font_size = font_size else: raise AttributeError(f"{label} is not compatible with add_labels.") @@ -612,7 +631,7 @@ def _create_label_tex( self, label_tex: str | float | VMobject, label_constructor: Callable | None = None, - **kwargs, + **kwargs: Any, ) -> VMobject: """Checks if the label is a :class:`~.VMobject`, otherwise, creates a label by passing ``label_tex`` to ``label_constructor``. @@ -633,24 +652,25 @@ def _create_label_tex( :class:`~.VMobject` The label. """ - if label_constructor is None: - label_constructor = self.label_constructor if isinstance(label_tex, (VMobject, OpenGLVMobject)): return label_tex - else: + if label_constructor is None: + label_constructor = self.label_constructor + if isinstance(label_tex, str): return label_constructor(label_tex, **kwargs) + return label_constructor(str(label_tex), **kwargs) @staticmethod - def _decimal_places_from_step(step) -> int: - step = str(step) - if "." not in step: + def _decimal_places_from_step(step: float) -> int: + step_str = str(step) + if "." not in step_str: return 0 - return len(step.split(".")[-1]) + return len(step_str.split(".")[-1]) - def __matmul__(self, other: float): + def __matmul__(self, other: float) -> Point3D: return self.n2p(other) - def __rmatmul__(self, other: Point3DLike | Mobject): + def __rmatmul__(self, other: Point3DLike | Mobject) -> float: if isinstance(other, Mobject): other = other.get_center() return self.p2n(other) @@ -659,10 +679,10 @@ def __rmatmul__(self, other: Point3DLike | Mobject): class UnitInterval(NumberLine): def __init__( self, - unit_size=10, - numbers_with_elongated_ticks=None, - decimal_number_config=None, - **kwargs, + unit_size: float = 10, + numbers_with_elongated_ticks: list[float] | None = None, + decimal_number_config: dict[str, Any] | None = None, + **kwargs: Any, ): numbers_with_elongated_ticks = ( [0, 1] diff --git a/manim/mobject/graphing/probability.py b/manim/mobject/graphing/probability.py index 24134c0a7a..309e0b7ec2 100644 --- a/manim/mobject/graphing/probability.py +++ b/manim/mobject/graphing/probability.py @@ -6,6 +6,7 @@ from collections.abc import Iterable, MutableSequence, Sequence +from typing import Any import numpy as np @@ -13,11 +14,11 @@ from manim.constants import * from manim.mobject.geometry.polygram import Rectangle from manim.mobject.graphing.coordinate_systems import Axes -from manim.mobject.mobject import Mobject -from manim.mobject.opengl.opengl_mobject import OpenGLMobject +from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject from manim.mobject.svg.brace import Brace from manim.mobject.text.tex_mobject import MathTex, Tex from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.typing import Vector3D from manim.utils.color import ( BLUE_E, DARK_GREY, @@ -54,13 +55,13 @@ def construct(self): def __init__( self, - height=3, - width=3, - fill_color=DARK_GREY, - fill_opacity=1, - stroke_width=0.5, - stroke_color=LIGHT_GREY, - default_label_scale_val=1, + height: float = 3, + width: float = 3, + fill_color: ParsableManimColor = DARK_GREY, + fill_opacity: float = 1, + stroke_width: float = 0.5, + stroke_color: ParsableManimColor = LIGHT_GREY, + default_label_scale_val: float = 1, ): super().__init__( height=height, @@ -72,7 +73,9 @@ def __init__( ) self.default_label_scale_val = default_label_scale_val - def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): + def add_title( + self, title: str = "Sample space", buff: float = MED_SMALL_BUFF + ) -> None: # TODO, should this really exist in SampleSpaceScene title_mob = Tex(title) if title_mob.width > self.width: @@ -81,23 +84,32 @@ def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): self.title = title_mob self.add(title_mob) - def add_label(self, label): + def add_label(self, label: str) -> None: self.label = label - def complete_p_list(self, p_list): - new_p_list = list(tuplify(p_list)) + def complete_p_list(self, p_list: float | Iterable[float]) -> list[float]: + p_list_tuplified: tuple[float] = tuplify(p_list) + new_p_list = list(p_list_tuplified) remainder = 1.0 - sum(new_p_list) if abs(remainder) > EPSILON: new_p_list.append(remainder) return new_p_list - def get_division_along_dimension(self, p_list, dim, colors, vect): - p_list = self.complete_p_list(p_list) - colors = color_gradient(colors, len(p_list)) + def get_division_along_dimension( + self, + p_list: float | Iterable[float], + dim: int, + colors: Sequence[ParsableManimColor], + vect: Vector3D, + ) -> VGroup: + p_list_complete = self.complete_p_list(p_list) + colors_in_gradient = color_gradient(colors, len(p_list_complete)) + + assert isinstance(colors_in_gradient, list) last_point = self.get_edge_center(-vect) parts = VGroup() - for factor, color in zip(p_list, colors): + for factor, color in zip(p_list_complete, colors_in_gradient): part = SampleSpace() part.set_fill(color, 1) part.replace(self, stretch=True) @@ -107,33 +119,43 @@ def get_division_along_dimension(self, p_list, dim, colors, vect): parts.add(part) return parts - def get_horizontal_division(self, p_list, colors=[GREEN_E, BLUE_E], vect=DOWN): + def get_horizontal_division( + self, + p_list: float | Iterable[float], + colors: Sequence[ParsableManimColor] = [GREEN_E, BLUE_E], + vect: Vector3D = DOWN, + ) -> VGroup: return self.get_division_along_dimension(p_list, 1, colors, vect) - def get_vertical_division(self, p_list, colors=[MAROON_B, YELLOW], vect=RIGHT): + def get_vertical_division( + self, + p_list: float | Iterable[float], + colors: Sequence[ParsableManimColor] = [MAROON_B, YELLOW], + vect: Vector3D = RIGHT, + ) -> VGroup: return self.get_division_along_dimension(p_list, 0, colors, vect) - def divide_horizontally(self, *args, **kwargs): + def divide_horizontally(self, *args: Any, **kwargs: Any) -> None: self.horizontal_parts = self.get_horizontal_division(*args, **kwargs) self.add(self.horizontal_parts) - def divide_vertically(self, *args, **kwargs): + def divide_vertically(self, *args: Any, **kwargs: Any) -> None: self.vertical_parts = self.get_vertical_division(*args, **kwargs) self.add(self.vertical_parts) def get_subdivision_braces_and_labels( self, - parts, - labels, - direction, - buff=SMALL_BUFF, - min_num_quads=1, - ): + parts: VGroup, + labels: list[str | VMobject | OpenGLVMobject], + direction: Vector3D, + buff: float = SMALL_BUFF, + min_num_quads: int = 1, + ) -> VGroup: label_mobs = VGroup() braces = VGroup() for label, part in zip(labels, parts): brace = Brace(part, direction, min_num_quads=min_num_quads, buff=buff) - if isinstance(label, (Mobject, OpenGLMobject)): + if isinstance(label, (VMobject, OpenGLVMobject)): label_mob = label else: label_mob = MathTex(label) @@ -141,34 +163,44 @@ def get_subdivision_braces_and_labels( label_mob.next_to(brace, direction, buff) braces.add(brace) + assert isinstance(label_mob, VMobject) label_mobs.add(label_mob) - parts.braces = braces - parts.labels = label_mobs - parts.label_kwargs = { + parts.braces = braces # type: ignore[attr-defined] + parts.labels = label_mobs # type: ignore[attr-defined] + parts.label_kwargs = { # type: ignore[attr-defined] "labels": label_mobs.copy(), "direction": direction, "buff": buff, } return VGroup(parts.braces, parts.labels) - def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs): + def get_side_braces_and_labels( + self, + labels: list[str | VMobject | OpenGLVMobject], + direction: Vector3D = LEFT, + **kwargs: Any, + ) -> VGroup: assert hasattr(self, "horizontal_parts") parts = self.horizontal_parts return self.get_subdivision_braces_and_labels( parts, labels, direction, **kwargs ) - def get_top_braces_and_labels(self, labels, **kwargs): + def get_top_braces_and_labels( + self, labels: list[str | VMobject | OpenGLVMobject], **kwargs: Any + ) -> VGroup: assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs) - def get_bottom_braces_and_labels(self, labels, **kwargs): + def get_bottom_braces_and_labels( + self, labels: list[str | VMobject | OpenGLVMobject], **kwargs: Any + ) -> VGroup: assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs) - def add_braces_and_labels(self): + def add_braces_and_labels(self) -> None: for attr in "horizontal_parts", "vertical_parts": if not hasattr(self, attr): continue @@ -177,11 +209,13 @@ def add_braces_and_labels(self): if hasattr(parts, subattr): self.add(getattr(parts, subattr)) - def __getitem__(self, index): + def __getitem__(self, index: int) -> SampleSpace: if hasattr(self, "horizontal_parts"): - return self.horizontal_parts[index] + val: SampleSpace = self.horizontal_parts[index] + return val elif hasattr(self, "vertical_parts"): - return self.vertical_parts[index] + val = self.vertical_parts[index] + return val return self.split()[index] @@ -253,7 +287,7 @@ def __init__( bar_width: float = 0.6, bar_fill_opacity: float = 0.7, bar_stroke_width: float = 3, - **kwargs, + **kwargs: Any, ): if isinstance(bar_colors, str): logger.warning( @@ -311,7 +345,7 @@ def __init__( self.y_axis.add_numbers() - def _update_colors(self): + def _update_colors(self) -> None: """Initialize the colors of the bars of the chart. Sets the color of ``self.bars`` via ``self.bar_colors``. @@ -321,13 +355,14 @@ def _update_colors(self): """ self.bars.set_color_by_gradient(*self.bar_colors) - def _add_x_axis_labels(self): + def _add_x_axis_labels(self) -> None: """Essentially :meth`:~.NumberLine.add_labels`, but differs in that the direction of the label with respect to the x_axis changes to UP or DOWN depending on the value. UP for negative values and DOWN for positive values. """ + assert isinstance(self.bar_names, list) val_range = np.arange( 0.5, len(self.bar_names), 1 ) # 0.5 shifted so that labels are centered, not on ticks @@ -338,7 +373,7 @@ def _add_x_axis_labels(self): # to accommodate negative bars, the label may need to be # below or above the x_axis depending on the value of the bar direction = UP if self.values[i] < 0 else DOWN - bar_name_label = self.x_axis.label_constructor(bar_name) + bar_name_label: MathTex = self.x_axis.label_constructor(bar_name) bar_name_label.font_size = self.x_axis.font_size bar_name_label.next_to( @@ -398,8 +433,8 @@ def get_bar_labels( color: ParsableManimColor | None = None, font_size: float = 24, buff: float = MED_SMALL_BUFF, - label_constructor: type[VMobject] = Tex, - ): + label_constructor: type[MathTex] = Tex, + ) -> VGroup: """Annotates each bar with its corresponding value. Use ``self.bar_labels`` to access the labels after creation. @@ -431,7 +466,7 @@ def construct(self): """ bar_labels = VGroup() for bar, value in zip(self.bars, self.values): - bar_lbl = label_constructor(str(value)) + bar_lbl: MathTex = label_constructor(str(value)) if color is None: bar_lbl.set_color(bar.get_fill_color()) @@ -446,7 +481,9 @@ def construct(self): return bar_labels - def change_bar_values(self, values: Iterable[float], update_colors: bool = True): + def change_bar_values( + self, values: Iterable[float], update_colors: bool = True + ) -> None: """Updates the height of the bars of the chart. Parameters @@ -512,4 +549,4 @@ def construct(self): if update_colors: self._update_colors() - self.values[: len(values)] = values + self.values[: len(list(values))] = values diff --git a/manim/mobject/graphing/scale.py b/manim/mobject/graphing/scale.py index ceda56f0a2..b6ed2b4ce3 100644 --- a/manim/mobject/graphing/scale.py +++ b/manim/mobject/graphing/scale.py @@ -2,7 +2,7 @@ import math from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import numpy as np @@ -11,7 +11,9 @@ from manim.mobject.text.numbers import Integer if TYPE_CHECKING: - from manim.mobject.mobject import Mobject + from typing import Callable + + from manim.mobject.types.vectorized_mobject import VMobject class _ScaleBase: @@ -26,6 +28,12 @@ class _ScaleBase: def __init__(self, custom_labels: bool = False): self.custom_labels = custom_labels + @overload + def function(self, value: float) -> float: ... + + @overload + def function(self, value: np.ndarray) -> np.ndarray: ... + def function(self, value: float) -> float: """The function that will be used to scale the values. @@ -59,7 +67,8 @@ def inverse_function(self, value: float) -> float: def get_custom_labels( self, val_range: Iterable[float], - ) -> Iterable[Mobject]: + **kw_args: Any, + ) -> Iterable[VMobject]: """Custom instructions for generating labels along an axis. Parameters @@ -147,12 +156,14 @@ def inverse_function(self, value: float) -> float: if isinstance(value, np.ndarray): condition = value.any() <= 0 + func: Callable[[float, float], float] + def func(value: float, base: float) -> float: return_value: float = np.log(value) / np.log(base) return return_value else: condition = value <= 0 - func = math.log # type: ignore[assignment] + func = math.log if condition: raise ValueError( @@ -179,7 +190,7 @@ def get_custom_labels( Additional arguments to be passed to :class:`~.Integer`. """ # uses `format` syntax to control the number of decimal places. - tex_labels = [ + tex_labels: list[Integer] = [ Integer( self.base, unit="^{%s}" % (f"{self.inverse_function(i):.{unit_decimal_places}f}"), # noqa: UP031 diff --git a/manim/mobject/types/image_mobject.py b/manim/mobject/types/image_mobject.py index f73e6a6475..baaa7d5a80 100644 --- a/manim/mobject/types/image_mobject.py +++ b/manim/mobject/types/image_mobject.py @@ -14,6 +14,7 @@ from manim.mobject.geometry.shape_matchers import SurroundingRectangle from ... import config +from ...camera.moving_camera import MovingCamera from ...constants import * from ...mobject.mobject import Mobject from ...utils.bezier import interpolate @@ -28,7 +29,9 @@ import numpy.typing as npt from typing_extensions import Self - from manim.typing import StrPath + from manim.typing import PixelArray, StrPath + + from ...camera.moving_camera import MovingCamera class AbstractImageMobject(Mobject): @@ -57,7 +60,7 @@ def __init__( self.set_resampling_algorithm(resampling_algorithm) super().__init__(**kwargs) - def get_pixel_array(self) -> None: + def get_pixel_array(self) -> PixelArray: raise NotImplementedError() def set_color(self, color, alpha=None, family=True): @@ -303,7 +306,7 @@ def get_style(self) -> dict[str, Any]: class ImageMobjectFromCamera(AbstractImageMobject): def __init__( self, - camera, + camera: MovingCamera, default_display_frame_config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: diff --git a/mypy.ini b/mypy.ini index 19cd3671a2..eeb6747e11 100644 --- a/mypy.ini +++ b/mypy.ini @@ -84,30 +84,15 @@ ignore_errors = True [mypy-manim.animation.updaters.mobject_update_utils] ignore_errors = True -[mypy-manim.camera.camera] -ignore_errors = True - [mypy-manim.camera.mapping_camera] ignore_errors = True [mypy-manim.camera.moving_camera] ignore_errors = True -[mypy-manim.camera.multi_camera] -ignore_errors = True - [mypy-manim.mobject.graphing.coordinate_systems] ignore_errors = True -[mypy-manim.mobject.graphing.functions] -ignore_errors = True - -[mypy-manim.mobject.graphing.number_line] -ignore_errors = True - -[mypy-manim.mobject.graphing.probability] -ignore_errors = True - [mypy-manim.mobject.graph] ignore_errors = True