jest.fn(fn => {
+ fn.cancel = jest.fn();
+ return fn;
+}));
+
+// Mock useSearchDirectories hook to avoid async state updates that cause act() warnings
+const mockSetDirectoryName = jest.fn();
+let mockHookState = {
+ directoryName: '',
+ directories: [],
+ channelDirectories: [],
+ domainDirectories: [],
+ isDir: false,
+ loading: false,
+};
+
+jest.mock('../hooks/customHooks', () => ({
+ ...jest.requireActual('../hooks/customHooks'),
+ useSearchDirectories: (value) => {
+ // Return current mock state - tests control state via setMockHookState
+ return {
+ ...mockHookState,
+ setDirectoryName: (newValue) => {
+ mockHookState.directoryName = newValue;
+ mockSetDirectoryName(newValue);
+ },
+ };
+ },
+}));
+
+describe('DirectorySearch', () => {
+ const mockOnSelect = jest.fn();
+
+ const mockSearchResults = {
+ directories: [
+ {path: 'videos/nature'},
+ {path: 'videos/tech'}
+ ],
+ channelDirectories: [
+ {path: 'videos/channels/news', name: 'News Channel'}
+ ],
+ domainDirectories: [
+ {path: 'archive/example.com', domain: 'example.com'}
+ ],
+ };
+
+ // Helper to reset mock hook state with specific values
+ const setMockHookState = (overrides = {}) => {
+ mockHookState = {
+ directoryName: '',
+ directories: [],
+ channelDirectories: [],
+ domainDirectories: [],
+ isDir: false,
+ loading: false,
+ ...overrides,
+ };
+ };
+
+ beforeEach(() => {
+ jest.clearAllMocks();
+ // Reset mock hook state with default search results
+ setMockHookState({
+ ...mockSearchResults,
+ isDir: false,
+ });
+ });
+
+ describe('Rendering', () => {
+ it('renders with placeholder text', () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ expect(input).toBeInTheDocument();
+ });
+
+ it('shows initial value when provided', () => {
+ setMockHookState({
+ ...mockSearchResults,
+ directoryName: 'videos/test',
+ });
+
+ render();
+
+ const input = screen.getByDisplayValue('videos/test');
+ expect(input).toBeInTheDocument();
+ });
+
+ it('applies disabled state correctly', () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ expect(input).toBeDisabled();
+ });
+
+ it('displays with required indicator', () => {
+ const {container} = render(
+
+ );
+
+ // Semantic UI doesn't add required attribute to Search input,
+ // but we verify the prop is passed
+ expect(container.querySelector('.ui.search')).toBeInTheDocument();
+ });
+ });
+
+ describe('Search Functionality', () => {
+ it('triggers setDirectoryName on value change', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.type(input, 'videos');
+
+ // Verify setDirectoryName was called (via mocked hook)
+ expect(mockSetDirectoryName).toHaveBeenCalled();
+ // Debounce is mocked, so each character triggers a call
+ // Verify it was called 6 times (one per character in "videos")
+ expect(mockSetDirectoryName.mock.calls.length).toBe(6);
+ });
+
+ it('shows loading indicator when loading state is true', () => {
+ setMockHookState({
+ ...mockSearchResults,
+ loading: true,
+ });
+
+ render();
+
+ const searchContainer = screen.getByPlaceholderText(/search directory names/i)
+ .closest('.ui.search');
+ expect(searchContainer).toHaveClass('loading');
+ });
+
+ it('hides loading indicator when loading state is false', () => {
+ setMockHookState({
+ ...mockSearchResults,
+ loading: false,
+ });
+
+ render();
+
+ const searchContainer = screen.getByPlaceholderText(/search directory names/i)
+ .closest('.ui.search');
+ expect(searchContainer).not.toHaveClass('loading');
+ });
+
+ it('displays categorized results', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ // Click to open dropdown
+ await userEvent.click(input);
+
+ await waitFor(() => {
+ // Should show category names
+ expect(screen.getAllByText(/Directories/i).length).toBeGreaterThan(0);
+ expect(screen.getAllByText(/Channels/i).length).toBeGreaterThan(0);
+ expect(screen.getAllByText(/Domains/i).length).toBeGreaterThan(0);
+ });
+ });
+
+ it('shows "New Directory" when path doesn\'t exist (isDir=false)', async () => {
+ setMockHookState({
+ directories: [],
+ channelDirectories: [],
+ domainDirectories: [],
+ isDir: false,
+ directoryName: 'new/path',
+ });
+
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.click(input);
+
+ await waitFor(() => {
+ expect(screen.getByText(/New Directory/i)).toBeInTheDocument();
+ });
+ });
+
+ it('hides "New Directory" when path exists (isDir=true)', async () => {
+ setMockHookState({
+ directories: [{path: 'videos/nature/wildlife'}],
+ channelDirectories: [],
+ domainDirectories: [],
+ isDir: true,
+ directoryName: 'videos/nature',
+ });
+
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.click(input);
+
+ // "New Directory" should not appear when is_dir=true
+ expect(screen.queryByText(/New Directory/i)).not.toBeInTheDocument();
+ });
+
+ it('debounces rapid typing (verifies setDirectoryName is called)', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+
+ // Type rapidly
+ await userEvent.type(input, 'abc', {delay: 10});
+
+ // Verify setDirectoryName was called
+ expect(mockSetDirectoryName).toHaveBeenCalled();
+ });
+ });
+
+ describe('User Interactions', () => {
+ it('calls onSelect when result is clicked', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.click(input);
+
+ await waitFor(() => {
+ expect(screen.getByText('videos/nature')).toBeInTheDocument();
+ });
+
+ // Click on a result
+ const result = screen.getByText('videos/nature');
+ await userEvent.click(result);
+
+ expect(mockOnSelect).toHaveBeenCalledWith('videos/nature');
+ });
+
+ it('commits typed value on blur when directoryName differs from value', async () => {
+ // Set up mock state where directoryName differs from the prop value
+ // This simulates what happens after user types in the input
+ setMockHookState({
+ ...mockSearchResults,
+ directoryName: 'typed/path', // User has typed this
+ });
+
+ // Render with empty value prop (different from directoryName)
+ const {container} = render();
+
+ // Find the Search component container and trigger blur on it
+ const searchComponent = container.querySelector('.ui.search');
+
+ // Blur the Search component - should trigger onBlur which calls onSelect
+ await act(async () => {
+ // Use fireEvent.blur which better simulates the Semantic UI Search blur behavior
+ const {fireEvent} = require('@testing-library/react');
+ fireEvent.blur(searchComponent);
+ });
+
+ // Should call onSelect with directoryName from hook state
+ expect(mockOnSelect).toHaveBeenCalledWith('typed/path');
+ });
+
+ it('does not call onSelect on blur if value unchanged', async () => {
+ setMockHookState({
+ ...mockSearchResults,
+ directoryName: 'existing/path',
+ });
+
+ render();
+
+ const input = screen.getByDisplayValue('existing/path');
+
+ // Blur without changing value
+ await act(async () => {
+ input.blur();
+ });
+
+ // Should not call onSelect since value didn't change
+ expect(mockOnSelect).not.toHaveBeenCalled();
+ });
+
+ it('disabled state prevents interactions', () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+
+ // Input should be disabled
+ expect(input).toBeDisabled();
+ });
+
+ it('handles rapid selection changes', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.click(input);
+
+ await waitFor(() => {
+ expect(screen.getByText('videos/nature')).toBeInTheDocument();
+ });
+
+ // Click multiple results in succession
+ await userEvent.click(screen.getByText('videos/nature'));
+ await userEvent.click(screen.getByText('videos/tech'));
+
+ // Should call onSelect for each selection
+ expect(mockOnSelect).toHaveBeenCalledWith('videos/nature');
+ expect(mockOnSelect).toHaveBeenCalledWith('videos/tech');
+ });
+ });
+
+ describe('Hook Integration', () => {
+ it('calls setDirectoryName from hook on search change', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.type(input, 'archive');
+
+ expect(mockSetDirectoryName).toHaveBeenCalled();
+ });
+
+ it('displays results from hook state', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ await userEvent.click(input);
+
+ await waitFor(() => {
+ // Should display results from mock hook state
+ expect(screen.getByText('videos/nature')).toBeInTheDocument();
+ expect(screen.getByText('videos/tech')).toBeInTheDocument();
+ expect(screen.getByText('News Channel')).toBeInTheDocument();
+ });
+ });
+ });
+
+ describe('Edge Cases', () => {
+ it('handles null/undefined initial value', () => {
+ setMockHookState({
+ ...mockSearchResults,
+ directoryName: '',
+ });
+
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ expect(input).toHaveValue('');
+ });
+
+ it('clears results display when empty', async () => {
+ setMockHookState({
+ directories: [],
+ channelDirectories: [],
+ domainDirectories: [],
+ isDir: false,
+ directoryName: '',
+ });
+
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+ expect(input).toHaveValue('');
+ });
+
+ it('maintains value when component remounts', () => {
+ setMockHookState({
+ ...mockSearchResults,
+ directoryName: 'videos/test',
+ });
+
+ const {rerender} = render(
+
+ );
+
+ expect(screen.getByDisplayValue('videos/test')).toBeInTheDocument();
+
+ // Remount with same value
+ rerender();
+
+ expect(screen.getByDisplayValue('videos/test')).toBeInTheDocument();
+ });
+
+ it('handles special characters in path', async () => {
+ render();
+
+ const input = screen.getByPlaceholderText(/search directory names/i);
+
+ // Type path with special characters
+ const specialPath = 'videos/test-folder_2024/v1.0';
+ await userEvent.type(input, specialPath);
+
+ expect(mockSetDirectoryName).toHaveBeenCalled();
+ });
+ });
+});
diff --git a/app/src/components/DomainEditPage.test.js b/app/src/components/DomainEditPage.test.js
new file mode 100644
index 000000000..5fd819b88
--- /dev/null
+++ b/app/src/components/DomainEditPage.test.js
@@ -0,0 +1,446 @@
+import React from 'react';
+import {render, renderInDarkMode, screen, waitFor, createTestForm} from '../test-utils';
+import userEvent from '@testing-library/user-event';
+import {DomainEditPage} from './Archive';
+import {createMockDomain, createMockMetadata} from '../test-utils';
+
+// Mock useParams to return domain ID
+jest.mock('react-router-dom', () => ({
+ ...jest.requireActual('react-router-dom'),
+ useParams: () => ({domainId: '1'}),
+ useNavigate: () => jest.fn(),
+}));
+
+// Mock the useDomain hook
+const mockUseDomain = jest.fn();
+const mockUseCollectionMetadata = jest.fn();
+
+jest.mock('../hooks/customHooks', () => ({
+ ...jest.requireActual('../hooks/customHooks'),
+ useDomain: (...args) => mockUseDomain(...args),
+ useCollectionMetadata: (...args) => mockUseCollectionMetadata(...args),
+}));
+
+// Mock useTitle
+jest.mock('./Common', () => ({
+ ...jest.requireActual('./Common'),
+ useTitle: jest.fn(),
+}));
+
+// Mock CollectionEditForm
+jest.mock('./collections/CollectionEditForm', () => ({
+ CollectionEditForm: ({form, metadata, title, actionButtons}) => (
+
+ {title &&
{title}
}
+ {form?.loading &&
Loading...
}
+ {form?.formData &&
Collection data loaded
}
+ {metadata &&
Metadata loaded
}
+ {actionButtons &&
{actionButtons}
}
+
+ ),
+}));
+
+// Mock CollectionTagModal
+jest.mock('./collections/CollectionTagModal', () => ({
+ CollectionTagModal: ({open, onClose, currentTagName, originalDirectory, getTagInfo, onSave, collectionName}) => {
+ if (!open) return null;
+ return (
+
+
{currentTagName ? 'Modify Tag' : 'Add Tag'}
+
+
+
+
+ );
+ },
+}));
+
+// Mock API functions
+const mockGetCollectionTagInfo = jest.fn();
+jest.mock('../api', () => ({
+ ...jest.requireActual('../api'),
+ getCollectionTagInfo: (...args) => mockGetCollectionTagInfo(...args),
+ tagDomain: jest.fn(),
+}));
+
+describe('DomainEditPage', () => {
+ const mockMetadata = createMockMetadata();
+
+ beforeEach(() => {
+ jest.clearAllMocks();
+ mockUseCollectionMetadata.mockReturnValue({metadata: mockMetadata});
+ });
+
+ describe('Loading States', () => {
+ it('handles loading state while fetching domain', async () => {
+ // Start with loading state (no domain yet)
+ const form = createTestForm({}, {
+ overrides: {ready: false, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: null,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // Should show Semantic UI Loader with text
+ expect(screen.getByText(/loading domain/i)).toBeInTheDocument();
+
+ // Form should not be visible during initial load
+ expect(screen.queryByTestId('collection-edit-form')).not.toBeInTheDocument();
+ });
+
+ it('shows form when domain is loaded', () => {
+ const mockDomain = createMockDomain({
+ domain: 'test.com',
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // Should NOT show loading message
+ expect(screen.queryByText(/loading domain/i)).not.toBeInTheDocument();
+
+ // Should show domain name (may appear multiple times in header and form)
+ expect(screen.getAllByText(/test\.com/i).length).toBeGreaterThan(0);
+ });
+
+ it('passes loading state to form during submission', () => {
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ });
+
+ // Domain is loaded but form is submitting
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: true}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // Should show loading indicator in form (from mocked component)
+ // Multiple indicators due to Fresnel rendering for mobile and tablet+
+ expect(screen.getAllByTestId('loading-indicator').length).toBeGreaterThan(0);
+
+ // Should also show the domain name
+ expect(screen.getAllByText(/example\.com/i).length).toBeGreaterThan(0);
+ });
+ });
+
+ describe('Error States', () => {
+ it('shows loader when form is not ready (fetch fails)', () => {
+ // When form.ready is false (e.g., fetch failed), show loader
+ const form = createTestForm({}, {
+ overrides: {ready: false, loading: false, error: new Error('Domain not found')}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: null,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // Should show loading screen when not ready
+ expect(screen.getByText(/loading domain/i)).toBeInTheDocument();
+
+ // Form should not be rendered
+ expect(screen.queryByTestId('collection-edit-form')).not.toBeInTheDocument();
+ });
+
+ it('shows form when ready even if there was a submission error', () => {
+ // Form is ready (domain loaded) but submission may have failed
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false, error: new Error('Update failed')}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // Should show domain name (form is rendered)
+ // Multiple occurrences due to Fresnel rendering for mobile and tablet+
+ expect(screen.getAllByText(/example\.com/i).length).toBeGreaterThan(0);
+
+ // Should NOT show the initial loading screen
+ expect(screen.queryByText(/loading domain/i)).not.toBeInTheDocument();
+ });
+ });
+
+ describe('Page Title', () => {
+ it('sets page title with domain name', () => {
+ const {useTitle} = require('./Common');
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // useTitle should be called with domain name
+ expect(useTitle).toHaveBeenCalledWith('Edit Domain: example.com');
+ });
+
+ it('sets page title with placeholder while loading', () => {
+ const {useTitle} = require('./Common');
+
+ const form = createTestForm({}, {
+ overrides: {ready: false, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: null,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // useTitle should be called with placeholder
+ expect(useTitle).toHaveBeenCalledWith('Edit Domain: ...');
+ });
+ });
+
+ describe('Theme Integration', () => {
+ it('passes theme context to CollectionEditForm in dark mode', () => {
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ // Render in dark mode
+ renderInDarkMode();
+
+ // CollectionEditForm should be rendered
+ expect(screen.getByTestId('collection-edit-form')).toBeInTheDocument();
+
+ // The title should include the domain name
+ expect(screen.getAllByText(/example\.com/i).length).toBeGreaterThan(0);
+ });
+
+ it('renders properly in light mode', () => {
+ const mockDomain = createMockDomain({
+ domain: 'test.com',
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ // Default render (light mode)
+ render();
+
+ // CollectionEditForm should be rendered
+ expect(screen.getByTestId('collection-edit-form')).toBeInTheDocument();
+
+ // The title should include the domain name
+ expect(screen.getAllByText(/test\.com/i).length).toBeGreaterThan(0);
+ });
+ });
+
+ describe('Tag Modal and Directory Suggestions', () => {
+ beforeEach(() => {
+ jest.clearAllMocks();
+ mockGetCollectionTagInfo.mockClear();
+ });
+
+ it('suggests directory when tag is selected', async () => {
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ id: 1,
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ // Mock successful tag info response
+ mockGetCollectionTagInfo.mockResolvedValue({
+ suggested_directory: 'archive/WROL/example.com',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ render();
+
+ // Click the Tag button to open modal
+ const tagButton = screen.getByText('Tag');
+ await userEvent.click(tagButton);
+
+ // Wait for modal to open
+ await waitFor(() => {
+ expect(screen.getByText(/Modify Tag|Add Tag/i)).toBeInTheDocument();
+ });
+
+ // Simulate selecting a tag (this would normally be done by TagsSelector)
+ // Since TagsSelector is a real component, we need to wait for the API call
+ // We'll verify the API was called when a tag would be selected
+ // This test validates the structure is in place
+ });
+
+ it('displays conflict warning in modal when directory conflict exists', async () => {
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ id: 1,
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ // Mock tag info response with conflict
+ mockGetCollectionTagInfo.mockResolvedValue({
+ suggested_directory: 'archive/WROL/example.com',
+ conflict: true,
+ conflict_message: "A domain collection 'other.com' already uses this directory. Choose a different tag or directory.",
+ });
+
+ render();
+
+ // Open the modal
+ const tagButton = screen.getByText('Tag');
+ await userEvent.click(tagButton);
+
+ // The modal should be present
+ await waitFor(() => {
+ expect(screen.getByTestId('collection-tag-modal')).toBeInTheDocument();
+ });
+
+ // Verify the modal structure includes the necessary inputs
+ expect(screen.getByTestId('directory-input')).toBeInTheDocument();
+ });
+
+ it('clears conflict message when tag is changed', async () => {
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ id: 1,
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ render();
+
+ // Open modal
+ const tagButton = screen.getByText('Tag');
+ await userEvent.click(tagButton);
+
+ await waitFor(() => {
+ expect(screen.getByText(/Modify Tag|Add Tag/i)).toBeInTheDocument();
+ });
+
+ // Modal should be open and ready for tag selection
+ // The actual tag selection and conflict clearing would be tested in integration tests
+ // This validates the structure is present
+ });
+
+ it('populates directory input with suggested directory', async () => {
+ const mockDomain = createMockDomain({
+ domain: 'example.com',
+ id: 1,
+ directory: 'archive/example.com',
+ });
+
+ const form = createTestForm(mockDomain, {
+ overrides: {ready: true, loading: false}
+ });
+
+ mockUseDomain.mockReturnValue({
+ domain: mockDomain,
+ form,
+ fetchDomain: jest.fn(),
+ });
+
+ mockGetCollectionTagInfo.mockResolvedValue({
+ suggested_directory: 'archive/WROL/example.com',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ render();
+
+ // Open modal
+ const tagButton = screen.getByText('Tag');
+ await userEvent.click(tagButton);
+
+ // Wait for modal to open
+ await waitFor(() => {
+ expect(screen.getByTestId('collection-tag-modal')).toBeInTheDocument();
+ });
+
+ // Verify directory input field exists with original directory value
+ const directoryInput = screen.getByTestId('directory-input');
+ expect(directoryInput).toBeInTheDocument();
+ expect(directoryInput).toHaveValue('archive/example.com');
+ });
+ });
+});
diff --git a/app/src/components/DomainTagging.test.js b/app/src/components/DomainTagging.test.js
new file mode 100644
index 000000000..220bf3463
--- /dev/null
+++ b/app/src/components/DomainTagging.test.js
@@ -0,0 +1,345 @@
+import React from 'react';
+import {render, screen, createTestForm} from '../test-utils';
+import {CollectionEditForm} from './collections/CollectionEditForm';
+import {DomainsPage} from './Archive';
+import {createMockDomain, createMockMetadata, createMockDomains} from '../test-utils';
+
+// Mock the TagsSelector component and TagsContext
+jest.mock('../Tags', () => ({
+ TagsSelector: ({selectedTagNames, onChange, disabled}) => (
+
+ onChange(e.target.value ? [e.target.value] : [])}
+ disabled={disabled}
+ />
+
+ ),
+ TagsContext: {
+ _currentValue: {
+ SingleTag: ({name}) => {name}
+ }
+ },
+}));
+
+// Mock the DirectorySearch and DestinationForm components
+jest.mock('./Common', () => ({
+ ...jest.requireActual('./Common'),
+ DirectorySearch: ({value, onSelect, placeholder}) => (
+ onSelect(e.target.value)}
+ placeholder={placeholder}
+ />
+ ),
+ SearchInput: ({placeholder, searchStr, onChange, disabled}) => (
+ onChange(e.target.value)}
+ disabled={disabled}
+ />
+ ),
+ ErrorMessage: ({children}) => {children}
,
+ useTitle: jest.fn(),
+}));
+
+// Mock DestinationForm (used for directory field)
+jest.mock('./Download', () => ({
+ DestinationForm: ({form, label, name}) => (
+
+
+ form.setValue(name, e.target.value)}
+ />
+
+ ),
+}));
+
+// Mock hooks for DomainsPage tests
+const mockUseDomains = jest.fn();
+const mockUseOneQuery = jest.fn(() => ['', jest.fn()]);
+
+jest.mock('../hooks/customHooks', () => ({
+ ...jest.requireActual('../hooks/customHooks'),
+ useDomains: (...args) => mockUseDomains(...args),
+ useOneQuery: (...args) => mockUseOneQuery(...args),
+}));
+
+// Mock CollectionTable for DomainsPage tests
+jest.mock('./collections/CollectionTable', () => ({
+ CollectionTable: ({collections}) => (
+
+
+
+ {collections?.map((domain) => (
+
+ | {domain.domain} |
+
+ {domain.tag_name || 'No tag'}
+ |
+
+ ))}
+
+
+
+ ),
+}));
+
+describe('Domain Tagging Logic', () => {
+ const mockMetadata = createMockMetadata();
+
+ describe('Tag Field Dependency on Directory', () => {
+ it('prevents tagging domain without directory', () => {
+ const domainWithoutDirectory = createMockDomain({
+ directory: '',
+ tag_name: null,
+ can_be_tagged: false
+ });
+
+ const form = createTestForm(domainWithoutDirectory);
+
+ render(
+
+ );
+
+ // Should show dependency warning
+ expect(screen.getByText(/set a directory to enable tagging/i)).toBeInTheDocument();
+
+ // Tag selector should be disabled
+ const tagSelector = screen.getByTestId('tags-selector');
+ expect(tagSelector).toHaveAttribute('data-disabled', 'true');
+
+ const tagInput = screen.getByTestId('tags-input');
+ expect(tagInput).toBeDisabled();
+ });
+
+ it('enables tagging when directory is set', () => {
+ const domainWithDirectory = createMockDomain({
+ directory: 'archive/example.com',
+ tag_name: null,
+ can_be_tagged: true
+ });
+
+ const form = createTestForm(domainWithDirectory);
+
+ render(
+
+ );
+
+ // Dependency warning should not be shown
+ expect(screen.queryByText(/set a directory to enable tagging/i)).not.toBeInTheDocument();
+
+ // Tag selector should be enabled
+ const tagSelector = screen.getByTestId('tags-selector');
+ expect(tagSelector).toHaveAttribute('data-disabled', 'false');
+
+ const tagInput = screen.getByTestId('tags-input');
+ expect(tagInput).not.toBeDisabled();
+ });
+
+ it('shows warning when directory is set (enabling tagging)', () => {
+ const domainWithDirectory = createMockDomain({
+ directory: 'archive/example.com',
+ tag_name: null,
+ can_be_tagged: true
+ });
+
+ const form = createTestForm(domainWithDirectory);
+
+ render(
+
+ );
+
+ // Tag selector should be available (not disabled)
+ const tagSelector = screen.getByTestId('tags-selector');
+ expect(tagSelector).toHaveAttribute('data-disabled', 'false');
+
+ // Dependency message should not appear
+ expect(screen.queryByText(/set a directory to enable tagging/i)).not.toBeInTheDocument();
+ });
+ });
+
+ describe('Tag Display in Domains List', () => {
+ const {useTitle} = require('./Common');
+
+ beforeEach(() => {
+ // Reset mocks
+ mockUseDomains.mockReset();
+ mockUseOneQuery.mockReset();
+ useTitle.mockReset();
+
+ // Re-setup default mocks
+ useTitle.mockImplementation(() => {});
+ mockUseOneQuery.mockReturnValue(['', jest.fn()]);
+ });
+
+ it('displays tag in domains list after tagging', () => {
+ const mockDomains = [
+ createMockDomain({
+ id: 1,
+ domain: 'example.com',
+ tag_name: 'News', // Tagged domain
+ directory: 'archive/example.com',
+ can_be_tagged: true
+ }),
+ createMockDomain({
+ id: 2,
+ domain: 'test.org',
+ tag_name: null, // Untagged domain
+ directory: '',
+ can_be_tagged: false
+ }),
+ ];
+
+ mockUseDomains.mockReturnValue([mockDomains, 2, mockMetadata]);
+
+ render();
+
+ // Tagged domain should display its tag
+ expect(screen.getByTestId('domain-tag-1')).toHaveTextContent('News');
+
+ // Untagged domain should show "No tag"
+ expect(screen.getByTestId('domain-tag-2')).toHaveTextContent('No tag');
+ });
+
+ it('displays multiple tagged domains correctly', () => {
+ const mockDomains = [
+ createMockDomain({
+ id: 1,
+ domain: 'news.com',
+ tag_name: 'News',
+ directory: 'archive/news.com',
+ can_be_tagged: true
+ }),
+ createMockDomain({
+ id: 2,
+ domain: 'tech.com',
+ tag_name: 'Tech',
+ directory: 'archive/tech.com',
+ can_be_tagged: true
+ }),
+ createMockDomain({
+ id: 3,
+ domain: 'science.com',
+ tag_name: 'Science',
+ directory: 'archive/science.com',
+ can_be_tagged: true
+ }),
+ ];
+
+ mockUseDomains.mockReturnValue([mockDomains, 3, mockMetadata]);
+
+ render();
+
+ // All domains should display their respective tags
+ expect(screen.getByTestId('domain-tag-1')).toHaveTextContent('News');
+ expect(screen.getByTestId('domain-tag-2')).toHaveTextContent('Tech');
+ expect(screen.getByTestId('domain-tag-3')).toHaveTextContent('Science');
+ });
+ });
+
+ describe('Tag Warning Messages', () => {
+ it('shows "tagging will move files" warning when tag is set', () => {
+ const domainWithTag = createMockDomain({
+ directory: 'archive/example.com',
+ tag_name: 'News',
+ can_be_tagged: true
+ });
+
+ const form = createTestForm(domainWithTag);
+
+ render(
+
+ );
+
+ // Should show move warning when tag is present
+ expect(screen.getByText(/tagging will move files/i)).toBeInTheDocument();
+ });
+
+ it('does not show move warning when no tag is set', () => {
+ const domainWithoutTag = createMockDomain({
+ directory: 'archive/example.com',
+ tag_name: null,
+ can_be_tagged: true
+ });
+
+ const form = createTestForm(domainWithoutTag);
+
+ render(
+
+ );
+
+ // Should not show move warning when tag is absent
+ expect(screen.queryByText(/tagging will move files/i)).not.toBeInTheDocument();
+ });
+ });
+
+ describe('Tag Clearing Submission Bug', () => {
+ it('should send empty string (not null) when clearing tag', async () => {
+ // Create a domain with a tag
+ const domainWithTag = createMockDomain({
+ id: 1,
+ domain: 'example.com',
+ directory: 'archive/example.com',
+ tag_name: 'News',
+ can_be_tagged: true
+ });
+
+ // Create form with the domain data
+ const form = createTestForm(domainWithTag);
+
+ // Simulate clearing the tag
+ form.setValue('tag_name', null);
+
+ // Verify form has null
+ expect(form.formData.tag_name).toBe(null);
+
+ // Mock updateDomain to track what it's called with
+ const mockUpdateDomain = jest.fn().mockResolvedValue({ok: true});
+
+ // Mock onSubmit to call updateDomain like useDomain does (with fix)
+ form.onSubmit = jest.fn(async () => {
+ const body = {
+ directory: form.formData.directory,
+ description: form.formData.description,
+ // FIX: Convert null to empty string - backend expects "" to clear tag
+ tag_name: form.formData.tag_name === null ? '' : form.formData.tag_name,
+ };
+ return await mockUpdateDomain(1, body);
+ });
+
+ // Submit the form
+ await form.onSubmit();
+
+ // Verify updateDomain was called
+ expect(mockUpdateDomain).toHaveBeenCalledTimes(1);
+
+ // Verify tag_name is correctly converted from null to "" for the API
+ expect(mockUpdateDomain).toHaveBeenCalledWith(1, {
+ directory: 'archive/example.com',
+ description: '',
+ tag_name: '', // Empty string clears the tag (null is converted)
+ });
+ });
+ });
+});
diff --git a/app/src/components/DomainsPage.test.js b/app/src/components/DomainsPage.test.js
new file mode 100644
index 000000000..5ee7c5a0d
--- /dev/null
+++ b/app/src/components/DomainsPage.test.js
@@ -0,0 +1,236 @@
+import React from 'react';
+import {render, screen, waitFor} from '../test-utils';
+import {DomainsPage} from './Archive';
+import {createMockDomains, createMockMetadata} from '../test-utils';
+
+// Mock the custom hooks
+jest.mock('../hooks/customHooks', () => ({
+ ...jest.requireActual('../hooks/customHooks'),
+ useDomains: jest.fn(),
+ useOneQuery: jest.fn(),
+}));
+
+// Mock CollectionTable component
+jest.mock('./collections/CollectionTable', () => ({
+ CollectionTable: ({collections, metadata, searchStr}) => (
+
+
{collections?.length || 0}
+
{searchStr}
+ {collections?.map((domain) => (
+
+ {domain.domain}
+
+
+ ))}
+
+ ),
+}));
+
+// Mock SearchInput component and useTitle
+jest.mock('./Common', () => ({
+ ...jest.requireActual('./Common'),
+ SearchInput: ({placeholder, searchStr, onChange, disabled}) => (
+ onChange(e.target.value)}
+ disabled={disabled}
+ />
+ ),
+ ErrorMessage: ({children}) => {children}
,
+ useTitle: jest.fn(),
+}));
+
+describe('DomainsPage', () => {
+ const {useDomains, useOneQuery} = require('../hooks/customHooks');
+ const {useTitle} = require('./Common');
+ const mockMetadata = createMockMetadata();
+
+ beforeEach(() => {
+ // Reset mocks before each test
+ jest.clearAllMocks();
+
+ // Default mock implementations
+ useTitle.mockImplementation(() => {});
+ useOneQuery.mockReturnValue(['', jest.fn()]);
+ });
+
+ describe('Page Rendering', () => {
+ it('displays domains page without errors', () => {
+ const mockDomains = createMockDomains(3);
+ useDomains.mockReturnValue([mockDomains, 3, mockMetadata]);
+
+ render();
+
+ // Page should render without crashing
+ expect(screen.getByTestId('search-input')).toBeInTheDocument();
+ expect(screen.getByTestId('collection-table')).toBeInTheDocument();
+ });
+
+ it('renders CollectionTable component', () => {
+ const mockDomains = createMockDomains(2);
+ useDomains.mockReturnValue([mockDomains, 2, mockMetadata]);
+
+ render();
+
+ expect(screen.getByTestId('collection-table')).toBeInTheDocument();
+ });
+
+ it('shows search input', () => {
+ const mockDomains = createMockDomains(1);
+ useDomains.mockReturnValue([mockDomains, 1, mockMetadata]);
+
+ render();
+
+ const searchInput = screen.getByTestId('search-input');
+ expect(searchInput).toBeInTheDocument();
+ expect(searchInput).toHaveAttribute('placeholder', 'Domain filter...');
+ });
+ });
+
+ describe('Domain Display', () => {
+ it('shows all domains from API', () => {
+ const mockDomains = createMockDomains(5);
+ useDomains.mockReturnValue([mockDomains, 5, mockMetadata]);
+
+ render();
+
+ // Should render all 5 domains
+ expect(screen.getByTestId('collection-count')).toHaveTextContent('5');
+
+ mockDomains.forEach((domain) => {
+ expect(screen.getByTestId(`domain-${domain.id}`)).toBeInTheDocument();
+ });
+ });
+
+ it('displays domain names', () => {
+ const mockDomains = [
+ {id: 1, domain: 'example1.com', archive_count: 10, size: 1000},
+ {id: 2, domain: 'example2.com', archive_count: 20, size: 2000},
+ {id: 3, domain: 'example3.com', archive_count: 30, size: 3000},
+ ];
+ useDomains.mockReturnValue([mockDomains, 3, mockMetadata]);
+
+ render();
+
+ expect(screen.getByTestId('domain-name-1')).toHaveTextContent('example1.com');
+ expect(screen.getByTestId('domain-name-2')).toHaveTextContent('example2.com');
+ expect(screen.getByTestId('domain-name-3')).toHaveTextContent('example3.com');
+ });
+
+ it('displays Edit buttons in Manage column', () => {
+ const mockDomains = createMockDomains(3);
+ useDomains.mockReturnValue([mockDomains, 3, mockMetadata]);
+
+ render();
+
+ // Each domain should have an Edit button
+ mockDomains.forEach((domain) => {
+ expect(screen.getByTestId(`edit-button-${domain.id}`)).toBeInTheDocument();
+ });
+ });
+
+ it('Edit button has correct styling', () => {
+ const mockDomains = createMockDomains(1);
+ useDomains.mockReturnValue([mockDomains, 1, mockMetadata]);
+
+ render();
+
+ const editButton = screen.getByTestId('edit-button-1');
+ expect(editButton).toHaveClass('ui');
+ expect(editButton).toHaveClass('mini');
+ expect(editButton).toHaveClass('primary');
+ expect(editButton).toHaveClass('button');
+ });
+ });
+
+ describe('Empty and Error States', () => {
+ it('shows "No items yet" message when no domains', () => {
+ // Empty array indicates no domains
+ useDomains.mockReturnValue([[], 0, mockMetadata]);
+
+ render();
+
+ // Should show empty state message
+ expect(screen.getByText(/no domains yet/i)).toBeInTheDocument();
+ expect(screen.getByText(/archive some webpages/i)).toBeInTheDocument();
+
+ // Should not show table
+ expect(screen.queryByTestId('collection-table')).not.toBeInTheDocument();
+ });
+
+ it('shows error message when fetch fails', () => {
+ // undefined indicates error state
+ useDomains.mockReturnValue([undefined, 0, mockMetadata]);
+
+ render();
+
+ // Should show error message
+ expect(screen.getByTestId('error-message')).toBeInTheDocument();
+ expect(screen.getByText(/could not fetch domains/i)).toBeInTheDocument();
+
+ // Should not show table
+ expect(screen.queryByTestId('collection-table')).not.toBeInTheDocument();
+ });
+
+ it('does not show "New Domain" button', () => {
+ const mockDomains = createMockDomains(2);
+ useDomains.mockReturnValue([mockDomains, 2, mockMetadata]);
+
+ render();
+
+ // Domains are auto-created, so there should be no "New" button
+ expect(screen.queryByRole('button', {name: /new/i})).not.toBeInTheDocument();
+ expect(screen.queryByRole('button', {name: /create/i})).not.toBeInTheDocument();
+ expect(screen.queryByRole('button', {name: /add/i})).not.toBeInTheDocument();
+ });
+ });
+
+ describe('Search Integration', () => {
+ it('disables search when no domains', () => {
+ useDomains.mockReturnValue([[], 0, mockMetadata]);
+
+ render();
+
+ const searchInput = screen.getByTestId('search-input');
+ expect(searchInput).toBeDisabled();
+ });
+
+ it('enables search when domains exist', () => {
+ const mockDomains = createMockDomains(3);
+ useDomains.mockReturnValue([mockDomains, 3, mockMetadata]);
+
+ render();
+
+ const searchInput = screen.getByTestId('search-input');
+ expect(searchInput).not.toBeDisabled();
+ });
+
+ it('passes search string to CollectionTable', () => {
+ const mockDomains = createMockDomains(2);
+ useDomains.mockReturnValue([mockDomains, 2, mockMetadata]);
+
+ const mockSetSearchStr = jest.fn();
+ useOneQuery.mockReturnValue(['example', mockSetSearchStr]);
+
+ render();
+
+ // Search string should be passed to table
+ expect(screen.getByTestId('search-filter')).toHaveTextContent('example');
+ });
+ });
+
+ describe('Page Title', () => {
+ it('sets page title correctly', () => {
+ const mockDomains = createMockDomains(1);
+ useDomains.mockReturnValue([mockDomains, 1, mockMetadata]);
+
+ render();
+
+ expect(useTitle).toHaveBeenCalledWith('Archive Domains');
+ });
+ });
+});
diff --git a/app/src/components/Download.test.js b/app/src/components/Download.test.js
new file mode 100644
index 000000000..6c09dbefa
--- /dev/null
+++ b/app/src/components/Download.test.js
@@ -0,0 +1,253 @@
+import React from 'react';
+import {render, screen, waitFor} from '@testing-library/react';
+import userEvent from '@testing-library/user-event';
+import {DestinationForm} from './Download';
+import {createTestForm} from '../test-utils';
+
+// Mock DirectorySearch component - simplified to avoid useState/useEffect warnings
+jest.mock('./Common', () => {
+ const React = require('react');
+ return {
+ ...jest.requireActual('./Common'),
+ DirectorySearch: ({value, onSelect, disabled, required, id}) => (
+
+ onSelect(e.target.value)}
+ disabled={disabled}
+ required={required}
+ id={id}
+ />
+
+ ),
+ RequiredAsterisk: () => React.createElement('span', null, '*'),
+ InfoPopup: ({content}) => React.createElement('span', {'data-testid': 'info-popup'}, content),
+ };
+});
+
+describe('DestinationForm', () => {
+ describe('Form Integration', () => {
+ it('renders DirectorySearch with form value', () => {
+ const form = createTestForm(
+ {destination: 'videos/test'},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveValue('videos/test');
+ });
+
+ it('calls form onChange when directory is selected', async () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+
+ // Type a new directory
+ await userEvent.type(input, 'archive/new');
+
+ // Form data should be updated
+ await waitFor(() => {
+ expect(form.formData.destination).toBe('archive/new');
+ });
+ });
+
+
+ it('displays required indicator when required=true', () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveAttribute('required');
+ });
+ });
+
+ describe('Props Handling', () => {
+ it('uses custom label when provided', () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ expect(screen.getByText(/custom folder/i)).toBeInTheDocument();
+ });
+
+ it('uses default label when not provided', () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ expect(screen.getByText(/destination/i)).toBeInTheDocument();
+ });
+
+ it('uses custom name/path when provided', () => {
+ const form = createTestForm(
+ {output_dir: 'videos/test'},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render(
+
+ );
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveValue('videos/test');
+ });
+
+ it('shows info popup when infoContent provided', () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render(
+
+ );
+
+ expect(screen.getByText(/this is helpful information/i)).toBeInTheDocument();
+ });
+ });
+
+ describe('useForm Integration', () => {
+ it('gets correct props from form.getCustomProps', () => {
+ const form = createTestForm(
+ {destination: 'videos/initial'},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ const getCustomPropsSpy = jest.spyOn(form, 'getCustomProps');
+
+ render();
+
+ expect(getCustomPropsSpy).toHaveBeenCalledWith({
+ name: 'destination',
+ path: 'destination',
+ required: true
+ });
+ });
+
+ it('updates form data on selection', async () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+
+ // Select a directory
+ await userEvent.type(input, 'videos/new-folder');
+
+ // Form should be updated
+ await waitFor(() => {
+ expect(form.formData.destination).toBe('videos/new-folder');
+ });
+ });
+
+ });
+
+ describe('Edge Cases', () => {
+
+ it('works with nested form paths', () => {
+ const form = createTestForm(
+ {config: {output: {destination: 'videos/nested'}}},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render(
+
+ );
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveValue('videos/nested');
+ });
+
+ it('handles concurrent field updates', async () => {
+ const form = createTestForm(
+ {destination: '', title: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+
+ // Simulate rapid updates
+ await userEvent.type(input, 'videos/a');
+ form.setValue('title', 'Test Title');
+ await userEvent.type(input, 'bc');
+
+ // Destination should have full value
+ await waitFor(() => {
+ expect(form.formData.destination).toBe('videos/abc');
+ });
+
+ // Title should also be set
+ expect(form.formData.title).toBe('Test Title');
+ });
+
+ it('handles empty string as initial value', () => {
+ const form = createTestForm(
+ {destination: ''},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveValue('');
+ });
+
+ it('handles null as initial value', () => {
+ const form = createTestForm(
+ {destination: null},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveValue('');
+ });
+
+ it('handles undefined as initial value', () => {
+ const form = createTestForm(
+ {},
+ {overrides: {ready: true, loading: false}}
+ );
+
+ render();
+
+ const input = screen.getByTestId('directory-search-input');
+ expect(input).toHaveValue('');
+ });
+ });
+});
diff --git a/app/src/components/Nav.js b/app/src/components/Nav.js
index b138ea432..5d4fa5462 100644
--- a/app/src/components/Nav.js
+++ b/app/src/components/Nav.js
@@ -186,9 +186,24 @@ export function NavBar() {
/>
}
+ // Upgrade available notification - only show on native (non-Docker) installs
+ let upgradeIcon;
+ if (status?.update_available && !status?.dockerized) {
+ const commitsBehind = status.commits_behind || 0;
+ const branch = status.git_branch || 'unknown';
+ const icon =
+
+ ;
+ upgradeIcon = ;
+ }
+
const icons =
{apiDownIcon}
{processingIcon}
+ {upgradeIcon}
{powerIcon}
{warningIcon}
diff --git a/app/src/components/Vars.js b/app/src/components/Vars.js
index 8a44c5dfd..d07b7b979 100644
--- a/app/src/components/Vars.js
+++ b/app/src/components/Vars.js
@@ -1,6 +1,7 @@
export const API_URI = process.env && process.env.REACT_APP_API_URI ? process.env.REACT_APP_API_URI : `https://${window.location.host}/api`;
export const VIDEOS_API = `${API_URI}/videos`;
export const ARCHIVES_API = `${API_URI}/archive`;
+export const COLLECTIONS_API = `${API_URI}/collections`;
export const OTP_API = `${API_URI}/otp`;
export const ZIM_API = `${API_URI}/zim`;
export const DEFAULT_LIMIT = 20;
diff --git a/app/src/components/admin/Settings.js b/app/src/components/admin/Settings.js
index e89f28902..a47633b13 100644
--- a/app/src/components/admin/Settings.js
+++ b/app/src/components/admin/Settings.js
@@ -1,5 +1,6 @@
import React from "react";
-import {postRestart, postShutdown} from "../../api";
+import {ThemeContext} from "../../contexts/contexts";
+import {checkUpgrade, postRestart, postShutdown, triggerUpgrade} from "../../api";
import {
Button,
Divider,
@@ -30,7 +31,7 @@ import QRCode from "react-qr-code";
import {useConfigs, useDockerized} from "../../hooks/customHooks";
import {toast} from "react-semantic-toasts-2";
import Grid from "semantic-ui-react/dist/commonjs/collections/Grid";
-import {SettingsContext} from "../../contexts/contexts";
+import {SettingsContext, StatusContext} from "../../contexts/contexts";
import {ConfigsTable} from "./Configs";
import {semanticUIColorMap} from "../Vars";
@@ -85,6 +86,89 @@ export function ShutdownButton() {
}
+function UpgradeSegment() {
+ const {status, fetchStatus} = React.useContext(StatusContext);
+ const dockerized = useDockerized();
+ const [upgrading, setUpgrading] = React.useState(false);
+ const [checking, setChecking] = React.useState(false);
+
+ const handleCheckUpgrade = async () => {
+ setChecking(true);
+ try {
+ await checkUpgrade(true); // Force a fresh check
+ toast({
+ type: 'info',
+ title: 'Update Check Complete',
+ description: 'Checked for updates from git remote.',
+ time: 3000,
+ });
+ } finally {
+ setChecking(false);
+ await fetchStatus();
+ }
+ };
+
+ const handleUpgrade = async () => {
+ setUpgrading(true);
+ try {
+ const response = await triggerUpgrade();
+ if (response.ok) {
+ // Redirect to maintenance page
+ window.location.href = '/maintenance.html';
+ }
+ } catch (e) {
+ setUpgrading(false);
+ }
+ };
+
+ // Not available in Docker
+ if (dockerized) {
+ return
+
+ Upgrades are not available in Docker environments. Please upgrade your Docker images manually.
+ ;
+ }
+
+ // No update available
+ if (!status?.update_available) {
+ return
+
+ Your WROLPi is up to date.
+ Version: v{status?.version} on branch {status?.git_branch || 'unknown'}
+
+ {checking ? 'Checking...' : 'Check for Updates'}
+
+ ;
+ }
+
+ // Update available
+ return
+
+ Branch: {status?.git_branch}
+ Current commit: {status?.current_commit}
+ Latest commit: {status?.latest_commit}
+ {status?.commits_behind} commit(s) behind
+
+
+ {upgrading ? 'Starting Upgrade...' : 'Upgrade Now'}
+
+ ;
+}
+
export function RestartButton() {
const dockerized = useDockerized();
@@ -529,6 +613,8 @@ export function SettingsPage() {
{configsSegment}
+
+
Show All Hints
diff --git a/app/src/components/collections/CollectionEditForm.js b/app/src/components/collections/CollectionEditForm.js
new file mode 100644
index 000000000..b6ffc0b6f
--- /dev/null
+++ b/app/src/components/collections/CollectionEditForm.js
@@ -0,0 +1,152 @@
+import React from 'react';
+import {Button, Form, Grid, Message, TextArea} from 'semantic-ui-react';
+import {Header, Segment} from '../Theme';
+import {TagsSelector, TagsContext} from '../../Tags';
+import {WROLModeMessage} from '../Common';
+import {DestinationForm} from '../Download';
+import {InputForm} from '../../hooks/useForm';
+
+/**
+ * Reusable form component for editing collections (Domains, Channels, etc).
+ *
+ * @param {Object} form - Form object from useForm hook
+ * @param {Object} metadata - Backend-provided metadata containing fields configuration
+ * @param {Function} onCancel - Optional callback when cancel is clicked
+ * @param {String} title - Page title to display in header
+ * @param {String} wrolModeContent - Content to show in WROL mode message (optional)
+ * @param {React.ReactNode} actionButtons - Optional additional action buttons to display in the button row
+ * @param {String} appliedTagName - Optional tag name to display (similar to ChannelEditPage pattern)
+ */
+export function CollectionEditForm({
+ form,
+ metadata,
+ onCancel,
+ title,
+ wrolModeContent,
+ actionButtons,
+ appliedTagName
+}) {
+ const {SingleTag} = React.useContext(TagsContext);
+ if (!metadata) {
+ return
+ No metadata available
+ ;
+ }
+
+ const handleSubmit = (e) => {
+ e.preventDefault();
+ form.onSubmit();
+ };
+
+ const renderField = (field) => {
+ const value = form.formData[field.key] || '';
+ const disabled = field.depends_on && !form.formData[field.depends_on];
+
+ switch (field.type) {
+ case 'text':
+ if (field.key === 'directory') {
+ // Use DestinationForm for directory picker
+ return ;
+ }
+ // Use InputForm for regular text fields
+ return ;
+
+ case 'textarea':
+ // Textarea doesn't have a form component, use manual Field
+ const [textareaProps] = form.getCustomProps({name: field.key, path: field.key, required: field.required});
+ return
+
+ ;
+
+ case 'tag':
+ // TagsSelector is custom, use manual Field with form props
+ const [tagProps] = form.getCustomProps({name: field.key, path: field.key, required: field.required});
+ return
+
+ {disabled &&
+ {metadata.messages?.no_directory || 'Set a directory to enable tagging'}
+ }
+ tagProps.onChange(tagNames[0] || null)}
+ single={true}
+ disabled={disabled || tagProps.disabled}
+ />
+ {!disabled && tagProps.value && metadata.messages?.tag_will_move &&
+ {metadata.messages.tag_will_move}
+ }
+ ;
+
+ default:
+ return null;
+ }
+ };
+
+ return
+ {title && }
+ {wrolModeContent && }
+
+
+ ;
+}
diff --git a/app/src/components/collections/CollectionEditForm.test.js b/app/src/components/collections/CollectionEditForm.test.js
new file mode 100644
index 000000000..1a2a90cac
--- /dev/null
+++ b/app/src/components/collections/CollectionEditForm.test.js
@@ -0,0 +1,347 @@
+import React from 'react';
+import {render, renderInDarkMode, renderInLightMode, hasInvertedStyling, screen, waitFor, createTestForm} from '../../test-utils';
+import {CollectionEditForm} from './CollectionEditForm';
+import {createMockMetadata, createMockDomain} from '../../test-utils';
+
+// Mock the TagsSelector component and TagsContext
+jest.mock('../../Tags', () => ({
+ TagsSelector: ({selectedTagNames, onChange, disabled}) => (
+
+ onChange(e.target.value ? [e.target.value] : [])}
+ disabled={disabled}
+ />
+
+ ),
+ TagsContext: {
+ _currentValue: {
+ SingleTag: ({name}) => {name}
+ }
+ },
+}));
+
+// Mock Common components
+jest.mock('../Common', () => ({
+ ...jest.requireActual('../Common'),
+ WROLModeMessage: () => ,
+}));
+
+// Mock DestinationForm (used for directory field)
+jest.mock('../Download', () => ({
+ DestinationForm: ({form, label, name}) => (
+
+
+ form.setValue(name, e.target.value)}
+ />
+
+ ),
+}));
+
+// Mock InputForm (used for text fields)
+jest.mock('../../hooks/useForm', () => ({
+ ...jest.requireActual('../../hooks/useForm'),
+ InputForm: ({form, label, name}) => (
+
+
+ form.setValue(name, e.target.value)}
+ />
+
+ ),
+}));
+
+describe('CollectionEditForm', () => {
+ const mockMetadata = createMockMetadata();
+ const mockCollection = createMockDomain();
+
+ describe('Form Rendering', () => {
+ it('renders all configured fields', () => {
+ const form = createTestForm(mockCollection);
+
+ render(
+
+ );
+
+ // Check that all fields from metadata are rendered
+ expect(screen.getByTestId('directory-search')).toBeInTheDocument();
+ expect(screen.getByTestId('tags-selector')).toBeInTheDocument();
+ expect(screen.getByPlaceholderText(/optional description/i)).toBeInTheDocument();
+
+ // Verify Save button exists
+ expect(screen.getByRole('button', {name: /save/i})).toBeInTheDocument();
+ });
+
+ it('loads initial values into form', () => {
+ const collectionWithData = createMockDomain({
+ description: 'Test description',
+ directory: 'archive/example.com'
+ });
+
+ const form = createTestForm(collectionWithData);
+
+ render(
+
+ );
+
+ // Check description is loaded
+ expect(screen.getByPlaceholderText(/optional description/i)).toHaveValue('Test description');
+ // Check directory field is rendered with the value
+ const directoryField = screen.getByTestId('directory-search');
+ expect(directoryField).toBeInTheDocument();
+ // The mocked DestinationForm renders an input inside the container
+ const directoryInput = directoryField.querySelector('input');
+ expect(directoryInput).toHaveValue('archive/example.com');
+ });
+
+ it('renders without errors when form is provided', () => {
+ const form = createTestForm(mockCollection);
+
+ render(
+
+ );
+
+ expect(screen.getByRole('button', {name: /save/i})).toBeInTheDocument();
+ });
+ });
+
+ describe('Field Types', () => {
+ it('renders textarea for description field', () => {
+ const form = createTestForm(mockCollection);
+
+ render(
+
+ );
+
+ const descriptionField = screen.getByPlaceholderText(/optional description/i);
+ expect(descriptionField.tagName).toBe('TEXTAREA');
+ });
+
+ it('renders directory field using DestinationForm', () => {
+ const form = createTestForm(mockCollection);
+
+ render(
+
+ );
+
+ expect(screen.getByTestId('directory-search')).toBeInTheDocument();
+ });
+
+ it('renders tag selector for tag_name field', () => {
+ const form = createTestForm(mockCollection);
+
+ render(
+
+ );
+
+ expect(screen.getByTestId('tags-selector')).toBeInTheDocument();
+ });
+ });
+
+ describe('Field Dependencies (tag requires directory)', () => {
+ it('disables tag field when directory is empty', () => {
+ const collectionWithoutDirectory = createMockDomain({
+ directory: '',
+ can_be_tagged: false
+ });
+
+ const form = createTestForm(collectionWithoutDirectory);
+
+ render(
+
+ );
+
+ // Should show dependency warning
+ expect(screen.getByText(/set a directory to enable tagging/i)).toBeInTheDocument();
+
+ // Tag selector should be disabled
+ const tagSelector = screen.getByTestId('tags-selector');
+ expect(tagSelector).toHaveAttribute('data-disabled', 'true');
+ });
+
+ it('enables tag field when directory is set', () => {
+ const collectionWithDirectory = createMockDomain({
+ directory: 'archive/example.com',
+ can_be_tagged: true
+ });
+
+ const form = createTestForm(collectionWithDirectory);
+
+ render(
+
+ );
+
+ // Dependency warning should not be shown
+ expect(screen.queryByText(/set a directory to enable tagging/i)).not.toBeInTheDocument();
+
+ // Tag selector should not be disabled
+ const tagSelector = screen.getByTestId('tags-selector');
+ expect(tagSelector).toHaveAttribute('data-disabled', 'false');
+ });
+
+ it('shows warning when tag is set with directory', () => {
+ const collectionWithDirectoryAndTag = createMockDomain({
+ directory: 'archive/example.com',
+ tag_name: 'News',
+ can_be_tagged: true
+ });
+
+ const form = createTestForm(collectionWithDirectoryAndTag);
+
+ render(
+
+ );
+
+ // Should show tag warning
+ expect(screen.getByText(/tagging will move files/i)).toBeInTheDocument();
+ });
+ });
+
+ describe('Save/Cancel Actions', () => {
+ it('shows Cancel button when onCancel provided', () => {
+ const form = createTestForm(mockCollection);
+ const mockOnCancel = jest.fn();
+
+ render(
+
+ );
+
+ expect(screen.getByRole('button', {name: /cancel/i})).toBeInTheDocument();
+ });
+
+ it('disables form during loading', () => {
+ const form = createTestForm(mockCollection, {
+ overrides: {loading: true, disabled: true}
+ });
+
+ render(
+
+ );
+
+ // Save button should be disabled during loading
+ const saveButton = screen.getByRole('button', {name: /save/i});
+ expect(saveButton).toBeDisabled();
+ });
+ });
+
+ describe('Error States', () => {
+ it('handles missing metadata gracefully', () => {
+ const form = createTestForm(mockCollection);
+
+ render(
+
+ );
+
+ // Should show warning about missing metadata
+ expect(screen.getByText(/no metadata available/i)).toBeInTheDocument();
+ });
+ });
+
+ describe('Theme Integration', () => {
+ it('applies inverted styling to Segment in dark mode', () => {
+ const form = createTestForm(mockCollection);
+
+ const {container} = renderInDarkMode(
+
+ );
+
+ // Segment should have inverted class in dark mode
+ const segment = container.querySelector('.ui.segment');
+ expect(segment).toBeInTheDocument();
+ expect(hasInvertedStyling(segment)).toBe(true);
+ });
+
+ it('does not apply inverted styling in light mode', () => {
+ const form = createTestForm(mockCollection);
+
+ const {container} = renderInLightMode(
+
+ );
+
+ // Segment should NOT have inverted class in light mode
+ const segment = container.querySelector('.ui.segment');
+ expect(segment).toBeInTheDocument();
+ expect(hasInvertedStyling(segment)).toBe(false);
+ });
+
+ it('applies dark theme styling to Header in dark mode', () => {
+ const form = createTestForm(mockCollection);
+
+ const {container} = renderInDarkMode(
+
+ );
+
+ // Header should have dark text color inline style in dark mode
+ const header = container.querySelector('.ui.header');
+ expect(header).toBeInTheDocument();
+ expect(header.style.color).toBe('rgb(238, 238, 238)'); // #eeeeee
+ });
+
+ it('uses theme context from provider', () => {
+ const form = createTestForm(mockCollection);
+
+ const {container} = render(
+
+ );
+
+ const segment = container.querySelector('.ui.segment');
+ expect(segment).toBeInTheDocument();
+ // Should not be inverted by default
+ expect(hasInvertedStyling(segment)).toBe(false);
+ });
+ });
+});
diff --git a/app/src/components/collections/CollectionTable.js b/app/src/components/collections/CollectionTable.js
new file mode 100644
index 000000000..129316913
--- /dev/null
+++ b/app/src/components/collections/CollectionTable.js
@@ -0,0 +1,279 @@
+import React, {useContext} from 'react';
+import {Link} from 'react-router-dom';
+import {Message, Placeholder, PlaceholderHeader, PlaceholderLine, Table, TableCell, TableRow} from 'semantic-ui-react';
+import _ from 'lodash';
+import {SortableTable} from '../SortableTable';
+import {humanFileSize} from '../Common';
+import {ThemeContext, Media} from '../../contexts/contexts';
+import {TagsContext} from '../../Tags';
+import {allFrequencyOptions} from '../Vars';
+
+/**
+ * Format a frequency value (in seconds) to a human-readable string using allFrequencyOptions.
+ * @param {number|null} frequency - Frequency in seconds
+ * @returns {string} Human-readable frequency text or '-' if null/undefined
+ */
+function formatFrequency(frequency) {
+ if (frequency === null || frequency === undefined) {
+ return '-';
+ }
+ const option = allFrequencyOptions[frequency];
+ return option ? option.text : `${frequency}s`;
+}
+
+/**
+ * Generate a search link for a collection based on metadata routing configuration.
+ * Supports both query parameter-based (e.g., ?domain=...) and route-based (e.g., /channel/:id/video) linking.
+ *
+ * @param {Object} collection - The collection object
+ * @param {Object} metadata - Collection metadata with routes configuration
+ * @param {string} primaryKey - The primary column key to use for the value
+ * @returns {string|null} The generated link, or null if no link can be generated
+ */
+function getCollectionSearchLink(collection, metadata, primaryKey) {
+ const searchRoute = metadata.routes?.search;
+ if (!searchRoute) {
+ return null;
+ }
+
+ // Check if metadata specifies a query parameter to use
+ if (metadata.routes.searchParam) {
+ // Query parameter-based linking (e.g., /archive?domain=example.com)
+ return `${searchRoute}?${metadata.routes.searchParam}=${collection[primaryKey]}`;
+ } else if (searchRoute.includes(':id')) {
+ // Route parameter-based linking (e.g., /videos/channel/123/video)
+ // Use id_field if specified (e.g., channel_id for channels), otherwise use id
+ const idField = metadata.routes.id_field || 'id';
+ return searchRoute.replace(':id', collection[idField]);
+ }
+
+ // No linking strategy available
+ return null;
+}
+
+/**
+ * Mobile row component for collections
+ */
+function MobileCollectionRow({collection, metadata}) {
+ const {SingleTag} = useContext(TagsContext);
+ const primaryColumn = metadata.columns[0];
+ // Use id_field if specified (e.g., channel_id for channels), otherwise use id
+ const idField = metadata.routes?.id_field || 'id';
+ const editRoute = metadata.routes?.edit?.replace(':id', collection[idField]);
+ const searchLink = getCollectionSearchLink(collection, metadata, primaryColumn.key);
+
+ return
+
+ {searchLink ? (
+
+
+ {collection[primaryColumn.key]}
+
+ {collection.tag_name && }
+
+ ) : (
+ <>
+ {collection[primaryColumn.key]}
+ {collection.tag_name && }
+ >
+ )}
+ {metadata.columns
+ .filter(col => col.type !== 'actions' && col.key !== primaryColumn.key && col.key !== 'tag_name')
+ .map(col => {
+ let value = collection[col.key];
+ if (col.format === 'bytes') {
+ value = humanFileSize(value);
+ } else if (col.format === 'frequency') {
+ value = formatFrequency(value);
+ }
+ return (
+
+ {col.label}: {value || '-'}
+
+ );
+ })
+ }
+
+
+
+ {editRoute && Edit}
+
+
+ ;
+}
+
+/**
+ * Reusable table component for displaying collections (Domains, Channels, etc).
+ *
+ * @param {Array} collections - Array of collection objects
+ * @param {Object} metadata - Backend-provided metadata containing columns, routes, etc.
+ * @param {String} searchStr - Search filter string (managed by parent)
+ * @param {Function} onRowClick - Optional callback when a row is clicked
+ * @param {String} emptyMessage - Message to display when there are no collections
+ */
+export function CollectionTable({collections, metadata, searchStr = '', onRowClick, emptyMessage = 'No items yet'}) {
+ const {inverted} = useContext(ThemeContext);
+ const {SingleTag} = useContext(TagsContext);
+
+ // Loading state
+ if (collections === null) {
+ return
+
+
+
+
+ ;
+ }
+
+ // Error state
+ if (collections === undefined) {
+ return
+ Could not fetch collections
+ ;
+ }
+
+ // Empty state
+ if (collections && collections.length === 0) {
+ return
+ {emptyMessage}
+ ;
+ }
+
+ // No metadata available (backward compatibility)
+ if (!metadata) {
+ return
+ No metadata available
+ ;
+ }
+
+ // Filter collections by search string
+ let filteredCollections = collections;
+ if (searchStr) {
+ const re = new RegExp(_.escapeRegExp(searchStr), 'i');
+ filteredCollections = collections.filter(collection => {
+ // Search across all string fields
+ return Object.values(collection).some(value => {
+ if (typeof value === 'string') {
+ return re.test(value);
+ }
+ return false;
+ });
+ });
+ }
+
+ // Build row renderer based on metadata
+ const renderRow = (collection) => {
+ const cells = metadata.columns.map(col => {
+ let value = collection[col.key];
+
+ // Handle actions column - render action buttons
+ if (col.type === 'actions') {
+ // Use id_field if specified (e.g., channel_id for channels), otherwise use id
+ const idField = metadata.routes?.id_field || 'id';
+ const editRoute = metadata.routes?.edit?.replace(':id', collection[idField]);
+ const buttonClass = `ui button secondary ${inverted}`;
+ return
+ {editRoute && Edit}
+ ;
+ }
+
+ // Format the value based on column configuration
+ if (col.format === 'bytes') {
+ value = humanFileSize(value);
+ } else if (col.format === 'frequency') {
+ value = formatFrequency(value);
+ }
+
+ // Special handling for tag_name column - render SingleTag component
+ if (col.key === 'tag_name' && value) {
+ value = ;
+ }
+
+ // Special handling for the primary column (usually domain/name)
+ if (col.key === metadata.columns[0].key) {
+ const searchLink = getCollectionSearchLink(collection, metadata, col.key);
+ if (searchLink) {
+ value = {value};
+ }
+ }
+
+ return
+ {value || '-'}
+ ;
+ });
+
+ return onRowClick && onRowClick(collection)}
+ style={onRowClick ? {cursor: 'pointer'} : {}}
+ >
+ {cells}
+ ;
+ };
+
+ // Build table headers from metadata (desktop)
+ const headers = metadata.columns.map((col, index) => {
+ // Determine width based on column position and type
+ // Widths use Semantic UI's 16-column grid system and should sum to 16
+ let width = null;
+ if (index === 0) {
+ // First column (name/domain) - gets most space to expand
+ // For 6 columns: 16 - (2+2+2+2+1) = 7
+ width = 7;
+ } else if (col.type === 'actions') {
+ // Actions column - minimum width to shrink to fit button
+ width = 1;
+ } else {
+ // Other columns (tags, counts, frequency, sizes)
+ width = 2;
+ }
+
+ return {
+ key: col.key,
+ text: col.label,
+ sortBy: col.sortable ? col.key : null,
+ width: width,
+ };
+ });
+
+ // Build mobile headers - simplified columns
+ const primaryColumn = metadata.columns[0];
+ const mobileHeaders = [
+ {
+ key: primaryColumn.key,
+ text: primaryColumn.label,
+ sortBy: primaryColumn.sortable ? primaryColumn.key : null
+ },
+ {
+ key: 'manage',
+ text: 'Manage'
+ }
+ ];
+
+ // Get default sort column from metadata (first column key)
+ const defaultSortColumn = metadata.columns[0]?.key || 'id';
+
+
+ return <>
+
+ }
+ rowKey='id'
+ tableHeaders={mobileHeaders}
+ defaultSortColumn={defaultSortColumn}
+ />
+
+
+ renderRow(collection)}
+ rowKey='id'
+ tableHeaders={headers}
+ defaultSortColumn={defaultSortColumn}
+ />
+
+ >;
+}
diff --git a/app/src/components/collections/CollectionTagModal.js b/app/src/components/collections/CollectionTagModal.js
new file mode 100644
index 000000000..e1e37108b
--- /dev/null
+++ b/app/src/components/collections/CollectionTagModal.js
@@ -0,0 +1,156 @@
+import React, {useState, useEffect, useCallback} from 'react';
+import {Grid, Input} from 'semantic-ui-react';
+import Message from 'semantic-ui-react/dist/commonjs/collections/Message';
+import {Button, Modal, ModalActions, ModalContent, ModalHeader} from '../Theme';
+import {TagsSelector} from '../../Tags';
+import {APIButton, Toggle} from '../Common';
+
+/**
+ * Reusable modal component for tagging collections (Domains, Channels, etc).
+ * Handles tag selection, directory suggestions, conflict warnings, and state reset on close.
+ *
+ * @param {boolean} open - Whether the modal is open
+ * @param {Function} onClose - Callback when modal is closed
+ * @param {string} currentTagName - The current tag name of the collection (if any)
+ * @param {string} originalDirectory - The original directory of the collection (for reset)
+ * @param {Function} getTagInfo - Async function to fetch tag info: (tagName) => Promise<{suggested_directory, conflict, conflict_message}>
+ * @param {Function} onSave - Async function called when saving: (tagName, directory) => Promise
+ * @param {string} collectionName - Name of the collection for toast messages (e.g., "Domain", "Channel")
+ */
+export function CollectionTagModal({
+ open,
+ onClose,
+ currentTagName,
+ originalDirectory,
+ getTagInfo,
+ onSave,
+ collectionName = 'Collection',
+}) {
+ const [newTagName, setNewTagName] = useState(currentTagName || null);
+ const [moveToTagDirectory, setMoveToTagDirectory] = useState(true);
+ const [newTagDirectory, setNewTagDirectory] = useState(originalDirectory || '');
+ const [conflictMessage, setConflictMessage] = useState(null);
+
+ // Reset state when modal opens or collection changes
+ useEffect(() => {
+ if (open) {
+ setNewTagName(currentTagName || null);
+ setNewTagDirectory(originalDirectory || '');
+ setConflictMessage(null);
+ }
+ }, [open, currentTagName, originalDirectory]);
+
+ // Handle modal close - reset to original values
+ const handleClose = useCallback(() => {
+ setNewTagName(currentTagName || null);
+ setNewTagDirectory(originalDirectory || '');
+ setConflictMessage(null);
+ onClose();
+ }, [currentTagName, originalDirectory, onClose]);
+
+ // Handle tag selection change - fetch tag info
+ const handleTagChange = useCallback(async (tagName) => {
+ setNewTagName(tagName);
+ setConflictMessage(null);
+
+ // Always fetch tag info when tag changes (including removal)
+ try {
+ const tagInfo = await getTagInfo(tagName);
+ if (tagInfo) {
+ // Handle both response formats:
+ // - Object format: {suggested_directory, conflict, conflict_message}
+ // - String format: just the directory path (legacy channel API)
+ if (typeof tagInfo === 'string') {
+ // Legacy format - just a directory string
+ setNewTagDirectory(tagInfo);
+ } else if (tagInfo.suggested_directory) {
+ setNewTagDirectory(tagInfo.suggested_directory);
+ } else if (!tagName) {
+ // If tag removed and no suggestion, reset to original
+ setNewTagDirectory(originalDirectory || '');
+ }
+
+ // Handle conflict message (only in object format)
+ if (typeof tagInfo === 'object' && tagInfo.conflict && tagInfo.conflict_message) {
+ setConflictMessage(tagInfo.conflict_message);
+ }
+ } else if (!tagName) {
+ // If no tag info returned and tag was removed, reset to original
+ setNewTagDirectory(originalDirectory || '');
+ }
+ } catch (e) {
+ console.error(`Failed to get tag info for ${collectionName}`, e);
+ // Don't block the UI if tag info fails
+ }
+ }, [getTagInfo, originalDirectory, collectionName]);
+
+ // Handle save
+ const handleSave = useCallback(async () => {
+ await onSave(newTagName, moveToTagDirectory ? newTagDirectory : null);
+ handleClose();
+ }, [newTagName, moveToTagDirectory, newTagDirectory, onSave, handleClose]);
+
+ const modalTitle = currentTagName ? 'Modify Tag' : 'Add Tag';
+ const saveButtonText = moveToTagDirectory ? 'Move' : 'Save';
+
+ return (
+
+ {modalTitle}
+
+
+
+
+ handleTagChange(null)}
+ />
+
+
+
+
+
+
+
+
+
+ setNewTagDirectory(value)}
+ disabled={!moveToTagDirectory}
+ />
+
+
+ {conflictMessage && (
+
+
+
+ Directory Conflict
+ {conflictMessage}
+
+
+
+ )}
+
+
+
+
+ {saveButtonText}
+
+
+ );
+}
diff --git a/app/src/components/collections/CollectionTagModal.test.js b/app/src/components/collections/CollectionTagModal.test.js
new file mode 100644
index 000000000..a44a961ac
--- /dev/null
+++ b/app/src/components/collections/CollectionTagModal.test.js
@@ -0,0 +1,442 @@
+import React from 'react';
+import {render, screen, waitFor, act} from '../../test-utils';
+import userEvent from '@testing-library/user-event';
+import {CollectionTagModal} from './CollectionTagModal';
+
+// Mock Theme components
+jest.mock('../Theme', () => ({
+ ...jest.requireActual('../Theme'),
+ Modal: ({open, onClose, children, closeIcon}) => {
+ if (!open) return null;
+ return (
+
+ {closeIcon && }
+ {children}
+
+ );
+ },
+ ModalHeader: ({children}) => {children}
,
+ ModalContent: ({children}) => {children}
,
+ ModalActions: ({children}) => {children}
,
+ Button: ({children, onClick, ...props}) => (
+
+ ),
+}));
+
+// Mock Common components
+jest.mock('../Common', () => ({
+ ...jest.requireActual('../Common'),
+ Toggle: ({label, checked, onChange}) => (
+
+ ),
+ APIButton: ({children, onClick, ...props}) => (
+
+ ),
+}));
+
+// Mock TagsSelector
+jest.mock('../../Tags', () => ({
+ TagsSelector: ({selectedTagNames, onAdd, onRemove, limit}) => (
+
+ {selectedTagNames?.join(', ') || 'none'}
+
+
+
+ ),
+}));
+
+describe('CollectionTagModal', () => {
+ const defaultProps = {
+ open: true,
+ onClose: jest.fn(),
+ currentTagName: null,
+ originalDirectory: '/original/directory',
+ getTagInfo: jest.fn(),
+ onSave: jest.fn(),
+ collectionName: 'Test',
+ };
+
+ beforeEach(() => {
+ jest.clearAllMocks();
+ });
+
+ describe('Modal Rendering', () => {
+ it('renders modal when open is true', () => {
+ render();
+ expect(screen.getByTestId('modal')).toBeInTheDocument();
+ });
+
+ it('does not render modal when open is false', () => {
+ render();
+ expect(screen.queryByTestId('modal')).not.toBeInTheDocument();
+ });
+
+ it('shows "Add Tag" header when no current tag', () => {
+ render();
+ expect(screen.getByTestId('modal-header')).toHaveTextContent('Add Tag');
+ });
+
+ it('shows "Modify Tag" header when there is a current tag', () => {
+ render();
+ expect(screen.getByTestId('modal-header')).toHaveTextContent('Modify Tag');
+ });
+ });
+
+ describe('Directory Input', () => {
+ it('displays original directory in input field', () => {
+ render();
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/original/directory');
+ });
+
+ it('updates directory when user types', async () => {
+ render();
+ const input = screen.getByRole('textbox');
+ await userEvent.clear(input);
+ await userEvent.type(input, '/new/directory');
+ expect(input).toHaveValue('/new/directory');
+ });
+
+ it('disables directory input when move toggle is off', async () => {
+ render();
+ const toggleCheckbox = screen.getByTestId('toggle-checkbox');
+ await userEvent.click(toggleCheckbox); // Turn off move to directory
+ const input = screen.getByRole('textbox');
+ expect(input).toBeDisabled();
+ });
+ });
+
+ describe('Tag Info Fetching', () => {
+ it('fetches tag info when tag is added', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/suggested/directory',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ render();
+
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ // Wait for all async state updates to complete by checking the final result
+ await waitFor(() => {
+ expect(mockGetTagInfo).toHaveBeenCalledWith('test-tag');
+ // Also verify the directory was updated (this ensures async call completed)
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/suggested/directory');
+ });
+ });
+
+ it('updates directory with suggested directory from tag info', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/suggested/directory',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ render();
+
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ await waitFor(() => {
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/suggested/directory');
+ });
+ });
+
+ it('handles legacy string format from tag info', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue('/legacy/directory');
+
+ render();
+
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ await waitFor(() => {
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/legacy/directory');
+ });
+ });
+
+ it('fetches tag info when tag is removed', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue(null);
+
+ render();
+
+ const removeTagButton = screen.getByTestId('remove-tag');
+ await userEvent.click(removeTagButton);
+
+ // Wait for all async state updates to complete
+ await waitFor(() => {
+ expect(mockGetTagInfo).toHaveBeenCalledWith(null);
+ // Also verify the directory was reset (ensures async call completed)
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/original/directory');
+ });
+ });
+
+ it('updates directory when tag is removed and backend returns suggestion', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/untagged/directory',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ render();
+
+ const removeTagButton = screen.getByTestId('remove-tag');
+ await userEvent.click(removeTagButton);
+
+ await waitFor(() => {
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/untagged/directory');
+ });
+ });
+
+ it('resets to original directory when tag is removed and no suggestion', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue(null);
+
+ render();
+
+ // First add a tag - wait for async to complete including state updates
+ const addTagButton = screen.getByTestId('add-tag');
+ await act(async () => {
+ await userEvent.click(addTagButton);
+ });
+
+ // Wait for first getTagInfo call to complete AND directory state to settle
+ // When getTagInfo returns null, the component sets directory to originalDirectory
+ await waitFor(() => {
+ expect(mockGetTagInfo).toHaveBeenCalledWith('test-tag');
+ const input = screen.getByRole('textbox');
+ // After null response, directory is reset to original
+ expect(input).toHaveValue('/original/path');
+ });
+
+ // Then remove the tag - wrap in act to catch async state updates
+ const removeTagButton = screen.getByTestId('remove-tag');
+ await act(async () => {
+ await userEvent.click(removeTagButton);
+ });
+
+ // Verify all state updates completed
+ expect(mockGetTagInfo).toHaveBeenCalledWith(null);
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/original/path');
+ });
+ });
+
+ describe('Conflict Handling', () => {
+ it('displays conflict warning when tag info indicates conflict', async () => {
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/conflict/directory',
+ conflict: true,
+ conflict_message: 'Directory already in use',
+ });
+
+ render();
+
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ // Wait for all async state updates including conflict message and directory
+ await waitFor(() => {
+ expect(screen.getByText('Directory Conflict')).toBeInTheDocument();
+ expect(screen.getByText('Directory already in use')).toBeInTheDocument();
+ // Also check directory was updated (ensures all state updates complete)
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/conflict/directory');
+ });
+ });
+
+ it('clears conflict warning when tag is changed', async () => {
+ const mockGetTagInfo = jest.fn()
+ .mockResolvedValueOnce({
+ suggested_directory: '/conflict/directory',
+ conflict: true,
+ conflict_message: 'Directory already in use',
+ })
+ .mockResolvedValueOnce({
+ suggested_directory: '/no-conflict/directory',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ render();
+
+ // Add first tag - shows conflict
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ // Wait for first async call to complete (conflict message AND directory update)
+ await waitFor(() => {
+ expect(screen.getByText('Directory Conflict')).toBeInTheDocument();
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/conflict/directory');
+ });
+
+ // Add different tag - conflict should clear
+ await userEvent.click(addTagButton);
+
+ // Wait for second async call to complete (no conflict AND new directory)
+ await waitFor(() => {
+ expect(screen.queryByText('Directory Conflict')).not.toBeInTheDocument();
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/no-conflict/directory');
+ });
+ });
+ });
+
+ describe('Modal Close Behavior', () => {
+ it('calls onClose when cancel button is clicked', async () => {
+ const mockOnClose = jest.fn();
+ render();
+
+ const cancelButton = screen.getByTestId('button-Cancel');
+ await userEvent.click(cancelButton);
+
+ expect(mockOnClose).toHaveBeenCalled();
+ });
+
+ it('resets directory to original value when modal is closed', async () => {
+ const mockOnClose = jest.fn();
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/changed/directory',
+ conflict: false,
+ conflict_message: null,
+ });
+
+ const {rerender} = render(
+
+ );
+
+ // Change directory by adding a tag
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ await waitFor(() => {
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/changed/directory');
+ });
+
+ // Close the modal
+ const cancelButton = screen.getByTestId('button-Cancel');
+ await userEvent.click(cancelButton);
+
+ // Reopen modal
+ rerender(
+
+ );
+
+ // Directory should be reset to original
+ await waitFor(() => {
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/original/directory');
+ });
+ });
+ });
+
+ describe('Save Behavior', () => {
+ it('calls onSave with tag name and directory when move toggle is on', async () => {
+ const mockOnClose = jest.fn();
+ // Use mockResolvedValue to make onSave async - this helps React batch state updates properly
+ const mockOnSave = jest.fn().mockResolvedValue(undefined);
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/suggested/directory',
+ conflict: false,
+ });
+
+ render(
+
+ );
+
+ // Add a tag
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ // Wait for the directory to be updated from getTagInfo
+ await waitFor(() => {
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/suggested/directory');
+ });
+
+ // Click save/move button and wait for all async operations to complete
+ await act(async () => {
+ await userEvent.click(screen.getByTestId('api-button-Move'));
+ });
+
+ // Verify the save was called
+ expect(mockOnSave).toHaveBeenCalledWith('test-tag', '/suggested/directory');
+ expect(mockOnClose).toHaveBeenCalled();
+ });
+
+ it('calls onSave with null directory when move toggle is off', async () => {
+ const mockOnClose = jest.fn();
+ // Use mockResolvedValue to make onSave async - this helps React batch state updates properly
+ const mockOnSave = jest.fn().mockResolvedValue(undefined);
+ const mockGetTagInfo = jest.fn().mockResolvedValue({
+ suggested_directory: '/suggested/directory',
+ conflict: false,
+ });
+
+ render(
+
+ );
+
+ // Add a tag
+ const addTagButton = screen.getByTestId('add-tag');
+ await userEvent.click(addTagButton);
+
+ // Wait for all async state updates to complete
+ await waitFor(() => {
+ expect(mockGetTagInfo).toHaveBeenCalled();
+ const input = screen.getByRole('textbox');
+ expect(input).toHaveValue('/suggested/directory');
+ });
+
+ // Turn off move toggle
+ const toggleCheckbox = screen.getByTestId('toggle-checkbox');
+ await userEvent.click(toggleCheckbox);
+
+ // Click save button and wait for all async operations to complete
+ await act(async () => {
+ await userEvent.click(screen.getByTestId('api-button-Save'));
+ });
+
+ // Verify the save was called
+ expect(mockOnSave).toHaveBeenCalledWith('test-tag', null);
+ expect(mockOnClose).toHaveBeenCalled();
+ });
+ });
+});
diff --git a/app/src/hooks/customHooks.js b/app/src/hooks/customHooks.js
index 98ae2cef4..26b815a9c 100644
--- a/app/src/hooks/customHooks.js
+++ b/app/src/hooks/customHooks.js
@@ -2,6 +2,7 @@ import React, {useContext, useEffect, useRef, useState} from "react";
import {
ApiDownError,
createChannel,
+ fetchChannels,
fetchDecoded,
fetchDomains,
fetchFilesProgress,
@@ -10,6 +11,7 @@ import {
getChannel,
getChannels,
getConfigs,
+ getDomain,
getDownloads,
getFiles,
getInventory,
@@ -32,6 +34,7 @@ import {
setHotspot,
setThrottle,
updateChannel,
+ updateDomain,
} from "../api";
import {createSearchParams, useLocation, useSearchParams} from "react-router-dom";
import {enumerate, filterToMimetypes, humanFileSize, secondsToFullDuration} from "../components/Common";
@@ -191,17 +194,21 @@ export const useAllQuery = (name) => {
export const useDomains = () => {
const [domains, setDomains] = useState(null);
const [total, setTotal] = useState(null);
+ const [metadata, setMetadata] = useState(null);
const _fetchDomains = async () => {
setDomains(null);
setTotal(0);
+ setMetadata(null);
try {
- let [domains, total] = await fetchDomains();
+ let [domains, total, metadata] = await fetchDomains();
setDomains(domains);
setTotal(total);
+ setMetadata(metadata);
} catch (e) {
setDomains(undefined); // Display error.
setTotal(0);
+ setMetadata(undefined);
}
}
@@ -209,7 +216,86 @@ export const useDomains = () => {
_fetchDomains();
}, []);
- return [domains, total];
+ return [domains, total, metadata];
+}
+
+export const useCollectionMetadata = (kind) => {
+ const [metadata, setMetadata] = useState(null);
+ const [loading, setLoading] = useState(false);
+ const [error, setError] = useState(null);
+
+ const fetchMetadata = async () => {
+ setLoading(true);
+ setError(null);
+ try {
+ if (kind === 'domain') {
+ const [, , metadata] = await fetchDomains();
+ setMetadata(metadata);
+ } else {
+ // Future: Support channels metadata
+ setMetadata(null);
+ }
+ } catch (e) {
+ console.error(e);
+ setError(e);
+ setMetadata(null);
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ useEffect(() => {
+ fetchMetadata();
+ }, [kind]);
+
+ return {metadata, loading, error, fetchMetadata};
+}
+
+export const useDomain = (domainId) => {
+ const emptyDomain = {
+ domain: '',
+ directory: '',
+ description: '',
+ tag_name: null,
+ };
+
+ const fetchDomain = async () => {
+ if (!domainId) {
+ console.debug('Not fetching domain because no domainId is provided');
+ return;
+ }
+ const d = await getDomain(domainId);
+ // Prevent controlled to uncontrolled
+ d.directory = d.directory || '';
+ d.description = d.description || '';
+ d.tag_name = d.tag_name || null;
+ return d;
+ };
+
+ const submitDomain = async () => {
+ const body = {
+ directory: form.formData.directory,
+ description: form.formData.description,
+ // Convert null to empty string - backend expects "" to clear tag
+ tag_name: form.formData.tag_name === null ? '' : form.formData.tag_name,
+ };
+ return await updateDomain(domainId, body);
+ };
+
+ const form = useForm({
+ fetcher: fetchDomain,
+ emptyFormData: emptyDomain,
+ clearOnSuccess: false,
+ submitter: submitDomain
+ });
+
+ React.useEffect(() => {
+ form.fetcher();
+ }, [domainId]);
+
+ const domain = form.formData;
+
+ return {domain, form, fetchDomain: form.fetcher};
}
export const useArchive = (archiveId) => {
@@ -522,23 +608,31 @@ export const useChannel = (channel_id) => {
export const useChannels = () => {
const [channels, setChannels] = useState(null);
+ const [total, setTotal] = useState(null);
+ const [metadata, setMetadata] = useState(null);
- const fetchChannels = async () => {
+ const _fetchChannels = async () => {
setChannels(null);
+ setTotal(0);
+ setMetadata(null);
try {
- const c = await getChannels();
- setChannels(c);
+ let [channels, total, metadata] = await fetchChannels();
+ setChannels(channels);
+ setTotal(total);
+ setMetadata(metadata);
} catch (e) {
console.error(e);
- setChannels(undefined); // Could not get Channels, display error.
+ setChannels(undefined); // Display error.
+ setTotal(0);
+ setMetadata(undefined);
}
}
useEffect(() => {
- fetchChannels();
+ _fetchChannels();
}, []);
- return {channels, fetchChannels}
+ return [channels, total, metadata];
}
export const useSearchRecentFiles = () => {
diff --git a/app/src/hooks/useForm.js b/app/src/hooks/useForm.js
index 1a7463e11..986ecc70b 100644
--- a/app/src/hooks/useForm.js
+++ b/app/src/hooks/useForm.js
@@ -113,10 +113,12 @@ export function useForm({
reset();
}
await onSuccess();
+ } catch (error) {
+ await onFailure(error);
+ throw error;
} finally {
setLoading(false);
setDisabled(false);
- await onFailure();
}
}
diff --git a/app/src/hooks/useForm.test.js b/app/src/hooks/useForm.test.js
new file mode 100644
index 000000000..fb0bfa10b
--- /dev/null
+++ b/app/src/hooks/useForm.test.js
@@ -0,0 +1,998 @@
+import React from 'react';
+import {renderHook, act, waitFor} from '@testing-library/react';
+import {render, screen, fireEvent} from '@testing-library/react';
+import {useForm, commaSeparatedValidator, InputForm, NumberInputForm, UrlInput, ToggleForm, UrlsTextarea} from './useForm';
+import {renderWithProviders} from '../test-utils';
+
+// Mock lodash debounce to control timing in tests
+jest.mock('lodash', () => {
+ const actual = jest.requireActual('lodash');
+ return {
+ ...actual,
+ debounce: (fn) => {
+ const debounced = (...args) => fn(...args);
+ debounced.cancel = jest.fn();
+ return debounced;
+ },
+ };
+});
+
+describe('useForm', () => {
+ let consoleErrorSpy;
+ let consoleDebugSpy;
+
+ beforeEach(() => {
+ consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(() => {
+ });
+ consoleDebugSpy = jest.spyOn(console, 'debug').mockImplementation(() => {
+ });
+ });
+
+ afterEach(() => {
+ consoleErrorSpy.mockRestore();
+ consoleDebugSpy.mockRestore();
+ });
+
+ describe('initialization', () => {
+ it('initializes with defaultFormData', () => {
+ const defaultData = {name: 'test', value: 123};
+ const {result} = renderHook(() => useForm({
+ defaultFormData: defaultData,
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.formData).toEqual(defaultData);
+ });
+
+ it('starts with correct initial states', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {},
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.dirty).toBe(false);
+ expect(result.current.loading).toBe(false);
+ expect(result.current.disabled).toBe(false);
+ // Ready starts true when there are no validators or required fields
+ expect(result.current.ready).toBe(true);
+ });
+
+ it('calls fetcher on mount if provided', async () => {
+ const fetcher = jest.fn().mockResolvedValue({fetched: 'data'});
+
+ renderHook(() => useForm({
+ fetcher,
+ defaultFormData: {},
+ submitter: jest.fn(),
+ }));
+
+ await waitFor(() => {
+ expect(fetcher).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ it('updates formData when fetcher resolves', async () => {
+ const fetchedData = {name: 'fetched', value: 456};
+ const fetcher = jest.fn().mockResolvedValue(fetchedData);
+
+ const {result} = renderHook(() => useForm({
+ fetcher,
+ defaultFormData: {name: 'default'},
+ submitter: jest.fn(),
+ }));
+
+ await waitFor(() => {
+ expect(result.current.formData).toEqual(fetchedData);
+ });
+ });
+
+ it('initializes with empty object when no defaultFormData provided', () => {
+ const {result} = renderHook(() => useForm({
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.formData).toEqual({});
+ });
+ });
+
+ describe('submit lifecycle', () => {
+ it('refuses to submit when not ready', async () => {
+ const submitter = jest.fn();
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {},
+ submitter,
+ }));
+
+ // Force ready to false by adding a required field that's empty
+ act(() => {
+ result.current.getInputProps({name: 'required_field', required: true});
+ });
+
+ await act(async () => {
+ await result.current.onSubmit();
+ });
+
+ expect(submitter).not.toHaveBeenCalled();
+ expect(consoleErrorSpy).toHaveBeenCalledWith('Refusing to submit form because it is not ready');
+ });
+
+ it('sets loading and disabled during submission', async () => {
+ let resolveSubmit;
+ const submitter = jest.fn(() => new Promise(resolve => {
+ resolveSubmit = resolve;
+ }));
+
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'test'},
+ submitter,
+ }));
+
+ // Start submission
+ let submitPromise;
+ act(() => {
+ submitPromise = result.current.onSubmit();
+ });
+
+ // During submission - loading and disabled are set synchronously
+ expect(result.current.loading).toBe(true);
+ expect(result.current.disabled).toBe(true);
+ // Note: ready is set to false in onSubmit, but the useEffect that computes
+ // ready may run after this check. The key assertion is loading/disabled.
+
+ // Complete submission
+ await act(async () => {
+ resolveSubmit();
+ await submitPromise;
+ });
+
+ // After submission - states are reset
+ expect(result.current.loading).toBe(false);
+ expect(result.current.disabled).toBe(false);
+ });
+
+ it('calls submitter with form data', async () => {
+ const formData = {name: 'test', value: 123};
+ const submitter = jest.fn().mockResolvedValue(undefined);
+
+ const {result} = renderHook(() => useForm({
+ defaultFormData: formData,
+ submitter,
+ }));
+
+ await act(async () => {
+ await result.current.onSubmit();
+ });
+
+ expect(submitter).toHaveBeenCalledWith(formData);
+ });
+
+ it('calls onSuccess after successful submit', async () => {
+ const onSuccess = jest.fn();
+ const submitter = jest.fn().mockResolvedValue(undefined);
+
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'test'},
+ submitter,
+ onSuccess,
+ }));
+
+ await act(async () => {
+ await result.current.onSubmit();
+ });
+
+ expect(onSuccess).toHaveBeenCalled();
+ });
+
+ it('does not call onFailure on successful submit', async () => {
+ const onSuccess = jest.fn();
+ const onFailure = jest.fn();
+ const submitter = jest.fn().mockResolvedValue(undefined);
+
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'test'},
+ submitter,
+ onSuccess,
+ onFailure,
+ }));
+
+ await act(async () => {
+ await result.current.onSubmit();
+ });
+
+ expect(onSuccess).toHaveBeenCalled();
+ expect(onFailure).not.toHaveBeenCalled();
+ });
+
+ it('clears form when clearOnSuccess is true', async () => {
+ const emptyFormData = {name: '', value: 0};
+ const submitter = jest.fn().mockResolvedValue(undefined);
+
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'test', value: 123},
+ emptyFormData,
+ submitter,
+ clearOnSuccess: true,
+ }));
+
+ await act(async () => {
+ await result.current.onSubmit();
+ });
+
+ expect(result.current.formData).toEqual(emptyFormData);
+ });
+
+ it('re-fetches when fetchOnSuccess is true', async () => {
+ const fetchedData = {name: 'refetched'};
+ const fetcher = jest.fn().mockResolvedValue(fetchedData);
+ const submitter = jest.fn().mockResolvedValue(undefined);
+
+ const {result} = renderHook(() => useForm({
+ fetcher,
+ defaultFormData: {name: 'initial'},
+ submitter,
+ fetchOnSuccess: true,
+ }));
+
+ // Wait for initial fetch
+ await waitFor(() => {
+ expect(fetcher).toHaveBeenCalledTimes(1);
+ });
+
+ await act(async () => {
+ await result.current.onSubmit();
+ });
+
+ // Should have been called twice (init + after submit)
+ expect(fetcher).toHaveBeenCalledTimes(2);
+ });
+
+ it('handles submitter errors gracefully', async () => {
+ const submitter = jest.fn().mockRejectedValue(new Error('Submit failed'));
+ const onFailure = jest.fn();
+
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'test'},
+ submitter,
+ onFailure,
+ }));
+
+ await act(async () => {
+ try {
+ await result.current.onSubmit();
+ } catch (e) {
+ // Expected to throw
+ }
+ });
+
+ // Should still reset loading/disabled states
+ expect(result.current.loading).toBe(false);
+ expect(result.current.disabled).toBe(false);
+ expect(onFailure).toHaveBeenCalled();
+ });
+ });
+
+ describe('form data management', () => {
+ it('setValue updates value at path', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'original'},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.setValue('name', 'updated');
+ });
+
+ expect(result.current.formData.name).toBe('updated');
+ });
+
+ it('setValue handles nested paths', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {config: {nested: {value: 'original'}}},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.setValue('config.nested.value', 'updated');
+ });
+
+ expect(result.current.formData.config.nested.value).toBe('updated');
+ });
+
+ it('handleInputEvent extracts value from event', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {username: ''},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {
+ type: 'text',
+ value: 'newvalue',
+ name: 'username',
+ dataset: {},
+ },
+ });
+ });
+
+ expect(result.current.formData.username).toBe('newvalue');
+ });
+
+ it('handleInputEvent converts number inputs to integers', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {count: 0},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {
+ type: 'number',
+ value: '42',
+ name: 'count',
+ dataset: {},
+ },
+ });
+ });
+
+ expect(result.current.formData.count).toBe(42);
+ expect(typeof result.current.formData.count).toBe('number');
+ });
+
+ it('handleInputEvent uses data-path attribute when available', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {nested: {field: ''}},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {
+ type: 'text',
+ value: 'pathvalue',
+ name: 'ignored',
+ dataset: {path: 'nested.field'},
+ },
+ });
+ });
+
+ expect(result.current.formData.nested.field).toBe('pathvalue');
+ });
+
+ it('reset restores to emptyFormData when provided', () => {
+ const emptyFormData = {name: '', value: null};
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'test', value: 123},
+ emptyFormData,
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.setValue('name', 'modified');
+ });
+
+ act(() => {
+ result.current.reset();
+ });
+
+ expect(result.current.formData).toEqual(emptyFormData);
+ });
+
+ it('reset restores to defaultFormData when emptyFormData not provided', () => {
+ const defaultFormData = {name: 'default', value: 100};
+ const {result} = renderHook(() => useForm({
+ defaultFormData,
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.setValue('name', 'modified');
+ });
+
+ act(() => {
+ result.current.reset();
+ });
+
+ expect(result.current.formData).toEqual(defaultFormData);
+ });
+ });
+
+ describe('dirty tracking', () => {
+ it('becomes dirty when formData differs from defaultFormData', async () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'original'},
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.dirty).toBe(false);
+
+ act(() => {
+ result.current.setValue('name', 'changed');
+ });
+
+ await waitFor(() => {
+ expect(result.current.dirty).toBe(true);
+ });
+ });
+
+ it('calls onDirty callback when dirty becomes true', async () => {
+ const onDirty = jest.fn();
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'original'},
+ submitter: jest.fn(),
+ onDirty,
+ }));
+
+ act(() => {
+ result.current.setValue('name', 'changed');
+ });
+
+ await waitFor(() => {
+ expect(onDirty).toHaveBeenCalled();
+ });
+ });
+
+ it('uses deep equality for nested objects', async () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {config: {nested: 'value'}},
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.dirty).toBe(false);
+
+ act(() => {
+ result.current.setValue('config.nested', 'different');
+ });
+
+ await waitFor(() => {
+ expect(result.current.dirty).toBe(true);
+ });
+ });
+ });
+
+ describe('validation', () => {
+ it('validates field and sets error when invalid', async () => {
+ const validator = (value) => value.length < 3 ? 'Too short' : null;
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: ''},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getInputProps({name: 'name', validator});
+ });
+
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {
+ type: 'text',
+ value: 'ab',
+ name: 'name',
+ dataset: {},
+ },
+ });
+ });
+
+ await waitFor(() => {
+ expect(result.current.ready).toBe(false);
+ });
+ });
+
+ it('clears error when value becomes valid', async () => {
+ const validator = (value) => value.length < 3 ? 'Too short' : null;
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: ''},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getInputProps({name: 'name', validator});
+ });
+
+ // Set invalid value
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {type: 'text', value: 'ab', name: 'name', dataset: {}},
+ });
+ });
+
+ // Set valid value
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {type: 'text', value: 'valid', name: 'name', dataset: {}},
+ });
+ });
+
+ await waitFor(() => {
+ expect(result.current.ready).toBe(true);
+ });
+ });
+
+ it('handles validator that throws exception', () => {
+ const validator = () => {
+ throw new Error('Validator crashed');
+ };
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: ''},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getInputProps({name: 'name', validator});
+ });
+
+ // Should not throw, just log error
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {type: 'text', value: 'test', name: 'name', dataset: {}},
+ });
+ });
+
+ expect(consoleErrorSpy).toHaveBeenCalledWith('Failed to validate name');
+ });
+ });
+
+ describe('required fields', () => {
+ it('form is not ready when required field is empty', async () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: ''},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getInputProps({name: 'name', required: true});
+ });
+
+ await waitFor(() => {
+ expect(result.current.ready).toBe(false);
+ });
+ });
+
+ it('form becomes ready when required field has value', async () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {name: 'has value'},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getInputProps({name: 'name', required: true});
+ });
+
+ await waitFor(() => {
+ expect(result.current.ready).toBe(true);
+ });
+ });
+
+ it('addRequires marks field as required', async () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {field: ''},
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.ready).toBe(true);
+
+ act(() => {
+ result.current.getCustomProps({name: 'field', required: true});
+ });
+
+ await waitFor(() => {
+ expect(result.current.ready).toBe(false);
+ });
+ });
+ });
+
+ describe('props generators', () => {
+ describe('getInputProps', () => {
+ it('returns props suitable for input elements', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {email: 'test@example.com'},
+ submitter: jest.fn(),
+ }));
+
+ const [inputProps] = result.current.getInputProps({name: 'email'});
+
+ expect(inputProps).toHaveProperty('type', 'text');
+ expect(inputProps).toHaveProperty('disabled', false);
+ expect(inputProps).toHaveProperty('value', 'test@example.com');
+ expect(inputProps).toHaveProperty('name', 'email');
+ expect(inputProps).toHaveProperty('onChange');
+ });
+
+ it('includes error when validation fails', async () => {
+ const validator = () => 'Error message';
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {field: 'value'},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getInputProps({name: 'field', validator});
+ });
+
+ // Trigger validation
+ act(() => {
+ result.current.handleInputEvent({
+ preventDefault: jest.fn(),
+ target: {type: 'text', value: 'x', name: 'field', dataset: {}},
+ });
+ });
+
+ const [inputProps] = result.current.getInputProps({name: 'field', validator});
+
+ await waitFor(() => {
+ expect(inputProps.error).toBe('Error message');
+ });
+ });
+
+ it('applies URL validator for url type', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {url: ''},
+ submitter: jest.fn(),
+ }));
+
+ const [inputProps, inputAttrs] = result.current.getInputProps({name: 'url', type: 'url'});
+
+ expect(inputProps.type).toBe('url');
+ });
+ });
+
+ describe('getCustomProps', () => {
+ it('returns props for custom components', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {tags: ['tag1', 'tag2']},
+ submitter: jest.fn(),
+ }));
+
+ const [customProps, attrs] = result.current.getCustomProps({
+ name: 'tags',
+ type: 'array',
+ });
+
+ expect(customProps).toHaveProperty('disabled', false);
+ expect(customProps).toHaveProperty('value', ['tag1', 'tag2']);
+ expect(customProps).toHaveProperty('onChange');
+ expect(customProps).toHaveProperty('data-path', 'tags');
+ });
+
+ it('initializes undefined array fields to empty array', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.getCustomProps({name: 'newArray', type: 'array'});
+ });
+
+ expect(result.current.formData.newArray).toEqual([]);
+ });
+ });
+
+ describe('getSelectionProps', () => {
+ it('returns props for dropdown elements', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {selected: 'option1'},
+ submitter: jest.fn(),
+ }));
+
+ const [selectionProps] = result.current.getSelectionProps({name: 'selected'});
+
+ expect(selectionProps).toHaveProperty('disabled', false);
+ expect(selectionProps).toHaveProperty('value', 'option1');
+ expect(selectionProps).toHaveProperty('onChange');
+ });
+
+ it('onChange handler extracts value from event correctly', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {selected: 'option1'},
+ submitter: jest.fn(),
+ }));
+
+ const [selectionProps] = result.current.getSelectionProps({name: 'selected'});
+
+ act(() => {
+ // Semantic UI dropdown passes (event, {value})
+ selectionProps.onChange({}, {value: 'option2'});
+ });
+
+ expect(result.current.formData.selected).toBe('option2');
+ });
+ });
+ });
+
+ describe('edge cases and error handling', () => {
+ it('logs error when both clearOnSuccess and fetchOnSuccess are true', () => {
+ renderHook(() => useForm({
+ fetcher: jest.fn(),
+ defaultFormData: {},
+ submitter: jest.fn(),
+ clearOnSuccess: true,
+ fetchOnSuccess: true,
+ }));
+
+ expect(consoleErrorSpy).toHaveBeenCalledWith('Cannot use both clearOnSuccess and fetchOnSuccess!');
+ });
+
+ it('logs error when fetchOnSuccess is true without fetcher', () => {
+ renderHook(() => useForm({
+ defaultFormData: {},
+ submitter: jest.fn(),
+ fetchOnSuccess: true,
+ }));
+
+ expect(consoleErrorSpy).toHaveBeenCalledWith('Cannot fetchOnSuccess without fetcher!');
+ });
+
+ it('handles empty defaultFormData', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {},
+ submitter: jest.fn(),
+ }));
+
+ expect(result.current.formData).toEqual({});
+ expect(result.current.dirty).toBe(false);
+ });
+
+ it('creates path when setting value on undefined nested path', () => {
+ const {result} = renderHook(() => useForm({
+ defaultFormData: {},
+ submitter: jest.fn(),
+ }));
+
+ act(() => {
+ result.current.setValue('deeply.nested.value', 'test');
+ });
+
+ expect(result.current.formData.deeply.nested.value).toBe('test');
+ });
+ });
+});
+
+describe('commaSeparatedValidator', () => {
+ it('returns null for valid comma-separated string', () => {
+ expect(commaSeparatedValidator('a,b,c')).toBeUndefined();
+ });
+
+ it('returns error when string ends with comma', () => {
+ expect(commaSeparatedValidator('a,b,')).toBe('Cannot end with comma');
+ });
+
+ it('returns error when string starts with comma', () => {
+ expect(commaSeparatedValidator(',a,b')).toBe('Cannot start with comma');
+ });
+
+ it('returns error for non-string input', () => {
+ expect(commaSeparatedValidator(123)).toBe('Expected a string');
+ expect(commaSeparatedValidator(null)).toBe('Expected a string');
+ });
+});
+
+describe('Form Components', () => {
+ const createMockForm = (formData = {}) => {
+ const mockForm = {
+ disabled: false,
+ formData,
+ getInputProps: jest.fn(({name, path, validator, type, required, onChange}) => {
+ const p = path || name;
+ return [{
+ type: type || 'text',
+ disabled: false,
+ value: formData[p] || '',
+ name,
+ onChange: jest.fn(),
+ error: null,
+ required: required ? null : undefined,
+ 'data-path': p,
+ }, {valid: true, path: p, localSetValue: jest.fn()}];
+ }),
+ getCustomProps: jest.fn(({name, path}) => {
+ const p = path || name;
+ return [{
+ disabled: false,
+ value: formData[p],
+ onChange: jest.fn(),
+ 'data-path': p,
+ }, {valid: true, path: p, localSetValue: jest.fn()}];
+ }),
+ onSubmit: jest.fn(),
+ };
+ return mockForm;
+ };
+
+ describe('InputForm', () => {
+ it('renders with label', () => {
+ const form = createMockForm({username: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(screen.getByText('Username')).toBeInTheDocument();
+ });
+
+ it('shows required asterisk when required', () => {
+ const form = createMockForm({email: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(form.getInputProps).toHaveBeenCalledWith(
+ expect.objectContaining({required: true})
+ );
+ });
+
+ it('passes placeholder to input', () => {
+ const form = createMockForm({name: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(screen.getByPlaceholderText('Enter name')).toBeInTheDocument();
+ });
+
+ it('respects disabled prop', () => {
+ const form = createMockForm({field: ''});
+
+ renderWithProviders(
+
+ );
+
+ const input = screen.getByRole('textbox');
+ expect(input).toBeDisabled();
+ });
+ });
+
+ describe('NumberInputForm', () => {
+ it('renders as number input', () => {
+ const form = createMockForm({count: 0});
+
+ renderWithProviders(
+
+ );
+
+ expect(form.getInputProps).toHaveBeenCalledWith(
+ expect.objectContaining({type: 'number'})
+ );
+ });
+
+ it('applies min and max constraints', () => {
+ const form = createMockForm({value: 5});
+
+ renderWithProviders(
+
+ );
+
+ const input = screen.getByRole('spinbutton');
+ expect(input).toHaveAttribute('min', '0');
+ expect(input).toHaveAttribute('max', '100');
+ });
+ });
+
+ describe('UrlInput', () => {
+ it('renders URL input with label', () => {
+ const form = createMockForm({url: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(screen.getByText('URL')).toBeInTheDocument();
+ });
+
+ it('applies URL type', () => {
+ const form = createMockForm({url: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(form.getInputProps).toHaveBeenCalledWith(
+ expect.objectContaining({type: 'url'})
+ );
+ });
+
+ it('is required by default', () => {
+ const form = createMockForm({url: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(form.getInputProps).toHaveBeenCalledWith(
+ expect.objectContaining({required: true})
+ );
+ });
+ });
+
+ describe('ToggleForm', () => {
+ it('renders toggle with label', () => {
+ const form = createMockForm({enabled: false});
+
+ renderWithProviders(
+
+ );
+
+ expect(screen.getByText('Enable Feature')).toBeInTheDocument();
+ });
+
+ it('renders with checkbox input', () => {
+ const form = createMockForm({active: true});
+
+ renderWithProviders(
+
+ );
+
+ expect(screen.getByRole('checkbox')).toBeInTheDocument();
+ });
+ });
+
+ describe('UrlsTextarea', () => {
+ it('renders textarea for multiple URLs', () => {
+ const form = createMockForm({urls: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(screen.getByPlaceholderText('Enter one URL per line')).toBeInTheDocument();
+ });
+
+ it('calls getInputProps with correct name', () => {
+ const form = createMockForm({urls: ''});
+
+ renderWithProviders(
+
+ );
+
+ expect(form.getInputProps).toHaveBeenCalledWith(
+ expect.objectContaining({name: 'urls'})
+ );
+ });
+ });
+});
diff --git a/app/src/setupTests.js b/app/src/setupTests.js
new file mode 100644
index 000000000..1e70bd9ac
--- /dev/null
+++ b/app/src/setupTests.js
@@ -0,0 +1,48 @@
+// Jest setup file for React Testing Library
+// This file is automatically loaded by Create React App before running tests
+
+import '@testing-library/jest-dom';
+
+// Mock window.matchMedia (used by Semantic UI Media components)
+Object.defineProperty(window, 'matchMedia', {
+ writable: true,
+ value: jest.fn().mockImplementation(query => ({
+ matches: false,
+ media: query,
+ onchange: null,
+ addListener: jest.fn(), // deprecated
+ removeListener: jest.fn(), // deprecated
+ addEventListener: jest.fn(),
+ removeEventListener: jest.fn(),
+ dispatchEvent: jest.fn(),
+ })),
+});
+
+// Mock IntersectionObserver (if needed for lazy loading)
+global.IntersectionObserver = class IntersectionObserver {
+ constructor() {}
+ disconnect() {}
+ observe() {}
+ takeRecords() {
+ return [];
+ }
+ unobserve() {}
+};
+
+// Suppress console errors during tests (optional - uncomment if needed)
+// const originalError = console.error;
+// beforeAll(() => {
+// console.error = (...args) => {
+// if (
+// typeof args[0] === 'string' &&
+// args[0].includes('Warning: ReactDOM.render')
+// ) {
+// return;
+// }
+// originalError.call(console, ...args);
+// };
+// });
+//
+// afterAll(() => {
+// console.error = originalError;
+// });
diff --git a/app/src/test-utils.js b/app/src/test-utils.js
new file mode 100644
index 000000000..bb1e248d1
--- /dev/null
+++ b/app/src/test-utils.js
@@ -0,0 +1,292 @@
+/**
+ * Test utilities for React Testing Library
+ * Provides custom render functions with necessary context providers
+ */
+
+import React from 'react';
+import {render} from '@testing-library/react';
+import {BrowserRouter} from 'react-router-dom';
+import {MediaContextProvider, ThemeContext} from './contexts/contexts';
+
+/**
+ * Custom render function that wraps components with necessary providers
+ * @param {React.Component} ui - Component to render
+ * @param {Object} options - Render options
+ * @param {boolean} options.inverted - Whether to use dark theme
+ * @param {Object} options.themeContext - Custom theme context values
+ * @param {boolean} options.withMedia - Include MediaContextProvider (default: false)
+ * @param {Object} options.renderOptions - Additional React Testing Library options
+ */
+export function renderWithProviders(
+ ui,
+ {
+ inverted = false,
+ themeContext = {},
+ withMedia = false,
+ ...renderOptions
+ } = {}
+) {
+ const defaultThemeContext = {
+ inverted,
+ setInverted: jest.fn(),
+ // Theme components use different properties:
+ i: inverted ? {inverted: true} : {inverted: undefined}, // For Semantic UI elements (Segment, Form, etc.)
+ s: inverted ? {style: {backgroundColor: '#1B1C1D', color: '#dddddd'}} : {}, // For style inversion
+ t: inverted ? {style: {color: '#eeeeee'}} : {}, // For text color inversion (Header, etc.)
+ theme: inverted ? 'dark' : 'light',
+ ...themeContext
+ };
+
+ function Wrapper({children}) {
+ const content = (
+
+
+ {children}
+
+
+ );
+
+ // Only wrap with MediaContextProvider if explicitly requested
+ // (since it requires window.matchMedia which can be tricky in tests)
+ if (withMedia) {
+ return {content};
+ }
+
+ return content;
+ }
+
+ return render(ui, {wrapper: Wrapper, ...renderOptions});
+}
+
+/**
+ * Creates a mock collection metadata object for testing
+ */
+export function createMockMetadata(kind = 'domain', overrides = {}) {
+ return {
+ kind,
+ columns: [
+ {key: 'domain', label: 'Domain', sortable: true},
+ {key: 'archive_count', label: 'Archives', sortable: true, align: 'right'},
+ {key: 'size', label: 'Size', sortable: true, align: 'right', format: 'bytes'},
+ {key: 'tag_name', label: 'Tag', sortable: true},
+ {key: 'actions', label: 'Manage', sortable: false, type: 'actions'},
+ ],
+ fields: [
+ {key: 'directory', label: 'Directory', type: 'text', placeholder: 'Optional directory path'},
+ {key: 'tag_name', label: 'Tag', type: 'tag', placeholder: 'Select or create tag', depends_on: 'directory'},
+ {key: 'description', label: 'Description', type: 'textarea', placeholder: 'Optional description'},
+ ],
+ routes: {
+ list: '/archive/domains',
+ edit: '/archive/domain/:id/edit',
+ search: '/archive',
+ },
+ messages: {
+ no_directory: 'Set a directory to enable tagging',
+ tag_will_move: 'Tagging will move files to a new directory'
+ },
+ ...overrides
+ };
+}
+
+/**
+ * Creates a mock domain collection object for testing
+ */
+export function createMockDomain(overrides = {}) {
+ return {
+ id: 1,
+ domain: 'example.com',
+ archive_count: 42,
+ size: 1024000,
+ tag_name: null,
+ directory: '',
+ can_be_tagged: false,
+ description: '',
+ ...overrides
+ };
+}
+
+/**
+ * Creates multiple mock domains for list testing
+ */
+export function createMockDomains(count = 3) {
+ return Array.from({length: count}, (_, i) => createMockDomain({
+ id: i + 1,
+ domain: `example${i + 1}.com`,
+ archive_count: (i + 1) * 10,
+ size: (i + 1) * 1000000,
+ }));
+}
+
+/**
+ * Mock fetch implementation for API calls
+ */
+export function mockFetch(data, options = {}) {
+ const {
+ ok = true,
+ status = 200,
+ delay = 0,
+ } = options;
+
+ return jest.fn(() =>
+ new Promise((resolve) => {
+ setTimeout(() => {
+ resolve({
+ ok,
+ status,
+ json: async () => data,
+ text: async () => JSON.stringify(data),
+ });
+ }, delay);
+ })
+ );
+}
+
+/**
+ * Mock API error response
+ */
+export function mockFetchError(error = 'An error occurred', status = 400) {
+ return jest.fn(() =>
+ Promise.resolve({
+ ok: false,
+ status,
+ json: async () => ({error}),
+ text: async () => error,
+ })
+ );
+}
+
+/**
+ * Wait for async updates to complete
+ * Useful for testing loading states
+ */
+export async function waitForLoadingToFinish() {
+ const {waitFor} = await import('@testing-library/react');
+ await waitFor(() => {}, {timeout: 100});
+}
+
+/**
+ * Test helper to render components in dark mode
+ *
+ * Usage:
+ * renderInDarkMode()
+ *
+ * Verify inverted styling is applied:
+ * const element = container.querySelector('.ui.segment.inverted');
+ * expect(element).toBeInTheDocument();
+ */
+export function renderInDarkMode(ui, options = {}) {
+ return renderWithProviders(ui, {
+ inverted: true,
+ ...options
+ });
+}
+
+/**
+ * Test helper to render components in light mode
+ * (This is the default, but provided for explicitness in tests)
+ */
+export function renderInLightMode(ui, options = {}) {
+ return renderWithProviders(ui, {
+ inverted: false,
+ ...options
+ });
+}
+
+/**
+ * Helper to check if an element has theme-aware (inverted) styling
+ * Returns true if the element has the 'inverted' class
+ *
+ * Usage:
+ * const segment = container.querySelector('.ui.segment');
+ * expect(hasInvertedStyling(segment)).toBe(true);
+ */
+export function hasInvertedStyling(element) {
+ return element && element.classList.contains('inverted');
+}
+
+/**
+ * Creates a test-friendly form object using real useForm hook
+ *
+ * This uses the actual useForm implementation, making tests more reliable.
+ * The form starts in a "ready" state with the provided data.
+ *
+ * Usage:
+ * const form = createTestForm({domain: 'test.com', directory: '/path'});
+ * render();
+ *
+ * // With overrides:
+ * const form = createTestForm(data, {overrides: {loading: true, disabled: true}});
+ */
+export function createTestForm(initialData = {}, config = {}) {
+ // Return a plain object that mimics the useForm interface without actual React hooks
+ // This avoids async state updates that cause act() warnings in tests
+ const _ = require('lodash');
+ const formData = {...initialData};
+
+ const setValue = (path, newValue) => {
+ _.set(formData, path, newValue);
+ };
+
+ const form = {
+ formData,
+ ready: true,
+ loading: false,
+ disabled: false,
+ dirty: false,
+ errors: {},
+ // Methods that update formData directly
+ patchFormData: jest.fn((updates) => {
+ Object.assign(formData, updates);
+ }),
+ reset: jest.fn(() => {
+ Object.keys(formData).forEach(key => delete formData[key]);
+ Object.assign(formData, initialData);
+ }),
+ onSubmit: config.submitter || jest.fn(async () => initialData),
+ // Input helpers
+ setError: jest.fn(),
+ setValidator: jest.fn(),
+ setValidValue: jest.fn(),
+ setRequired: jest.fn(),
+ // Property getter for field values
+ get: (path) => _.get(formData, path),
+ setValue: jest.fn((path, value) => setValue(path, value)),
+ // getCustomProps - mimics the real useForm method
+ getCustomProps: ({name, path, required = false, type = 'text'}) => {
+ path = path || name;
+ const value = _.get(formData, path);
+
+ const inputProps = {
+ type,
+ disabled: form.disabled,
+ value: value !== undefined ? value : (type === 'array' ? [] : null),
+ onChange: (newValue) => {
+ setValue(path, newValue);
+ },
+ 'data-path': path,
+ };
+
+ const inputAttrs = {
+ valid: true,
+ path,
+ localSetValue: (newValue) => setValue(path, newValue),
+ };
+
+ return [inputProps, inputAttrs];
+ },
+ };
+
+ // Apply any overrides
+ if (config.overrides) {
+ Object.assign(form, config.overrides);
+ }
+
+ return form;
+}
+
+// Re-export everything from React Testing Library
+export * from '@testing-library/react';
+
+// Override the default render with our custom one
+export {renderWithProviders as render};
diff --git a/docker/archive/main.py b/docker/archive/main.py
index 8c04167ea..0f80896ea 100644
--- a/docker/archive/main.py
+++ b/docker/archive/main.py
@@ -10,6 +10,7 @@
import os.path
import pathlib
import subprocess
+import sys
import tempfile
import traceback
from json import JSONDecodeError
@@ -33,6 +34,12 @@
if not SINGLEFILE_PATH.is_file():
raise FileNotFoundError("Can't find single-file executable!")
+BROWSER_EXEC = pathlib.Path('/usr/bin/google-chrome')
+if not BROWSER_EXEC.is_file():
+ print('Unable to find browser!', file=sys.stderr)
+ sys.exit(1)
+
+
# Increase response timeout, archiving can take several minutes.
RESPONSE_TIMEOUT = 10 * 60
config = {
@@ -117,7 +124,7 @@ async def extract_readability(path: str, url: str) -> dict:
async def take_screenshot(url: str) -> bytes:
- cmd = '/usr/bin/google-chrome' \
+ cmd = f'{BROWSER_EXEC}' \
' --headless' \
' --disable-gpu' \
' --no-sandbox' \
@@ -155,6 +162,61 @@ def prepare_bytes(b: bytes) -> str:
return b
+@app.post('/screenshot')
+async def post_screenshot(request: Request):
+ """Generate a screenshot for the provided singlefile."""
+ url = request.json['url']
+ singlefile = request.json.get('singlefile')
+
+ try:
+ logger.info(f'Generating screenshot for {url}')
+
+ # Decode and decompress the singlefile
+ if singlefile:
+ singlefile = base64.b64decode(singlefile)
+ singlefile = gzip.decompress(singlefile)
+
+ if not singlefile:
+ raise ValueError(f'No singlefile provided for {url}')
+
+ # Write singlefile to temp file and screenshot it
+ # Use html suffix so chrome screenshot recognizes it as an HTML file
+ with tempfile.NamedTemporaryFile('wb', suffix='.html') as fh:
+ fh.write(singlefile)
+ fh.flush()
+
+ screenshot = None
+ try:
+ # Screenshot the local singlefile
+ screenshot = await take_screenshot(f'file://{fh.name}')
+ except Exception as e:
+ logger.error(f'Failed to take screenshot of {fh.name}', exc_info=e)
+
+ # Fall back to URL if local screenshot failed
+ if not screenshot:
+ logger.warning(f'Failed to screenshot local singlefile, attempting to screenshot URL: {url}')
+ try:
+ screenshot = await take_screenshot(url)
+ except Exception as e:
+ logger.error(f'Failed to take screenshot of {url}', exc_info=e)
+
+ if not screenshot:
+ raise ValueError(f'Failed to generate screenshot for {url}')
+
+ # Compress for smaller response
+ screenshot = prepare_bytes(screenshot)
+
+ ret = dict(
+ url=url,
+ screenshot=screenshot,
+ )
+ return response.json(ret)
+ except Exception as e:
+ logger.error(f'Failed to generate screenshot for {url}', exc_info=e)
+ error = str(traceback.format_exc())
+ return response.json({'error': f'Failed to generate screenshot for {url} traceback is below... \n\n {error}'})
+
+
@app.post('/json')
async def post_archive(request: Request):
url = request.json['url']
diff --git a/etc/raspberrypios/90-wrolpi b/etc/raspberrypios/90-wrolpi
index 3b2a36f76..e1010331d 100644
--- a/etc/raspberrypios/90-wrolpi
+++ b/etc/raspberrypios/90-wrolpi
@@ -12,4 +12,6 @@
%wrolpi ALL= NOPASSWD:/usr/bin/systemctl stop wrolpi-app.service
%wrolpi ALL= NOPASSWD:/usr/bin/systemctl start wrolpi-app.service
%wrolpi ALL= NOPASSWD:/opt/wrolpi/scripts/import_map.sh
+%wrolpi ALL= NOPASSWD:/opt/wrolpi/upgrade.sh *
+%wrolpi ALL= NOPASSWD:/usr/bin/systemctl start wrolpi-upgrade.service
%wrolpi ALL= NOPASSWD:/usr/sbin/shutdown
\ No newline at end of file
diff --git a/etc/raspberrypios/maintenance.html b/etc/raspberrypios/maintenance.html
new file mode 100644
index 000000000..e67fbc860
--- /dev/null
+++ b/etc/raspberrypios/maintenance.html
@@ -0,0 +1,189 @@
+
+
+
+
+
+ WROLPi - Upgrade in Progress
+
+
+
+
+
WROLPi
+
Upgrade in Progress
+
Please wait while your system is being upgraded
+
+
+
+
+
+
Checking upgrade status...
+
Elapsed: 0:00
+
+
+ Do not power off your device or close this window during the upgrade.
+
+
+
+
+
+
diff --git a/etc/raspberrypios/wrolpi-upgrade.service b/etc/raspberrypios/wrolpi-upgrade.service
new file mode 100644
index 000000000..76d6a047e
--- /dev/null
+++ b/etc/raspberrypios/wrolpi-upgrade.service
@@ -0,0 +1,11 @@
+[Unit]
+Description=WROLPi Upgrade Service
+After=network.target
+
+[Service]
+Type=oneshot
+EnvironmentFile=/tmp/wrolpi-upgrade.env
+ExecStart=/opt/wrolpi/upgrade.sh -b ${BRANCH}
+User=root
+StandardOutput=journal
+StandardError=journal
diff --git a/etc/raspberrypios/wrolpi.conf b/etc/raspberrypios/wrolpi.conf
index b911d7c76..0233b035c 100644
--- a/etc/raspberrypios/wrolpi.conf
+++ b/etc/raspberrypios/wrolpi.conf
@@ -1,5 +1,22 @@
# Nginx config for the WROLPi services.
+#
+# Maintenance page for upgrades
+#
+
+location = /maintenance.html {
+ alias /var/www/maintenance.html;
+}
+
+#
+# Error pages
+#
+
+location /error/ {
+ alias /var/www/;
+ internal;
+}
+
#
# App
#
diff --git a/main.py b/main.py
index 789b1d6d1..28e13babb 100755
--- a/main.py
+++ b/main.py
@@ -13,6 +13,7 @@
from modules.inventory.common import import_inventories_config
from modules.videos.lib import import_channels_config, get_videos_downloader_config
from wrolpi import flags, admin
+from modules.archive.lib import import_domains_config
from wrolpi import root_api # noqa
from wrolpi import tags
from wrolpi.api_utils import api_app, perpetual_signal
@@ -244,6 +245,10 @@ async def start_single_tasks(app: Sanic):
with suppress(Exception):
import_channels_config()
logger.debug('channels config imported')
+ # Domains config import
+ with suppress(Exception):
+ import_domains_config()
+ logger.debug('domains config imported')
with suppress(Exception):
import_inventories_config()
logger.debug('inventories config imported')
@@ -353,5 +358,38 @@ async def perpetual_start_video_missing_comments_download():
logger.debug('Waiting for downloads to be enabled before downloading comments...')
+@perpetual_signal(sleep=3600) # Check hourly
+async def perpetual_check_for_updates():
+ """
+ Check git remote for new commits on the current branch.
+
+ Updates shared_ctx with update info for the /api/status endpoint.
+ Only runs on native installs (not Docker).
+ """
+ from wrolpi.upgrade import check_for_update
+
+ # Don't check for updates in Docker environments
+ if DOCKERIZED:
+ return
+
+ # Only check when we have internet
+ if not flags.have_internet.is_set():
+ return
+
+ try:
+ result = check_for_update(fetch=True)
+ # Store in shared_ctx.status (a manager.dict) to share across all workers.
+ api_app.shared_ctx.status['update_available'] = result.get('update_available', False)
+ api_app.shared_ctx.status['latest_commit'] = result.get('latest_commit')
+ api_app.shared_ctx.status['current_commit'] = result.get('current_commit')
+ api_app.shared_ctx.status['commits_behind'] = result.get('commits_behind', 0)
+ api_app.shared_ctx.status['git_branch'] = result.get('branch')
+
+ if result.get('update_available'):
+ logger.info(f"Update available: {result['commits_behind']} commits behind on {result['branch']}")
+ except Exception as e:
+ logger.error('Failed to check for updates', exc_info=e)
+
+
if __name__ == '__main__':
sys.exit(main())
diff --git a/modules/archive/__init__.py b/modules/archive/__init__.py
index 05426e000..b2702595b 100644
--- a/modules/archive/__init__.py
+++ b/modules/archive/__init__.py
@@ -8,6 +8,7 @@
from sqlalchemy.orm import Session
from wrolpi.cmd import SINGLE_FILE_BIN, CHROMIUM
+from wrolpi.collections import Collection
from wrolpi.common import logger, register_modeler, register_refresh_cleanup, limit_concurrent, split_lines_by_length, \
slow_logger, get_title_from_html
from wrolpi.db import optional_session, get_db_session
@@ -18,8 +19,8 @@
from . import lib
from .api import archive_bp # noqa
from .errors import InvalidArchive
-from .lib import is_singlefile_file, request_archive, SINGLEFILE_HEADER
-from .models import Archive, Domain
+from .lib import is_singlefile_file, request_archive, SINGLEFILE_HEADER, get_url_from_singlefile
+from .models import Archive
PRETTY_NAME = 'Archive'
@@ -186,15 +187,61 @@ def model_archive(file_group: FileGroup, session: Session = None) -> Archive:
if not archive:
# Create new Archive if it doesn't exist
archive = Archive(file_group_id=file_group_id, file_group=file_group)
+
+ # Set collection_id BEFORE adding to session to avoid autoflush constraint violation
+ from modules.archive.lib import get_or_create_domain_collection
+ if file_group.url:
+ # URL is already set on the FileGroup
+ collection = get_or_create_domain_collection(session, file_group.url)
+ archive.collection_id = collection.id if collection else None
+ else:
+ # No URL - try to extract from singlefile
+ try:
+ url = get_url_from_singlefile(singlefile_path.read_bytes())
+ file_group.url = url
+ collection = get_or_create_domain_collection(session, url)
+ archive.collection_id = collection.id if collection else None
+ except (RuntimeError, ValueError) as e:
+ # Could not extract URL from singlefile
+ if not PYTEST:
+ # In production, archives must have a URL/collection
+ raise InvalidArchive(f'Archive has no URL and could not extract from singlefile: {e}') from e
+ # In tests, allow archives without collections (factory will set it later)
+ logger.debug(f'Could not extract URL from singlefile (test mode): {e}')
+ archive.collection_id = None
+
session.add(archive)
+
archive.validate()
archive.flush()
+ # Validate that archive path matches domain directory (efficient O(1) check)
+ if archive.collection_id and archive.collection:
+ collection = archive.collection
+ if collection.kind == 'domain' and collection.directory and not collection.tag_id:
+ # Check if archive path is under the domain's directory
+ archive_path = pathlib.Path(file_group.primary_path)
+ from wrolpi.common import get_media_directory
+ media_dir = get_media_directory()
+ collection_abs_dir = media_dir / collection.directory
+
+ try:
+ # Check if archive is under the collection's directory
+ archive_path.relative_to(collection_abs_dir)
+ except ValueError:
+ # Archive is NOT under the domain's directory - clear it
+ logger.warning(
+ f'Archive {archive.id} at {archive_path} is not under domain directory {collection.directory}. '
+ f'Clearing domain directory (domain is not tagged).'
+ )
+ collection.directory = None
+ session.flush([collection])
+
file_group.title = file_group.a_text = title or archive.file_group.title
file_group.d_text = contents
file_group.data = {
'id': archive.id,
- 'domain': archive.domain.domain if archive.domain else None,
+ 'domain': archive.domain,
'readability_json_path': archive.readability_json_path,
'readability_path': archive.readability_path,
'readability_txt_path': archive.readability_txt_path,
@@ -230,26 +277,33 @@ async def archive_modeler():
processed = 0
for processed, (file_group, archive) in enumerate(results):
- # Even if indexing fails, we mark it as indexed. We won't retry indexing this.
- file_group.indexed = True
-
with slow_logger(1, f'Modeling archive took %(elapsed)s seconds: {file_group}',
logger__=logger):
if archive:
try:
archive_id = archive.id
archive.validate()
+ # Successfully validated, mark as indexed
+ file_group.indexed = True
except Exception:
logger.error(f'Unable to validate Archive {archive_id}')
+ # Don't mark as indexed - will retry later
if PYTEST:
raise
else:
try:
model_archive(file_group, session=session)
+ # Successfully modeled, mark as indexed
+ file_group.indexed = True
except InvalidArchive:
# It was not a real Archive. Many HTML files will not be an Archive.
file_group.indexed = False
invalid_archives.add(file_group.id)
+ except Exception as e:
+ # Some other error occurred during modeling - don't mark as indexed so we can retry
+ logger.error(f'Failed to model Archive for FileGroup {file_group.id}: {e}')
+ if PYTEST:
+ raise
session.commit()
@@ -267,11 +321,14 @@ async def archive_modeler():
@limit_concurrent(1)
def archive_cleanup():
with get_db_session(commit=True) as session:
- # Remove any Domains without any Archives.
- domain_ids = [i[0] for i in session.execute('SELECT DISTINCT domain_id FROM archive') if i[0]]
- for domain in session.query(Domain):
- if domain.id not in domain_ids:
- session.delete(domain)
+ # Remove any domain Collections without any Archives.
+ # Get all collection_ids that have archives
+ collection_ids = [i[0] for i in session.execute('SELECT DISTINCT collection_id FROM archive') if i[0]]
+ # Find domain collections that have no archives
+ for collection in session.query(Collection).filter_by(kind='domain'):
+ if collection.id not in collection_ids:
+ logger.info(f'Deleting empty domain collection: {collection.name}')
+ session.delete(collection)
def get_title(path):
diff --git a/modules/archive/api.py b/modules/archive/api.py
index 8ea6469bd..d80b61790 100644
--- a/modules/archive/api.py
+++ b/modules/archive/api.py
@@ -46,14 +46,6 @@ async def delete_archive(_: Request, archive_ids: str):
return response.empty()
-@archive_bp.get('/domains')
-@openapi.summary('Get a list of all Domains and their Archive statistics')
-@openapi.response(200, schema.GetDomainsResponse, "The list of domains")
-async def get_domains(_: Request):
- domains = lib.get_domains()
- return json_response({'domains': domains, 'totals': {'domains': len(domains)}})
-
-
archive_limit_limiter = api_param_limiter(100)
@@ -136,6 +128,49 @@ async def singlefile_upload_switch_handler(url=None):
singlefile_upload_switch_handler: ActivateSwitchMethod
+@register_switch_handler('generate_screenshot_switch_handler')
+async def generate_screenshot_switch_handler(archive_id=None):
+ """Used by `post_generate_screenshot` to generate screenshots in the background"""
+ q: multiprocessing.Queue = api_app.shared_ctx.archive_screenshots
+
+ trace_enabled = logger.isEnabledFor(TRACE_LEVEL)
+ if trace_enabled:
+ logger.trace(f'generate_screenshot_switch_handler called for archive_id={archive_id}')
+ try:
+ archive_id = q.get_nowait()
+ except queue.Empty:
+ if trace_enabled:
+ logger.trace(f'generate_screenshot_switch_handler called on empty queue')
+ return
+
+ try:
+ q_size = q.qsize()
+ except NotImplementedError:
+ # qsize() is not implemented on macOS
+ q_size = '?'
+ logger.info(f'generate_screenshot_switch_handler queue size: {q_size}')
+
+ try:
+ await lib.generate_archive_screenshot(archive_id)
+ # Always send success event since exceptions are raised on failure
+ from modules.archive import Archive
+ archive = lib.get_archive(archive_id=archive_id)
+ location = archive.location
+ name = archive.file_group.title or archive.file_group.url
+ logger.info(f'Generated screenshot for Archive ({q_size}): {archive_id}')
+ Events.send_screenshot_generated(f'Generated screenshot for: {name}', url=location)
+ except Exception as e:
+ logger.error(f'generate_screenshot_switch_handler failed for Archive {archive_id}', exc_info=e)
+ Events.send_screenshot_generation_failed(f'Failed to generate screenshot: {e}')
+ raise
+
+ # Call this function again so any new screenshot requests can be processed.
+ generate_screenshot_switch_handler.activate_switch()
+
+
+generate_screenshot_switch_handler: ActivateSwitchMethod
+
+
@archive_bp.post('/upload')
@openapi.definition(
summary='Upload SingleFile from SingleFile browser extension and convert it to an Archive.'
@@ -153,3 +188,28 @@ async def post_upload_singlefile(request: Request):
singlefile_upload_switch_handler.activate_switch(context=dict(url=url))
# Return empty json response because SingleFile extension expects a JSON response.
return json_response(dict(), status=HTTPStatus.OK)
+
+
+@archive_bp.post('//generate_screenshot')
+@openapi.description('Generate a screenshot for an Archive that does not have one')
+@openapi.response(HTTPStatus.OK, description='Screenshot generation queued')
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+@openapi.response(HTTPStatus.BAD_REQUEST, JSONErrorResponse)
+@wrol_mode_check
+async def post_generate_screenshot(_: Request, archive_id: int):
+ """Queue a screenshot generation request for an Archive."""
+ # Verify archive exists
+ try:
+ archive = lib.get_archive(archive_id=archive_id)
+ except Exception:
+ return json_response({'error': f'Archive {archive_id} not found'}, status=HTTPStatus.NOT_FOUND)
+
+ if not archive.singlefile_path:
+ return json_response({'error': 'Archive has no singlefile'}, status=HTTPStatus.BAD_REQUEST)
+
+ # Queue the screenshot generation request
+ logger.info(f'Queueing screenshot generation for Archive {archive_id}')
+ api_app.shared_ctx.archive_screenshots.put(archive_id)
+ generate_screenshot_switch_handler.activate_switch(context=dict(archive_id=archive_id))
+
+ return json_response({'message': 'Screenshot generation queued'}, status=HTTPStatus.OK)
diff --git a/modules/archive/conftest.py b/modules/archive/conftest.py
index 15667deeb..78e6394f5 100644
--- a/modules/archive/conftest.py
+++ b/modules/archive/conftest.py
@@ -8,7 +8,8 @@
import pytz
from modules.archive.lib import archive_strftime
-from modules.archive.models import Archive, Domain
+from modules.archive.models import Archive
+from wrolpi.collections import Collection
@pytest.fixture
@@ -58,21 +59,27 @@ def _(domain: str = None, url: str = None, title: str = 'NA', contents: str = No
})
if domain:
- domain = test_session.query(Domain).filter_by(domain=domain).one_or_none()
+ # Find or create domain collection
+ collection = test_session.query(Collection).filter_by(
+ name=domain,
+ kind='domain'
+ ).one_or_none()
+ if not collection:
+ collection = Collection(
+ name=domain,
+ kind='domain',
+ directory=None, # Domain collections are unrestricted
+ )
+ test_session.add(collection)
+ test_session.flush([collection])
+ else:
+ collection = None
screenshot_path = None
if screenshot:
screenshot_path = domain_dir / f'{timestamp}_{title}.png'
screenshot_path.write_bytes(image_bytes_factory())
- if not domain and domain_dir.name != 'NA':
- domain = Domain(
- domain=domain_dir.name,
- directory=domain_dir,
- )
- test_session.add(domain)
- test_session.flush([domain])
-
# Only add files that were created.
files = (readability_path, readability_json_path, readability_txt_path, screenshot_path, singlefile_path)
files = list(filter(None, files))
@@ -82,7 +89,7 @@ def _(domain: str = None, url: str = None, title: str = 'NA', contents: str = No
archive.title = title
archive.file_group.download_datetime = archive.file_group.published_datetime = next(now)
archive.file_group.modification_datetime = next(now)
- archive.domain = domain
+ archive.collection = collection
archive.validate()
for tag_name in tag_names:
diff --git a/modules/archive/lib.py b/modules/archive/lib.py
index 681fabaa1..9ea95393b 100644
--- a/modules/archive/lib.py
+++ b/modules/archive/lib.py
@@ -7,7 +7,7 @@
import re
import shlex
import tempfile
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from datetime import datetime
from json import JSONDecodeError
from typing import Optional, Tuple, List, Union
@@ -17,22 +17,190 @@
from sqlalchemy import asc
from sqlalchemy.orm import Session
-from modules.archive.models import Domain, Archive
+from modules.archive.models import Archive
from wrolpi import dates
from wrolpi.cmd import READABILITY_BIN, run_command
+from wrolpi.collections import Collection
from wrolpi.common import get_media_directory, logger, extract_domain, escape_file_name, aiohttp_post, \
- format_html_string, split_lines_by_length, get_html_soup, get_title_from_html, get_wrolpi_config, html_screenshot
+ format_html_string, split_lines_by_length, get_html_soup, get_title_from_html, get_wrolpi_config, html_screenshot, \
+ ConfigFile
from wrolpi.dates import now, Seconds
from wrolpi.db import get_db_session, get_db_curs, optional_session
from wrolpi.errors import UnknownArchive, InvalidOrderBy, InvalidDatetime
+from wrolpi.events import Events
+from wrolpi.switches import register_switch_handler, ActivateSwitchMethod
from wrolpi.tags import tag_append_sub_select_where
from wrolpi.vars import PYTEST, DOCKERIZED
logger = logger.getChild(__name__)
+__all__ = ['DomainsConfig', 'domains_config', 'get_domains_config', 'save_domains_config', 'import_domains_config']
+
ARCHIVE_SERVICE = 'http://archive:8080'
+@dataclass
+class DomainsConfigValidator:
+ """Validator for domains config file."""
+ version: int = 0
+ collections: List[dict] = field(default_factory=list)
+
+
+class DomainsConfig(ConfigFile):
+ """
+ Config file for Domain Collections.
+
+ This is a specialized config that manages domain collections
+ (Collections with kind='domain'). It maintains a domains.yaml file for
+ backward compatibility and user convenience.
+
+ Format:
+ collections:
+ - name: "example.com"
+ kind: "domain"
+ description: "Archives from example.com"
+ - name: "wikipedia.org"
+ kind: "domain"
+
+ Note: Domain collections can have optional directories for tagging support.
+ Items are managed dynamically when Archives are indexed.
+ """
+ file_name = 'domains.yaml'
+ validator = DomainsConfigValidator
+ default_config = dict(
+ version=0,
+ collections=[],
+ )
+ # Use wider width to accommodate longer paths
+ width = 120
+
+ def __getitem__(self, item):
+ return self._config[item]
+
+ def __setitem__(self, key, value):
+ self._config[key] = value
+
+ @property
+ def collections(self) -> List[dict]:
+ """Get list of collection configs."""
+ return self._config.get('collections', [])
+
+ def import_config(self, file: pathlib.Path = None, send_events=False):
+ """Import domain collections from config file into database."""
+ ConfigFile.import_config(self, file, send_events)
+
+ file_str = str(self.get_relative_file())
+ collections_data = self._config.get('collections', [])
+
+ if not collections_data:
+ logger.info(f'No domain collections to import from {file_str}')
+ self.successful_import = True
+ return
+
+ logger.info(f'Importing {len(collections_data)} domain collections from {file_str}')
+
+ try:
+ with get_db_session(commit=True) as session:
+ # Track imported domain names
+ imported_domains = set()
+
+ # Import each domain collection, forcing kind='domain'
+ for idx, collection_data in enumerate(collections_data):
+ try:
+ name = collection_data.get('name')
+ if not name:
+ logger.error(f'Domain collection at index {idx} has no name, skipping')
+ continue
+
+ # Ensure kind is 'domain'
+ collection_data = collection_data.copy()
+ collection_data['kind'] = 'domain'
+
+ # Warn if tag_name provided without directory
+ tag_name = collection_data.get('tag_name')
+ directory = collection_data.get('directory')
+ if tag_name and not directory:
+ logger.warning(
+ f"Domain collection '{name}' has tag_name '{tag_name}' "
+ f"but no directory - tags require a directory. Tag will be ignored."
+ )
+ collection_data.pop('tag_name', None)
+
+ # Use Collection.from_config to create/update
+ collection = Collection.from_config(collection_data, session)
+ imported_domains.add(collection.name)
+
+ except Exception as e:
+ logger.error(f'Failed to import domain collection at index {idx}', exc_info=e)
+ continue
+
+ # Delete domain collections that are no longer in config
+ all_domain_collections = session.query(Collection).filter_by(kind='domain').all()
+ for collection in all_domain_collections:
+ if collection.name not in imported_domains:
+ logger.info(f'Deleting domain collection {repr(collection.name)} (no longer in config)')
+ session.delete(collection)
+
+ logger.info(f'Successfully imported {len(imported_domains)} domain collections from {file_str}')
+ self.successful_import = True
+
+ except Exception as e:
+ self.successful_import = False
+ message = f'Failed to import {file_str} config!'
+ logger.error(message, exc_info=e)
+ if send_events:
+ Events.send_config_import_failed(message)
+ raise
+
+ def dump_config(self, file: pathlib.Path = None, send_events=False, overwrite=False):
+ """Dump all domain collections from database to config file."""
+ logger.info('Dumping domain collections to config')
+
+ with get_db_session() as session:
+ # Get only domain collections, ordered by name
+ collections = session.query(Collection).filter_by(kind='domain').order_by(Collection.name).all()
+
+ # Use to_config to export each collection
+ collections_data = [collection.to_config() for collection in collections]
+
+ self._config['collections'] = collections_data
+
+ logger.info(f'Dumping {len(collections_data)} domain collections to config')
+ self.save(file, send_events, overwrite)
+
+
+# Global instance
+domains_config = DomainsConfig()
+
+
+def get_domains_config() -> DomainsConfig:
+ """Get the global domains config instance."""
+ return domains_config
+
+
+# Switch handler for saving domains config
+@register_switch_handler('save_domains_config')
+def save_domains_config():
+ """Save the domains config when the switch is activated."""
+ domains_config.background_dump.activate_switch()
+
+
+# Explicit type for activate_switch helper
+save_domains_config: ActivateSwitchMethod
+
+
+def import_domains_config():
+ """Import domain collections from config file into database."""
+ logger.info('Importing domains config')
+ domains_config.import_config()
+
+ # Link downloads to domain collections after import
+ with get_db_session() as session:
+ link_domain_and_downloads(session)
+
+ logger.info('Importing domains config completed')
+
+
@dataclass
class ArchiveFiles:
"""Every Archive will have some of these files."""
@@ -74,6 +242,25 @@ def get_domain_directory(url: str) -> pathlib.Path:
return directory
+def get_archive_destination(domain_collection: 'Collection') -> pathlib.Path:
+ """
+ Get the destination directory for archives in a domain collection.
+
+ Args:
+ domain_collection: The domain Collection (kind='domain')
+
+ Returns:
+ Path where archives should be placed. If collection has a directory,
+ returns that directory. Otherwise returns the default archives// path.
+ """
+ if domain_collection.directory:
+ # Restricted collection - use collection directory
+ return get_media_directory() / domain_collection.directory
+ else:
+ # Unrestricted collection - use default archive// path
+ return get_archive_directory() / domain_collection.name
+
+
# File names include domain and datetime.
MAXIMUM_ARCHIVE_FILE_CHARACTER_LENGTH = 200
@@ -156,6 +343,42 @@ async def request_archive(url: str, singlefile: str = None) -> Tuple[str, Option
return singlefile, readability, screenshot
+async def request_screenshot(url: str, singlefile_path: pathlib.Path) -> Optional[bytes]:
+ """Send a request to the archive service to generate a screenshot from the singlefile."""
+ logger.info(f'Sending screenshot request to archive service: {url}')
+
+ # Read, compress, and encode the singlefile
+ singlefile_contents = singlefile_path.read_bytes()
+ singlefile_compressed = gzip.compress(singlefile_contents)
+ singlefile_b64 = base64.b64encode(singlefile_compressed).decode()
+
+ data = dict(url=url, singlefile=singlefile_b64)
+ try:
+ async with aiohttp_post(f'{ARCHIVE_SERVICE}/screenshot', json_=data, timeout=ARCHIVE_TIMEOUT) as response:
+ status = response.status
+ contents = await response.json()
+ if contents and (error := contents.get('error')):
+ # Report the error from the archive service.
+ raise Exception(f'Received error from archive service: {error}')
+
+ # Compressed base64
+ screenshot = contents.get('screenshot')
+ if not screenshot:
+ logger.warning(f'Failed to get screenshot for {url=}')
+ return None
+
+ logger.debug(f'screenshot request status code {status}')
+ except Exception as e:
+ logger.error('Error when requesting screenshot', exc_info=e)
+ raise
+
+ # Decode and decompress.
+ screenshot = base64.b64decode(screenshot)
+ screenshot = gzip.decompress(screenshot)
+
+ return screenshot
+
+
async def model_archive_result(url: str, singlefile: str, readability: dict, screenshot: bytes) -> Archive:
"""
Convert results from ArchiveDownloader into real files. Create Archive record.
@@ -212,24 +435,223 @@ async def model_archive_result(url: str, singlefile: str, readability: dict, scr
archive = Archive.from_paths(session, *paths)
archive.file_group.download_datetime = now()
archive.url = url
- archive.domain = get_or_create_domain(session, url)
+ archive.collection = get_or_create_domain_collection(session, url)
archive.flush()
- archive.domain.flush()
return archive
-def get_or_create_domain(session: Session, url) -> Domain:
+def detect_domain_directory(collection: Collection, session: Session) -> Optional[pathlib.Path]:
"""
- Get/create the Domain for this archive.
+ Detect if all archives for a domain collection share a common directory.
+
+ Args:
+ collection: The domain collection to analyze
+ session: Database session
+
+ Returns:
+ Path (relative to media directory) if all archives share a common directory within archive media directory.
+ None if archives are scattered across different directories or if collection has no archives.
"""
- domain_ = extract_domain(url)
- domain = session.query(Domain).filter_by(domain=domain_).one_or_none()
- if not domain:
- domain = Domain(domain=domain_, directory=str(get_domain_directory(url)))
- session.add(domain)
+ if collection.kind != 'domain':
+ return None
+
+ # Query all archives for this domain collection
+ archives = session.query(Archive).filter_by(collection_id=collection.id).all()
+
+ if not archives:
+ # No archives yet, can't determine directory
+ return None
+
+ # Get all archive file paths
+ paths = []
+ for archive in archives:
+ if archive.file_group and archive.file_group.primary_path:
+ paths.append(pathlib.Path(archive.file_group.primary_path))
+
+ if not paths:
+ return None
+
+ # Find common ancestor directory
+ # Start with the first path's parent directory
+ common_dir = paths[0].parent
+
+ # Check if all other paths are under this directory
+ for path in paths[1:]:
+ try:
+ # Check if path is relative to common_dir
+ path.relative_to(common_dir)
+ except ValueError:
+ # Path is not under common_dir, find the common ancestor
+ # Walk up until we find a common parent
+ while common_dir != common_dir.parent: # Stop at root
+ try:
+ path.relative_to(common_dir)
+ break # Found common ancestor
+ except ValueError:
+ common_dir = common_dir.parent
+
+ # Check if common directory is within the media directory
+ media_dir = get_media_directory()
+ try:
+ relative_path = common_dir.relative_to(media_dir)
+ except ValueError:
+ # Common directory is outside media directory
+ return None
+
+ # Check if it's within the archive directory structure
+ archive_base = get_archive_directory()
+ try:
+ archive_base.relative_to(media_dir) # Verify archive_base is under media_dir
+ common_dir.relative_to(archive_base.parent) # Verify common_dir is under archive structure
+ except ValueError:
+ # Not in the archive directory structure
+ return None
+
+ logger.debug(f'Detected directory for domain {collection.name}: {relative_path}')
+ return relative_path
+
+
+def update_domain_directories(session: Session = None) -> int:
+ """
+ One-time update to detect and set directories for existing domain collections.
+
+ This should be run once to fix domain collections that were created without directories
+ but have all their archives in a common location.
+
+ Args:
+ session: Database session
+
+ Returns:
+ Number of domain collections updated
+ """
+ if session is None:
+ with get_db_session(commit=True) as session:
+ return update_domain_directories(session)
+
+ # Find all domain collections without directories
+ collections = session.query(Collection).filter_by(kind='domain', directory=None).all()
+
+ updated_count = 0
+ for collection in collections:
+ detected_dir = detect_domain_directory(collection, session)
+ if detected_dir:
+ collection.directory = detected_dir
+ session.flush([collection])
+ updated_count += 1
+ logger.info(f'Updated domain {collection.name} with directory: {detected_dir}')
+
+ session.commit()
+ logger.info(f'Updated {updated_count} domain collection(s) with auto-detected directories')
+ return updated_count
+
+
+def get_or_create_domain_collection(session: Session, url, directory: pathlib.Path = None) -> Collection:
+ """
+ Get or create the domain Collection for this archive.
+
+ Args:
+ session: Database session
+ url: URL of the archive
+ directory: Optional directory to restrict this domain collection.
+ If None, creates unrestricted collection (default).
+ If provided, archives will be placed in this directory and collection can be tagged.
+
+ Returns:
+ Collection with kind='domain' for the domain extracted from the URL
+ """
+ domain_name = extract_domain(url)
+
+ # Try to find existing domain collection
+ collection = session.query(Collection).filter_by(
+ name=domain_name,
+ kind='domain'
+ ).one_or_none()
+
+ if not collection:
+ # Create new domain collection
+ collection = Collection(
+ name=domain_name,
+ kind='domain',
+ directory=directory, # Can be None (unrestricted) or a Path (restricted)
+ )
+ session.add(collection)
session.flush()
- return domain
+ # Trigger domain config save for new domain
+ save_domains_config.activate_switch()
+ if directory:
+ logger.info(f'Created domain collection with directory: {domain_name} -> {directory}')
+ else:
+ logger.info(f'Created unrestricted domain collection: {domain_name}')
+
+ # Auto-detect directory if not explicitly set and collection doesn't have one
+ if not directory and not collection.directory:
+ detected_dir = detect_domain_directory(collection, session)
+ if detected_dir:
+ collection.directory = detected_dir
+ session.flush()
+ # Trigger domain config save when directory is auto-detected
+ save_domains_config.activate_switch()
+ logger.info(f'Auto-detected directory for domain {domain_name}: {detected_dir}')
+
+ return collection
+
+
+def link_domain_and_downloads(session: Session):
+ """Associate any Download related to a Domain Collection.
+
+ Downloads are linked to Collections (via collection_id).
+ This function finds Domain Collections and links their Downloads.
+
+ Matching criteria:
+ - Recurring downloads whose destination is within the domain collection's directory (including subdirectories)
+ - RSS downloads with sub_downloader='archive' (matched by URL domain)
+ """
+ from wrolpi.downloader import Download
+ from wrolpi.collections.models import Collection
+
+ # Only Downloads with a frequency can be a Collection Download.
+ downloads = list(session.query(Download).filter(Download.frequency.isnot(None)).all())
+
+ # Get domain collections that have a directory
+ domain_collections = session.query(Collection).filter(
+ Collection.kind == 'domain',
+ Collection.directory.isnot(None)
+ ).all()
+
+ need_commit = False
+
+ # Match downloads by destination directory (including subdirectories)
+ downloads_with_destination = [d for d in downloads if (d.settings or {}).get('destination')]
+ for collection in domain_collections:
+ # Ensure directory ends with / for proper prefix matching
+ directory = str(collection.directory)
+ directory_prefix = directory if directory.endswith('/') else directory + '/'
+ for download in downloads_with_destination:
+ dest = download.settings['destination']
+ # Match if destination equals directory or is a subdirectory
+ if not download.collection_id and (dest == directory or dest.startswith(directory_prefix)):
+ download.collection_id = collection.id
+ need_commit = True
+
+ # Match RSS downloads with archive sub_downloader by URL domain
+ rss_archive_downloads = [
+ d for d in downloads
+ if d.downloader == 'rss' and d.sub_downloader == 'archive' and not d.collection_id
+ ]
+ for download in rss_archive_downloads:
+ # Extract domain from the RSS URL and find matching domain collection
+ domain_name = extract_domain(download.url)
+ collection = session.query(Collection).filter_by(
+ name=domain_name,
+ kind='domain'
+ ).one_or_none()
+ if collection:
+ download.collection_id = collection.id
+ need_commit = True
+
+ if need_commit:
+ session.commit()
SINGLEFILE_URL_EXTRACTOR = re.compile(r'Page saved with SingleFile \s+url: (http.+?)\n')
@@ -427,18 +849,86 @@ def is_singlefile_file(path: pathlib.Path) -> bool:
def get_domains():
+ """
+ Get all domain collections with their archive statistics.
+
+ This is a thin wrapper around Collection queries that adds archive-specific statistics.
+ Returns a list of dicts with collection id, domain name, url_count, and total size.
+ """
+ from sqlalchemy import func, BigInteger
+ from wrolpi.files.models import FileGroup
+
+ with get_db_session() as session:
+ # Query all domain collections with archive statistics in a single query
+ # This uses ORM for better maintainability while keeping performance
+ query = (
+ session.query(
+ Collection.id,
+ Collection.name.label('domain'),
+ func.count(Archive.id).label('url_count'),
+ func.sum(FileGroup.size).cast(BigInteger).label('size')
+ )
+ .outerjoin(Archive, Collection.id == Archive.collection_id)
+ .outerjoin(FileGroup, FileGroup.id == Archive.file_group_id)
+ .filter(Collection.kind == 'domain')
+ .group_by(Collection.id, Collection.name)
+ .order_by(Collection.name)
+ )
+
+ domains = [
+ {
+ 'id': row.id,
+ 'domain': row.domain,
+ 'url_count': row.url_count or 0,
+ 'size': row.size or 0,
+ }
+ for row in query.all()
+ ]
+
+ return domains
+
+
+def get_domain(domain_id: int) -> dict:
+ """
+ Get a single domain collection by ID with its archive statistics.
+
+ Returns a dict with collection details including id, domain name, url_count, size,
+ tag_name, directory (relative to media directory), and description.
+
+ Raises UnknownArchive if domain not found.
+ """
+ # Get the collection using find_by_id
+ try:
+ collection = Collection.find_by_id(domain_id)
+ except Exception:
+ # Collection.find_by_id raises UnknownCollection, but we want UnknownArchive
+ raise UnknownArchive(f"Domain collection with ID {domain_id} not found")
+
+ if collection.kind != 'domain':
+ raise UnknownArchive(f"Collection {domain_id} is not a domain")
+
+ # Get base domain data from __json__()
+ domain_data = collection.__json__()
+
+ # Get archive statistics for this domain
with get_db_curs() as curs:
stmt = '''
- SELECT domains.domain AS domain, COUNT(a.id) AS url_count, SUM(fg.size)::BIGINT AS size
- FROM domains
- LEFT JOIN archive a on domains.id = a.domain_id
+ SELECT COUNT(a.id) AS url_count, SUM(fg.size)::BIGINT AS size
+ FROM archive a
LEFT JOIN file_group fg on fg.id = a.file_group_id
- GROUP BY domains.domain
- ORDER BY domains.domain \
+ WHERE a.collection_id = %(domain_id)s
'''
- curs.execute(stmt)
- domains = [dict(i) for i in curs.fetchall()]
- return domains
+ curs.execute(stmt, {'domain_id': domain_id})
+ stats = dict(curs.fetchone())
+
+ # Enhance with archive stats
+ domain_data['url_count'] = stats['url_count'] or 0
+ domain_data['size'] = stats['size'] or 0
+
+ # Use 'domain' key for domain collections instead of 'name'
+ domain_data['domain'] = domain_data.pop('name')
+
+ return domain_data
ARCHIVE_ORDERS = {
@@ -520,7 +1010,9 @@ def search_archives(search_str: str, domain: str, limit: int, offset: int, order
if domain:
params['domain'] = domain
- wheres.append('a.domain_id = (select id from domains where domains.domain = %(domain)s)')
+ # Use LIMIT 1 to handle potential duplicate domain collections with the same name
+ wheres.append(
+ "a.collection_id = (select id from collection where collection.name = %(domain)s and collection.kind = 'domain' LIMIT 1)")
select_columns = f", {select_columns}" if select_columns else ""
wheres = '\n AND '.join(wheres)
@@ -546,13 +1038,38 @@ def search_archives(search_str: str, domain: str, limit: int, offset: int, order
@optional_session
-async def search_domains_by_name(name: str, limit: int = 5, session: Session = None) -> List[Domain]:
- domains = session.query(Domain) \
- .filter(Domain.domain.ilike(f'%{name}%')) \
- .order_by(asc(Domain.domain)) \
+async def search_domains_by_name(name: str, limit: int = 5, session: Session = None) -> List[dict]:
+ """
+ Search for domain collections by name.
+
+ Args:
+ name: Search string to match against collection names
+ limit: Maximum number of results to return
+ session: Database session
+
+ Returns:
+ List of domain dicts matching the search (in old Domain format for backward compatibility)
+ """
+ collections = session.query(Collection) \
+ .filter(Collection.kind == 'domain') \
+ .filter(Collection.name.ilike(f'%{name}%')) \
+ .order_by(asc(Collection.name)) \
.limit(limit) \
.all()
- return domains
+
+ # Convert to old Domain format for backward compatibility
+ from wrolpi.common import get_relative_to_media_directory
+ archive_dir = get_archive_directory()
+ return [
+ {
+ 'id': c.id,
+ 'domain': c.name,
+ # Domain collections can have explicit directories or use default archive path
+ 'directory': get_relative_to_media_directory(c.directory) if c.directory else str(
+ (archive_dir / c.name).relative_to(get_media_directory())),
+ }
+ for c in collections
+ ]
async def html_to_readability(html: str | bytes, url: str, timeout: int = 120):
@@ -625,3 +1142,74 @@ async def singlefile_to_archive(singlefile: bytes) -> Archive:
logger.trace(f'singlefile_to_archive modeling: {url}')
archive: Archive = await model_archive_result(url, singlefile, readability, screenshot)
return archive
+
+
+async def generate_archive_screenshot(archive_id: int) -> pathlib.Path:
+ """
+ Generate a screenshot for an existing Archive that doesn't have one.
+ If the Archive already has a screenshot, verify it exists and ensure it's tracked in the FileGroup.
+
+ Returns the path to the generated screenshot.
+
+ Raises:
+ ValueError: If Archive has no singlefile
+ RuntimeError: If screenshot generation fails
+ """
+ from wrolpi.db import get_db_session
+
+ with get_db_session() as session:
+ archive = Archive.find_by_id(archive_id, session=session)
+
+ if not archive.singlefile_path:
+ raise ValueError(f'Cannot generate screenshot for Archive {archive_id}: no singlefile')
+
+ # Check if screenshot already exists
+ if archive.screenshot_path:
+ # Verify the screenshot file actually exists on disk
+ if archive.screenshot_path.is_file():
+ logger.info(f'Archive {archive_id} already has a screenshot, ensuring it is tracked')
+ # Ensure the screenshot is tracked in FileGroup.files and FileGroup.data
+ with get_db_session(commit=True) as tracking_session:
+ archive = Archive.find_by_id(archive_id, session=tracking_session)
+ file_group = archive.file_group
+ # append_files uses unique_by_predicate, so this is safe even if already tracked
+ file_group.append_files(archive.screenshot_path)
+
+ # Also update FileGroup.data (same pattern as set_screenshot)
+ data = dict(file_group.data) if file_group.data else {}
+ data['screenshot_path'] = str(archive.screenshot_path)
+ file_group.data = data
+
+ archive.validate()
+ tracking_session.flush()
+ return archive.screenshot_path
+ else:
+ logger.warning(f'Archive {archive_id} has screenshot_path but file does not exist, regenerating')
+
+ singlefile_path = archive.singlefile_path
+ url = archive.file_group.url
+
+ # Request screenshot from Archive docker service or generate locally
+ if DOCKERIZED:
+ logger.debug(f'Requesting screenshot from archive service for Archive {archive_id}')
+ screenshot_bytes = await request_screenshot(url, singlefile_path)
+ else:
+ logger.debug(f'Generating screenshot locally for Archive {archive_id}')
+ singlefile_contents = singlefile_path.read_bytes()
+ screenshot_bytes = html_screenshot(singlefile_contents)
+
+ if not screenshot_bytes:
+ raise RuntimeError(f'Failed to generate screenshot for Archive {archive_id}')
+
+ # Save screenshot next to singlefile with same naming pattern
+ screenshot_path = singlefile_path.with_suffix('.png')
+ screenshot_path.write_bytes(screenshot_bytes)
+ logger.info(f'Generated screenshot for Archive {archive_id}: {screenshot_path}')
+
+ # Update the Archive to include the new screenshot file
+ with get_db_session(commit=True) as session:
+ archive = Archive.find_by_id(archive_id, session=session)
+ archive.set_screenshot(screenshot_path)
+ session.flush()
+
+ return screenshot_path
diff --git a/modules/archive/models.py b/modules/archive/models.py
index 4f8fe647d..132289089 100644
--- a/modules/archive/models.py
+++ b/modules/archive/models.py
@@ -4,24 +4,23 @@
from typing import Iterable, List, Optional
import pytz
-from sqlalchemy import Column, Integer, String, ForeignKey, BigInteger
-from sqlalchemy.orm import relationship, Session, validates
-from sqlalchemy.orm.collections import InstrumentedList
+from sqlalchemy import Column, Integer, ForeignKey, BigInteger
+from sqlalchemy.orm import relationship, Session
from wrolpi import dates
+from wrolpi.collections import Collection
from wrolpi.common import ModelHelper, Base, logger, get_title_from_html, get_wrolpi_config, get_media_directory
from wrolpi.dates import now
from wrolpi.db import optional_session
from wrolpi.errors import UnknownArchive
from wrolpi.files.models import FileGroup
-from wrolpi.media_path import MediaPathType
from wrolpi.tags import TagFile
from wrolpi.vars import PYTEST
from .errors import InvalidArchive
logger = logger.getChild(__name__)
-__all__ = ['Archive', 'Domain']
+__all__ = ['Archive']
MATCH_URL = re.compile(r'^\s+?url:\s+?(http.*)', re.MULTILINE)
MATCH_DATE = re.compile(r'^\s+?saved date:\s+?(.*)', re.MULTILINE)
@@ -31,17 +30,33 @@ class Archive(Base, ModelHelper):
__tablename__ = 'archive'
id = Column(Integer, primary_key=True)
- domain_id = Column(Integer, ForeignKey('domains.id'))
- domain = relationship('Domain', primaryjoin='Archive.domain_id==Domain.id')
+ collection_id = Column(Integer, ForeignKey('collection.id', ondelete='CASCADE'))
+ collection: Collection = relationship('Collection', primaryjoin='Archive.collection_id==Collection.id')
file_group_id = Column(BigInteger, ForeignKey('file_group.id', ondelete='CASCADE'), unique=True, nullable=False)
file_group: FileGroup = relationship('FileGroup')
def __repr__(self):
- if self.domain:
+ domain_name = self.domain if self.collection else None
+ if domain_name:
return f''
+ f'domain={domain_name}>'
return f''
+ @property
+ def domain(self) -> str | None:
+ """
+ Get the domain name for this Archive.
+
+ This property provides backward compatibility by returning the Collection's name
+ (which is the domain string for domain collections).
+
+ Returns:
+ The domain string (e.g., "example.com") or None if no collection
+ """
+ if self.collection and self.collection.kind == 'domain':
+ return self.collection.name
+ return None
+
@staticmethod
@optional_session
def get_by_id(id_: int, session: Session = None) -> Optional['Archive']:
@@ -116,14 +131,14 @@ def delete(self):
self.file_group.delete()
session = Session.object_session(self)
+ collection = self.collection
session.delete(self)
- if self.domain:
- # Delete a domain if it has no Archives.
- try:
- next(i.id for i in self.domain.archives)
- except StopIteration:
- self.domain.delete()
+ if collection and collection.kind == 'domain':
+ # Delete a domain collection if it has no Archives.
+ remaining_archives = session.query(Archive).filter_by(collection_id=collection.id).count()
+ if remaining_archives == 0:
+ session.delete(collection)
@property
def history(self) -> Iterable[Base]:
@@ -187,7 +202,7 @@ def readability_json_path(self) -> Optional[pathlib.Path]:
return readability_json_file['path']
@property
- def screenshot_file(self) -> Optional[dict]:
+ def screenshot_file(self) -> dict | None:
files = self.file_group.my_files('image/')
for file in files:
return file
@@ -197,6 +212,35 @@ def screenshot_path(self) -> Optional[pathlib.Path]:
if screenshot_file := self.screenshot_file:
return screenshot_file['path']
+ def set_screenshot(self, screenshot_path: pathlib.Path) -> None:
+ """Set the screenshot for this Archive.
+
+ Args:
+ screenshot_path: Path to the screenshot file
+
+ Raises:
+ ValueError: If Archive already has a screenshot or path doesn't exist
+ """
+ # Check if screenshot already exists
+ if self.screenshot_path:
+ raise ValueError(f'Archive {self.id} already has a screenshot at {self.screenshot_path}')
+
+ # Validate path exists
+ if not screenshot_path.is_file():
+ raise ValueError(f'Screenshot path does not exist: {screenshot_path}')
+
+ # Add to FileGroup.files
+ self.file_group.append_files(screenshot_path)
+
+ # Update FileGroup.data for quick lookup (similar to ebook cover_path pattern)
+ # Create new dict and reassign (like append_files does) to ensure SQLAlchemy detects the change
+ data = dict(self.file_group.data) if self.file_group.data else {}
+ data['screenshot_path'] = str(screenshot_path)
+ self.file_group.data = data
+
+ # Validate the archive
+ self.validate()
+
@property
def info_json_file(self) -> Optional[dict]:
files = self.file_group.my_files()
@@ -270,16 +314,16 @@ def apply_singlefile_data(self):
logger.error(f'Could not get archive date from singlefile {path}', exc_info=e)
def apply_domain(self):
- """Get the domain from the URL."""
- from modules.archive.lib import get_or_create_domain
- domain = None
+ """Get the domain collection from the URL."""
+ from modules.archive.lib import get_or_create_domain_collection
+ collection = None
if self.file_group.url:
session = Session.object_session(self)
if not session:
raise ValueError('No session found!')
- domain = get_or_create_domain(session, self.file_group.url)
- # Clear domain if the URL is missing.
- self.domain_id = domain.id if domain else None
+ collection = get_or_create_domain_collection(session, self.file_group.url)
+ # Clear collection if the URL is missing.
+ self.collection_id = collection.id if collection else None
def apply_singlefile_title(self):
"""Get the title from the Singlefile, if it's missing."""
@@ -334,39 +378,23 @@ def location(self):
"""The location where this Archive can be viewed in the UI."""
return f'/archive/{self.id}'
-
-class Domain(Base, ModelHelper):
- __tablename__ = 'domains' # plural to avoid conflict
- id = Column(Integer, primary_key=True)
-
- domain = Column(String, nullable=False)
- directory = Column(MediaPathType)
-
- archives: InstrumentedList = relationship('Archive', primaryjoin='Archive.domain_id==Domain.id')
-
- def __repr__(self):
- return f''
-
- def delete(self):
- session = Session.object_session(self)
- session.execute('DELETE FROM archive WHERE domain_id=:id', dict(id=self.id))
- session.query(Domain).filter_by(id=self.id).delete()
-
- @validates('domain')
- def validate_domain(self, key, value: str):
- if not isinstance(value, str):
- raise ValueError('Domain must be a string')
- if len(value.split('.')) < 2:
- raise ValueError(f'Domain must contain at least one "." domain={repr(value)}')
- return value
-
@property
def download_directory(self) -> pathlib.Path:
+ """
+ Get the download directory for this Archive based on its domain and date.
+
+ This uses the configured archive_destination template with variables:
+ - domain: the domain name from the Collection
+ - year, month, day: current date
+
+ Returns:
+ Absolute path where new archives for this domain should be downloaded
+ """
archive_destination = get_wrolpi_config().archive_destination
now_ = now()
variables = dict(
- domain=self.domain,
+ domain=self.domain or 'unknown',
year=now_.year,
month=now_.month,
day=now_.day,
diff --git a/modules/archive/schema.py b/modules/archive/schema.py
index 07a5188db..f3cfc728c 100644
--- a/modules/archive/schema.py
+++ b/modules/archive/schema.py
@@ -15,18 +15,6 @@ class ArchiveDict:
title: str
-@dataclass
-class DomainDict:
- domain: str
- url_count: int
- size: int
-
-
-@dataclass
-class GetDomainsResponse:
- domains: List[DomainDict]
-
-
@dataclass
class ArchiveSearchRequest:
search_str: Optional[str] = None
diff --git a/modules/archive/test/test_api.py b/modules/archive/test/test_api.py
index 3fdb05415..dc233e55a 100644
--- a/modules/archive/test/test_api.py
+++ b/modules/archive/test/test_api.py
@@ -1,17 +1,16 @@
import json
from http import HTTPStatus
+from unittest.mock import patch
import pytest
from modules.archive import lib, Archive
-from wrolpi.files.models import FileGroup
-from wrolpi.test.common import skip_circleci, skip_macos
-
from wrolpi.common import get_relative_to_media_directory
+from wrolpi.test.common import skip_circleci, skip_macos
-def check_results(test_client, data, ids):
- request, response = test_client.post('/api/archive/search', content=json.dumps(data))
+async def check_results(async_client, data, ids):
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(data))
if not response.status_code == HTTPStatus.OK:
raise AssertionError(str(response.json))
@@ -22,25 +21,27 @@ def check_results(test_client, data, ids):
f'{response.json["file_groups"][0]["id"]}..{response.json["file_groups"][-1]["id"]}'
-def test_archives_search_order(test_session, archive_directory, archive_factory, test_client):
+@pytest.mark.asyncio
+async def test_archives_search_order(test_session, archive_directory, archive_factory, async_client):
"""Search using all orders."""
archive_factory('example.com', 'https://example.com/one', 'my archive', 'foo bar qux')
for order_by in lib.ARCHIVE_ORDERS:
data = {'search_str': 'foo', 'order_by': order_by}
- request, response = test_client.post('/api/archive/search', content=json.dumps(data))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(data))
if not response.status_code == HTTPStatus.OK:
raise AssertionError(str(response.json))
data = {'search_str': None, 'order_by': order_by}
- request, response = test_client.post('/api/archive/search', content=json.dumps(data))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(data))
if not response.status_code == HTTPStatus.OK:
raise AssertionError(str(response.json))
-def test_archives_search(test_session, archive_directory, archive_factory, test_client):
+@pytest.mark.asyncio
+async def test_archives_search(test_session, archive_directory, archive_factory, async_client):
"""Archives can be searched by their title and their contents."""
# Search with no archives.
- check_results(test_client, {'search_str': 'foo'}, [])
+ await check_results(async_client, {'search_str': 'foo'}, [])
archive_factory('example.com', 'https://example.com/one', 'my archive', 'foo bar qux')
archive_factory('example.com', 'https://example.com/one', 'other archive', 'foo baz qux qux')
@@ -50,38 +51,38 @@ def test_archives_search(test_session, archive_directory, archive_factory, test_
# 1 and 2 contain "foo".
data = {'search_str': 'foo'}
- check_results(test_client, data, [2, 1])
+ await check_results(async_client, data, [2, 1])
# 2 and 3 contain "baz".
data = {'search_str': 'baz'}
- check_results(test_client, data, [3, 2])
+ await check_results(async_client, data, [3, 2])
# 1 contains "bar".
data = {'search_str': 'bar'}
- check_results(test_client, data, [1, ])
+ await check_results(async_client, data, [1, ])
# No archives contain "huzzah"
data = {'search_str': 'huzzah'}
- check_results(test_client, data, [])
+ await check_results(async_client, data, [])
# Only 3 contains "baz" and is in domain "example.org"
data = {'search_str': 'baz', 'domain': 'example.org'}
- check_results(test_client, data, [3, ])
+ await check_results(async_client, data, [3, ])
# 1's title contains "my", this is ignored by Postgres.
data = {'search_str': 'my'}
- check_results(test_client, data, [])
+ await check_results(async_client, data, [])
# 3's title contains "third".
data = {'search_str': 'third'}
- check_results(test_client, data, [3, ])
+ await check_results(async_client, data, [3, ])
# All contents contain "qux", but they contain different amounts. They are ordered by the amount.
data = {'search_str': 'qux'}
- check_results(test_client, data, [3, 2, 1])
+ await check_results(async_client, data, [3, 2, 1])
data = {'search_str': 'qux', 'order_by': 'bad order_by'}
- request, response = test_client.post('/api/archive/search', content=json.dumps(data))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(data))
assert response.status_code == HTTPStatus.BAD_REQUEST
@@ -105,7 +106,8 @@ async def test_search_archive_tags(test_session, async_client, archive_factory,
assert response.json['totals']['file_groups'] == 0
-def test_archives_search_headline(test_session, archive_directory, archive_factory, test_client):
+@pytest.mark.asyncio
+async def test_archives_search_headline(test_session, archive_directory, archive_factory, async_client):
"""Headlines can be requested."""
archive_factory('example.com', 'https://example.com/one', 'my archive', 'foo bar qux')
archive_factory('example.com', 'https://example.com/one', 'other archive', 'foo baz qux qux')
@@ -114,7 +116,7 @@ def test_archives_search_headline(test_session, archive_directory, archive_facto
test_session.commit()
content = dict(search_str='foo')
- request, response = test_client.post('/api/archive/search', content=json.dumps(content))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(content))
assert response.status_code == HTTPStatus.OK
# Headlines are only fetched if requested.
@@ -122,7 +124,7 @@ def test_archives_search_headline(test_session, archive_directory, archive_facto
assert response.json['file_groups'][1]['d_headline'] is None
content = dict(search_str='foo', headline=True)
- request, response = test_client.post('/api/archive/search', content=json.dumps(content))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(content))
assert response.status_code == HTTPStatus.OK
# Postgresql uses ... to highlight matching words.
@@ -130,25 +132,27 @@ def test_archives_search_headline(test_session, archive_directory, archive_facto
assert response.json['file_groups'][1]['d_headline'] == 'foo bar qux'
-def test_search_offset(test_session, archive_factory, test_client):
+@pytest.mark.asyncio
+async def test_search_offset(test_session, archive_factory, async_client):
"""Archive search can be offset."""
for i in range(500):
archive_factory('example.com', f'https://example.com/{i}', contents='foo bar')
test_session.commit()
data = {'search_str': None, 'offset': 0}
- check_results(test_client, data, list(range(500, 480, -1)))
+ await check_results(async_client, data, list(range(500, 480, -1)))
data = {'search_str': None, 'offset': 20}
- check_results(test_client, data, list(range(480, 460, -1)))
+ await check_results(async_client, data, list(range(480, 460, -1)))
data = {'search_str': None, 'offset': 100}
- check_results(test_client, data, list(range(400, 380, -1)))
+ await check_results(async_client, data, list(range(400, 380, -1)))
data = {'search_str': None, 'offset': 200}
- check_results(test_client, data, list(range(300, 280, -1)))
+ await check_results(async_client, data, list(range(300, 280, -1)))
data = {'search_str': None, 'offset': 500}
- check_results(test_client, data, [])
+ await check_results(async_client, data, [])
-def test_archives_search_no_query(test_session, archive_factory, test_client):
+@pytest.mark.asyncio
+async def test_archives_search_no_query(test_session, archive_factory, async_client):
"""Archive Search API endpoint does not require data in the body."""
# Add 100 random archives.
for _ in range(100):
@@ -156,83 +160,26 @@ def test_archives_search_no_query(test_session, archive_factory, test_client):
test_session.commit()
# All archives are returned when no `search_str` is passed.
- request, response = test_client.post('/api/archive/search', content='{}')
+ request, response = await async_client.post('/api/archive/search', content='{}')
assert response.status_code == HTTPStatus.OK, response.json
assert [i['id'] for i in response.json['file_groups']] == list(range(100, 80, -1))
assert response.json['totals']['file_groups'] == 100
# All archives are from "example.com".
data = dict(domain='example.com')
- request, response = test_client.post('/api/archive/search', content=json.dumps(data))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(data))
assert response.status_code == HTTPStatus.OK, response.json
assert [i['id'] for i in response.json['file_groups']] == list(range(100, 80, -1))
assert response.json['totals']['file_groups'] == 100
# No archives are from "example.org".
data = dict(domain='example.org')
- request, response = test_client.post('/api/archive/search', content=json.dumps(data))
+ request, response = await async_client.post('/api/archive/search', content=json.dumps(data))
assert response.status_code == HTTPStatus.OK, response.json
assert [i['id'] for i in response.json['file_groups']] == []
assert response.json['totals']['file_groups'] == 0
-def test_archive_and_domain_crud(test_session, test_client, archive_factory):
- """Getting an Archive returns it's File. Testing deleting Archives."""
- # Can get empty results.
- request, response = test_client.get(f'/api/archive/1')
- assert response.status_code == HTTPStatus.NOT_FOUND
- request, response = test_client.get(f'/api/archive/domains')
- assert response.status_code == HTTPStatus.OK
- assert response.json['domains'] == []
-
- archive1 = archive_factory(domain='example.com', url='https://example.com/1')
- archive2 = archive_factory(domain='example.com', url='https://example.com/1')
- test_session.commit()
-
- # Archive1 has Archive2 as history.
- request, response = test_client.get(f'/api/archive/{archive1.id}')
- assert response.status_code == HTTPStatus.OK
- assert response.json['file_group']['id'] == archive1.id
- assert response.json['history'][0]['id'] == archive2.id
-
- # Archive2 has Archive1 as history.
- request, response = test_client.get(f'/api/archive/{archive2.id}')
- assert response.status_code == HTTPStatus.OK
- assert archive2.id == response.json['file_group']['id']
- assert response.json['history'][0]['id'] == archive1.id
-
- # Only one domain.
- request, response = test_client.get(f'/api/archive/domains')
- assert response.status_code == HTTPStatus.OK
- assert response.json['domains'][0]['domain'] == 'example.com'
- assert response.json['domains'][0]['size'] == 254
-
- # Deleting works.
- request, response = test_client.delete(f'/api/archive/{archive1.id}')
- assert response.status_code == HTTPStatus.NO_CONTENT
-
- # Trying to delete again returns NOT_FOUND.
- request, response = test_client.delete(f'/api/archive/{archive1.id}')
- assert response.status_code == HTTPStatus.NOT_FOUND
-
- # Can't get deleted Archive.
- request, response = test_client.get(f'/api/archive/{archive1.id}')
- assert response.status_code == HTTPStatus.NOT_FOUND
-
- # Archive2 no longer has history.
- request, response = test_client.get(f'/api/archive/{archive2.id}')
- assert response.status_code == HTTPStatus.OK
- assert response.json['file_group']['id'] == archive2.id
- assert response.json['history'] == []
-
- # No Archives, no Domains.
- request, response = test_client.delete(f'/api/archive/{archive2.id}')
- assert response.status_code == HTTPStatus.NO_CONTENT
- request, response = test_client.get(f'/api/archive/domains')
- assert response.status_code == HTTPStatus.OK
- assert response.json['domains'] == []
-
-
@skip_macos
@skip_circleci
@pytest.mark.asyncio
@@ -355,7 +302,8 @@ async def test_archive_upload_file_tracking(test_session, async_client, archive_
# Assert screenshot properties exist
assert archive.screenshot_path is not None, 'Archive.screenshot_path should exist'
assert archive.screenshot_file is not None, 'Archive.screenshot_file should exist'
- assert archive.file_group.data.get('screenshot_path') == archive.screenshot_path, 'Archive screenshot should be in FileGroup.data'
+ assert archive.file_group.data.get(
+ 'screenshot_path') == archive.screenshot_path, 'Archive screenshot should be in FileGroup.data'
# Assert image is in FileGroup.files
file_paths = [f['path'] for f in archive.file_group.files]
@@ -364,3 +312,97 @@ async def test_archive_upload_file_tracking(test_session, async_client, archive_
# Assert FileGroup has correct number of files
assert len(archive.file_group.files) == file_count_before_image + 1, \
f'FileGroup should have {file_count_before_image + 1} files after adding image'
+
+
+@pytest.mark.asyncio
+async def test_archive_generate_screenshot(test_session, async_client, archive_factory, await_switches,
+ image_bytes_factory, wrol_mode_fixture):
+ """Test generating a screenshot for an Archive that doesn't have one."""
+ # Mock html_screenshot to avoid Selenium dependency
+ mock_screenshot_bytes = image_bytes_factory()
+
+ with patch('modules.archive.lib.html_screenshot', return_value=mock_screenshot_bytes):
+ # Test success case: Archive without screenshot
+ archive = archive_factory('example.com', 'https://example.com/test', 'Test Archive', screenshot=False)
+ test_session.commit()
+
+ # Verify archive has no screenshot
+ assert archive.screenshot_path is None, 'Archive should not have a screenshot yet'
+ assert archive.singlefile_path is not None, 'Archive should have a singlefile'
+
+ # Request screenshot generation
+ request, response = await async_client.post(f'/api/archive/{archive.id}/generate_screenshot')
+ assert response.status_code == HTTPStatus.OK
+ assert response.json['message'] == 'Screenshot generation queued'
+
+ # Wait for background processing
+ await await_switches()
+
+ # Verify screenshot was generated
+ test_session.expire_all()
+ archive = test_session.query(Archive).filter_by(id=archive.id).one()
+ assert archive.screenshot_path is not None, 'Screenshot should have been generated'
+ assert archive.screenshot_path.is_file(), 'Screenshot file should exist'
+ assert archive.screenshot_file is not None, 'Screenshot file should be tracked'
+
+ # Test error case: Archive not found
+ request, response = await async_client.post('/api/archive/99999/generate_screenshot')
+ assert response.status_code == HTTPStatus.NOT_FOUND
+ assert 'not found' in response.json['error'].lower()
+
+ # Test success case: Archive already has screenshot - should return OK and ensure tracking
+ original_screenshot_path = archive.screenshot_path
+ assert original_screenshot_path is not None, 'Archive should have a screenshot'
+ assert original_screenshot_path.is_file(), 'Screenshot file should exist'
+
+ # Simulate the bug: archive has screenshot file but data['screenshot_path'] is not set
+ # (This could happen if the archive was created before set_screenshot was implemented)
+ if archive.file_group.data and 'screenshot_path' in archive.file_group.data:
+ data = dict(archive.file_group.data)
+ del data['screenshot_path']
+ archive.file_group.data = data
+ test_session.commit()
+ test_session.expire_all()
+ archive = test_session.query(Archive).filter_by(id=archive.id).one()
+ # Verify data was cleared but file still exists
+ assert 'screenshot_path' not in (archive.file_group.data or {}), 'screenshot_path should not be in data yet'
+ assert archive.screenshot_path is not None, 'But screenshot file should still exist'
+
+ request, response = await async_client.post(f'/api/archive/{archive.id}/generate_screenshot')
+ assert response.status_code == HTTPStatus.OK
+ assert response.json['message'] == 'Screenshot generation queued'
+
+ # Wait for background processing
+ await await_switches()
+
+ # Verify screenshot is still there and properly tracked
+ test_session.expire_all()
+ archive = test_session.query(Archive).filter_by(id=archive.id).one()
+ assert archive.screenshot_path == original_screenshot_path, 'Screenshot path should be unchanged'
+ assert archive.screenshot_path.is_file(), 'Screenshot file should still exist'
+ assert archive.screenshot_file is not None, 'Screenshot file should still be tracked'
+ # IMPORTANT: Also verify FileGroup.data was updated (this catches the bug where only files was updated)
+ assert 'screenshot_path' in archive.file_group.data, \
+ 'screenshot_path should be in FileGroup.data after ensuring tracking'
+ assert str(archive.file_group.data['screenshot_path']) == str(original_screenshot_path), \
+ 'screenshot_path in data should match the file path'
+
+ # Test error case: Archive has no singlefile
+ archive_no_singlefile = archive_factory('example.com', 'https://example.com/no-singlefile', screenshot=False)
+ test_session.commit()
+ # Delete the singlefile to simulate missing file
+ if archive_no_singlefile.singlefile_path and archive_no_singlefile.singlefile_path.is_file():
+ archive_no_singlefile.singlefile_path.unlink()
+
+ request, response = await async_client.post(f'/api/archive/{archive_no_singlefile.id}/generate_screenshot')
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert 'no singlefile' in response.json['error'].lower()
+
+ # Test WROL mode case
+ await wrol_mode_fixture(True)
+ archive_wrol = archive_factory('example.com', 'https://example.com/wrol-test', screenshot=False)
+ test_session.commit()
+
+ request, response = await async_client.post(f'/api/archive/{archive_wrol.id}/generate_screenshot')
+ assert response.status_code == HTTPStatus.FORBIDDEN
+ await wrol_mode_fixture(False)
diff --git a/modules/archive/test/test_collection_id_bug.py b/modules/archive/test/test_collection_id_bug.py
new file mode 100644
index 000000000..5304e5c8b
--- /dev/null
+++ b/modules/archive/test/test_collection_id_bug.py
@@ -0,0 +1,139 @@
+"""Test for collection_id bug when modeling archives from discovered files.
+
+This test reproduces the bug where Archives created by model_archive()
+don't have collection_id set, causing a NOT NULL constraint violation.
+
+The bug happens because:
+1. The archive_factory sets collection_id explicitly (masking the bug)
+2. But model_archive() (used in production) doesn't set collection_id
+3. This causes NULL constraint violations when indexing real files
+"""
+import pathlib
+
+import pytest
+
+from modules.archive import model_archive
+from modules.archive.lib import archive_strftime
+from wrolpi.collections import Collection
+
+
+@pytest.mark.asyncio
+async def test_model_archive_sets_collection_id(async_client, test_session, test_directory, make_files_structure):
+ """
+ Test that model_archive() creates Archives with collection_id set.
+
+ This test uses model_archive() (the production code path) instead of
+ archive_factory, so it will detect the collection_id bug.
+ """
+ import datetime
+ import pytz
+
+ # Create archive files with a URL in the singlefile
+ domain = 'example.com'
+ url = f'https://{domain}/test-page'
+ archive_dir = test_directory / 'archives' / domain
+ archive_dir.mkdir(parents=True)
+
+ timestamp = archive_strftime(datetime.datetime(2000, 1, 1, 0, 0, 0).astimezone(pytz.UTC))
+ title = 'Test Page'
+
+ # Create a minimal singlefile with URL embedded (note trailing space after SingleFile)
+ singlefile_content = f'''
+
+{title}
+
+
+{title}
+Test content
+
+'''
+
+ files = make_files_structure({
+ str(archive_dir / f'{timestamp}_{title}.html'): singlefile_content.strip(),
+ str(archive_dir / f'{timestamp}_{title}.readability.json'): '{"title": "' + title + '"}',
+ })
+
+ from wrolpi.files.models import FileGroup
+
+ # Create FileGroup from the files (simulating what refresh does)
+ file_paths = [pathlib.Path(f) for f in files]
+ file_group = FileGroup.from_paths(test_session, *file_paths)
+
+ # Model the archive using the production code path
+ # This should create an Archive with collection_id set
+ archive = model_archive(file_group, session=test_session)
+
+ # The bug: collection_id is None, causing NOT NULL constraint violation
+ assert archive is not None, "Archive should be created"
+ assert archive.collection_id is not None, \
+ "Archive.collection_id should be set (this is the bug!)"
+
+ # Verify the collection was created
+ collection = test_session.query(Collection).filter_by(
+ name=domain,
+ kind='domain'
+ ).one_or_none()
+
+ assert collection is not None, "Domain collection should be created"
+ assert archive.collection_id == collection.id, \
+ "Archive should be linked to the domain collection"
+
+
+@pytest.mark.asyncio
+async def test_model_archive_extracts_url_from_singlefile(async_client, test_session, test_directory, make_files_structure):
+ """
+ Test that model_archive() can extract URL from singlefile when file_group.url is None.
+
+ This is the second part of the fix - if the URL isn't in the file_group yet,
+ we need to extract it from the singlefile content.
+ """
+ import datetime
+ import pytz
+
+ # Create archive files
+ domain = 'test.org'
+ url = f'https://{domain}/article'
+ archive_dir = test_directory / 'archives' / domain
+ archive_dir.mkdir(parents=True)
+
+ timestamp = archive_strftime(datetime.datetime(2000, 1, 1, 0, 0, 0).astimezone(pytz.UTC))
+
+ # Create singlefile with URL in saved-from comment
+ singlefile_content = f'''
+
+Test Article
+
+Content
+'''
+
+ files = make_files_structure({
+ str(archive_dir / f'{timestamp}_Test.html'): singlefile_content.strip(),
+ })
+
+ from wrolpi.files.models import FileGroup
+
+ # Create FileGroup from the files (simulating what refresh does)
+ file_paths = [pathlib.Path(f) for f in files]
+ file_group = FileGroup.from_paths(test_session, *file_paths)
+
+ # Model the archive - it should extract the URL from the singlefile
+ archive = model_archive(file_group, session=test_session)
+
+ assert archive is not None
+ assert archive.file_group.url == url, \
+ "URL should be extracted from singlefile content"
+ assert archive.collection_id is not None, \
+ "collection_id should be set after URL extraction"
+
+ # Verify collection was created with correct domain
+ collection = test_session.query(Collection).get(archive.collection_id)
+ assert collection.name == domain
+ assert collection.kind == 'domain'
diff --git a/modules/archive/test/test_domain_tagging.py b/modules/archive/test/test_domain_tagging.py
new file mode 100644
index 000000000..5d36b4b27
--- /dev/null
+++ b/modules/archive/test/test_domain_tagging.py
@@ -0,0 +1,313 @@
+"""Tests for domain collection tagging functionality."""
+
+import pytest
+
+from wrolpi.collections import Collection
+
+
+@pytest.mark.asyncio
+async def test_create_domain_collection_with_directory(test_session, test_directory):
+ """Domain collection can be created with a directory."""
+ domain_dir = test_directory / 'archives' / 'example.com'
+ domain_dir.mkdir(parents=True)
+
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=domain_dir
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ assert collection.directory == domain_dir
+ assert collection.can_be_tagged is True
+
+
+@pytest.mark.asyncio
+async def test_create_unrestricted_domain_collection(test_session):
+ """Unrestricted domain collection (no directory) cannot be tagged."""
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=None # Unrestricted
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ assert collection.directory is None
+ assert collection.can_be_tagged is False
+
+
+@pytest.mark.asyncio
+async def test_tag_domain_collection_with_directory(async_client, test_session, test_directory, tag_factory):
+ """Domain collection with directory can be tagged."""
+ domain_dir = test_directory / 'archives' / 'example.com'
+ domain_dir.mkdir(parents=True, exist_ok=True)
+
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=domain_dir
+ )
+ test_session.add(collection)
+ tag = await tag_factory(name='News')
+ test_session.commit()
+
+ # Tag the collection
+ collection.set_tag('News')
+ test_session.commit()
+
+ assert collection.tag is not None
+ assert collection.tag.name == 'News'
+ assert collection.tag_id == tag.id
+
+
+@pytest.mark.asyncio
+async def test_cannot_tag_unrestricted_domain_collection(async_client, test_session, tag_factory):
+ """Domain collection without directory cannot be tagged."""
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=None # Unrestricted
+ )
+ test_session.add(collection)
+ tag = await tag_factory(name='News')
+ test_session.commit()
+
+ assert collection.can_be_tagged is False
+
+ with pytest.raises(ValueError, match='Cannot tag domain collection.*without a directory'):
+ collection.set_tag('News')
+
+
+@pytest.mark.asyncio
+async def test_tag_domain_collection_moves_files(
+ test_session, archive_factory, tag_factory, test_directory, make_files_structure, archive_directory,
+ await_switches,
+):
+ """Tagging domain collection with directory moves archive files."""
+ # Create domain collection with directory
+ # Use archive_directory fixture to match where archive_factory creates files
+ domain_dir = archive_directory / 'test.com'
+ domain_dir.mkdir(parents=True, exist_ok=True)
+
+ collection = Collection(
+ name='test.com',
+ kind='domain',
+ directory=domain_dir
+ )
+ test_session.add(collection)
+ test_session.flush()
+
+ # Create archives in domain directory
+ archive1 = archive_factory(domain='test.com')
+ archive2 = archive_factory(domain='test.com')
+ archive1.collection = collection
+ archive2.collection = collection
+ test_session.commit()
+
+ # Get file paths before move
+ old_paths = []
+ for archive in [archive1, archive2]:
+ for path in archive.my_paths():
+ old_paths.append(path)
+
+ assert all(p.is_file() for p in old_paths), "Files should exist before move"
+
+ # Create tag and assign to collection
+ tag = await tag_factory(name='Tech')
+ collection.set_tag('Tech')
+ test_session.commit()
+
+ # Compute new directory and move collection
+ new_directory = collection.format_directory('Tech')
+ # Create the destination directory before moving
+ new_directory.mkdir(parents=True, exist_ok=True)
+ await collection.move_collection(new_directory, test_session)
+
+ # Verify files moved
+ assert new_directory.is_dir(), "New directory should exist"
+ assert 'Tech' in str(new_directory), "New directory should contain tag name"
+ assert 'test.com' in str(new_directory), "New directory should contain domain name"
+
+ # Old paths should no longer exist
+ for old_path in old_paths:
+ assert not old_path.is_file(), f"Old file should be moved: {old_path}"
+
+
+@pytest.mark.asyncio
+async def test_domain_config_with_directory_and_tag(test_session, test_directory, tag_factory, async_client):
+ """Domain config can include directory and tag_name."""
+ from modules.archive.lib import DomainsConfig
+
+ # Create config file
+ config_file = test_directory / 'domains.yaml'
+ config_file.write_text("""
+collections:
+ - name: "example.com"
+ kind: "domain"
+ description: "News from example.com"
+ directory: "archives/example.com"
+ tag_name: "News"
+""")
+
+ # Create tag first
+ tag = await tag_factory(name='News')
+ test_session.commit()
+
+ # Import config
+ domains_config = DomainsConfig()
+ domains_config.import_config(config_file)
+
+ # Verify collection created with directory and tag
+ collection = test_session.query(Collection).filter_by(
+ name='example.com',
+ kind='domain'
+ ).one()
+
+ assert collection.directory is not None
+ assert 'example.com' in str(collection.directory)
+ assert collection.tag is not None
+ assert collection.tag.name == 'News'
+ assert collection.can_be_tagged is True
+
+
+@pytest.mark.asyncio
+async def test_domain_config_warns_tag_without_directory(test_session, test_directory, tag_factory, caplog,
+ async_client):
+ """Config warns when tag_name provided without directory."""
+ from modules.archive.lib import DomainsConfig
+
+ # Create config file with tag but no directory
+ config_file = test_directory / 'domains.yaml'
+ config_file.write_text("""
+collections:
+ - name: "example.com"
+ kind: "domain"
+ tag_name: "News"
+""")
+
+ # Create tag first
+ tag = await tag_factory(name='News')
+ test_session.commit()
+
+ # Import config
+ domains_config = DomainsConfig()
+ domains_config.import_config(config_file)
+
+ # Verify collection created without tag
+ collection = test_session.query(Collection).filter_by(
+ name='example.com',
+ kind='domain'
+ ).one()
+
+ assert collection.directory is None
+ assert collection.tag is None # Tag ignored
+ assert collection.can_be_tagged is False
+
+ # Check warning was logged
+ assert "tags require a directory" in caplog.text
+
+
+@pytest.mark.asyncio
+async def test_domain_config_export_includes_directory_and_tag(test_session, test_directory, tag_factory, async_client):
+ """Exporting domain config includes directory and tag_name."""
+ from modules.archive.lib import DomainsConfig
+ from wrolpi.common import get_media_directory
+ import yaml
+
+ # Create domain collection with directory and tag
+ domain_dir = get_media_directory() / 'archives' / 'news.com'
+ domain_dir.mkdir(parents=True, exist_ok=True)
+
+ collection = Collection(
+ name='news.com',
+ kind='domain',
+ directory=domain_dir,
+ description='News website'
+ )
+ test_session.add(collection)
+ tag = await tag_factory(name='News')
+ collection.set_tag('News')
+ test_session.commit()
+
+ # Export config
+ config_file = test_directory / 'domains.yaml'
+ domains_config = DomainsConfig()
+ domains_config.dump_config(config_file, overwrite=True)
+
+ # Read and verify exported config
+ with open(config_file) as f:
+ exported = yaml.safe_load(f)
+
+ assert len(exported['collections']) == 1
+ domain_config = exported['collections'][0]
+ assert domain_config['name'] == 'news.com'
+ assert domain_config['kind'] == 'domain'
+ assert 'archives/news.com' in domain_config['directory']
+ assert domain_config['tag_name'] == 'News'
+ assert domain_config['description'] == 'News website'
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_domain_collection_with_directory(test_session, test_directory):
+ """get_or_create_domain_collection can create collection with directory."""
+ from modules.archive.lib import get_or_create_domain_collection
+ from wrolpi.common import get_media_directory
+
+ domain_dir = get_media_directory() / 'archives' / 'test.com'
+ domain_dir.mkdir(parents=True, exist_ok=True)
+
+ # Create with directory
+ collection = get_or_create_domain_collection(
+ test_session,
+ 'https://test.com/article',
+ directory=domain_dir
+ )
+
+ assert collection.name == 'test.com'
+ assert collection.kind == 'domain'
+ assert collection.directory == domain_dir
+ assert collection.can_be_tagged is True
+
+
+@pytest.mark.asyncio
+async def test_get_archive_destination_with_directory(test_session, test_directory):
+ """get_archive_destination returns collection directory when set."""
+ from modules.archive.lib import get_archive_destination
+ from wrolpi.common import get_media_directory
+
+ domain_dir = get_media_directory() / 'archives' / 'tagged' / 'example.com'
+ domain_dir.mkdir(parents=True, exist_ok=True)
+
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=domain_dir
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ destination = get_archive_destination(collection)
+
+ assert destination == domain_dir
+ assert destination.is_dir()
+
+
+@pytest.mark.asyncio
+async def test_get_archive_destination_unrestricted(test_session):
+ """get_archive_destination returns default path for unrestricted collection."""
+ from modules.archive.lib import get_archive_destination, get_archive_directory
+
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=None # Unrestricted
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ destination = get_archive_destination(collection)
+ expected = get_archive_directory() / 'example.com'
+
+ assert destination == expected
diff --git a/modules/archive/test/test_lib.py b/modules/archive/test/test_lib.py
index b3e1fee85..090a175a8 100644
--- a/modules/archive/test/test_lib.py
+++ b/modules/archive/test/test_lib.py
@@ -10,10 +10,11 @@
from pytz import utc
from modules.archive import lib
-from modules.archive.lib import get_or_create_domain, get_new_archive_files, delete_archives, model_archive_result, \
- get_domains
-from modules.archive.models import Archive, Domain
+from modules.archive.lib import get_or_create_domain_collection, get_new_archive_files, delete_archives, \
+ model_archive_result, get_domains
+from modules.archive.models import Archive
from wrolpi.api_utils import CustomJSONEncoder
+from wrolpi.collections import Collection
from wrolpi.common import get_wrolpi_config
from wrolpi.db import get_db_session
from wrolpi.files import lib as files_lib
@@ -37,7 +38,7 @@ def make_fake_archive_result(readability=True, screenshot=True, title=True):
@pytest.mark.asyncio
-async def test_no_screenshot(test_directory, test_session):
+async def test_no_screenshot(async_client, test_directory, test_session):
singlefile, readability, screenshot = make_fake_archive_result(screenshot=False)
archive = await model_archive_result('https://example.com', singlefile, readability, screenshot)
assert isinstance(archive.singlefile_path, pathlib.Path)
@@ -47,7 +48,7 @@ async def test_no_screenshot(test_directory, test_session):
@pytest.mark.asyncio
-async def test_no_readability(test_directory, test_session):
+async def test_no_readability(async_client, test_directory, test_session):
singlefile, readability, screenshot = make_fake_archive_result(readability=False)
archive = await model_archive_result('https://example.com', singlefile, readability, screenshot)
assert isinstance(archive.singlefile_path, pathlib.Path)
@@ -58,7 +59,7 @@ async def test_no_readability(test_directory, test_session):
@pytest.mark.asyncio
-async def test_dict(test_session):
+async def test_dict(async_client, test_session, test_directory):
singlefile, readability, screenshot = make_fake_archive_result()
d = (await model_archive_result('https://example.com', singlefile, readability, screenshot)).dict()
assert isinstance(d, dict)
@@ -66,17 +67,18 @@ async def test_dict(test_session):
@pytest.mark.asyncio
-async def test_relationships(test_session, example_singlefile):
+async def test_relationships(async_client, test_session, example_singlefile):
with get_db_session(commit=True) as session:
url = 'https://wrolpi.org:443'
- domain = get_or_create_domain(session, url)
+ collection = get_or_create_domain_collection(session, url)
archive = Archive.from_paths(test_session, example_singlefile)
archive.url = url
- archive.domain_id = domain.id
+ archive.collection_id = collection.id
session.add(archive)
session.flush()
- assert archive.domain == domain
+ assert archive.collection == collection
+ assert archive.domain == 'wrolpi.org'
@pytest.mark.asyncio
@@ -119,8 +121,9 @@ async def reset_and_get_archive():
assert archive1.file_group.title is None
-def test_archive_refresh_deleted_archive(test_client, test_session, archive_directory, archive_factory):
- """Archives/Domains should be deleted when archive files are deleted."""
+@pytest.mark.asyncio
+async def test_archive_refresh_deleted_archive(async_client, test_session, archive_directory, archive_factory):
+ """Archives/domain collections should be deleted when archive files are deleted."""
archive1 = archive_factory('example.com', 'https://example.com/1')
archive2 = archive_factory('example.com', 'https://example.com/1')
archive3 = archive_factory('example.com')
@@ -133,29 +136,30 @@ def test_archive_refresh_deleted_archive(test_client, test_session, archive_dire
def check_counts(archive_count, domain_count):
assert test_session.query(Archive).count() == archive_count, 'Archive count does not match'
- assert test_session.query(Domain).count() == domain_count, 'Domain count does not match'
+ assert test_session.query(Collection).filter_by(
+ kind='domain').count() == domain_count, 'domain collection count does not match'
# All 5 archives are already in the DB.
check_counts(archive_count=5, domain_count=2)
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
check_counts(archive_count=5, domain_count=1)
# Delete archive2's files, it's the latest for 'https://example.com/1'
for path in archive2.my_paths():
path.unlink()
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
check_counts(archive_count=4, domain_count=1)
# Delete archive1's files, now the URL is empty.
for path in archive1.my_paths():
path.unlink()
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
check_counts(archive_count=3, domain_count=0)
# Delete archive3, now there is now example.com domain
for path in archive3.my_paths():
path.unlink()
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
check_counts(archive_count=2, domain_count=0)
# Delete all the rest of the archives
@@ -163,12 +167,12 @@ def check_counts(archive_count, domain_count):
path.unlink()
for path in archive5.my_paths():
path.unlink()
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
check_counts(archive_count=0, domain_count=0)
@pytest.mark.asyncio
-async def test_fills_contents_with_refresh(test_session, archive_factory, singlefile_contents_factory):
+async def test_fills_contents_with_refresh(async_client, test_session, archive_factory, singlefile_contents_factory):
"""Refreshing archives fills in any missing contents."""
archive1 = archive_factory('example.com', 'https://example.com/one')
archive2 = archive_factory('example.com', 'https://example.com/one')
@@ -212,7 +216,8 @@ async def test_fills_contents_with_refresh(test_session, archive_factory, single
assert archive3.file_group.title == 'last title' # from singlefile HTML
-def test_delete_archive(test_session, archive_factory):
+@pytest.mark.asyncio
+async def test_delete_archive(async_client, test_session, archive_factory):
"""Archives can be deleted."""
archive1 = archive_factory('example.com', 'https://example.com/1')
archive2 = archive_factory('example.com', 'https://example.com/1')
@@ -227,30 +232,35 @@ def test_delete_archive(test_session, archive_factory):
assert test_session.query(Archive).count() == 3
assert test_session.query(FileGroup).count() == 3
+ # Save paths before deletion (archives become detached after delete)
+ archive1_paths = list(archive1.my_paths())
+ archive2_paths = list(archive2.my_paths())
+ archive3_paths = list(archive3.my_paths())
+
# Delete the oldest.
delete_archives(archive1.id, archive3.id)
assert test_session.query(Archive).count() == 1
assert test_session.query(FileGroup).count() == 1
# Files were deleted.
- assert archive1.my_paths() and not any(i.is_file() for i in archive1.my_paths())
- assert archive3.my_paths() and not any(i.is_file() for i in archive3.my_paths())
+ assert archive1_paths and not any(i.is_file() for i in archive1_paths)
+ assert archive3_paths and not any(i.is_file() for i in archive3_paths)
# Archive2 is untouched
- assert archive2.my_paths() and all(i.is_file() for i in archive2.my_paths())
+ assert archive2_paths and all(i.is_file() for i in archive2_paths)
- # Delete the last archive. The Domain should also be deleted.
+ # Delete the last archive. The domain collection should also be deleted.
delete_archives(archive2.id)
assert test_session.query(Archive).count() == 0
- domain = test_session.query(Domain).one_or_none()
+ domain = test_session.query(Collection).filter_by(kind='domain').one_or_none()
assert domain is None
# All Files were deleted.
assert test_session.query(FileGroup).count() == 0
- assert archive2.my_paths() and not any(i.is_file() for i in archive2.my_paths())
+ assert archive2_paths and not any(i.is_file() for i in archive2_paths)
def test_get_domains(test_session, archive_factory):
"""
- `get_domains` gets only Domains with Archives.
+ `get_domains` gets only domain collections with Archives.
"""
archive1 = archive_factory('example.com')
archive2 = archive_factory('example.com')
@@ -305,7 +315,7 @@ async def test_new_archive(test_session, test_directory, fake_now):
fake_now(datetime(2000, 1, 2))
archive2 = await model_archive_result('https://example.com', singlefile, readability, screenshot)
- # Domain is reused.
+ # domain collection is reused.
assert archive1.domain == archive2.domain
@@ -572,7 +582,7 @@ async def test_archive_meta(async_client, test_session, make_files_structure):
@pytest.mark.asyncio
-async def test_refresh_archives_deleted_singlefile(test_session, make_files_structure, singlefile_contents_factory):
+async def test_refresh_archives_deleted_singlefile(async_client, test_session, make_files_structure, singlefile_contents_factory):
"""Removing a Singlefile file from a FileGroup makes that group no longer an Archive."""
singlefile, readability = make_files_structure({
'2022-09-04-16-20-11_The Title.html': singlefile_contents_factory(),
@@ -840,3 +850,218 @@ async def test_get_custom_archive_directory(async_client, test_directory, test_w
assert lib.get_archive_directory() == (test_directory / 'custom/archives')
assert lib.get_domain_directory('https://example.com') == (test_directory / 'custom/archives/example.com')
+
+
+@pytest.mark.asyncio
+async def test_detect_domain_directory_single_archive(async_client, test_directory, test_session, archive_factory):
+ """detect_domain_directory should detect directory when all archives are in same location."""
+ from modules.archive.lib import detect_domain_directory
+
+ # Create an archive using the factory (which creates a domain collection and auto-detects directory)
+ archive = archive_factory(domain='example.com', url='https://example.com/page1')
+ test_session.flush()
+
+ # Get the collection
+ collection = archive.collection
+ assert collection is not None
+ assert collection.name == 'example.com'
+
+ # Directory should have been auto-detected during collection creation
+ assert collection.directory is not None
+ assert 'example.com' in str(collection.directory)
+
+ # Verify detect_domain_directory also returns the same result
+ detected = detect_domain_directory(collection, test_session)
+ assert detected is not None
+ assert str(detected) == 'archive/example.com'
+
+
+@pytest.mark.asyncio
+async def test_detect_domain_directory_multiple_archives(async_client, test_directory, test_session, archive_factory):
+ """detect_domain_directory should detect common directory for multiple archives."""
+ from modules.archive.lib import detect_domain_directory
+
+ # Create multiple archives in the same domain
+ archive1 = archive_factory(domain='test.com', url='https://test.com/page1')
+ archive2 = archive_factory(domain='test.com', url='https://test.com/page2')
+ test_session.flush()
+
+ # Get the collection
+ collection = archive1.collection
+ assert collection is not None
+ assert collection.name == 'test.com'
+
+ # Detect directory
+ detected = detect_domain_directory(collection, test_session)
+ assert detected is not None
+ assert str(detected) == 'archive/test.com'
+
+
+def test_detect_domain_directory_no_archives(test_directory, test_session):
+ """detect_domain_directory should return None when collection has no archives."""
+ from modules.archive.lib import detect_domain_directory
+
+ # Create a domain collection without archives
+ collection = Collection(name='empty.com', kind='domain', directory=None)
+ test_session.add(collection)
+ test_session.flush()
+
+ # Detect directory - should return None
+ detected = detect_domain_directory(collection, test_session)
+ assert detected is None
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_domain_collection_auto_detects_directory(async_client, test_directory, test_session, archive_factory):
+ """get_or_create_domain_collection should auto-detect directory for existing archives."""
+ from modules.archive.lib import get_or_create_domain_collection
+
+ # Create an archive first
+ archive = archive_factory(domain='autodetect.com', url='https://autodetect.com/page1')
+ test_session.commit()
+
+ # The collection should have been created with no directory initially
+ # Call get_or_create_domain_collection again - should auto-detect
+ collection = get_or_create_domain_collection(test_session, 'https://autodetect.com/page2')
+ assert collection.directory is not None
+ assert 'autodetect.com' in str(collection.directory)
+
+
+@pytest.mark.asyncio
+async def test_update_domain_directories(async_client, test_directory, test_session, archive_factory):
+ """update_domain_directories should fix existing collections that lost their directory."""
+ from modules.archive.lib import update_domain_directories
+
+ # Create an archive - this creates a collection WITH directory (auto-detected)
+ archive = archive_factory(domain='needsdir.com', url='https://needsdir.com/page1')
+ test_session.commit()
+
+ # Manually clear the directory to simulate legacy data
+ collection = archive.collection
+ collection.directory = None
+ test_session.commit()
+
+ # Verify collection has no directory
+ assert collection.directory is None
+
+ # Run update
+ count = update_domain_directories(test_session)
+ assert count == 1
+
+ # Verify directory was set
+ test_session.expire(collection)
+ assert collection.directory is not None
+ assert 'needsdir.com' in str(collection.directory)
+
+
+def test_collection_unique_name_kind_constraint(test_directory, test_session, archive_factory):
+ """Collections should have unique (name, kind) combinations.
+
+ This prevents duplicate domain collections with the same name.
+ """
+ from sqlalchemy.exc import IntegrityError
+
+ # Create an archive which creates a domain collection
+ archive = archive_factory(domain='uniquetest.com')
+ test_session.commit()
+
+ # Get the original collection
+ original_collection = archive.collection
+ assert original_collection is not None
+
+ # Attempting to create a duplicate domain collection should fail
+ duplicate_collection = Collection(name='uniquetest.com', kind='domain', directory=None)
+ test_session.add(duplicate_collection)
+
+ with pytest.raises(IntegrityError) as exc_info:
+ test_session.commit()
+
+ # Verify it's the unique constraint violation
+ assert 'uq_collection_name_kind' in str(exc_info.value)
+ test_session.rollback()
+
+
+def test_search_archives_by_domain(test_directory, test_session, archive_factory):
+ """search_archives should correctly filter by domain collection name."""
+ from modules.archive.lib import search_archives
+
+ # Create archives in different domains
+ archive1 = archive_factory(domain='searchtest.com')
+ archive2 = archive_factory(domain='searchtest.com')
+ archive3 = archive_factory(domain='other.com')
+ test_session.commit()
+
+ # Search for archives in searchtest.com
+ file_groups, total = search_archives(
+ search_str=None,
+ domain='searchtest.com',
+ limit=10,
+ offset=0,
+ order=None,
+ tag_names=None
+ )
+
+ # Should return only archives from searchtest.com
+ assert total == 2
+ assert len(file_groups) == 2
+
+ # Search for other.com
+ file_groups, total = search_archives(
+ search_str=None,
+ domain='other.com',
+ limit=10,
+ offset=0,
+ order=None,
+ tag_names=None
+ )
+
+ assert total == 1
+ assert len(file_groups) == 1
+
+
+def test_link_domain_and_downloads(test_session, test_download_manager):
+ """Test that downloads are linked to domain collections."""
+ from wrolpi.downloader import Download
+ from modules.archive.lib import link_domain_and_downloads
+
+ # Create a domain collection with a directory
+ collection = Collection(name='example.com', kind='domain', directory='archive/example.com')
+ test_session.add(collection)
+ test_session.commit()
+
+ # 1. Download with matching destination directory (exact match)
+ download1 = Download(
+ url='https://other.com/rss',
+ downloader='rss',
+ sub_downloader='archive',
+ frequency=86400,
+ settings={'destination': 'archive/example.com'}
+ )
+ # 2. Download with destination in subdirectory
+ download2 = Download(
+ url='https://other.com/rss2',
+ downloader='rss',
+ sub_downloader='archive',
+ frequency=86400,
+ settings={'destination': 'archive/example.com/2025/01'}
+ )
+ # 3. RSS download with archive sub_downloader (matches by URL domain)
+ download3 = Download(
+ url='https://example.com/feed.xml',
+ downloader='rss',
+ sub_downloader='archive',
+ frequency=86400
+ )
+ # 4. Download without frequency (should NOT be linked)
+ download4 = Download(url='https://example.com/once', downloader='archive')
+ test_session.add_all([download1, download2, download3, download4])
+ test_session.commit()
+
+ assert not any(d.collection_id for d in [download1, download2, download3, download4])
+
+ link_domain_and_downloads(test_session)
+
+ assert download1.collection_id == collection.id # matched by directory (exact)
+ assert download2.collection_id == collection.id # matched by subdirectory
+ assert download3.collection_id == collection.id # matched by URL domain
+ assert download4.collection_id is None # no frequency = no link
diff --git a/modules/archive/test/test_models.py b/modules/archive/test/test_models.py
index c24ee2c0b..7037a7a6e 100644
--- a/modules/archive/test/test_models.py
+++ b/modules/archive/test/test_models.py
@@ -1,8 +1,9 @@
+import pathlib
from datetime import datetime
import pytest
-from modules.archive import Domain
+from modules.archive import Archive
from wrolpi.common import get_wrolpi_config
@@ -12,16 +13,87 @@ async def test_archive_download_destination(async_client, test_session, test_dir
wrolpi_config = get_wrolpi_config()
- archive_factory(domain='wrolpi.org')
- domain = test_session.query(Domain).one()
+ archive = archive_factory(domain='wrolpi.org')
+ test_session.commit()
# Test the default download directory.
- assert str(domain.download_directory) == str(test_directory / 'archive/wrolpi.org')
+ assert str(archive.download_directory) == str(test_directory / 'archive/wrolpi.org')
# Year of download is supported.
wrolpi_config.archive_destination = 'archives/%(domain)s/%(year)s'
- assert str(domain.download_directory) == str(test_directory / 'archives/wrolpi.org/2000')
+ assert str(archive.download_directory) == str(test_directory / 'archives/wrolpi.org/2000')
# More download date is supported.
wrolpi_config.archive_destination = 'archive/%(domain)s/%(year)s/%(month)s/%(day)s'
- assert str(domain.download_directory) == str(test_directory / 'archive/wrolpi.org/2000/1/2')
+ assert str(archive.download_directory) == str(test_directory / 'archive/wrolpi.org/2000/1/2')
+
+
+@pytest.mark.asyncio
+async def test_set_screenshot(async_client, test_session, archive_factory, image_bytes_factory):
+ """Test Archive.set_screenshot() method with various scenarios."""
+
+ # ========== Test 1: Successfully set screenshot on archive without one ==========
+ archive = archive_factory('example.com', 'https://example.com/test', 'Test Archive', screenshot=False)
+ test_session.commit()
+
+ # Verify no screenshot initially
+ assert archive.screenshot_path is None
+
+ # Create a screenshot file
+ screenshot_path = archive.singlefile_path.parent / 'test_screenshot.png'
+ screenshot_path.write_bytes(image_bytes_factory())
+
+ # Set the screenshot
+ archive.set_screenshot(screenshot_path)
+
+ # Verify data was set before commit
+ assert archive.file_group.data['screenshot_path'] == str(screenshot_path), \
+ f"Data not set correctly before commit: {archive.file_group.data.get('screenshot_path')}"
+
+ test_session.commit()
+
+ # Refresh the archive to get latest data from DB
+ test_session.expire_all()
+ archive = test_session.query(Archive).filter_by(id=archive.id).one()
+
+ # Verify screenshot was set
+ assert archive.screenshot_path == screenshot_path
+ assert archive.screenshot_file is not None
+ # Verify FileGroup.data was persisted to DB (FancyJSON converts strings back to Path objects)
+ assert 'screenshot_path' in archive.file_group.data
+ assert str(archive.file_group.data['screenshot_path']) == str(screenshot_path)
+
+ # Verify file is tracked in FileGroup.files
+ screenshot_files = archive.file_group.my_files('image/')
+ assert len(screenshot_files) > 0
+ assert screenshot_files[0]['path'] == screenshot_path
+
+ # ========== Test 2: Error when trying to set screenshot on archive that already has one ==========
+ original_screenshot = archive.screenshot_path
+ assert original_screenshot is not None
+
+ # Try to set another screenshot
+ new_screenshot_path = archive.singlefile_path.parent / 'another_screenshot.png'
+ new_screenshot_path.write_bytes(b'fake image data')
+
+ # Should raise ValueError
+ with pytest.raises(ValueError, match='already has a screenshot'):
+ archive.set_screenshot(new_screenshot_path)
+
+ # Verify original screenshot is unchanged
+ assert archive.screenshot_path == original_screenshot
+
+ # ========== Test 3: Error when trying to set non-existent screenshot file ==========
+ archive_no_screenshot = archive_factory('example.com', 'https://example.com/test2', 'Test Archive 2',
+ screenshot=False)
+ test_session.commit()
+
+ # Try to set a screenshot that doesn't exist
+ nonexistent_path = pathlib.Path('/tmp/does_not_exist.png')
+
+ # Should raise ValueError
+ with pytest.raises(ValueError, match='does not exist'):
+ archive_no_screenshot.set_screenshot(nonexistent_path)
+
+ # Verify no screenshot was set
+ assert archive_no_screenshot.screenshot_path is None
diff --git a/modules/inventory/test/test_api.py b/modules/inventory/test/test_api.py
index 651f4dff7..db4ea1209 100644
--- a/modules/inventory/test/test_api.py
+++ b/modules/inventory/test/test_api.py
@@ -7,7 +7,8 @@
from modules.inventory.inventory import save_item
-def test_delete_items(test_session, init_test_inventory, test_client):
+@pytest.mark.asyncio
+async def test_delete_items(test_session, init_test_inventory, async_client):
"""
Multiple Items can be deleted in a single request.
"""
@@ -24,14 +25,15 @@ def test_delete_items(test_session, init_test_inventory, test_client):
save_item(init_test_inventory.id, item)
items = test_session.query(Item).filter(Item.brand != None).all()
- request, response = test_client.delete(f'/api/inventory/item/{items[0].id},{items[1].id}')
+ request, response = await async_client.delete(f'/api/inventory/item/{items[0].id},{items[1].id}')
assert response.status_code == HTTPStatus.NO_CONTENT, response.status_code
# Refresh all items. Verify they were deleted.
[test_session.refresh(i) for i in items]
assert all([i.deleted_at for i in items])
-def test_delete_item(test_session, init_test_inventory, test_client):
+@pytest.mark.asyncio
+async def test_delete_item(test_session, init_test_inventory, async_client):
"""
An item can be deleted by itself.
"""
@@ -48,22 +50,23 @@ def test_delete_item(test_session, init_test_inventory, test_client):
item = test_session.query(Item).filter(Item.brand != None).one()
assert not item.deleted_at
- request, response = test_client.delete(f'/api/inventory/item/{item.id}')
+ request, response = await async_client.delete(f'/api/inventory/item/{item.id}')
assert response.status_code == HTTPStatus.NO_CONTENT, response.status_code
test_session.refresh(item)
assert item.deleted_at
-def test_item_api(test_session, init_test_inventory, test_client):
+@pytest.mark.asyncio
+async def test_item_api(test_session, init_test_inventory, async_client):
"""An Item can be added to an Inventory."""
inventory_id = init_test_inventory.id
item = {'brand': 'Salty', 'name': 'Salt', 'item_size': '1', 'unit': 'lbs', 'count': '5',
'category': 'cooking ingredients', 'subcategory': 'salt', 'expiration_date': None}
- request, response = test_client.post(f'/api/inventory/{inventory_id}/item', content=json.dumps(item))
+ request, response = await async_client.post(f'/api/inventory/{inventory_id}/item', content=json.dumps(item))
assert response.status_code == HTTPStatus.NO_CONTENT, response.status_code
# Get the Item we just created.
- request, response = test_client.get(f'/api/inventory/{inventory_id}/item')
+ request, response = await async_client.get(f'/api/inventory/{inventory_id}/item')
assert response.status_code == HTTPStatus.OK
# assert item == response.json['items'][0]
item = response.json['items'][0]
@@ -71,7 +74,7 @@ def test_item_api(test_session, init_test_inventory, test_client):
expiration_dates = ('1969-12-31T20:25:45.123450', None)
for expiration_date in expiration_dates:
item['expiration_date'] = expiration_date
- request, response = test_client.put(f'/api/inventory/item/{item["id"]}', content=json.dumps(item))
+ request, response = await async_client.put(f'/api/inventory/item/{item["id"]}', content=json.dumps(item))
assert response.status_code == HTTPStatus.NO_CONTENT, response.json
diff --git a/modules/map/test/test_api.py b/modules/map/test/test_api.py
index 2bdee3843..c5b357e14 100644
--- a/modules/map/test/test_api.py
+++ b/modules/map/test/test_api.py
@@ -3,10 +3,11 @@
from itertools import zip_longest
import mock
+import pytest
-def assert_map_file_status(test_client, expected, status_code=HTTPStatus.OK):
- request, response = test_client.get('/api/map/files')
+async def assert_map_file_status(async_client, expected, status_code=HTTPStatus.OK):
+ request, response = await async_client.get('/api/map/files')
assert response.status_code == status_code
for pbf, expected_ in zip_longest(response.json['files'], expected):
imported = 'to be imported' if expected_['imported'] else 'to NOT be imported'
@@ -16,9 +17,10 @@ def assert_map_file_status(test_client, expected, status_code=HTTPStatus.OK):
assert response.json['pending'] is None
+@pytest.mark.asyncio
@mock.patch('modules.map.lib.is_pbf_file', lambda i: True)
@mock.patch('modules.map.lib.is_dump_file', lambda i: True)
-def test_status_and_import(test_client, test_session, make_files_structure, mock_run_command):
+async def test_status_and_import(async_client, test_session, make_files_structure, mock_run_command):
"""PBF files can be imported, and the status of the import can be monitored."""
pbf1, pbf2, dump1 = make_files_structure([
'map/pbf/some-country.osm.pbf',
@@ -32,10 +34,10 @@ def test_status_and_import(test_client, test_session, make_files_structure, mock
{'imported': False, 'path': 'map/pbf/other-country.osm.pbf'},
{'imported': False, 'path': 'map/pbf/some-country.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
body = {'files': [str(pbf1), str(pbf2)]}
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.NO_CONTENT
# Can import multiple pbf files.
@@ -44,10 +46,10 @@ def test_status_and_import(test_client, test_session, make_files_structure, mock
{'imported': True, 'path': 'map/pbf/other-country.osm.pbf'},
{'imported': True, 'path': 'map/pbf/some-country.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
body = {'files': [str(pbf1), ], }
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.NO_CONTENT
# Importing one pbf resets the import statuses.
@@ -56,10 +58,10 @@ def test_status_and_import(test_client, test_session, make_files_structure, mock
{'imported': False, 'path': 'map/pbf/other-country.osm.pbf'},
{'imported': True, 'path': 'map/pbf/some-country.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
body = {'files': [str(dump1), ]}
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.NO_CONTENT
# Importing a dump also resets statuses.
@@ -68,12 +70,13 @@ def test_status_and_import(test_client, test_session, make_files_structure, mock
{'imported': False, 'path': 'map/pbf/other-country.osm.pbf'},
{'imported': True, 'path': 'map/pbf/some-country.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
+@pytest.mark.asyncio
@mock.patch('modules.map.lib.is_pbf_file', lambda i: True)
@mock.patch('modules.map.lib.is_dump_file', lambda i: True)
-def test_multiple_import(test_client, test_session, test_directory, make_files_structure, mock_run_command):
+async def test_multiple_import(async_client, test_session, test_directory, make_files_structure, mock_run_command):
"""Multiple PBFs can be imported. Importing a second time overwrites the previous imports."""
pbf1, pbf2, pbf3, dump = make_files_structure([
'map/pbf/country1.osm.pbf',
@@ -83,7 +86,7 @@ def test_multiple_import(test_client, test_session, test_directory, make_files_s
])
body = {'files': [str(pbf2), str(pbf1)]}
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.NO_CONTENT
# Two PBFs can be merged and imported.
@@ -93,10 +96,10 @@ def test_multiple_import(test_client, test_session, test_directory, make_files_s
{'imported': True, 'path': 'map/pbf/country2.osm.pbf'},
{'imported': False, 'path': 'map/pbf/country3.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
body = {'files': [str(pbf3), ]}
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.NO_CONTENT
# New imports overwrite old imports.
@@ -106,11 +109,11 @@ def test_multiple_import(test_client, test_session, test_directory, make_files_s
{'imported': False, 'path': 'map/pbf/country2.osm.pbf'},
{'imported': True, 'path': 'map/pbf/country3.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
# Can't mix PBF and dump.
body = {'files': [str(dump), ]}
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.NO_CONTENT
# Dump does not overwrite.
@@ -120,12 +123,13 @@ def test_multiple_import(test_client, test_session, test_directory, make_files_s
{'imported': False, 'path': 'map/pbf/country2.osm.pbf'},
{'imported': True, 'path': 'map/pbf/country3.osm.pbf'},
]
- assert_map_file_status(test_client, expected)
+ await assert_map_file_status(async_client, expected)
-def test_empty_import(test_client, test_session, test_directory):
+@pytest.mark.asyncio
+async def test_empty_import(async_client, test_session, test_directory):
"""Some files must be requested."""
body = {'files': []}
- request, response = test_client.post('/api/map/import', content=json.dumps(body))
+ request, response = await async_client.post('/api/map/import', content=json.dumps(body))
assert response.status_code == HTTPStatus.BAD_REQUEST
assert 'validate' in response.json['message']
diff --git a/modules/otp/test/test_api.py b/modules/otp/test/test_api.py
index e00225be4..0c205345f 100644
--- a/modules/otp/test/test_api.py
+++ b/modules/otp/test/test_api.py
@@ -1,10 +1,13 @@
import json
from http import HTTPStatus
+import pytest
-def test_encrypt(test_client):
+
+@pytest.mark.asyncio
+async def test_encrypt(async_client):
body = dict(otp='asdf', plaintext='asdf')
- request, response = test_client.post('/api/otp/encrypt_otp', content=json.dumps(body))
+ request, response = await async_client.post('/api/otp/encrypt_otp', content=json.dumps(body))
assert response.status_code == HTTPStatus.OK
assert response.json == dict(
ciphertext='AAGK',
@@ -13,9 +16,10 @@ def test_encrypt(test_client):
)
-def test_decrypt(test_client):
+@pytest.mark.asyncio
+async def test_decrypt(async_client):
body = dict(otp='asdf', ciphertext='aagk')
- request, response = test_client.post('/api/otp/decrypt_otp', content=json.dumps(body))
+ request, response = await async_client.post('/api/otp/decrypt_otp', content=json.dumps(body))
assert response.status_code == HTTPStatus.OK
assert response.json == dict(
ciphertext='AAGK',
@@ -24,6 +28,7 @@ def test_decrypt(test_client):
)
-def test_get_html(test_client):
- request, response = test_client.get('/api/otp/html')
+@pytest.mark.asyncio
+async def test_get_html(async_client):
+ request, response = await async_client.get('/api/otp/html')
assert response.status_code == HTTPStatus.OK
diff --git a/modules/videos/__init__.py b/modules/videos/__init__.py
index 534bf934f..e1c6de422 100644
--- a/modules/videos/__init__.py
+++ b/modules/videos/__init__.py
@@ -75,18 +75,19 @@ def video_cleanup():
with get_db_curs(commit=True) as curs:
# Delete all Videos if the FileModel no longer contains a video.
curs.execute('''
- WITH deleted AS
- (UPDATE file_group SET model=null WHERE model='video' AND mimetype NOT LIKE 'video/%' RETURNING id)
- DELETE FROM video WHERE file_group_id = ANY(select id from deleted)
- ''')
+ WITH deleted AS
+ (UPDATE file_group SET model = null WHERE model = 'video' AND mimetype NOT LIKE 'video/%' RETURNING id)
+ DELETE
+ FROM video
+ WHERE file_group_id = ANY (select id from deleted)
+ ''')
# Claim all Videos in a Channel's directory for that Channel. But, only if they have not yet been claimed.
curs.execute('''
- UPDATE video v
- SET
- channel_id = c.id
- FROM channel c
- LEFT JOIN file_group fg ON fg.primary_path LIKE c.directory || '/%'::VARCHAR
- WHERE
- v.channel_id IS NULL
- AND fg.id = v.file_group_id
- ''')
+ UPDATE video v
+ SET channel_id = c.id
+ FROM channel c
+ INNER JOIN collection col ON col.id = c.collection_id
+ LEFT JOIN file_group fg ON fg.primary_path LIKE col.directory || '/%'::VARCHAR
+ WHERE v.channel_id IS NULL
+ AND fg.id = v.file_group_id
+ ''')
diff --git a/modules/videos/channel/lib.py b/modules/videos/channel/lib.py
index 36d52f8c0..fad3abe84 100644
--- a/modules/videos/channel/lib.py
+++ b/modules/videos/channel/lib.py
@@ -1,16 +1,18 @@
import pathlib
from pathlib import Path
-from typing import List, Dict, Union
+from typing import List, Union
from sqlalchemy import or_, func, desc, asc
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, joinedload
from wrolpi import flags
+from wrolpi.collections import Collection
from wrolpi.common import logger, \
get_media_directory, wrol_mode_check, background_task
from wrolpi.db import get_db_curs, optional_session, get_db_session
from wrolpi.downloader import save_downloads_config, download_manager, Download
from wrolpi.errors import APIError, ValidationError, RefreshConflict
+from wrolpi.tags import Tag
from wrolpi.vars import PYTEST
from .. import schema
from ..common import check_for_channel_conflicts
@@ -30,15 +32,16 @@ async def get_minimal_channels() -> List[dict]:
# one that will consume the most resources.
stmt = '''
SELECT c.id,
- c.name AS "name",
- c.directory,
+ col.name AS "name",
+ col.directory,
c.url,
- t.name AS "tag_name",
+ t.name AS "tag_name",
c.video_count,
c.total_size,
c.minimum_frequency
FROM channel AS c
- LEFT JOIN tag t ON t.id = c.tag_id
+ INNER JOIN collection col ON col.id = c.collection_id
+ LEFT JOIN tag t ON t.id = col.tag_id
'''
curs.execute(stmt)
logger.debug(stmt)
@@ -60,19 +63,22 @@ def get_channel(session: Session, *, channel_id: int = None, source_id: str = No
"""
channel: COD = None # noqa
# Try to find the channel by the most reliable methods first.
+ # Always eagerly load Collection to ensure name/directory/tag_id are available
if channel_id:
- channel = session.query(Channel).filter_by(id=channel_id).one_or_none()
+ channel = session.query(Channel).options(joinedload(Channel.collection)).filter_by(id=channel_id).one_or_none()
if not channel and source_id:
- channel = session.query(Channel).filter_by(source_id=source_id).one_or_none()
+ channel = session.query(Channel).options(joinedload(Channel.collection)).filter_by(
+ source_id=source_id).one_or_none()
if not channel and url:
- channel = session.query(Channel).filter_by(url=url).one_or_none()
+ channel = session.query(Channel).options(joinedload(Channel.collection)).filter_by(url=url).one_or_none()
if not channel and directory:
directory = Path(directory)
if not directory.is_absolute():
directory = get_media_directory() / directory
- channel = session.query(Channel).filter_by(directory=directory).one_or_none()
+ channel = session.query(Channel).join(Collection).filter(
+ Collection.directory == str(directory)).one_or_none()
if not channel and name:
- channel = session.query(Channel).filter_by(name=name).one_or_none()
+ channel = session.query(Channel).join(Collection).filter(Collection.name == name).one_or_none()
if not channel:
raise UnknownChannel(f'No channel matches {channel_id=} {source_id=} {url=} {directory=}')
@@ -147,7 +153,17 @@ def create_channel(session: Session, data: schema.ChannelPostRequest, return_dic
directory = get_media_directory() / directory if not directory.is_absolute() else directory
directory.mkdir(parents=True, exist_ok=True)
- channel = Channel()
+ # Create Collection first
+ collection = Collection(
+ name=data.name,
+ kind='channel',
+ directory=directory,
+ )
+ session.add(collection)
+ session.flush([collection])
+
+ # Create Channel linked to Collection
+ channel = Channel(collection_id=collection.id)
session.add(channel)
session.flush([channel])
# Apply the changes now that we've OK'd them
@@ -179,22 +195,24 @@ async def search_channels_by_name(name: str, limit: int = 5, session: Session =
name_no_spaces = ''.join(name.split(' '))
if order_by_video_count:
stmt = session.query(Channel, func.count(Video.id).label('video_count')) \
+ .join(Collection) \
.filter(or_(
- Channel.name.ilike(f'%{name}%'),
- Channel.name.ilike(f'%{name_no_spaces}%'),
+ Collection.name.ilike(f'%{name}%'),
+ Collection.name.ilike(f'%{name_no_spaces}%'),
)) \
.outerjoin(Video, Video.channel_id == Channel.id) \
- .group_by(Channel.id, Channel.name) \
- .order_by(desc('video_count'), asc(Channel.name)) \
+ .group_by(Channel.id, Collection.id, Collection.name) \
+ .order_by(desc('video_count'), asc(Collection.name)) \
.limit(limit)
channels = [i[0] for i in stmt]
else:
stmt = session.query(Channel) \
+ .join(Collection) \
.filter(or_(
- Channel.name.ilike(f'%{name}%'),
- Channel.name.ilike(f'%{name_no_spaces}%'),
+ Collection.name.ilike(f'%{name}%'),
+ Collection.name.ilike(f'%{name_no_spaces}%'),
)) \
- .order_by(asc(Channel.name)) \
+ .order_by(asc(Collection.name)) \
.limit(limit)
channels = stmt.all()
return channels
@@ -270,6 +288,5 @@ async def tag_channel(tag_name: str | None, directory: pathlib.Path | None, chan
@optional_session
async def search_channels(tag_names: List[str], session: Session) -> List[Channel]:
"""Search Tagged Channels."""
- from wrolpi.tags import Tag
- channels = session.query(Channel).join(Tag).filter(Tag.name.in_(tag_names)).all()
+ channels = session.query(Channel).join(Collection).join(Tag).filter(Tag.name.in_(tag_names)).all()
return channels
diff --git a/modules/videos/channel/test/test_api.py b/modules/videos/channel/test/test_api.py
index 2f7aa300b..cee5c6b15 100644
--- a/modules/videos/channel/test/test_api.py
+++ b/modules/videos/channel/test/test_api.py
@@ -21,16 +21,18 @@
from wrolpi.test.common import assert_dict_contains
-def test_get_channels(test_directory, channel_factory, test_client):
+@pytest.mark.asyncio
+async def test_get_channels(test_directory, channel_factory, async_client):
channel_factory()
channel_factory()
channel_factory()
- request, response = test_client.get('/api/videos/channels')
+ request, response = await async_client.get('/api/videos/channels')
assert response.status_code == HTTPStatus.OK
assert len(response.json['channels']) == 3
-def test_get_video(test_client, test_session, simple_channel, video_factory):
+@pytest.mark.asyncio
+async def test_get_video(async_client, test_session, simple_channel, video_factory):
"""Test that you get can information about a video. Test that video file can be gotten."""
now_ = now()
video1 = video_factory(channel_id=simple_channel.id, title='vid1')
@@ -40,14 +42,14 @@ def test_get_video(test_client, test_session, simple_channel, video_factory):
test_session.commit()
# Test that a 404 is returned when no video exists
- _, response = test_client.get('/api/videos/video/10')
+ _, response = await async_client.get('/api/videos/video/10')
assert response.status_code == HTTPStatus.NOT_FOUND, response.json
assert response.json == {'code': 'UNKNOWN_VIDEO',
'error': 'Cannot find Video with id 10',
'message': 'The video could not be found.'}
# Get the video info we inserted
- _, response = test_client.get('/api/videos/video/1')
+ _, response = await async_client.get('/api/videos/video/1')
assert response.status_code == HTTPStatus.OK, response.json
assert_dict_contains(response.json['file_group'], {'title': 'vid1'})
@@ -92,9 +94,10 @@ async def test_channel_no_download_frequency(async_client, test_session, test_di
assert test_session.query(Channel).count() == 1
-def test_video_file_name(test_session, simple_video, test_client):
+@pytest.mark.asyncio
+async def test_video_file_name(test_session, simple_video, async_client):
"""If a Video has no title, the front-end can use the file name as the title."""
- _, resp = test_client.get(f'/api/videos/video/{simple_video.id}')
+ _, resp = await async_client.get(f'/api/videos/video/{simple_video.id}')
assert resp.status_code == HTTPStatus.OK
assert resp.json['file_group']['video']['video_path'] == 'simple_video.mp4'
assert resp.json['file_group']['video'].get('stem') == 'simple_video'
@@ -165,20 +168,21 @@ async def _post_channel(channel):
'message': 'Could not validate the contents of the request'}
-def test_channel_empty_url_doesnt_conflict(test_client, test_session, test_directory):
+@pytest.mark.asyncio
+async def test_channel_empty_url_doesnt_conflict(async_client, test_session, test_directory):
"""Two channels with empty URLs shouldn't conflict"""
channel_directory = tempfile.TemporaryDirectory(dir=test_directory).name
pathlib.Path(channel_directory).mkdir()
new_channel = dict(name='Fooz', directory=channel_directory)
- request, response = test_client.post('/api/videos/channels', json=new_channel)
+ request, response = await async_client.post('/api/videos/channels', json=new_channel)
assert response.status_code == HTTPStatus.CREATED, response.json
location = response.headers['Location']
channel_directory2 = tempfile.TemporaryDirectory(dir=test_directory).name
pathlib.Path(channel_directory2).mkdir()
new_channel = dict(name='Barz', directory=channel_directory2)
- request, response = test_client.post('/api/videos/channels', json=new_channel)
+ request, response = await async_client.post('/api/videos/channels', json=new_channel)
assert response.status_code == HTTPStatus.CREATED, response.json
assert location != response.headers['Location']
@@ -235,11 +239,12 @@ async def do_download(_, download):
assert list(events_history) == []
-def test_channel_post_directory(test_session, test_client, test_directory):
+@pytest.mark.asyncio
+async def test_channel_post_directory(test_session, async_client, test_directory):
"""A Channel can be created with or without an existing directory."""
# Channel can be created with a directory which is not on disk.
data = dict(name='foo', directory='foo')
- request, response = test_client.post('/api/videos/channels', content=json.dumps(data))
+ request, response = await async_client.post('/api/videos/channels', content=json.dumps(data))
assert response.status_code == HTTPStatus.CREATED
directory = test_session.query(Channel).filter_by(id=1).one().directory
assert (test_directory / 'foo') == directory
@@ -247,8 +252,9 @@ def test_channel_post_directory(test_session, test_client, test_directory):
assert directory.is_absolute()
-def test_channel_by_id(test_session, test_client, simple_channel, simple_video):
- request, response = test_client.get(f'/api/videos/channels/{simple_channel.id}')
+@pytest.mark.asyncio
+async def test_channel_by_id(test_session, async_client, simple_channel, simple_video):
+ request, response = await async_client.get(f'/api/videos/channels/{simple_channel.id}')
assert response.status_code == HTTPStatus.OK
@@ -334,10 +340,20 @@ async def test_change_channel_url(async_client, test_session, test_download_mana
assert channel.url == 'https://example.com/new-url', "Channel's URL was not changed."
-def test_search_videos_channel(test_client, test_session, video_factory):
+@pytest.mark.asyncio
+async def test_search_videos_channel(async_client, test_session, video_factory, test_directory):
+ from wrolpi.collections import Collection
+
with get_db_session(commit=True) as session:
- channel1 = Channel(name='Foo')
- channel2 = Channel(name='Bar')
+ # Create Collections first
+ collection1 = Collection(name='Foo', kind='channel', directory=test_directory / 'foo')
+ collection2 = Collection(name='Bar', kind='channel', directory=test_directory / 'bar')
+ session.add_all([collection1, collection2])
+ session.flush()
+
+ # Create Channels linked to Collections
+ channel1 = Channel(collection_id=collection1.id)
+ channel2 = Channel(collection_id=collection2.id)
session.add(channel1)
session.add(channel2)
session.flush()
@@ -346,7 +362,7 @@ def test_search_videos_channel(test_client, test_session, video_factory):
# Channels don't have videos yet
d = dict(channel_id=channel1.id)
- request, response = test_client.post(f'/api/videos/search', content=json.dumps(d))
+ request, response = await async_client.post(f'/api/videos/search', content=json.dumps(d))
assert response.status_code == HTTPStatus.OK
assert len(response.json['file_groups']) == 0
@@ -357,19 +373,19 @@ def test_search_videos_channel(test_client, test_session, video_factory):
session.add(vid2)
# Videos are gotten by their respective channels
- request, response = test_client.post(f'/api/videos/search', content=json.dumps(d))
+ request, response = await async_client.post(f'/api/videos/search', content=json.dumps(d))
assert response.status_code == HTTPStatus.OK
assert len(response.json['file_groups']) == 1
assert response.json['totals']['file_groups'] == 1
assert_dict_contains(response.json['file_groups'][0],
- dict(primary_path='vid2.mp4', video=dict(channel_id=channel1.id)))
+ dict(primary_path='foo/vid2.mp4', video=dict(channel_id=channel1.id)))
d = dict(channel_id=channel2.id)
- request, response = test_client.post(f'/api/videos/search', content=json.dumps(d))
+ request, response = await async_client.post(f'/api/videos/search', content=json.dumps(d))
assert response.status_code == HTTPStatus.OK
assert len(response.json['file_groups']) == 1
assert_dict_contains(response.json['file_groups'][0],
- dict(primary_path='vid1.mp4', video=dict(channel_id=channel2.id)))
+ dict(primary_path='bar/vid1.mp4', video=dict(channel_id=channel2.id)))
@pytest.mark.asyncio
@@ -413,35 +429,37 @@ async def test_channel_download_id(async_client, test_session, tag_factory, simp
urls=['https://example.com/channel1'],
tag_names=[tag.name],
downloader=test_downloader.name,
- settings=dict(channel_id=simple_channel.id),
+ settings=dict(),
+ collection_id=simple_channel.collection_id,
frequency=120,
)
request, response = await async_client.post(f'/api/download', json=body)
assert response.status_code == HTTPStatus.CREATED
download = test_session.query(Download).one()
assert download.url == 'https://example.com/channel1'
- assert download.channel_id == simple_channel.id
+ assert download.collection_id == simple_channel.collection_id
- # A Channel relationship can be removed from a Download.
+ # A Collection relationship can be removed from a Download.
body = dict(
urls=['https://example.com/channel1'],
tag_names=[tag.name],
downloader=test_downloader.name,
settings=dict(),
+ collection_id=None,
frequency=240,
)
request, response = await async_client.put(f'/api/download/{download.id}', json=body)
assert response.status_code == HTTPStatus.NO_CONTENT
download = test_session.query(Download).one()
- assert download.channel_id is None
+ assert download.collection_id is None
assert download.frequency == 240
- # A once-Download cannot be associated with a Channel.
+ # A once-Download cannot be associated with a Collection.
body = dict(
urls=['https://example.com/channel1'],
tag_names=[tag.name],
downloader=test_downloader.name,
- settings=dict(channel_id=simple_channel.id),
+ collection_id=simple_channel.collection_id,
)
request, response = await async_client.put(f'/api/download/{download.id}', json=body)
assert response.status_code == HTTPStatus.BAD_REQUEST
diff --git a/modules/videos/channel/test/test_models.py b/modules/videos/channel/test/test_models.py
index 4edb9f0ff..ae5faa09d 100644
--- a/modules/videos/channel/test/test_models.py
+++ b/modules/videos/channel/test/test_models.py
@@ -8,7 +8,7 @@
from wrolpi.downloader import Download, DownloadFrequency
-def test_delete_channel_no_url(test_session, test_client, channel_factory):
+def test_delete_channel_no_url(test_session, channel_factory):
"""
A Channel can be deleted even if it has no URL.
"""
@@ -104,8 +104,16 @@ async def test_channel_download_relationships(test_session, download_channel):
def test_channel_info_json(test_session, test_directory):
- channel = Channel(name='new channel')
+ from wrolpi.collections import Collection
+
+ # Create Collection first without directory
+ collection = Collection(name='new channel', kind='channel')
+ test_session.add(collection)
+ test_session.flush()
+
+ channel = Channel(collection_id=collection.id)
test_session.add(channel)
+ test_session.flush()
with pytest.raises(FileNotFoundError) as e:
# Cannot get info json because directory is not defined.
diff --git a/modules/videos/common.py b/modules/videos/common.py
index 149df67c9..58560d97e 100644
--- a/modules/videos/common.py
+++ b/modules/videos/common.py
@@ -20,7 +20,6 @@
from wrolpi.db import get_db_session, get_db_curs
from wrolpi.vars import DEFAULT_FILE_PERMISSIONS
from .errors import ChannelNameConflict, ChannelURLConflict, ChannelDirectoryConflict, ChannelSourceIdConflict
-from .models import Channel
logger = logger.getChild(__name__)
@@ -52,6 +51,8 @@ def check_for_channel_conflicts(session: Session, id_=None, url=None, name=None,
"""
Search for any channels that conflict with the provided args, raise a relevant exception if any conflicts are found.
"""
+ from .models import Channel
+
if not any([id_, url, name, directory]):
raise Exception('Cannot search for channel with no arguments')
@@ -68,11 +69,13 @@ def check_for_channel_conflicts(session: Session, id_=None, url=None, name=None,
if list(conflicts):
raise ChannelURLConflict()
if name:
- conflicts = base_where.filter(Channel.name == name)
+ from wrolpi.collections import Collection
+ conflicts = base_where.join(Collection).filter(Collection.name == name)
if list(conflicts):
raise ChannelNameConflict()
if directory:
- conflicts = base_where.filter(Channel.directory == directory)
+ from wrolpi.collections import Collection
+ conflicts = base_where.join(Collection).filter(Collection.directory == directory)
if list(conflicts):
raise ChannelDirectoryConflict()
if source_id:
@@ -269,6 +272,8 @@ def extract_video_duration(video_path: Path) -> Optional[int]:
async def update_view_counts_and_censored(channel_id: int):
"""Update view_count for all Videos in a channel using its info_json file. Also sets FileGroup.censored
if Video is no longer available on the Channel."""
+ from .models import Channel
+
with get_db_session() as session:
channel: Channel = session.query(Channel).filter_by(id=channel_id).one()
channel_name = channel.name
@@ -284,13 +289,14 @@ async def update_view_counts_and_censored(channel_id: int):
with get_db_curs(commit=True) as curs:
# Update the view_count for each video.
stmt = '''
- WITH source AS (select * from json_to_recordset(%s::json) as (id text, view_count int))
- UPDATE video
- SET view_count = s.view_count
- FROM source as s
- WHERE source_id=s.id AND channel_id=%s
- RETURNING video.id AS updated_ids
- '''
+ WITH source AS (select * from json_to_recordset(%s::json) as (id text, view_count int))
+ UPDATE video
+ SET view_count = s.view_count
+ FROM source as s
+ WHERE source_id = s.id
+ AND channel_id = %s
+ RETURNING video.id AS updated_ids \
+ '''
curs.execute(stmt, (view_counts_str, channel_id))
count = len(curs.fetchall())
logger.info(f'Updated {count} view counts in DB for {channel_name}.')
@@ -299,13 +305,13 @@ async def update_view_counts_and_censored(channel_id: int):
with get_db_curs(commit=True) as curs:
# Set FileGroup.censored if the video is no longer on the Channel.
stmt = '''
- UPDATE file_group fg
- SET censored = NOT (v.source_id = ANY(%(source_ids)s))
- FROM video v
- WHERE v.file_group_id = fg.id
- AND v.channel_id = %(channel_id)s
- RETURNING fg.id, fg.censored
- '''
+ UPDATE file_group fg
+ SET censored = NOT (v.source_id = ANY (%(source_ids)s))
+ FROM video v
+ WHERE v.file_group_id = fg.id
+ AND v.channel_id = %(channel_id)s
+ RETURNING fg.id, fg.censored \
+ '''
curs.execute(stmt, {'channel_id': channel_id, 'source_ids': source_ids})
censored = len([i for i in curs.fetchall() if i['censored']])
logger.info(f'Set {censored} censored videos for {channel_name}.')
diff --git a/modules/videos/conftest.py b/modules/videos/conftest.py
index 072605c41..f935029d7 100644
--- a/modules/videos/conftest.py
+++ b/modules/videos/conftest.py
@@ -22,11 +22,23 @@
@pytest.fixture
def simple_channel(test_session, test_directory) -> Channel:
"""Get a Channel with the minimum properties. This Channel has no download!"""
- channel = Channel(
- directory=test_directory,
+ from wrolpi.collections import Collection
+
+ # Create Collection first
+ collection = Collection(
name='Simple Channel',
+ kind='channel',
+ directory=test_directory,
+ )
+ test_session.add(collection)
+ test_session.flush([collection])
+
+ # Create Channel linked to Collection
+ channel = Channel(
+ collection_id=collection.id,
url='https://example.com/channel1',
)
+
test_session.add(channel)
test_session.commit()
return channel
@@ -36,21 +48,32 @@ def simple_channel(test_session, test_directory) -> Channel:
def channel_factory(test_session, test_directory):
"""Create a random Channel with a directory, and Download."""
from wrolpi.tags import Tag
+ from wrolpi.collections import Collection
def factory(source_id: str = None, download_frequency: DownloadFrequency = None, url: str = None, name: str = None,
directory: pathlib.Path = None, tag_name: str = None):
name = name or str(uuid4())
tag = Tag.find_by_name(tag_name) if tag_name else None
- channel = Channel(
+ tag_name = tag.name if tag else None
+ directory = directory or format_videos_destination(name, tag_name, url or f'https://example.com/{name}')
+ directory.mkdir(exist_ok=True, parents=True)
+
+ # Create Collection first
+ collection = Collection(
name=name,
+ kind='channel',
+ directory=directory,
+ tag_id=tag.id if tag else None,
+ )
+ test_session.add(collection)
+ test_session.flush([collection])
+
+ # Create Channel linked to Collection
+ channel = Channel(
+ collection_id=collection.id,
url=url or f'https://example.com/{name}',
source_id=source_id,
- tag=tag,
- tag_id=tag.id if tag else None,
)
- tag_name = tag.name if tag else None
- channel.directory = directory or format_videos_destination(name, tag_name, channel.url)
- channel.directory.mkdir(exist_ok=True, parents=True)
test_session.add(channel)
test_session.flush([channel])
test_session.add(Directory(path=channel.directory, name=channel.directory.name))
@@ -68,11 +91,25 @@ def factory(source_id: str = None, download_frequency: DownloadFrequency = None,
@pytest.fixture
def download_channel(test_session, test_directory, video_download_manager) -> Channel:
"""Get a test Channel that has a download frequency."""
- # Add a frequency to the test channel, then give it a download.
- channel = Channel(directory=test_directory, name='Download Channel', url='https://example.com/channel1',
- source_id='channel1')
+ from wrolpi.collections import Collection
+
+ # Create Collection first
+ collection = Collection(
+ name='Download Channel',
+ kind='channel',
+ directory=test_directory,
+ )
+ test_session.add(collection)
+ test_session.flush([collection])
+
+ # Create Channel linked to Collection
+ channel = Channel(
+ collection_id=collection.id,
+ url='https://example.com/channel1',
+ source_id='channel1',
+ )
test_session.add(channel)
- test_session.flush([channel, ])
+ test_session.flush([channel])
assert channel and channel.id and channel.url
download = channel.get_or_create_download(channel.url, 60, test_session)
assert download.url == channel.url
diff --git a/modules/videos/downloader.py b/modules/videos/downloader.py
index 994a7577e..f2c269c8b 100755
--- a/modules/videos/downloader.py
+++ b/modules/videos/downloader.py
@@ -5,7 +5,6 @@
import os.path
import pathlib
import re
-import sys
import traceback
from abc import ABC
from datetime import timedelta
@@ -208,7 +207,7 @@ async def do_download(self, download: Download) -> DownloadResult:
# The settings to send to the VideoDownloader.
settings = dict()
if channel:
- settings.update(dict(channel_id=channel.id, channel_url=download.url))
+ settings.update(dict(collection_id=channel.collection_id, channel_url=download.url))
if download.destination:
destination = get_absolute_media_path(download.destination)
settings['destination'] = str(destination) # Need str for JSON conversion
@@ -622,14 +621,20 @@ async def _get_channel(download: Download) \
except UnknownChannel:
# Destination must not be a channel.
pass
- channel_id = download.channel_id
+ # Look up Channel via Collection or channel_url from settings
+ collection_id = download.collection_id
channel_url = settings.get('channel_url')
- if not channel and (channel_id or channel_url):
+ if not channel and (collection_id or channel_url):
# Could not find Channel via yt-dlp info_json, use info from ChannelDownloader if it created this Download.
logger.info(f'Using download.settings to find channel')
try:
- channel = get_channel(channel_id=channel_id, url=channel_url, return_dict=False)
- logger.debug(f'Found channel with channel url {channel=}')
+ # Find Channel by collection_id or URL
+ if collection_id:
+ with get_db_session() as session:
+ channel = session.query(Channel).filter_by(collection_id=collection_id).one_or_none()
+ if not channel and channel_url:
+ channel = get_channel(url=channel_url, return_dict=False)
+ logger.debug(f'Found channel with collection_id or channel url {channel=}')
except UnknownChannel:
# We do not need a Channel if we have a destination directory.
if not destination:
diff --git a/modules/videos/lib.py b/modules/videos/lib.py
index 3864ef123..c6362dd03 100644
--- a/modules/videos/lib.py
+++ b/modules/videos/lib.py
@@ -252,7 +252,6 @@ def channels(self, value: dict):
self.update({'channels': value})
def import_config(self, file: pathlib.Path = None, send_events=False):
- from modules.videos.channel.lib import get_channel
super().import_config()
try:
channels = self.channels
@@ -268,31 +267,8 @@ def import_config(self, file: pathlib.Path = None, send_events=False):
for option in (i for i in REQUIRED_OPTIONS if i not in data):
raise ConfigError(f'Channel "{directory}" is required to have "{option}"')
- # Try to find Channel by directory because it is unique.
- channel = Channel.get_by_path(directory, session)
- if not channel:
- try:
- # Try to find Channel using other attributes before creating new Channel.
- channel = get_channel(
- session,
- source_id=data.get('source_id'),
- url=data.get('url'),
- directory=str(directory),
- return_dict=False,
- )
- except UnknownChannel:
- # Channel not yet in the DB, add it.
- channel = Channel(directory=directory)
- session.add(channel)
- channel.flush()
- # TODO refresh the files in the channel.
- logger.warning(f'Creating new Channel from config: {directory}')
-
- # Copy existing channel data, update all values from the config. This is necessary to clear out
- # values not in the config.
- full_data = channel.dict()
- full_data.update(data)
- channel.update(full_data)
+ # Create or update Channel from config (handles Collection creation)
+ channel = Channel.from_config(data, session=session)
updated_channel_ids.add(channel.id)
if not channel.source_id and channel.url and flags.have_internet.is_set():
@@ -519,8 +495,9 @@ def set_test_downloader_config(enabled: bool):
def get_channels_config_from_db(session: Session) -> dict:
"""Create a dictionary that contains all the Channels from the DB."""
- channels = session.query(Channel).order_by(Channel.directory).all()
- channels = sorted((i.config_view() for i in channels), key=lambda i: i['directory'])
+ from wrolpi.collections import Collection
+ channels = session.query(Channel).join(Collection).order_by(Collection.directory).all()
+ channels = sorted((i.config_view() for i in channels), key=lambda i: i['directory'] if i['directory'] else '')
return dict(channels=channels)
@@ -572,9 +549,12 @@ def import_channels_config():
def link_channel_and_downloads(session: Session, channel_: Type[Base] = Channel, download_: Type[Base] = Download):
- """Create any missing Downloads for any Channel.url/Channel.directory that has a Download. Associate any Download
- related to a Channel."""
- # Only Downloads with a frequency can be a Channel Download.
+ """Associate any Download related to a Channel via its Collection.
+
+ Downloads are linked to Collections (via collection_id) rather than directly to Channels.
+ This function finds Channels and links their Downloads to the Channel's Collection.
+ """
+ # Only Downloads with a frequency can be a Collection Download.
downloads = list(session.query(download_).filter(download_.frequency.isnot(None)).all())
# Download.url is unique and cannot be null.
downloads_by_url = {i.url: i for i in downloads}
@@ -586,26 +566,27 @@ def link_channel_and_downloads(session: Session, channel_: Type[Base] = Channel,
for channel in channels:
directory = str(channel.directory)
for download in downloads_with_destination:
- if download.settings['destination'] == directory and not download.channel_id:
- download.channel_id = channel.id
+ if download.settings['destination'] == directory and not download.collection_id:
+ download.collection_id = channel.collection_id
need_commit = True
download = downloads_by_url.get(channel.url)
- if download and not download.channel_id:
- download.channel_id = channel.id
+ if download and not download.collection_id:
+ download.collection_id = channel.collection_id
need_commit = True
# Get any Downloads for a Channel's RSS feed.
rss_url = channel.get_rss_url()
if rss_url and (download := downloads_by_url.get(rss_url)):
- download.channel_id = channel.id
- need_commit = True
+ if not download.collection_id:
+ download.collection_id = channel.collection_id
+ need_commit = True
# Associate any Download which shares a Channel's URL.
for download in downloads:
channel = channel_.get_by_url(download.url, session)
- if channel and not download.channel_id:
- download.channel_id = channel.id
+ if channel and not download.collection_id:
+ download.collection_id = channel.collection_id
need_commit = True
if need_commit:
@@ -676,10 +657,11 @@ async def get_statistics():
if monthly_videos else 0
curs.execute('''
- SELECT COUNT(c.id) AS "channels",
- COUNT(c.id) FILTER ( WHERE c.tag_id IS NOT NULL ) AS "tagged_channels"
+ SELECT COUNT(c.id) AS "channels",
+ COUNT(c.id) FILTER ( WHERE col.tag_id IS NOT NULL ) AS "tagged_channels"
FROM channel c
- LEFT JOIN public.tag t on t.id = c.tag_id
+ LEFT JOIN collection col on col.id = c.collection_id
+ LEFT JOIN public.tag t on t.id = col.tag_id
''')
channel_stats = dict(curs.fetchone())
ret = dict(statistics=dict(
diff --git a/modules/videos/models.py b/modules/videos/models.py
index 17596e591..a2a197fc9 100644
--- a/modules/videos/models.py
+++ b/modules/videos/models.py
@@ -531,9 +531,8 @@ def replace_info_json(self, info_json: dict, clean: bool = True, format_: bool =
class Channel(ModelHelper, Base):
__tablename__ = 'channel'
id = Column(Integer, primary_key=True)
- name = Column(String)
+ # name and directory are stored in Collection, accessed via properties
url = Column(String, unique=True) # will only be downloaded if related Download exists.
- directory: pathlib.Path = Column(MediaPathType)
generate_posters = Column(Boolean, default=False) # generating posters may delete files, and can be slow.
calculate_duration = Column(Boolean, default=True) # use ffmpeg to extract duration (slower than info json).
download_missing_data = Column(Boolean, default=True) # fetch missing data like `source_id` and video comments.
@@ -549,16 +548,65 @@ class Channel(ModelHelper, Base):
info_date = Column(Date)
videos: InstrumentedList = relationship('Video', primaryjoin='Channel.id==Video.channel_id')
- downloads: InstrumentedList = relationship('Download', primaryjoin='Download.channel_id==Channel.id')
- tag_id = Column(Integer, ForeignKey('tag.id', ondelete='CASCADE'))
- tag = relationship('Tag', primaryjoin='Channel.tag_id==Tag.id')
+ collection_id = Column(Integer, ForeignKey('collection.id', ondelete='CASCADE'))
+ collection = relationship('Collection', foreign_keys=[collection_id])
+
+ @property
+ def downloads(self) -> InstrumentedList:
+ """Get downloads from associated Collection."""
+ return self.collection.downloads if self.collection else []
def __repr__(self):
return f''
+ @property
+ def name(self) -> str | None:
+ """Delegate name to Collection"""
+ return self.collection.name if self.collection else None
+
+ @name.setter
+ def name(self, value: str):
+ """Set name on Collection"""
+ if self.collection:
+ self.collection.name = value
+
+ @property
+ def directory(self) -> pathlib.Path | None:
+ """Delegate directory to Collection"""
+ return self.collection.directory if self.collection else None
+
+ @directory.setter
+ def directory(self, value: pathlib.Path | str):
+ """Set directory on Collection"""
+ if self.collection:
+ self.collection.directory = value
+
+ @property
+ def tag(self):
+ """Delegate tag to Collection"""
+ return self.collection.tag if self.collection else None
+
+ @tag.setter
+ def tag(self, value):
+ """Set tag on Collection"""
+ if self.collection:
+ self.collection.tag = value
+
+ @property
+ def tag_id(self) -> int | None:
+ """Delegate tag_id to Collection"""
+ return self.collection.tag_id if self.collection else None
+
+ @tag_id.setter
+ def tag_id(self, value: int):
+ """Set tag_id on Collection"""
+ if self.collection:
+ self.collection.tag_id = value
+
@property
def tag_name(self) -> str | None:
- return self.tag.name if self.tag else None
+ """Delegate tag_name to Collection"""
+ return self.collection.tag_name if self.collection else None
@property
def location(self) -> str:
@@ -597,7 +645,7 @@ def update(self, data: dict):
data = data.copy()
# URL should not be empty string.
- url = data.pop('url')
+ url = data.pop('url', None)
self.url = url or None
session: Session = Session.object_session(self)
@@ -637,7 +685,7 @@ def update(self, data: dict):
continue
download = self.get_or_create_download(url, frequency, session=session, reset_attempts=True)
- download.channel_id = channel_id
+ download.collection_id = self.collection_id
session.add(download)
def config_view(self) -> dict:
@@ -657,13 +705,62 @@ def config_view(self) -> dict:
)
return config
+ @classmethod
+ def from_config(cls, data: dict, session: Session = None) -> 'Channel':
+ """
+ Create or update a Channel from config data. This also creates/updates the Collection.
+
+ Args:
+ data: Config dict containing channel metadata (name, directory, url, source_id, etc.)
+ session: Database session
+
+ Returns:
+ The created or updated Channel
+ """
+ from wrolpi.collections import Collection
+ from wrolpi.db import get_db_session
+
+ if session is None:
+ session = get_db_session()
+
+ # Ensure this is treated as a channel collection
+ data = data.copy()
+ data['kind'] = 'channel'
+
+ # Create or update the Collection first
+ collection = Collection.from_config(data, session=session)
+
+ # Extract Channel-specific fields
+ url = data.get('url')
+ source_id = data.get('source_id')
+
+ # Find existing Channel by collection_id or create new one
+ channel = session.query(cls).filter_by(collection_id=collection.id).one_or_none()
+
+ if not channel:
+ # Create new Channel
+ channel = cls(
+ collection_id=collection.id,
+ url=url,
+ source_id=source_id,
+ )
+ session.add(channel)
+ session.flush([channel])
+
+ # Use the update() method to handle all fields including downloads
+ channel.update(data)
+
+ return channel
+
@staticmethod
def get_by_path(path: pathlib.Path, session: Session) -> Optional['Channel']:
if not path:
raise RuntimeError('Must provide path to get Channel')
path = pathlib.Path(path) if isinstance(path, str) else path
path = str(path.absolute()) if path.is_absolute() else str(get_media_directory() / path)
- channel = session.query(Channel).filter_by(directory=path).one_or_none()
+ # Query through Collection relationship since directory is now on Collection
+ from wrolpi.collections import Collection
+ channel = session.query(Channel).join(Collection).filter(Collection.directory == path).one_or_none()
return channel
def __json__(self) -> dict:
@@ -679,6 +776,8 @@ def __json__(self) -> dict:
def dict(self, with_statistics: bool = False, with_downloads: bool = True) -> dict:
d = super(Channel, self).dict()
+ # Add Collection-delegated properties
+ d['name'] = self.name
d['tag_name'] = self.tag_name
d['rss_url'] = self.get_rss_url()
d['directory'] = \
@@ -777,26 +876,17 @@ def get_or_create_download(self, url: str, frequency: int, session: Session = No
reset_attempts: bool = False) -> Download:
"""Get a Download record, if it does not exist, create it. Create a Download if necessary
which goes into this Channel's directory.
- """
- if not isinstance(url, str) or not url:
- raise InvalidDownload(f'Cannot get Download without url')
- if not frequency:
- raise InvalidDownload('Download for Channel must have a frequency')
+ Delegates to Collection.get_or_create_download() with Channel-specific downloaders.
+ """
from modules.videos.downloader import ChannelDownloader, VideoDownloader
- download = Download.get_by_url(url, session=session)
- if not download:
- download = download_manager.recurring_download(url, frequency, ChannelDownloader.name, session=session,
- sub_downloader_name=VideoDownloader.name,
- destination=self.directory,
- reset_attempts=reset_attempts,
- )
- if reset_attempts:
- download.attempts = 0
- download.channel_id = self.id
-
- return download
+ return self.collection.get_or_create_download(
+ url, frequency, session=session,
+ reset_attempts=reset_attempts,
+ downloader_name=ChannelDownloader.name,
+ sub_downloader_name=VideoDownloader.name
+ )
def get_rss_url(self) -> str | None:
"""Return the RSS Feed URL for this Channel, if any is possible."""
diff --git a/modules/videos/test/test_api.py b/modules/videos/test/test_api.py
index 8d7a4d765..ed4ec4fbe 100644
--- a/modules/videos/test/test_api.py
+++ b/modules/videos/test/test_api.py
@@ -43,8 +43,9 @@ async def test_refresh_videos_index(async_client, test_session, test_directory,
assert not video.file_group.d_text, 'Video captions were not removed'
-def test_refresh_videos(test_client, test_session, test_directory, simple_channel, video_factory, video_file_factory,
- image_bytes_factory):
+@pytest.mark.asyncio
+async def test_refresh_videos(async_client, test_session, test_directory, simple_channel, video_factory,
+ video_file_factory, image_bytes_factory):
subdir = test_directory / 'subdir'
subdir.mkdir()
@@ -81,7 +82,7 @@ def test_refresh_videos(test_client, test_session, test_directory, simple_channe
stmt = "INSERT INTO video (file_group_id) values (%(video_id)s)"
curs.execute(stmt, {'video_id': str(video4_id)})
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
test_session.expire_all()
@@ -104,12 +105,13 @@ def test_refresh_videos(test_client, test_session, test_directory, simple_channe
# Remove video1's poster, video1 should be updated.
video1_video_path = str(video1.video_path)
video1.poster_path.unlink()
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
video1 = Video.get_by_path(video1_video_path, test_session)
assert not video1.poster_path
-def test_channels_with_videos(test_session, test_client, test_directory, channel_factory, video_factory):
+@pytest.mark.asyncio
+async def test_channels_with_videos(test_session, async_client, test_directory, channel_factory, video_factory):
channel1 = channel_factory('channel1', name='channel1')
channel2 = channel_factory('channel2', name='channel2')
vid1 = video_factory(channel_id=channel1.id, with_video_file=True)
@@ -132,10 +134,12 @@ def test_channels_with_videos(test_session, test_client, test_directory, channel
assert vid2_path.is_file(), 'Video file was deleted'
assert vid3_path.is_file(), 'Video file was deleted'
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
assert test_session.query(Video).count() == 3, 'Did not find correct amount of video files.'
- assert {i[0] for i in test_session.query(Channel.name)} == {'channel1', 'channel2'}, 'Channels were changed.'
+ from wrolpi.collections import Collection
+ assert {i[0] for i in test_session.query(Collection.name).join(Channel).filter(Collection.kind == 'channel')} == {
+ 'channel1', 'channel2'}, 'Channels were changed.'
vid1 = Video.get_by_path(vid1_path)
vid2 = Video.get_by_path(vid2_path)
@@ -146,11 +150,12 @@ def test_channels_with_videos(test_session, test_client, test_directory, channel
assert vid3.channel is None
-def test_api_download(test_session, test_client, test_directory):
+@pytest.mark.asyncio
+async def test_api_download(test_session, async_client, test_directory):
"""A video can be downloaded."""
content = dict(urls=['https://example.com/video1', ], downloader='video', destination='dest',
settings=dict(excluded_urls='example.com'))
- request, response = test_client.post('/api/download', content=json.dumps(content))
+ request, response = await async_client.post('/api/download', content=json.dumps(content))
assert response.status_code == HTTPStatus.CREATED
download = test_session.query(Download).one()
@@ -364,7 +369,8 @@ async def test_video_file_format(async_client, test_session, fake_now):
@pytest.mark.asyncio
-async def test_video_upload_file_tracking(test_session, async_client, video_factory, await_switches, make_multipart_form,
+async def test_video_upload_file_tracking(test_session, async_client, video_factory, await_switches,
+ make_multipart_form,
image_bytes_factory):
"""Test that uploading info.json, poster, and caption files via /api/files/upload properly tracks them in Video and FileGroup."""
from wrolpi.common import get_relative_to_media_directory
diff --git a/modules/videos/test/test_common.py b/modules/videos/test/test_common.py
index eb0f1f6c5..d1ca344a7 100644
--- a/modules/videos/test/test_common.py
+++ b/modules/videos/test/test_common.py
@@ -257,6 +257,11 @@ async def test_import_channel_delete_missing_channels(await_switches, test_sessi
channel1 = channel_factory(source_id='foo')
channel2 = channel_factory(source_id='bar')
test_session.commit()
+
+ # Capture directory values before deletion (channels will be detached after deletion)
+ channel1_dir = str(channel1.directory)
+ channel2_dir = str(channel2.directory)
+
# Write Channels to the config file.
save_channels_config()
test_session.delete(channel1)
@@ -267,29 +272,29 @@ async def test_import_channel_delete_missing_channels(await_switches, test_sessi
# Importing the config creates two Channels.
import_channels_config()
assert len(test_session.query(Channel).all()) == 2
- assert str(channel1.directory) in test_channels_config.read_text()
- assert str(channel2.directory) in test_channels_config.read_text()
+ assert channel1_dir in test_channels_config.read_text()
+ assert channel2_dir in test_channels_config.read_text()
# Delete channel2 from the config file.
config = get_channels_config()
config_dict = config.dict()
- config_dict['channels'] = [i for i in config.channels if i['directory'] != str(channel2.directory)]
+ config_dict['channels'] = [i for i in config.channels if i['directory'] != channel2_dir]
config.update(config_dict)
await await_switches()
- assert str(channel1.directory) in test_channels_config.read_text()
- assert str(channel2.directory) not in test_channels_config.read_text()
+ assert channel1_dir in test_channels_config.read_text()
+ assert channel2_dir not in test_channels_config.read_text()
# Importing the config deletes the Channel record.
import_channels_config()
assert len(test_session.query(Channel).all()) == 1
- assert str(channel1.directory) in test_channels_config.read_text()
- assert str(channel2.directory) not in test_channels_config.read_text()
+ assert channel1_dir in test_channels_config.read_text()
+ assert channel2_dir not in test_channels_config.read_text()
# Saving and importing does not change anything.
save_channels_config()
import_channels_config()
- assert str(channel1.directory) in test_channels_config.read_text()
- assert str(channel2.directory) not in test_channels_config.read_text()
+ assert channel1_dir in test_channels_config.read_text()
+ assert channel2_dir not in test_channels_config.read_text()
@pytest.mark.asyncio
diff --git a/modules/videos/test/test_downloader.py b/modules/videos/test/test_downloader.py
index de593af22..b0de821a6 100644
--- a/modules/videos/test/test_downloader.py
+++ b/modules/videos/test/test_downloader.py
@@ -173,7 +173,7 @@ def reset_downloads():
downloads = video_download_manager.get_once_downloads(test_session)
assert {i.url for i in downloads} == {'https://youtube.com/watch?v=video_2_url'}
assert downloads[0].settings == {
- 'channel_id': 1,
+ 'collection_id': simple_channel.collection_id,
'channel_url': 'https://www.youtube.com/c/LearningSelfReliance/videos',
'destination': str(test_directory),
}
@@ -191,7 +191,7 @@ def reset_downloads():
downloads = video_download_manager.get_once_downloads(test_session)
assert {i.url for i in downloads} == {'https://youtube.com/watch?v=video_1_url'}
assert downloads[0].settings == {
- 'channel_id': 1,
+ 'collection_id': simple_channel.collection_id,
'channel_url': 'https://www.youtube.com/c/LearningSelfReliance/videos',
'destination': str(test_directory),
}
@@ -203,13 +203,23 @@ async def test_get_or_create_channel(async_client, test_session, test_directory,
Attempt to use an existing Channel if we can match it.
"""
+ from wrolpi.collections import Collection
one, two = await tag_factory(), await tag_factory()
- c1 = Channel(name='foo', source_id='foo', url='https://example.com')
- c2 = Channel(name='bar', source_id='bar')
- c3 = Channel(name='baz', source_id='baz', url='https://example.net')
- c4 = Channel(name='qux')
- test_session.add_all([c1, c2, c3, c4])
+ # Create Collections for each Channel
+ def create_channel_with_collection(name, source_id=None, url=None):
+ collection = Collection(name=name, kind='channel')
+ test_session.add(collection)
+ test_session.flush([collection])
+ channel = Channel(collection_id=collection.id, source_id=source_id, url=url)
+ test_session.add(channel)
+ test_session.flush([channel])
+ return channel
+
+ c1 = create_channel_with_collection('foo', source_id='foo', url='https://example.com')
+ c2 = create_channel_with_collection('bar', source_id='bar')
+ c3 = create_channel_with_collection('baz', source_id='baz', url='https://example.net')
+ c4 = create_channel_with_collection('qux')
test_session.commit()
# All existing channels should be used.
@@ -549,7 +559,8 @@ def test_normalize_video_file_names(test_directory, video_download_manager):
@pytest.mark.asyncio
async def test_video_download_cookies(test_session, test_directory, mock_video_extract_info, await_switches,
- video_download_manager, mock_video_process_runner, image_file, test_videos_downloader_config):
+ video_download_manager, mock_video_process_runner, image_file,
+ test_videos_downloader_config):
config = get_videos_downloader_config()
config.browser_profile = str(test_directory / 'firefox/some directory')
@@ -577,7 +588,8 @@ async def test_video_download_cookies(test_session, test_directory, mock_video_e
@pytest.mark.asyncio
async def test_video_download_always_use_cookies(test_session, test_directory, mock_video_extract_info, await_switches,
- video_download_manager, mock_video_process_runner, image_file, test_videos_downloader_config):
+ video_download_manager, mock_video_process_runner, image_file,
+ test_videos_downloader_config):
config = get_videos_downloader_config()
config.always_use_browser_profile = True
diff --git a/modules/videos/test/test_lib.py b/modules/videos/test/test_lib.py
index 7b85d2839..89a01647f 100644
--- a/modules/videos/test/test_lib.py
+++ b/modules/videos/test/test_lib.py
@@ -191,13 +191,13 @@ def test_link_channel_and_downloads(test_session, channel_factory, test_download
test_session.add_all([download1, download2, download3])
test_session.commit()
assert test_session.query(Download).count() == 3
- assert not any(i.channel_id for i in test_download_manager.get_downloads(test_session))
+ assert not any(i.collection_id for i in test_download_manager.get_downloads(test_session))
- # `link_channel_and_downloads` creates missing Downloads.
+ # `link_channel_and_downloads` links Downloads to Collections.
lib.link_channel_and_downloads(session=test_session)
assert test_session.query(Download).count() == 3
- assert all(i.channel_id for i in test_download_manager.get_recurring_downloads(test_session))
- assert not any(i.channel_id for i in test_download_manager.get_once_downloads(test_session))
+ assert all(i.collection_id for i in test_download_manager.get_recurring_downloads(test_session))
+ assert not any(i.collection_id for i in test_download_manager.get_once_downloads(test_session))
@pytest.mark.asyncio
@@ -233,13 +233,13 @@ def test_link_channel_and_downloads_migration(async_client, test_session, channe
assert test_session.query(Download).count() == 4
d1, d2a, d2b, d2c = test_session.query(Download).order_by(Download.url).all()
assert d1.url == channel1.url == 'https://example.com/channel1' and d1.frequency == 1
- assert d1.channel_id
+ assert d1.collection_id == channel1.collection_id
assert d2a.url == 'https://example.com/channel2' and d2a.frequency == 1
- assert d2a.channel_id
+ assert d2a.collection_id == channel2.collection_id
assert d2b.url == 'https://example.com/channel2/rss' and d2b.frequency == 1
- assert d2b.channel_id
+ assert d2b.collection_id == channel2.collection_id
assert d2c.url == 'https://example.com/channel2/video/1' and d2c.frequency is None
- assert not d2c.channel_id
+ assert not d2c.collection_id
@pytest.mark.asyncio
diff --git a/modules/zim/lib.py b/modules/zim/lib.py
index 3ba8e673d..a855f97de 100644
--- a/modules/zim/lib.py
+++ b/modules/zim/lib.py
@@ -111,12 +111,11 @@ def get_all_entries_tags():
},
"""
stmt = '''
- SELECT tz.zim_id, tz.zim_entry, array_agg(t.name)::TEXT[]
- FROM
- tag t
- LEFT JOIN tag_zim tz on t.id = tz.tag_id
- GROUP BY 1, 2
- '''
+ SELECT tz.zim_id, tz.zim_entry, array_agg(t.name)::TEXT[]
+ FROM tag t
+ LEFT JOIN tag_zim tz on t.id = tz.tag_id
+ GROUP BY 1, 2 \
+ '''
with get_db_curs() as curs:
curs.execute(stmt)
entries = dict()
diff --git a/modules/zim/test/test_api.py b/modules/zim/test/test_api.py
index c449dcc9b..32a41edf5 100644
--- a/modules/zim/test/test_api.py
+++ b/modules/zim/test/test_api.py
@@ -133,18 +133,19 @@ async def test_zim_tag_and_untag(async_client, test_session, test_zim, tag_facto
assert response.status_code == HTTPStatus.NOT_FOUND
-def test_zim_subscribe(test_session, test_client):
+@pytest.mark.asyncio
+async def test_zim_subscribe(test_session, async_client):
"""A Kiwix subscription can be scheduled in the API. The language can be changed."""
# Subscribe to English.
content = dict(name='Wikipedia (mini)', language='en')
- request, response = test_client.post('/api/zim/subscribe', content=json.dumps(content))
+ request, response = await async_client.post('/api/zim/subscribe', content=json.dumps(content))
assert response.status_code == HTTPStatus.CREATED
download: Download = test_session.query(Download).one()
assert download.url == 'https://download.kiwix.org/zim/wikipedia/wikipedia_en_all_mini_'
assert download.info_json == {'language': 'en', 'name': 'Wikipedia (mini)'}
assert download.frequency
- request, response = test_client.get('/api/zim/subscribe')
+ request, response = await async_client.get('/api/zim/subscribe')
assert response.status_code == HTTPStatus.OK
assert response.json['subscriptions'] == {
'Wikipedia (mini)': {'download_id': 1,
@@ -155,7 +156,7 @@ def test_zim_subscribe(test_session, test_client):
# Change subscription to French.
content = dict(name='Wikipedia (mini)', language='fr')
- request, response = test_client.post('/api/zim/subscribe', content=json.dumps(content))
+ request, response = await async_client.post('/api/zim/subscribe', content=json.dumps(content))
assert response.status_code == HTTPStatus.CREATED
assert test_session.query(Download).count() == 1
@@ -163,7 +164,7 @@ def test_zim_subscribe(test_session, test_client):
assert download.url == 'https://download.kiwix.org/zim/wikipedia/wikipedia_fr_all_mini_'
assert download.info_json == {'language': 'fr', 'name': 'Wikipedia (mini)'}
assert download.frequency
- request, response = test_client.get('/api/zim/subscribe')
+ request, response = await async_client.get('/api/zim/subscribe')
assert response.status_code == HTTPStatus.OK
assert response.json['subscriptions'] == {
'Wikipedia (mini)': {'download_id': 1,
@@ -173,9 +174,9 @@ def test_zim_subscribe(test_session, test_client):
'name': 'Wikipedia (mini)'}}
# Delete subscription.
- request, response = test_client.delete('/api/zim/subscribe/1')
+ request, response = await async_client.delete('/api/zim/subscribe/1')
assert response.status_code == HTTPStatus.NO_CONTENT
- request, response = test_client.get('/api/zim/subscribe')
+ request, response = await async_client.get('/api/zim/subscribe')
assert response.status_code == HTTPStatus.OK
assert response.json['subscriptions'] == {}
# No subscriptions or downloads.
@@ -184,17 +185,18 @@ def test_zim_subscribe(test_session, test_client):
# English Stack Exchange is special.
content = dict(name='Stackoverflow (Stack Exchange)', language='en')
- request, response = test_client.post('/api/zim/subscribe', content=json.dumps(content))
+ request, response = await async_client.post('/api/zim/subscribe', content=json.dumps(content))
assert response.status_code == HTTPStatus.CREATED
download: Download = test_session.query(Download).one()
assert download.url == 'https://download.kiwix.org/zim/stack_exchange/stackoverflow.com_en_all_'
-def test_get_zim_entry(test_session, test_client, test_zim):
- request, response = test_client.get('/api/zim/1/entry/home')
+@pytest.mark.asyncio
+async def test_get_zim_entry(test_session, async_client, test_zim):
+ request, response = await async_client.get('/api/zim/1/entry/home')
assert response.status_code == HTTPStatus.OK
- request, response = test_client.get('/api/zim/1/entry/does not exist')
+ request, response = await async_client.get('/api/zim/1/entry/does not exist')
assert response.status_code == HTTPStatus.NOT_FOUND
diff --git a/repair.sh b/repair.sh
index 559b077a8..f10fe65bb 100755
--- a/repair.sh
+++ b/repair.sh
@@ -43,6 +43,7 @@ cp /opt/wrolpi/etc/raspberrypios/nginx.conf /etc/nginx/nginx.conf
[ -f /etc/nginx/conf.d/default.conf ] && rm /etc/nginx/conf.d/default.conf
cp /opt/wrolpi/etc/raspberrypios/wrolpi.conf /etc/nginx/conf.d/wrolpi.conf
cp /opt/wrolpi/etc/raspberrypios/50x.html /var/www/50x.html
+cp /opt/wrolpi/etc/raspberrypios/maintenance.html /var/www/maintenance.html
# Generate nginx certificate for HTTPS.
if [[ ! -f /etc/nginx/cert.crt || ! -f /etc/nginx/cert.key ]]; then
diff --git a/requirements.txt b/requirements.txt
index 50f5194f2..d27cf23aa 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -28,4 +28,4 @@ srt==3.5.3
tqdm==4.67.1
vininfo[cli]==1.8.0
webvtt-py==0.5.1
-yt-dlp==2025.09.26
+yt-dlp==2025.11.12
diff --git a/wrolpi/api_utils.py b/wrolpi/api_utils.py
index 78aceec15..a0af9e0ca 100644
--- a/wrolpi/api_utils.py
+++ b/wrolpi/api_utils.py
@@ -1,8 +1,6 @@
import asyncio
import json
import logging
-import multiprocessing
-import sys
from asyncio import CancelledError
from datetime import datetime, timezone, date
from decimal import Decimal
diff --git a/wrolpi/captions.py b/wrolpi/captions.py
index 8fcd784f4..092f12f4c 100755
--- a/wrolpi/captions.py
+++ b/wrolpi/captions.py
@@ -3,7 +3,7 @@
import subprocess
import tempfile
from pathlib import Path
-from typing import Generator, Union, Optional
+from typing import Generator, Union
import srt
import webvtt
diff --git a/wrolpi/cmd.py b/wrolpi/cmd.py
index 15f21623a..bad02dac6 100644
--- a/wrolpi/cmd.py
+++ b/wrolpi/cmd.py
@@ -169,9 +169,11 @@ async def run_command(cmd: tuple[str | pathlib.Path, ...], cwd: pathlib.Path | s
stderr = stderr_file.read_bytes() or b''
# Logs details of the call, but only if it took a long time or TRACE is enabled.
if log_command:
- logger.debug(f'run_command: finished ({elapsed=}s) with stdout={len(stdout)} stderr={len(stderr)}: {cmd_str=}')
+ logger.debug(
+ f'run_command: finished ({elapsed=}s) with stdout={len(stdout)} stderr={len(stderr)}: {cmd_str=}')
elif logger.isEnabledFor(TRACE_LEVEL):
- logger.trace(f'run_command: finished ({elapsed=}s) with stdout={len(stdout)} stderr={len(stderr)}: {cmd_str=}')
+ logger.trace(
+ f'run_command: finished ({elapsed=}s) with stdout={len(stdout)} stderr={len(stderr)}: {cmd_str=}')
return CommandResult(
return_code=proc.returncode,
cancelled=cancelled,
diff --git a/wrolpi/collections/__init__.py b/wrolpi/collections/__init__.py
new file mode 100644
index 000000000..185099ff5
--- /dev/null
+++ b/wrolpi/collections/__init__.py
@@ -0,0 +1,16 @@
+from . import lib
+from .config import (collections_config, CollectionsConfig, save_collections_config)
+from .errors import UnknownCollection
+from .models import Collection, CollectionItem
+from .types import collection_type_registry
+
+__all__ = [
+ 'Collection',
+ 'CollectionItem',
+ 'CollectionsConfig',
+ 'UnknownCollection',
+ 'collection_type_registry',
+ 'collections_config',
+ 'lib',
+ 'save_collections_config',
+]
diff --git a/wrolpi/collections/api.py b/wrolpi/collections/api.py
new file mode 100644
index 000000000..0b0dd3131
--- /dev/null
+++ b/wrolpi/collections/api.py
@@ -0,0 +1,261 @@
+"""
+Collection API Endpoints
+
+Unified REST API for all collection types (domains, channels, playlists, etc.).
+"""
+from http import HTTPStatus
+
+from sanic import Request, Blueprint, response
+from sanic_ext.extensions.openapi import openapi
+
+from wrolpi.api_utils import json_response
+from wrolpi.common import wrol_mode_check
+from wrolpi.schema import JSONErrorResponse
+from . import lib, schema
+from .errors import UnknownCollection
+
+# Create blueprint
+collection_bp = Blueprint('Collection', url_prefix='/api/collections')
+
+
+@collection_bp.get('/')
+@openapi.summary('Get all collections')
+@openapi.parameter('kind', str, 'query', description='Filter by collection kind (e.g., domain, channel)',
+ required=False)
+@openapi.response(HTTPStatus.OK, description="List of collections with metadata")
+async def get_collections_endpoint(request: Request):
+ """
+ Get all collections, optionally filtered by kind.
+
+ This unified endpoint works for all collection types. Use the 'kind' query parameter
+ to filter by collection type (e.g., ?kind=domain or ?kind=channel).
+
+ Examples:
+ GET /api/collections - Get all collections
+ GET /api/collections?kind=domain - Get only domain collections
+ GET /api/collections?kind=channel - Get only channel collections
+
+ Returns:
+ - collections: List of collection objects
+ - totals: Count of collections
+ - metadata: UI metadata (columns, fields, routes, messages) if kind is specified
+ """
+ kind = request.args.get('kind')
+
+ collections = lib.get_collections(kind=kind)
+
+ # Include metadata if a specific kind is requested
+ response_data = {
+ 'collections': collections,
+ 'totals': {'collections': len(collections)}
+ }
+
+ if kind:
+ response_data['metadata'] = lib.get_collection_metadata(kind)
+
+ return json_response(response_data)
+
+
+@collection_bp.get('/')
+@openapi.summary('Get a single collection by ID')
+@openapi.response(HTTPStatus.OK, description="Collection details with statistics")
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+async def get_collection_endpoint(_: Request, collection_id: int):
+ """
+ Get details for a single collection including type-specific statistics.
+
+ Returns collection metadata plus statistics specific to the collection type:
+ - Domain collections: includes url_count and size
+ - Channel collections: includes video_count and size
+ - Other types: includes item_count and total_size
+ """
+ try:
+ collection_data = lib.get_collection_with_stats(collection_id)
+ return json_response({'collection': collection_data})
+ except UnknownCollection as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.NOT_FOUND)
+
+
+@collection_bp.put('/')
+@openapi.definition(
+ summary='Update a collection',
+ body=schema.CollectionUpdateRequest,
+ validate=True,
+)
+@openapi.response(HTTPStatus.OK, description="Collection updated successfully")
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+@openapi.response(HTTPStatus.BAD_REQUEST, JSONErrorResponse)
+@wrol_mode_check
+async def put_collection_endpoint(_: Request, collection_id: int, body: schema.CollectionUpdateRequest):
+ """
+ Update collection properties (directory, tag, description).
+
+ This endpoint works for all collection types. You can update:
+ - directory: Set or clear the collection's directory restriction
+ - tag_name: Set or clear the collection's tag (requires directory)
+ - description: Set or update the collection's description
+
+ Note: To clear a field, pass an empty string. To leave unchanged, omit the field.
+ """
+ try:
+ collection = lib.update_collection(
+ collection_id=collection_id,
+ directory=body.directory,
+ tag_name=body.tag_name,
+ description=body.description
+ )
+
+ # Return updated collection data
+ collection_data = lib.get_collection_with_stats(collection_id)
+ return json_response({'collection': collection_data})
+
+ except UnknownCollection as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.NOT_FOUND)
+ except Exception as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.BAD_REQUEST)
+
+
+@collection_bp.post('//refresh')
+@openapi.summary('Refresh files in collection directory')
+@openapi.response(HTTPStatus.OK, description="Collection refresh started")
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+@openapi.response(HTTPStatus.BAD_REQUEST, JSONErrorResponse)
+@wrol_mode_check
+async def refresh_collection_endpoint(_: Request, collection_id: int):
+ """
+ Queue a refresh for a collection's directory.
+
+ This scans the collection's directory for new or modified files and updates
+ the database accordingly. Only works for directory-restricted collections.
+
+ The refresh happens asynchronously in the background.
+ """
+ try:
+ lib.refresh_collection(collection_id, send_events=True)
+ return json_response({'message': 'Collection refresh started'})
+
+ except UnknownCollection as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.NOT_FOUND)
+ except Exception as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.BAD_REQUEST)
+
+
+@collection_bp.post('//tag')
+@openapi.definition(
+ summary='Tag a collection and optionally move files',
+ body=schema.CollectionTagRequest,
+ validate=True,
+)
+@openapi.response(HTTPStatus.OK, schema.CollectionTagResponse, description="Collection tagged successfully")
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+@openapi.response(HTTPStatus.BAD_REQUEST, JSONErrorResponse)
+@wrol_mode_check
+async def tag_collection_endpoint(_: Request, collection_id: int, body: schema.CollectionTagRequest):
+ """
+ Tag a collection and optionally move its files to a new directory.
+
+ Tagging a collection:
+ 1. Creates or assigns the specified tag
+ 2. Optionally moves the collection to a new directory
+ 3. Updates the collection's metadata
+
+ If no directory is specified, the collection must already have one.
+ """
+ try:
+ result = lib.tag_collection(
+ collection_id=collection_id,
+ tag_name=body.tag_name,
+ directory=body.directory
+ )
+ return json_response(result)
+
+ except UnknownCollection as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.NOT_FOUND)
+ except Exception as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.BAD_REQUEST)
+
+
+@collection_bp.post('//tag_info')
+@openapi.definition(
+ summary='Get tag information for a collection',
+ body=schema.CollectionTagInfoRequest,
+ validate=True,
+)
+@openapi.response(HTTPStatus.OK, schema.CollectionTagInfoResponse, description="Tag info retrieved successfully")
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+async def get_tag_info_endpoint(_: Request, collection_id: int, body: schema.CollectionTagInfoRequest):
+ """
+ Get information about tagging a collection with a specific tag.
+
+ Returns the suggested directory path and checks for conflicts with existing collections.
+ Domain collections cannot share directories with other domain collections,
+ but can share with channel collections.
+
+ This is useful for showing users the suggested directory before they commit to tagging.
+ """
+ try:
+ tag_info = lib.get_tag_info(
+ collection_id=collection_id,
+ tag_name=body.tag_name
+ )
+ return json_response(tag_info)
+
+ except UnknownCollection as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.NOT_FOUND)
+
+
+@collection_bp.delete('/')
+@openapi.definition(
+ summary='Delete a collection',
+)
+@openapi.response(HTTPStatus.NO_CONTENT, description="Collection deleted successfully")
+@openapi.response(HTTPStatus.NOT_FOUND, JSONErrorResponse)
+@wrol_mode_check
+def delete_collection_endpoint(_: Request, collection_id: int):
+ """
+ Delete a collection and orphan its child items.
+
+ For domain collections:
+ - Orphans child Archives (sets collection_id to NULL)
+ - Deletes the Collection record
+ - Archives remain in the database but are no longer associated with a domain
+ - Triggers domain config save
+ """
+ try:
+ collection = lib.delete_collection(collection_id=collection_id)
+ from wrolpi.events import Events
+ Events.send_deleted(f'Deleted {collection["kind"]} collection: {collection["name"]}')
+ return response.raw('', HTTPStatus.NO_CONTENT)
+
+ except UnknownCollection as e:
+ return json_response({'error': str(e)}, status=HTTPStatus.NOT_FOUND)
+
+
+@collection_bp.post('/search')
+@openapi.definition(
+ summary='Search collections',
+ body=schema.CollectionSearchRequest,
+ validate=True,
+)
+@openapi.response(HTTPStatus.OK, description="Search results")
+async def search_collections_endpoint(_: Request, body: schema.CollectionSearchRequest):
+ """
+ Search collections by kind, tags, and name.
+
+ Supports filtering by:
+ - kind: Collection type (domain, channel, etc.)
+ - tag_names: List of tag names (returns collections with any of these tags)
+ - search_str: Search string for collection names (case-insensitive partial match)
+
+ All filters are optional and can be combined.
+ """
+ collections = lib.search_collections(
+ kind=body.kind,
+ tag_names=body.tag_names if body.tag_names else None,
+ search_str=body.search_str
+ )
+
+ return json_response({
+ 'collections': collections,
+ 'totals': {'collections': len(collections)}
+ })
diff --git a/wrolpi/collections/config.py b/wrolpi/collections/config.py
new file mode 100644
index 000000000..a570a70de
--- /dev/null
+++ b/wrolpi/collections/config.py
@@ -0,0 +1,163 @@
+import pathlib
+from dataclasses import dataclass, field
+from typing import List
+
+from wrolpi.common import ConfigFile, logger
+from wrolpi.db import get_db_session
+from wrolpi.events import Events
+from wrolpi.switches import register_switch_handler, ActivateSwitchMethod
+
+logger = logger.getChild(__name__)
+
+__all__ = ['CollectionsConfig', 'collections_config', 'save_collections_config']
+
+
+@dataclass
+class CollectionsConfigValidator:
+ """Validator for collections config file."""
+ version: int = 0
+ collections: List[dict] = field(default_factory=list)
+
+
+class CollectionsConfig(ConfigFile):
+ """
+ Config file for Collections.
+
+ This config stores Collection metadata (name, description, directory, tag)
+ but NOT the individual items. Items are stored in the database.
+
+ For directory-restricted collections, items are auto-populated from the directory.
+ For unrestricted collections, items are managed manually through the API.
+
+ Format:
+ collections:
+ - name: "My Videos"
+ description: "Favorite videos"
+ directory: "videos/favorites" # optional, enables auto-population
+ tag_name: "favorites" # optional
+ - name: "Reading List"
+ description: "Books to read"
+ # No directory means manual item management
+ """
+ file_name = 'collections.yaml'
+ validator = CollectionsConfigValidator
+ default_config = dict(
+ version=0,
+ collections=[],
+ )
+ # Use wider width to accommodate longer paths
+ width = 120
+
+ def __getitem__(self, item):
+ return self._config[item]
+
+ def __setitem__(self, key, value):
+ self._config[key] = value
+
+ @property
+ def collections(self) -> List[dict]:
+ """Get list of collection configs."""
+ return self._config.get('collections', [])
+
+ def import_config(self, file: pathlib.Path = None, send_events=False):
+ """Import collections from config file into database."""
+ from .models import Collection
+
+ super().import_config(file, send_events)
+
+ file_str = str(self.get_relative_file())
+ collections_data = self._config.get('collections', [])
+
+ if not collections_data:
+ logger.info(f'No collections to import from {file_str}')
+ self.successful_import = True
+ return
+
+ logger.info(f'Importing {len(collections_data)} collections from {file_str}')
+
+ try:
+ with get_db_session(commit=True) as session:
+ # Track which collections were imported
+ imported_dirs = set() # absolute paths for directory-restricted collections
+ imported_pairs = set() # (name, kind) for unrestricted collections
+
+ # Import each collection using from_config
+ for idx, collection_data in enumerate(collections_data):
+ try:
+ name = collection_data.get('name')
+ if not name:
+ logger.error(f'Collection at index {idx} has no name, skipping')
+ continue
+
+ # Use Collection.from_config to create/update
+ # This will auto-populate items if directory exists
+ collection = Collection.from_config(collection_data, session)
+
+ if collection.directory:
+ # Normalize to absolute string for comparison
+ imported_dirs.add(str(collection.directory))
+ else:
+ imported_pairs.add((collection.name, collection.kind))
+
+ except Exception as e:
+ logger.error(f'Failed to import collection at index {idx}', exc_info=e)
+ continue
+
+ # Delete collections that are no longer in config
+ all_collections = session.query(Collection).all()
+ for collection in all_collections:
+ if collection.directory:
+ if str(collection.directory) not in imported_dirs:
+ logger.info(
+ f'Deleting collection {repr(collection.name)} at {collection.directory} (no longer in config)')
+ session.delete(collection)
+ else:
+ if (collection.name, collection.kind) not in imported_pairs:
+ logger.info(
+ f"Deleting unrestricted collection {repr(collection.name)} kind={collection.kind} (no longer in config)")
+ session.delete(collection)
+
+ total_imported = len(imported_dirs) + len(imported_pairs)
+ logger.info(f'Successfully imported {total_imported} collections from {file_str}')
+ self.successful_import = True
+
+ except Exception as e:
+ self.successful_import = False
+ message = f'Failed to import {file_str} config!'
+ logger.error(message, exc_info=e)
+ if send_events:
+ Events.send_config_import_failed(message)
+ raise
+
+ def dump_config(self, file: pathlib.Path = None, send_events=False, overwrite=False):
+ """Dump all collections from database to config file."""
+ from .models import Collection
+
+ logger.info('Dumping collections to config')
+
+ with get_db_session() as session:
+ # Order by name for consistency
+ collections = session.query(Collection).order_by(Collection.name).all()
+
+ # Use to_config to export each collection
+ collections_data = [collection.to_config() for collection in collections]
+
+ self._config['collections'] = collections_data
+
+ logger.info(f'Dumping {len(collections_data)} collections to config')
+ self.save(file, send_events, overwrite)
+
+
+# Global instance
+collections_config = CollectionsConfig()
+
+
+# Switch handler for saving collections config
+@register_switch_handler('save_collections_config')
+def save_collections_config():
+ """Save the collections config when the switch is activated."""
+ collections_config.background_dump.activate_switch()
+
+
+# Explicit type for activate_switch helper
+save_collections_config: ActivateSwitchMethod
diff --git a/wrolpi/collections/errors.py b/wrolpi/collections/errors.py
new file mode 100644
index 000000000..197f5951c
--- /dev/null
+++ b/wrolpi/collections/errors.py
@@ -0,0 +1,9 @@
+from wrolpi.errors import APIError
+
+__all__ = ['UnknownCollection']
+
+
+class UnknownCollection(APIError):
+ """Cannot find Collection"""
+ code = 'UNKNOWN_COLLECTION'
+ summary = 'Cannot find Collection'
diff --git a/wrolpi/collections/lib.py b/wrolpi/collections/lib.py
new file mode 100644
index 000000000..009175988
--- /dev/null
+++ b/wrolpi/collections/lib.py
@@ -0,0 +1,652 @@
+"""
+Collection Library Functions
+
+Provides generic operations for collections of any kind (domains, channels, etc.).
+These functions are used by both the unified collection API and legacy endpoints.
+"""
+import asyncio
+from typing import List, Optional, Dict
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from wrolpi.common import logger, get_media_directory, get_relative_to_media_directory
+from wrolpi.db import get_db_session, optional_session
+from wrolpi.errors import ValidationError
+from wrolpi.events import Events
+from wrolpi.tags import Tag
+from .errors import UnknownCollection
+from .models import Collection
+
+logger = logger.getChild(__name__)
+
+__all__ = [
+ 'get_collections',
+ 'get_collection_with_stats',
+ 'get_collection_metadata',
+ 'update_collection',
+ 'refresh_collection',
+ 'tag_collection',
+ 'get_tag_info',
+ 'delete_collection',
+ 'search_collections',
+]
+
+
+@optional_session
+def get_collections(kind: Optional[str] = None, session: Session = None) -> List[dict]:
+ """
+ Get all collections, optionally filtered by kind.
+
+ Args:
+ kind: Optional collection kind to filter by (e.g., 'domain', 'channel')
+ session: Database session
+
+ Returns:
+ List of collection dicts with statistics for each collection type
+ """
+ query = session.query(Collection)
+
+ if kind:
+ query = query.filter(Collection.kind == kind)
+
+ collections = query.order_by(Collection.name).all()
+
+ # Convert to JSON and add type-specific statistics
+ result = []
+ for collection in collections:
+ data = collection.__json__()
+
+ # Compute minimum download frequency from all downloads
+ if collection.downloads:
+ # Filter to recurring downloads (frequency > 0) and get minimum
+ recurring_frequencies = [d.frequency for d in collection.downloads if d.frequency and d.frequency > 0]
+ data['min_download_frequency'] = min(recurring_frequencies) if recurring_frequencies else None
+ else:
+ data['min_download_frequency'] = None
+
+ # Add type-specific statistics
+ if collection.kind == 'domain':
+ # Add archive statistics for domain collections
+ from modules.archive import Archive
+ from wrolpi.files.models import FileGroup
+
+ stats_query = session.query(
+ func.count(Archive.id).label('archive_count'),
+ func.sum(FileGroup.size).label('size')
+ ).outerjoin(
+ FileGroup, FileGroup.id == Archive.file_group_id
+ ).filter(
+ Archive.collection_id == collection.id
+ ).one()
+
+ data['archive_count'] = stats_query.archive_count or 0
+ data['size'] = stats_query.size or 0
+ data['domain'] = data['name'] # Alias for backward compatibility
+
+ elif collection.kind == 'channel':
+ # Add video statistics for channel collections
+ from modules.videos.models import Video, Channel
+ from wrolpi.files.models import FileGroup
+
+ # Get the Channel associated with this collection
+ channel = session.query(Channel).filter(
+ Channel.collection_id == collection.id
+ ).one_or_none()
+
+ if channel:
+ stats_query = session.query(
+ func.count(Video.id).label('video_count'),
+ func.sum(FileGroup.size).label('size')
+ ).outerjoin(
+ FileGroup, FileGroup.id == Video.file_group_id
+ ).filter(
+ Video.channel_id == channel.id
+ ).one()
+
+ data['video_count'] = int(stats_query.video_count or 0)
+ data['total_size'] = int(stats_query.size or 0)
+ data['channel_id'] = channel.id # Include actual Channel ID for frontend links
+ else:
+ data['video_count'] = 0
+ data['total_size'] = 0
+ data['channel_id'] = None
+
+ result.append(data)
+
+ return result
+
+
+@optional_session
+def get_collection_with_stats(collection_id: int, session: Session = None) -> dict:
+ """
+ Get a single collection with type-specific statistics.
+
+ Args:
+ collection_id: The collection ID
+ session: Database session
+
+ Returns:
+ Collection dict with statistics
+
+ Raises:
+ UnknownCollection: If collection not found
+ """
+ collection = session.query(Collection).filter_by(id=collection_id).one_or_none()
+
+ if not collection:
+ raise UnknownCollection(f"Collection with ID {collection_id} not found")
+
+ # Get base collection data
+ data = collection.__json__()
+
+ # Add type-specific statistics
+ # This can be extended for different collection types
+ if collection.kind == 'domain':
+ # Add archive statistics for domain collections
+ from modules.archive import Archive
+ from wrolpi.files.models import FileGroup
+
+ stats_query = session.query(
+ func.count(Archive.id).label('archive_count'),
+ func.sum(FileGroup.size).label('size')
+ ).outerjoin(
+ FileGroup, FileGroup.id == Archive.file_group_id
+ ).filter(
+ Archive.collection_id == collection_id
+ ).one()
+
+ data['archive_count'] = stats_query.archive_count or 0
+ data['size'] = stats_query.size or 0
+ data['domain'] = data['name'] # Alias for backward compatibility
+
+ elif collection.kind == 'channel':
+ # Add video statistics for channel collections
+ from modules.videos.models import Video, Channel
+ from wrolpi.files.models import FileGroup
+
+ # Get the Channel associated with this collection
+ channel = session.query(Channel).filter(
+ Channel.collection_id == collection_id
+ ).one_or_none()
+
+ if channel:
+ stats_query = session.query(
+ func.count(Video.id).label('video_count'),
+ func.sum(FileGroup.size).label('size')
+ ).outerjoin(
+ FileGroup, FileGroup.id == Video.file_group_id
+ ).filter(
+ Video.channel_id == channel.id
+ ).one()
+
+ data['video_count'] = int(stats_query.video_count or 0)
+ data['total_size'] = int(stats_query.size or 0)
+ else:
+ data['video_count'] = 0
+ data['total_size'] = 0
+
+ return data
+
+
+def get_collection_metadata(kind: str) -> dict:
+ """
+ Get metadata for a collection kind.
+
+ Metadata includes:
+ - Column definitions for table display
+ - Field configurations for edit forms
+ - Route information for navigation
+ - User-facing messages
+
+ Args:
+ kind: Collection kind ('domain', 'channel', etc.)
+
+ Returns:
+ Metadata dict with columns, fields, routes, and messages
+ """
+ if kind == 'domain':
+ return {
+ 'kind': 'domain',
+ 'columns': [
+ {'key': 'domain', 'label': 'Domain', 'sortable': True},
+ {'key': 'tag_name', 'label': 'Tag', 'sortable': True},
+ {'key': 'archive_count', 'label': 'Archives', 'sortable': True, 'align': 'right'},
+ {'key': 'min_download_frequency', 'label': 'Download Frequency', 'sortable': True, 'format': 'frequency'},
+ {'key': 'size', 'label': 'Size', 'sortable': True, 'align': 'right', 'format': 'bytes'},
+ {'key': 'actions', 'label': 'Manage', 'sortable': False, 'type': 'actions'}
+ ],
+ 'fields': [
+ {'key': 'directory', 'label': 'Directory', 'type': 'text',
+ 'placeholder': 'Optional directory path'},
+ {'key': 'tag_name', 'label': 'Tag', 'type': 'tag',
+ 'placeholder': 'Select or create tag', 'depends_on': 'directory'},
+ {'key': 'description', 'label': 'Description', 'type': 'textarea',
+ 'placeholder': 'Optional description'}
+ ],
+ 'routes': {
+ 'list': '/archive/domains',
+ 'edit': '/archive/domain/:id/edit',
+ 'search': '/archive',
+ 'searchParam': 'domain'
+ },
+ 'messages': {
+ 'no_directory': 'Set a directory to enable tagging',
+ 'tag_will_move': 'Tagging will move files to a new directory'
+ }
+ }
+ elif kind == 'channel':
+ return {
+ 'kind': 'channel',
+ 'columns': [
+ {'key': 'name', 'label': 'Name', 'sortable': True},
+ {'key': 'tag_name', 'label': 'Tag', 'sortable': True},
+ {'key': 'video_count', 'label': 'Videos', 'sortable': True, 'align': 'right'},
+ {'key': 'min_download_frequency', 'label': 'Download Frequency', 'sortable': True, 'format': 'frequency'},
+ {'key': 'total_size', 'label': 'Size', 'sortable': True, 'align': 'right', 'format': 'bytes'},
+ {'key': 'actions', 'label': 'Manage', 'sortable': False, 'type': 'actions'}
+ ],
+ 'fields': [
+ {'key': 'name', 'label': 'Name', 'type': 'text', 'required': True},
+ {'key': 'url', 'label': 'URL', 'type': 'text', 'placeholder': 'Channel URL'},
+ {'key': 'directory', 'label': 'Directory', 'type': 'text', 'required': True},
+ {'key': 'tag_name', 'label': 'Tag', 'type': 'tag',
+ 'placeholder': 'Select or create tag'},
+ {'key': 'description', 'label': 'Description', 'type': 'textarea',
+ 'placeholder': 'Optional description'}
+ ],
+ 'routes': {
+ 'list': '/videos/channels',
+ 'edit': '/videos/channel/:id/edit',
+ 'search': '/videos/channel/:id/video',
+ 'id_field': 'channel_id' # Use channel_id instead of collection id for URLs
+ },
+ 'messages': {
+ 'no_directory': 'Directory is required for channels',
+ 'tag_will_move': 'Tagging will move files to a new directory'
+ }
+ }
+ else:
+ # Generic metadata for unknown kinds
+ return {
+ 'kind': kind or 'collection',
+ 'columns': [
+ {'key': 'name', 'label': 'Name', 'sortable': True},
+ {'key': 'item_count', 'label': 'Items', 'sortable': True, 'align': 'right'},
+ {'key': 'total_size', 'label': 'Size', 'sortable': True, 'align': 'right', 'format': 'bytes'},
+ {'key': 'tag_name', 'label': 'Tag', 'sortable': True},
+ {'key': 'actions', 'label': 'Manage', 'sortable': False, 'type': 'actions'}
+ ],
+ 'fields': [
+ {'key': 'name', 'label': 'Name', 'type': 'text', 'required': True},
+ {'key': 'directory', 'label': 'Directory', 'type': 'text'},
+ {'key': 'tag_name', 'label': 'Tag', 'type': 'tag'},
+ {'key': 'description', 'label': 'Description', 'type': 'textarea'}
+ ],
+ 'routes': {
+ 'list': '/collections',
+ 'edit': '/collection/:id/edit',
+ 'search': '/search'
+ },
+ 'messages': {
+ 'no_directory': 'Set a directory to enable tagging',
+ 'tag_will_move': 'Tagging will move files to a new directory'
+ }
+ }
+
+
+@optional_session
+def update_collection(
+ collection_id: int,
+ directory: Optional[str] = None,
+ tag_name: Optional[str] = None,
+ description: Optional[str] = None,
+ session: Session = None
+) -> Collection:
+ """
+ Update a collection's properties.
+
+ Args:
+ collection_id: The collection ID
+ directory: New directory (relative or absolute path), or None to clear
+ tag_name: New tag name, empty string to clear, or None to leave unchanged
+ description: New description, or None to leave unchanged
+ session: Database session
+
+ Returns:
+ Updated Collection object
+
+ Raises:
+ UnknownCollection: If collection not found
+ ValidationError: If validation fails
+ """
+ collection = session.query(Collection).filter_by(id=collection_id).one_or_none()
+
+ if not collection:
+ raise UnknownCollection(f"Collection with ID {collection_id} not found")
+
+ # Update directory if provided
+ if directory is not None:
+ if directory:
+ # Convert relative path to absolute with validation
+ from .models import validate_collection_directory
+ directory_path = validate_collection_directory(directory)
+ collection.directory = directory_path
+ else:
+ collection.directory = None
+
+ # Update description if provided
+ if description is not None:
+ collection.description = description
+
+ # Update tag if provided
+ if tag_name is not None:
+ if tag_name:
+ # Set or create tag
+ if not collection.directory:
+ raise ValidationError(
+ f"Collection '{collection.name}' has no directory. "
+ f"Set a directory before tagging."
+ )
+ tag = session.query(Tag).filter_by(name=tag_name).one_or_none()
+ if not tag:
+ tag = Tag(name=tag_name)
+ session.add(tag)
+ session.flush()
+ collection.tag_id = tag.id
+ elif tag_name == '':
+ # Clear tag (empty string explicitly clears)
+ collection.tag_id = None
+
+ session.flush()
+
+ # Trigger domain config save if this is a domain collection
+ if collection.kind == 'domain':
+ from modules.archive.lib import save_domains_config
+ save_domains_config.activate_switch()
+
+ return collection
+
+
+def refresh_collection(collection_id: int, send_events: bool = True) -> None:
+ """
+ Refresh all files in a collection's directory.
+
+ Args:
+ collection_id: The collection ID
+ send_events: Whether to send events about the refresh
+
+ Raises:
+ UnknownCollection: If collection not found
+ ValidationError: If collection has no directory
+ """
+ from wrolpi.files.lib import refresh_files
+
+ with get_db_session() as session:
+ collection = session.query(Collection).filter_by(id=collection_id).one_or_none()
+
+ if not collection:
+ raise UnknownCollection(f"Collection with ID {collection_id} not found")
+
+ if not collection.directory:
+ raise ValidationError(
+ f"Collection '{collection.name}' has no directory. "
+ f"Set a directory before refreshing."
+ )
+
+ directory = collection.directory
+
+ # Refresh files asynchronously
+ asyncio.ensure_future(refresh_files([directory], send_events=send_events))
+
+ if send_events:
+ relative_dir = get_relative_to_media_directory(directory)
+ Events.send_directory_refresh(f'Refreshing: {relative_dir}')
+
+
+@optional_session
+def tag_collection(
+ collection_id: int,
+ tag_name: Optional[str] = None,
+ directory: Optional[str] = None,
+ session: Session = None
+) -> Dict:
+ """
+ Tag a collection and optionally move files to a new directory, or remove a tag if no tag_name is provided.
+
+ Args:
+ collection_id: The collection ID
+ tag_name: Tag name to apply, or None to remove the tag
+ directory: Optional new directory for the collection
+ session: Database session
+
+ Returns:
+ Dict with tag information and suggested directory
+
+ Raises:
+ UnknownCollection: If collection not found
+ ValidationError: If tagging requirements not met
+ """
+ collection = session.query(Collection).filter_by(id=collection_id).one_or_none()
+
+ if not collection:
+ raise UnknownCollection(f"Collection with ID {collection_id} not found")
+
+ # If tag_name is None, remove the tag from the collection
+ if tag_name is None:
+ collection.tag_id = None
+ session.flush()
+
+ # Trigger domain config save if this is a domain collection
+ if collection.kind == 'domain':
+ from modules.archive.lib import save_domains_config
+ save_domains_config.activate_switch()
+
+ # Return info about the un-tagging operation
+ relative_dir = get_relative_to_media_directory(collection.directory) if collection.directory else None
+ return {
+ 'collection_id': collection.id,
+ 'collection_name': collection.name,
+ 'tag_name': None,
+ 'directory': str(relative_dir) if relative_dir else None,
+ 'will_move_files': False,
+ }
+
+ # Get or create the tag
+ tag = session.query(Tag).filter_by(name=tag_name).one_or_none()
+ if not tag:
+ tag = Tag(name=tag_name)
+ session.add(tag)
+ session.flush()
+
+ # Determine target directory
+ if directory:
+ # User specified a directory - validate it
+ from .models import validate_collection_directory
+ target_directory = validate_collection_directory(directory)
+ elif collection.directory:
+ # Use existing directory
+ target_directory = collection.directory
+ else:
+ # Suggest a directory based on collection type and name
+ if collection.kind == 'domain':
+ from modules.archive.lib import get_archive_directory
+ base_dir = get_archive_directory()
+ else:
+ base_dir = get_media_directory()
+
+ target_directory = base_dir / collection.name
+
+ # Apply the tag
+ collection.tag_id = tag.id
+ collection.directory = target_directory
+
+ session.flush()
+
+ # Trigger domain config save if this is a domain collection
+ if collection.kind == 'domain':
+ from modules.archive.lib import save_domains_config
+ save_domains_config.activate_switch()
+
+ # Return info about the tagging operation
+ relative_dir = get_relative_to_media_directory(target_directory)
+ return {
+ 'collection_id': collection.id,
+ 'collection_name': collection.name,
+ 'tag_name': tag_name,
+ 'directory': str(relative_dir),
+ 'will_move_files': collection.directory != target_directory,
+ }
+
+
+@optional_session
+def get_tag_info(
+ collection_id: int,
+ tag_name: Optional[str],
+ session: Session = None
+) -> Dict:
+ """
+ Get information about tagging a collection with a specific tag.
+
+ Returns the suggested directory and checks for conflicts with existing collections.
+
+ Args:
+ collection_id: The collection ID
+ tag_name: Tag name to check
+ session: Database session
+
+ Returns:
+ Dict with suggested_directory, conflict flag, and optional conflict_message
+
+ Raises:
+ UnknownCollection: If collection not found
+ """
+ collection = session.query(Collection).filter_by(id=collection_id).one_or_none()
+
+ if not collection:
+ raise UnknownCollection(f"Collection with ID {collection_id} not found")
+
+ # Use the collection's format_directory method to get the suggested directory
+ suggested_directory = collection.format_directory(tag_name)
+
+ # Check for directory conflicts with other domain collections
+ conflict = False
+ conflict_message = None
+
+ # Only check for conflicts if this is a domain collection
+ if collection.kind == 'domain':
+ # Check if another domain collection already has this directory
+ existing_domain = session.query(Collection).filter(
+ Collection.directory == str(suggested_directory),
+ Collection.kind == 'domain',
+ Collection.id != collection_id
+ ).first()
+
+ if existing_domain:
+ conflict = True
+ conflict_message = (
+ f"A domain collection '{existing_domain.name}' already uses this directory. "
+ f"Choose a different tag or directory."
+ )
+
+ # Return relative path for the frontend
+ relative_dir = get_relative_to_media_directory(suggested_directory)
+
+ return {
+ 'suggested_directory': str(relative_dir),
+ 'conflict': conflict,
+ 'conflict_message': conflict_message,
+ }
+
+
+@optional_session
+def delete_collection(
+ collection_id: int,
+ session: Session = None
+) -> Dict:
+ """
+ Delete a collection and orphan its child items.
+
+ For domain collections:
+ - Orphans child Archives (sets collection_id to NULL)
+ - Deletes the Collection record
+ - Triggers domain config save
+
+ Args:
+ collection_id: The collection ID to delete
+ session: Database session
+
+ Returns:
+ Dict with collection information
+
+ Raises:
+ UnknownCollection: If collection not found
+ """
+ collection = session.query(Collection).filter_by(id=collection_id).one_or_none()
+
+ if not collection:
+ raise UnknownCollection(f"Collection with ID {collection_id} not found")
+
+ collection_dict = {
+ 'id': collection.id,
+ 'name': collection.name,
+ 'kind': collection.kind,
+ }
+
+ # Orphan child Archives if this is a domain collection
+ if collection.kind == 'domain':
+ from modules.archive.models import Archive
+ archives = session.query(Archive).filter_by(collection_id=collection_id).all()
+ for archive in archives:
+ archive.collection_id = None
+ session.flush()
+
+ # Trigger domain config save
+ from modules.archive.lib import save_domains_config
+ save_domains_config.activate_switch()
+
+ # Delete the collection
+ session.delete(collection)
+ session.flush()
+
+ return collection_dict
+
+
+@optional_session
+def search_collections(
+ kind: Optional[str] = None,
+ tag_names: Optional[List[str]] = None,
+ search_str: Optional[str] = None,
+ session: Session = None
+) -> List[dict]:
+ """
+ Search collections by kind, tags, and search string.
+
+ Args:
+ kind: Optional collection kind filter
+ tag_names: Optional list of tag names to filter by
+ search_str: Optional search string for collection names
+ session: Database session
+
+ Returns:
+ List of matching collection dicts
+ """
+ query = session.query(Collection)
+
+ # Filter by kind
+ if kind:
+ query = query.filter(Collection.kind == kind)
+
+ # Filter by tags
+ if tag_names:
+ query = query.join(Tag).filter(Tag.name.in_(tag_names))
+
+ # Filter by search string
+ if search_str:
+ query = query.filter(Collection.name.ilike(f'%{search_str}%'))
+
+ # Order by name
+ query = query.order_by(Collection.name)
+
+ collections = query.all()
+ return [collection.__json__() for collection in collections]
diff --git a/wrolpi/collections/models.py b/wrolpi/collections/models.py
new file mode 100644
index 000000000..31e17a515
--- /dev/null
+++ b/wrolpi/collections/models.py
@@ -0,0 +1,814 @@
+import pathlib
+from typing import Optional, List
+
+from sqlalchemy import Column, Integer, String, ForeignKey, Text, DateTime, Index, UniqueConstraint, func
+from sqlalchemy.orm import relationship, Session
+from sqlalchemy.orm.collections import InstrumentedList
+
+from wrolpi.common import Base, ModelHelper, logger, get_media_directory
+from wrolpi.db import optional_session
+from wrolpi.files.models import FileGroup
+from wrolpi.media_path import MediaPathType
+from wrolpi.tags import Tag
+from .errors import UnknownCollection
+from .types import collection_type_registry
+
+logger = logger.getChild(__name__)
+
+__all__ = ['Collection', 'CollectionItem']
+
+
+def validate_collection_directory(directory: pathlib.Path) -> pathlib.Path:
+ """Validate and normalize a collection directory path.
+
+ Args:
+ directory: Directory path (relative or absolute)
+
+ Returns:
+ Normalized absolute path under media directory
+
+ Raises:
+ ValidationError: If absolute path is outside media directory
+ """
+ from wrolpi.errors import ValidationError
+
+ media_directory = get_media_directory()
+ directory = pathlib.Path(directory)
+
+ if not directory.is_absolute():
+ # Relative path - make absolute under media directory
+ directory = media_directory / directory
+ else:
+ # Absolute path - must be under media directory
+ try:
+ directory.relative_to(media_directory)
+ except ValueError:
+ raise ValidationError(
+ f"Collection directory must be under media directory {media_directory}, "
+ f"but got {directory}"
+ )
+
+ return directory
+
+
+class Collection(ModelHelper, Base):
+ """
+ A Collection is a grouping of FileGroups (videos, archives, ebooks, etc).
+
+ Collections can be:
+ - Directory-restricted: Only contains files within a specific directory tree
+ - Unrestricted: Contains files from anywhere in the media library
+
+ Collections maintain order through the CollectionItem junction table.
+ """
+ __tablename__ = 'collection'
+
+ __table_args__ = (
+ UniqueConstraint('directory', name='uq_collection_directory'),
+ UniqueConstraint('name', 'kind', name='uq_collection_name_kind'),
+ Index('idx_collection_kind', 'kind'),
+ )
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String, nullable=False)
+ description = Column(Text)
+
+ # Type of collection to keep separation of kinds (e.g., videos, archives, books, etc.)
+ kind = Column(String, nullable=False, server_default='channel', default='channel')
+
+ # Optional directory restriction - if set, only files in this directory tree can be added
+ directory: Optional[pathlib.Path] = Column(MediaPathType, nullable=True)
+
+ # Optional tag relationship (similar to Channel)
+ tag_id = Column(Integer, ForeignKey('tag.id', ondelete='CASCADE'))
+ tag = relationship('Tag', primaryjoin='Collection.tag_id==Tag.id')
+
+ created_date = Column(DateTime, server_default=func.now())
+
+ # Columns updated by triggers (similar to Channel)
+ item_count = Column(Integer, default=0)
+ total_size = Column(Integer, default=0)
+
+ # Relationship to items (ordered)
+ items: InstrumentedList = relationship(
+ 'CollectionItem',
+ primaryjoin='Collection.id==CollectionItem.collection_id',
+ back_populates='collection',
+ order_by='CollectionItem.position',
+ cascade='all, delete-orphan'
+ )
+
+ # Relationship to downloads
+ downloads: InstrumentedList = relationship(
+ 'Download',
+ primaryjoin='Download.collection_id==Collection.id',
+ back_populates='collection'
+ )
+
+ def __repr__(self):
+ return f''
+
+ @property
+ def tag_name(self) -> str | None:
+ return self.tag.name if self.tag else None
+
+ @property
+ def location(self) -> str:
+ """The href of the collection in the App."""
+ return f'/collections/{self.id}'
+
+ @property
+ def is_directory_restricted(self) -> bool:
+ """Returns True if this collection is restricted to a specific directory."""
+ return self.directory is not None
+
+ @staticmethod
+ def is_valid_domain_name(name: str) -> bool:
+ """
+ Validate that a name is a valid domain format.
+
+ A valid domain must:
+ - Be a string
+ - Contain at least one "." (e.g., "example.com")
+ - Not start or end with "."
+
+ Examples:
+ Valid: "example.com", "sub.example.com", "a.b.c"
+ Invalid: "example", "example.", ".example", "."
+
+ Args:
+ name: The domain name to validate
+
+ Returns:
+ True if valid domain format, False otherwise
+ """
+ # Use the registry for validation
+ return collection_type_registry.validate('domain', name)
+
+ def to_config(self) -> dict:
+ """
+ Export this Collection's metadata to a config dict.
+ Only includes the minimum data necessary for reconstruction.
+ """
+ config = {
+ 'name': self.name,
+ 'kind': self.kind,
+ }
+
+ if self.description:
+ config['description'] = self.description
+
+ if self.directory:
+ # Store absolute path for consistency with Channel config
+ config['directory'] = str(self.directory)
+
+ if self.tag_name:
+ config['tag_name'] = self.tag_name
+
+ # Include downloads if any exist
+ if self.downloads:
+ config['downloads'] = [
+ {'url': d.url, 'frequency': d.frequency}
+ for d in self.downloads
+ ]
+
+ return config
+
+ @optional_session
+ def get_or_create_download(self, url: str, frequency: int, session: Session = None,
+ reset_attempts: bool = False,
+ downloader_name: str = None,
+ sub_downloader_name: str = None) -> 'Download':
+ """Get or create a Download for this Collection.
+
+ Args:
+ url: The URL to download
+ frequency: Seconds between re-downloading
+ session: Database session
+ reset_attempts: Reset download attempts if True
+ downloader_name: Name of the downloader to use
+ sub_downloader_name: Name of the sub-downloader for child downloads
+
+ Returns:
+ The Download instance
+
+ Raises:
+ InvalidDownload: If url or frequency is missing
+ """
+ from wrolpi.downloader import Download, download_manager
+ from wrolpi.errors import InvalidDownload
+
+ if not url:
+ raise InvalidDownload('Cannot get Download without url')
+ if not frequency:
+ raise InvalidDownload('Download for Collection must have a frequency')
+
+ download = Download.get_by_url(url, session=session)
+ if not download:
+ download = download_manager.recurring_download(
+ url, frequency, downloader_name, session=session,
+ sub_downloader_name=sub_downloader_name,
+ destination=str(self.directory) if self.directory else None,
+ reset_attempts=reset_attempts,
+ )
+ if reset_attempts:
+ download.attempts = 0
+ download.collection_id = self.id
+ return download
+
+ @staticmethod
+ @optional_session
+ def from_config(data: dict, session: Session = None) -> 'Collection':
+ """
+ Create or update a Collection from config data.
+
+ If the collection has a directory, auto-populate with FileGroups from that directory.
+ If the collection has no directory, items are managed manually (not from config).
+
+ Args:
+ data: Config dict containing collection metadata
+ session: Database session
+
+ Returns:
+ The created or updated Collection
+
+ Raises:
+ ValueError: If required fields are missing
+ """
+ name = data.get('name')
+ if not name:
+ raise ValueError('Collection config must have a name')
+
+ description = data.get('description')
+ directory = data.get('directory')
+ tag_name = data.get('tag_name')
+ kind = (data.get('kind') or 'channel').strip()
+ # Validate known kinds (forward-compatible: allow future values)
+ if kind not in {'channel', 'domain'}:
+ logger.warning(f"Unknown collection kind '{kind}', defaulting to 'channel'")
+ kind = 'channel'
+
+ # Validate collection name using the type registry
+ if not collection_type_registry.validate(kind, name):
+ description = collection_type_registry.get_description(kind) or "Invalid name format"
+ raise ValueError(f'Invalid {kind} name for collection: {repr(name)}. {description}')
+
+ # Convert directory to absolute path if provided, with validation
+ if directory:
+ directory = validate_collection_directory(directory)
+
+ # Try to find existing collection:
+ # - If directory is provided, resolve by unique directory
+ # - Otherwise, fall back to first match by (name, kind)
+ if directory:
+ collection = Collection.get_by_path(directory, session)
+ else:
+ collection = session.query(Collection).filter_by(name=name, kind=kind).first()
+
+ if collection:
+ # Update existing collection
+ logger.debug(f'Updating collection from config: {name}')
+ collection.description = description
+ collection.directory = directory
+ collection.kind = kind
+
+ if tag_name:
+ tag = Tag.get_by_name(tag_name, session)
+ if tag:
+ collection.tag = tag
+ else:
+ logger.warning(f'Tag {repr(tag_name)} not found for collection {repr(name)}')
+ collection.tag = None
+ else:
+ collection.tag = None
+ else:
+ # Create new collection
+ logger.info(f'Creating new collection from config: {name}')
+ collection = Collection(
+ name=name,
+ description=description,
+ directory=directory,
+ kind=kind,
+ )
+
+ if tag_name:
+ tag = Tag.get_by_name(tag_name, session)
+ if tag:
+ collection.tag = tag
+ else:
+ logger.warning(f'Tag {repr(tag_name)} not found for collection {repr(name)}')
+
+ session.add(collection)
+
+ session.flush([collection])
+
+ # If directory-restricted, populate with FileGroups from that directory
+ if collection.directory and collection.directory.is_dir():
+ logger.info(f'Populating collection {repr(name)} from directory {collection.directory}')
+ collection.populate_from_directory(session=session)
+ elif collection.directory and not collection.directory.is_dir():
+ logger.warning(f'Collection directory does not exist: {collection.directory}')
+
+ return collection
+
+ def populate_from_directory(self, session: Session = None):
+ """
+ Populate this collection with all FileGroups in the collection's directory.
+ Only works for directory-restricted collections.
+ """
+ session = session or Session.object_session(self)
+
+ if not self.is_directory_restricted:
+ logger.warning(f'Cannot populate unrestricted collection from directory: {self}')
+ return
+
+ if not self.directory.is_dir():
+ logger.warning(f'Cannot populate collection, directory does not exist: {self.directory}')
+ return
+
+ # Find all FileGroups in this directory tree
+ file_groups = session.query(FileGroup).filter(
+ FileGroup.primary_path.like(f'{self.directory}/%')
+ ).all()
+
+ if not file_groups:
+ logger.debug(f'No FileGroups found in {self.directory}')
+ return
+
+ # Get existing items to avoid duplicates
+ existing_fg_ids = {item.file_group_id for item in self.items}
+
+ # Add new FileGroups
+ new_file_groups = [fg for fg in file_groups if fg.id not in existing_fg_ids]
+ if new_file_groups:
+ self.add_file_groups(new_file_groups, session=session)
+ logger.info(f'Added {len(new_file_groups)} FileGroups to collection {repr(self.name)}')
+
+ def validate_file_group(self, file_group: FileGroup) -> bool:
+ """
+ Check if a FileGroup can be added to this Collection.
+
+ Returns True if the FileGroup is valid for this collection:
+ - If directory-restricted, the file must be in the directory tree
+ - Otherwise, any file is valid
+ """
+ if not self.is_directory_restricted:
+ return True
+
+ # Check if the file_group's primary_path is within the collection's directory
+ try:
+ file_path = file_group.primary_path
+ if file_path:
+ # Check if file is in directory tree
+ file_path.relative_to(self.directory)
+ return True
+ except (ValueError, AttributeError):
+ # relative_to raises ValueError if path is not relative
+ return False
+
+ return False
+
+ def add_file_group(self, file_group: FileGroup, position: Optional[int] = None,
+ session: Session = None) -> 'CollectionItem':
+ """
+ Add a FileGroup to this Collection.
+
+ Args:
+ file_group: The FileGroup to add
+ position: Optional position in the collection (None = append to end)
+ session: Database session
+
+ Returns:
+ The created CollectionItem
+
+ Raises:
+ ValueError: If the file_group cannot be added to this collection
+ """
+ session = session or Session.object_session(self)
+
+ if not self.validate_file_group(file_group):
+ raise ValueError(
+ f'FileGroup {file_group.id} at {file_group.primary_path} '
+ f'cannot be added to collection "{self.name}" (directory restriction: {self.directory})'
+ )
+
+ # Check if already exists
+ existing = session.query(CollectionItem).filter_by(
+ collection_id=self.id,
+ file_group_id=file_group.id
+ ).first()
+
+ if existing:
+ logger.warning(f'FileGroup {file_group.id} already in Collection {self.id}')
+ return existing
+
+ # Determine position
+ if position is None:
+ # Append to end
+ max_position = session.query(func.max(CollectionItem.position)).filter_by(
+ collection_id=self.id
+ ).scalar()
+ position = (max_position or 0) + 1
+ else:
+ # Insert at specific position - need to shift existing items
+ self._shift_positions(position, shift_by=1, session=session)
+
+ item = CollectionItem(
+ collection_id=self.id,
+ file_group_id=file_group.id,
+ position=position
+ )
+ session.add(item)
+ session.flush([item])
+
+ return item
+
+ def add_file_groups(self, file_groups: List[FileGroup], session: Session = None) -> List['CollectionItem']:
+ """
+ Add multiple FileGroups to this Collection in batch.
+ More efficient than calling add_file_group in a loop.
+
+ Args:
+ file_groups: List of FileGroups to add
+ session: Database session
+
+ Returns:
+ List of created CollectionItems
+ """
+ session = session or Session.object_session(self)
+
+ # Validate all file groups first
+ for fg in file_groups:
+ if not self.validate_file_group(fg):
+ raise ValueError(
+ f'FileGroup {fg.id} at {fg.primary_path} '
+ f'cannot be added to collection "{self.name}" (directory restriction: {self.directory})'
+ )
+
+ # Get existing items to avoid duplicates
+ existing_fg_ids = {item.file_group_id for item in self.items}
+ new_file_groups = [fg for fg in file_groups if fg.id not in existing_fg_ids]
+
+ if not new_file_groups:
+ logger.debug(f'All FileGroups already in Collection {self.id}')
+ return []
+
+ # Get starting position
+ max_position = session.query(func.max(CollectionItem.position)).filter_by(
+ collection_id=self.id
+ ).scalar()
+ position = (max_position or 0) + 1
+
+ # Create items in batch
+ items = []
+ for fg in new_file_groups:
+ item = CollectionItem(
+ collection_id=self.id,
+ file_group_id=fg.id,
+ position=position
+ )
+ items.append(item)
+ position += 1
+
+ session.add_all(items)
+ session.flush(items)
+
+ logger.debug(f'Added {len(items)} FileGroups to Collection {self.id}')
+ return items
+
+ def remove_file_group(self, file_group_id: int, session: Session = None):
+ """Remove a FileGroup from this Collection."""
+ session = session or Session.object_session(self)
+
+ item = session.query(CollectionItem).filter_by(
+ collection_id=self.id,
+ file_group_id=file_group_id
+ ).first()
+
+ if item:
+ position = item.position
+ session.delete(item)
+ session.flush()
+
+ # Shift remaining items down
+ self._shift_positions(position + 1, shift_by=-1, session=session)
+
+ def remove_file_groups(self, file_group_ids: List[int], session: Session = None):
+ """
+ Remove multiple FileGroups from this Collection in batch.
+ More efficient than calling remove_file_group in a loop.
+
+ Args:
+ file_group_ids: List of FileGroup IDs to remove
+ session: Database session
+ """
+ session = session or Session.object_session(self)
+
+ if not file_group_ids:
+ return
+
+ # Delete all items at once
+ deleted = session.query(CollectionItem).filter(
+ CollectionItem.collection_id == self.id,
+ CollectionItem.file_group_id.in_(file_group_ids)
+ ).delete(synchronize_session=False)
+
+ logger.debug(f'Removed {deleted} FileGroups from Collection {self.id}')
+
+ # Resequence positions to close gaps (optional, but keeps positions clean)
+ items = session.query(CollectionItem).filter_by(
+ collection_id=self.id
+ ).order_by(CollectionItem.position).all()
+
+ for idx, item in enumerate(items, start=1):
+ item.position = idx
+
+ session.flush(items)
+
+ def reorder_item(self, file_group_id: int, new_position: int, session: Session = None):
+ """Move an item to a new position in the collection."""
+ session = session or Session.object_session(self)
+
+ item = session.query(CollectionItem).filter_by(
+ collection_id=self.id,
+ file_group_id=file_group_id
+ ).first()
+
+ if not item:
+ raise ValueError(f'FileGroup {file_group_id} not in Collection {self.id}')
+
+ old_position = item.position
+
+ if old_position == new_position:
+ return
+
+ # Remove item from old position
+ session.query(CollectionItem).filter(
+ CollectionItem.collection_id == self.id,
+ CollectionItem.position > old_position
+ ).update({'position': CollectionItem.position - 1})
+
+ # Make space at new position
+ session.query(CollectionItem).filter(
+ CollectionItem.collection_id == self.id,
+ CollectionItem.position >= new_position
+ ).update({'position': CollectionItem.position + 1})
+
+ # Update item position
+ item.position = new_position
+ session.flush()
+
+ def _shift_positions(self, from_position: int, shift_by: int, session: Session):
+ """Helper to shift positions of items."""
+ session.query(CollectionItem).filter(
+ CollectionItem.collection_id == self.id,
+ CollectionItem.position >= from_position
+ ).update({'position': CollectionItem.position + shift_by})
+ session.flush()
+
+ def get_items(self, limit: Optional[int] = None, offset: int = 0,
+ session: Session = None) -> List['CollectionItem']:
+ """Get items in this collection, ordered by position."""
+ session = session or Session.object_session(self)
+
+ query = session.query(CollectionItem).filter_by(
+ collection_id=self.id
+ ).order_by(CollectionItem.position)
+
+ if offset:
+ query = query.offset(offset)
+ if limit:
+ query = query.limit(limit)
+
+ return query.all()
+
+ @staticmethod
+ @optional_session
+ def get_by_id(id_: int, session: Session = None) -> Optional['Collection']:
+ """Attempt to find a Collection with the provided id. Returns None if not found."""
+ return session.query(Collection).filter_by(id=id_).one_or_none()
+
+ @staticmethod
+ @optional_session
+ def find_by_id(id_: int, session: Session = None) -> 'Collection':
+ """Find a Collection with the provided id, raises an exception if not found.
+
+ @raise UnknownCollection: if the collection cannot be found
+ """
+ if collection := Collection.get_by_id(id_, session=session):
+ return collection
+ raise UnknownCollection(f'Cannot find Collection with id {id_}')
+
+ @staticmethod
+ @optional_session
+ def get_by_path(path: pathlib.Path, session: Session = None) -> Optional['Collection']:
+ """Find a Collection by its directory path. Returns None if not found."""
+ if not path:
+ return None
+ path = pathlib.Path(path) if isinstance(path, str) else path
+ path = str(path.absolute()) if path.is_absolute() else str(get_media_directory() / path)
+ return session.query(Collection).filter_by(directory=path).one_or_none()
+
+ @staticmethod
+ @optional_session
+ def find_by_path(path: pathlib.Path, session: Session = None) -> 'Collection':
+ """Find a Collection by its directory path, raises an exception if not found.
+
+ @raise UnknownCollection: if the collection cannot be found
+ """
+ if collection := Collection.get_by_path(path, session=session):
+ return collection
+ raise UnknownCollection(f'Cannot find Collection with directory {path}')
+
+ def dict(self) -> dict:
+ """Return dictionary representation."""
+ d = super(Collection, self).dict()
+ d['tag_name'] = self.tag_name
+ # Directory may be outside media root; mirror to_config behavior
+ if self.directory:
+ try:
+ d['directory'] = self.directory.relative_to(get_media_directory())
+ except ValueError:
+ d['directory'] = self.directory
+ else:
+ d['directory'] = None
+ d['is_directory_restricted'] = self.is_directory_restricted
+ d['item_count'] = self.item_count
+ d['total_size'] = self.total_size
+ d['kind'] = self.kind
+ return d
+
+ # --- Tagging and moving helpers ---
+ def set_tag(self, tag_id_or_name: int | str | None) -> Optional[Tag]:
+ """Assign or clear the Tag on this Collection. Mirrors Channel.set_tag semantics.
+
+ Raises:
+ ValueError: If attempting to tag a domain collection without a directory.
+ """
+ # Validate domain collections can only be tagged if they have a directory
+ if self.kind == 'domain' and self.directory is None and tag_id_or_name is not None:
+ raise ValueError(
+ f"Cannot tag domain collection '{self.name}' without a directory. "
+ "Set collection.directory first or use an unrestricted domain collection."
+ )
+
+ session = Session.object_session(self)
+ if tag_id_or_name is None:
+ self.tag = None
+ elif isinstance(tag_id_or_name, int):
+ self.tag = Tag.find_by_id(tag_id_or_name, session)
+ elif isinstance(tag_id_or_name, str):
+ self.tag = Tag.find_by_name(tag_id_or_name, session)
+ self.tag_id = self.tag.id if self.tag else None
+ return self.tag
+
+ @property
+ def can_be_tagged(self) -> bool:
+ """Check if this Collection can be tagged.
+
+ Domain collections require a directory to be tagged.
+ Other collection types can always be tagged.
+
+ Returns:
+ True if this collection can be tagged, False otherwise.
+ """
+ if self.kind == 'domain':
+ return self.directory is not None
+ return True
+
+ def __json__(self) -> dict:
+ """Return JSON-serializable dict for API responses.
+
+ Returns directory as a string relative to media directory if applicable.
+ """
+ from wrolpi.common import get_relative_to_media_directory
+
+ # Get directory as string, relative to media directory
+ directory_str = None
+ if self.directory:
+ try:
+ directory_str = str(get_relative_to_media_directory(self.directory))
+ except ValueError:
+ # Directory is not relative to media directory, return as-is
+ directory_str = str(self.directory)
+
+ return {
+ 'id': self.id,
+ 'name': self.name,
+ 'kind': self.kind,
+ 'directory': directory_str,
+ 'description': self.description,
+ 'tag_name': self.tag.name if self.tag else None,
+ 'can_be_tagged': self.can_be_tagged,
+ 'item_count': self.item_count,
+ 'total_size': self.total_size,
+ 'downloads': self.downloads,
+ }
+
+ def format_directory(self, tag_name: Optional[str]) -> pathlib.Path:
+ """Compute the destination directory for this Collection given a tag name.
+
+ For video-style collections (kind == 'channel'), reuse videos destination formatting:
+ videos//
+
+ For domain collections (kind == 'domain'), use the configured archive directory:
+ archive//
+ """
+ if self.kind == 'channel':
+ from modules.videos.lib import format_videos_destination
+ return format_videos_destination(self.name, tag_name, None)
+ elif self.kind == 'domain':
+ # Use configured archive directory for domain collections
+ from modules.archive.lib import get_archive_directory
+ base = get_archive_directory()
+ if tag_name:
+ return base / tag_name / self.name
+ return base / self.name
+ # Default behavior: place under media root by kind/tag/name
+ base = get_media_directory() / self.kind
+ if tag_name:
+ return base / tag_name / self.name
+ return base / self.name
+
+ async def move_collection(self, directory: pathlib.Path, session: Session, send_events: bool = False):
+ """Move all files under this Collection's directory to a new directory.
+
+ Similar to Channel.move_channel but without download management.
+ """
+ from wrolpi.files.lib import move as move_files, refresh_files
+ from wrolpi.events import Events
+ from wrolpi import flags
+
+ if not directory.is_dir():
+ raise FileNotFoundError(f'Destination directory does not exist: {directory}')
+
+ if not self.directory:
+ raise RuntimeError('Cannot move an unrestricted Collection (no directory set)')
+
+ old_directory = self.directory
+ self.directory = directory
+
+ with flags.refreshing:
+ session.commit()
+ # Move the contents of the Collection directory into the destination directory.
+ logger.info(f'Moving Collection {repr(self.name)} from {repr(str(old_directory))}')
+ try:
+ if not old_directory.exists():
+ # Old directory does not exist; refresh both
+ await refresh_files([old_directory, directory])
+ if send_events:
+ Events.send_file_move_completed(f'Collection {repr(self.name)} was moved (directory missing)')
+ else:
+ await move_files(directory, *list(old_directory.iterdir()), session=session)
+ if send_events:
+ Events.send_file_move_completed(f'Collection {repr(self.name)} was moved')
+ except Exception as e:
+ logger.error(f'Collection move failed! Reverting changes...', exc_info=e)
+ self.directory = old_directory
+ self.flush(session)
+ if send_events:
+ Events.send_file_move_failed(f'Moving Collection {self.name} has failed')
+ raise
+ finally:
+ session.commit()
+ if old_directory.exists() and not next(iter(old_directory.iterdir()), None):
+ old_directory.rmdir()
+
+
+class CollectionItem(ModelHelper, Base):
+ """
+ Junction table between Collection and FileGroup.
+ Maintains ordering and metadata about the relationship.
+ """
+ __tablename__ = 'collection_item'
+
+ id = Column(Integer, primary_key=True)
+ collection_id = Column(Integer, ForeignKey('collection.id', ondelete='CASCADE'), nullable=False)
+ file_group_id = Column(Integer, ForeignKey('file_group.id', ondelete='CASCADE'), nullable=False)
+
+ # Position in the collection (for ordering)
+ position = Column(Integer, nullable=False, default=0)
+
+ # When this item was added
+ added_date = Column(DateTime, server_default=func.now())
+
+ # Relationships
+ collection = relationship('Collection', back_populates='items')
+ file_group: FileGroup = relationship('FileGroup')
+
+ # Indexes for performance
+ __table_args__ = (
+ Index('idx_collection_item_collection_position', 'collection_id', 'position'),
+ Index('idx_collection_item_file_group', 'file_group_id'),
+ UniqueConstraint('collection_id', 'file_group_id', name='uq_collection_file_group'),
+ )
+
+ def __repr__(self):
+ return f''
+
+ def dict(self) -> dict:
+ """Return dictionary representation."""
+ d = super(CollectionItem, self).dict()
+ if self.file_group:
+ d['file_group'] = self.file_group.__json__()
+ return d
diff --git a/wrolpi/collections/schema.py b/wrolpi/collections/schema.py
new file mode 100644
index 000000000..5114b176f
--- /dev/null
+++ b/wrolpi/collections/schema.py
@@ -0,0 +1,54 @@
+"""
+Collection API Schemas
+
+Request and response schemas for the unified collection API endpoints.
+"""
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+
+@dataclass
+class CollectionUpdateRequest:
+ """Request body for updating a collection."""
+ directory: Optional[str] = None
+ tag_name: Optional[str] = None
+ description: Optional[str] = None
+
+
+@dataclass
+class CollectionTagRequest:
+ """Request body for tagging a collection."""
+ tag_name: Optional[str] = None
+ directory: Optional[str] = None
+
+
+@dataclass
+class CollectionSearchRequest:
+ """Request body for searching collections."""
+ kind: Optional[str] = None
+ tag_names: List[str] = field(default_factory=list)
+ search_str: Optional[str] = None
+
+
+@dataclass
+class CollectionTagResponse:
+ """Response for tagging operation."""
+ collection_id: int
+ collection_name: str
+ tag_name: str
+ directory: str
+ will_move_files: bool
+
+
+@dataclass
+class CollectionTagInfoRequest:
+ """Request body for getting tag info."""
+ tag_name: Optional[str] = None
+
+
+@dataclass
+class CollectionTagInfoResponse:
+ """Response for tag info operation."""
+ suggested_directory: str
+ conflict: bool
+ conflict_message: Optional[str] = None
diff --git a/wrolpi/collections/test/test_api.py b/wrolpi/collections/test/test_api.py
new file mode 100644
index 000000000..9c1012b25
--- /dev/null
+++ b/wrolpi/collections/test/test_api.py
@@ -0,0 +1,485 @@
+"""Tests for Collection API endpoints."""
+from http import HTTPStatus
+
+import pytest
+
+from wrolpi.collections.models import Collection
+
+
+class TestCollectionsAPI:
+ """Test the GET /api/collections endpoint."""
+
+ @pytest.mark.asyncio
+ async def test_get_collections_returns_metadata_for_domains(
+ self, async_client, test_session
+ ):
+ """Test that GET /api/collections?kind=domain returns metadata."""
+ # Create a test domain collection
+ collection = Collection(name='example.com', kind='domain')
+ test_session.add(collection)
+ test_session.commit()
+
+ # Make the request
+ request, response = await async_client.get('/api/collections?kind=domain')
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Check basic response structure
+ assert 'collections' in data
+ assert 'totals' in data
+ assert 'metadata' in data
+
+ # Check collections
+ assert len(data['collections']) == 1
+ assert data['collections'][0]['name'] == 'example.com'
+ assert data['collections'][0]['kind'] == 'domain'
+
+ # Check totals
+ assert data['totals']['collections'] == 1
+
+ # Check metadata structure
+ metadata = data['metadata']
+ assert metadata['kind'] == 'domain'
+ assert 'columns' in metadata
+ assert 'fields' in metadata
+ assert 'routes' in metadata
+ assert 'messages' in metadata
+
+ # Check metadata columns
+ columns = metadata['columns']
+ assert len(columns) > 0
+ column_keys = [col['key'] for col in columns]
+ assert 'domain' in column_keys
+ assert 'archive_count' in column_keys
+ assert 'size' in column_keys
+ assert 'tag_name' in column_keys
+ assert 'actions' in column_keys
+
+ # Check metadata fields
+ fields = metadata['fields']
+ assert len(fields) > 0
+ field_keys = [field['key'] for field in fields]
+ assert 'directory' in field_keys
+ assert 'tag_name' in field_keys
+ assert 'description' in field_keys
+
+ # Check metadata routes
+ routes = metadata['routes']
+ assert routes['list'] == '/archive/domains'
+ assert routes['edit'] == '/archive/domain/:id/edit'
+ assert routes['search'] == '/archive'
+
+ # Check metadata messages
+ messages = metadata['messages']
+ assert 'no_directory' in messages
+ assert 'tag_will_move' in messages
+
+ @pytest.mark.asyncio
+ async def test_get_collections_returns_metadata_for_channels(
+ self, async_client, test_session
+ ):
+ """Test that GET /api/collections?kind=channel returns channel metadata."""
+ from pathlib import Path
+
+ # Create a test channel collection
+ collection = Collection(
+ name='Test Channel',
+ kind='channel',
+ directory=Path('/media/wrolpi/videos/test')
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ # Make the request
+ request, response = await async_client.get('/api/collections?kind=channel')
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Check metadata structure
+ assert 'metadata' in data
+ metadata = data['metadata']
+ assert metadata['kind'] == 'channel'
+
+ # Check channel-specific metadata
+ column_keys = [col['key'] for col in metadata['columns']]
+ assert 'name' in column_keys
+ assert 'video_count' in column_keys
+ assert 'total_size' in column_keys
+
+ # Check channel routes
+ routes = metadata['routes']
+ assert routes['list'] == '/videos/channels'
+ assert routes['edit'] == '/videos/channel/:id/edit'
+
+ @pytest.mark.asyncio
+ async def test_get_collections_no_metadata_without_kind(
+ self, async_client, test_session
+ ):
+ """Test that GET /api/collections without kind parameter does not include metadata."""
+ # Create test collections of different kinds
+ domain = Collection(name='example.com', kind='domain')
+ channel = Collection(name='Test Channel', kind='channel')
+ test_session.add_all([domain, channel])
+ test_session.commit()
+
+ # Make the request without kind parameter
+ request, response = await async_client.get('/api/collections')
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Should have both collections
+ assert len(data['collections']) == 2
+ assert data['totals']['collections'] == 2
+
+ # Should NOT have metadata (since no specific kind was requested)
+ assert 'metadata' not in data
+
+ @pytest.mark.asyncio
+ async def test_get_empty_collections_with_metadata(
+ self, async_client, test_session
+ ):
+ """Test that metadata is returned even when no collections exist."""
+ # Don't create any collections
+
+ # Make the request
+ request, response = await async_client.get('/api/collections?kind=domain')
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Should have empty collections
+ assert len(data['collections']) == 0
+ assert data['totals']['collections'] == 0
+
+ # Should still have metadata
+ assert 'metadata' in data
+ assert data['metadata']['kind'] == 'domain'
+
+ @pytest.mark.asyncio
+ async def test_get_collections_includes_video_count_for_channels(
+ self, async_client, test_session, test_directory, channel_factory, video_factory
+ ):
+ """Test that GET /api/collections?kind=channel includes video statistics."""
+ # Create a channel using the factory
+ channel = channel_factory(name="Test Channel")
+ test_session.flush()
+
+ # Create a video in the channel
+ video = video_factory(channel_id=channel.id, with_video_file=True)
+ test_session.commit()
+
+ # Make the request
+ request, response = await async_client.get('/api/collections?kind=channel')
+
+ # Check response includes video statistics
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Find our test channel
+ test_channel_data = next(c for c in data['collections'] if c['name'] == "Test Channel")
+
+ # Should have video_count and total_size
+ assert 'video_count' in test_channel_data
+ assert test_channel_data['video_count'] == 1
+ assert 'total_size' in test_channel_data
+ assert test_channel_data['total_size'] > 0
+
+
+class TestCollectionTagInfoAPI:
+ """Test the POST /api/collections//tag_info endpoint."""
+
+ @pytest.mark.asyncio
+ async def test_get_tag_info_suggests_directory(
+ self, async_client, test_session, test_directory
+ ):
+ """Test that tag_info endpoint suggests a directory for a domain collection."""
+ # Create a domain collection
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'example.com'
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ # Request tag info
+ request, response = await async_client.post(
+ f'/api/collections/{collection.id}/tag_info',
+ json={'tag_name': 'WROL'}
+ )
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Should suggest a directory with the tag
+ assert 'suggested_directory' in data
+ assert 'WROL' in data['suggested_directory']
+ assert 'example.com' in data['suggested_directory']
+
+ # Should not have a conflict
+ assert data['conflict'] is False
+ assert data['conflict_message'] is None
+
+ @pytest.mark.asyncio
+ async def test_get_tag_info_detects_domain_conflict(
+ self, async_client, test_session, test_directory
+ ):
+ """Test that tag_info detects conflicts with existing domain collections."""
+ # Create two domain collections
+ collection1 = Collection(
+ name='example.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'example.com'
+ )
+ collection2 = Collection(
+ name='other.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'WROL' / 'other.com'
+ )
+ test_session.add_all([collection1, collection2])
+ test_session.commit()
+
+ # Request tag info for collection1 with a tag that would conflict with collection2
+ request, response = await async_client.post(
+ f'/api/collections/{collection1.id}/tag_info',
+ json={'tag_name': 'WROL'}
+ )
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Should suggest archive/WROL/example.com
+ assert 'suggested_directory' in data
+ assert 'WROL' in data['suggested_directory']
+ assert 'example.com' in data['suggested_directory']
+
+ # Should not have a conflict (different names, different directories)
+ assert data['conflict'] is False
+
+ @pytest.mark.asyncio
+ async def test_get_tag_info_allows_channel_domain_same_directory(
+ self, async_client, test_session, test_directory, channel_factory
+ ):
+ """Test that channel and domain collections can share directories."""
+ # Create a channel collection
+ channel = channel_factory(name='test', directory=test_directory / 'videos' / 'WROL' / 'test')
+ test_session.flush()
+
+ # Create a domain collection with a different directory
+ domain = Collection(
+ name='example.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'example.com'
+ )
+ test_session.add(domain)
+ test_session.commit()
+
+ # Request tag info for domain
+ request, response = await async_client.post(
+ f'/api/collections/{domain.id}/tag_info',
+ json={'tag_name': 'WROL'}
+ )
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Should not have a conflict (different kinds can share)
+ assert data['conflict'] is False
+
+ @pytest.mark.asyncio
+ async def test_get_tag_info_detects_exact_directory_conflict(
+ self, async_client, test_session, test_directory
+ ):
+ """Test that tag_info detects conflicts when suggested directory exactly matches existing collection."""
+ # Create two domain collections with different names but where one's suggested directory
+ # would conflict with the other's existing directory
+ # collection1 is already in the "WROL" tagged location
+ collection1 = Collection(
+ name='conflicting.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'WROL' / 'example.com'
+ )
+ # collection2 is named example.com and would want to move to archive/WROL/example.com when tagged
+ collection2 = Collection(
+ name='example.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'example.com'
+ )
+ test_session.add_all([collection1, collection2])
+ test_session.commit()
+
+ # Request tag info for collection2 with a tag that would create the same directory as collection1
+ request, response = await async_client.post(
+ f'/api/collections/{collection2.id}/tag_info',
+ json={'tag_name': 'WROL'}
+ )
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+ data = response.json
+
+ # Should suggest archive/WROL/example.com
+ assert 'suggested_directory' in data
+ assert 'WROL' in data['suggested_directory']
+ assert 'example.com' in data['suggested_directory']
+
+ # Should have a conflict (same directory as collection1)
+ assert data['conflict'] is True
+ assert data['conflict_message'] is not None
+ assert 'conflicting.com' in data['conflict_message']
+
+ @pytest.mark.asyncio
+ async def test_get_tag_info_unknown_collection(
+ self, async_client, test_session
+ ):
+ """Test that tag_info returns 404 for unknown collection."""
+ # Request tag info for non-existent collection
+ request, response = await async_client.post(
+ '/api/collections/99999/tag_info',
+ json={'tag_name': 'WROL'}
+ )
+
+ # Check response
+ assert response.status_code == HTTPStatus.NOT_FOUND
+ data = response.json
+ assert 'error' in data
+
+
+class TestCollectionDeletion:
+ """Test the DELETE /api/collections/ endpoint."""
+
+ @pytest.mark.asyncio
+ async def test_delete_domain_collection_orphans_archives(
+ self, async_client, test_session, test_directory
+ ):
+ """Test that deleting a domain collection orphans its archives."""
+ from modules.archive.models import Archive
+ from wrolpi.files.models import FileGroup
+
+ # Create a domain collection
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'example.com'
+ )
+ test_session.add(collection)
+ test_session.flush()
+
+ # Create file groups and archives for this domain
+ fg1 = FileGroup(primary_path=test_directory / 'archive' / 'example.com' / 'page1.html', url='https://example.com/page1')
+ fg2 = FileGroup(primary_path=test_directory / 'archive' / 'example.com' / 'page2.html', url='https://example.com/page2')
+ test_session.add_all([fg1, fg2])
+ test_session.flush()
+
+ archive1 = Archive(file_group_id=fg1.id, collection_id=collection.id)
+ archive2 = Archive(file_group_id=fg2.id, collection_id=collection.id)
+ test_session.add_all([archive1, archive2])
+ test_session.commit()
+
+ collection_id = collection.id
+
+ # Verify archives are associated with the collection
+ assert test_session.query(Archive).filter_by(collection_id=collection_id).count() == 2
+
+ # Delete the collection
+ request, response = await async_client.delete(f'/api/collections/{collection_id}')
+
+ # Check response
+ assert response.status_code == HTTPStatus.NO_CONTENT
+
+ # Verify collection is deleted
+ assert test_session.query(Collection).filter_by(id=collection_id).count() == 0
+
+ # Verify archives are orphaned (collection_id is NULL) but still exist
+ orphaned_archives = test_session.query(Archive).filter_by(collection_id=None).all()
+ assert len(orphaned_archives) == 2
+ assert {a.id for a in orphaned_archives} == {archive1.id, archive2.id}
+
+ # Verify archives still exist in total
+ assert test_session.query(Archive).count() == 2
+
+ @pytest.mark.asyncio
+ async def test_delete_unknown_collection(
+ self, async_client, test_session
+ ):
+ """Test that deleting unknown collection returns 404."""
+ # Try to delete non-existent collection
+ request, response = await async_client.delete('/api/collections/99999')
+
+ # Check response
+ assert response.status_code == HTTPStatus.NOT_FOUND
+ data = response.json
+ assert 'error' in data
+
+ @pytest.mark.asyncio
+ async def test_delete_channel_collection(
+ self, async_client, test_session, test_directory, channel_factory
+ ):
+ """Test that deleting a channel collection works."""
+ # Create a channel collection
+ channel = channel_factory(name='test', directory=test_directory / 'videos' / 'test')
+ test_session.commit()
+
+ collection_id = channel.collection.id
+
+ # Delete the collection
+ request, response = await async_client.delete(f'/api/collections/{collection_id}')
+
+ # Check response
+ assert response.status_code == HTTPStatus.NO_CONTENT
+
+ # Verify collection is deleted
+ assert test_session.query(Collection).filter_by(id=collection_id).count() == 0
+
+
+class TestCollectionTagging:
+ """Test the POST /api/collections//tag endpoint for tagging and un-tagging."""
+
+ @pytest.mark.asyncio
+ async def test_untag_collection(
+ self, async_client, test_session, test_directory, tag_factory
+ ):
+ """Test that sending tag_name=null removes the tag from a collection."""
+ # Create a tag
+ tag = await tag_factory('TestTag')
+ test_session.flush()
+
+ # Create a domain collection with a tag and directory
+ collection = Collection(
+ name='example.com',
+ kind='domain',
+ directory=test_directory / 'archive' / 'TestTag' / 'example.com',
+ tag_id=tag.id
+ )
+ test_session.add(collection)
+ test_session.commit()
+
+ # Verify collection has the tag
+ assert collection.tag_id == tag.id
+ assert collection.tag_name == 'TestTag'
+
+ # Send POST request to un-tag (with tag_name: null)
+ request, response = await async_client.post(
+ f'/api/collections/{collection.id}/tag',
+ json={'tag_name': None}
+ )
+
+ # Check response
+ assert response.status_code == HTTPStatus.OK
+
+ # Refresh the collection from database
+ test_session.refresh(collection)
+
+ # Verify the tag has been removed
+ assert collection.tag_id is None
+ assert collection.tag_name is None
diff --git a/wrolpi/collections/test/test_collection_lifecycle.py b/wrolpi/collections/test/test_collection_lifecycle.py
new file mode 100644
index 000000000..0c63a810c
--- /dev/null
+++ b/wrolpi/collections/test/test_collection_lifecycle.py
@@ -0,0 +1,119 @@
+import pathlib
+
+import pytest
+import yaml
+from sqlalchemy.orm import Session
+
+from wrolpi.collections import collections_config
+from wrolpi.collections.models import Collection
+from wrolpi.common import get_media_directory
+
+
+@pytest.mark.asyncio
+async def test_collection_lifecycle_end_to_end(async_client, test_session: Session, test_directory: pathlib.Path, video_factory):
+ """
+ Lifecycle:
+ - Start with no collections
+ - Create a directory-restricted collection with a unique directory containing some videos
+ - Dump to config (collections.yaml)
+ - Add a new video to that directory; ensure collection fetches show it
+ - Remove one video from the collection only (files remain)
+ - Delete another video entirely (files removed and item gone from collection)
+ - Delete the collection; dump config; files remain (except the deleted video)
+ """
+ # 1) Start clean
+ assert test_directory.is_dir()
+ assert test_session.query(Collection).count() == 0
+
+ # 2) Prepare unique directory with some videos
+ coll_dir = test_directory / 'videos' / 'collections_lifecycle' / 'case_001'
+ coll_dir.mkdir(parents=True, exist_ok=True)
+
+ # Create three videos in the directory
+ v1 = video_factory(with_video_file=coll_dir / 'v1.mp4')
+ v2 = video_factory(with_video_file=coll_dir / 'v2.mp4')
+ v3 = video_factory(with_video_file=coll_dir / 'v3.mp4')
+ test_session.commit()
+
+ # Sanity: files exist
+ assert v1.file_group.primary_path.exists()
+ assert v2.file_group.primary_path.exists()
+ assert v3.file_group.primary_path.exists()
+
+ # 3) Create a directory-restricted collection via config import semantics
+ rel_dir = coll_dir.relative_to(get_media_directory())
+ cfg = {
+ 'name': 'Lifecycle Test Collection',
+ 'description': 'E2E lifecycle test',
+ 'directory': str(rel_dir),
+ 'kind': 'channel',
+ }
+ coll = Collection.from_config(cfg, session=test_session)
+ test_session.commit()
+
+ # After creation, items should be auto-populated from directory
+ items = coll.get_items(session=test_session)
+ fg_ids = {i.file_group_id for i in items}
+ assert {v1.file_group_id, v2.file_group_id, v3.file_group_id}.issubset(fg_ids)
+
+ # 4) Dump collections config; verify entry present
+ collections_config.dump_config()
+ config_file = collections_config.get_file()
+ assert config_file.is_file(), f"Expected config file at {config_file}"
+
+ data = yaml.safe_load(config_file.read_text())
+ assert isinstance(data, dict)
+ dumped = data.get('collections', [])
+ # Config stores absolute paths
+ assert any(c.get('name') == 'Lifecycle Test Collection' and c.get('directory') == str(coll_dir) and c.get(
+ 'kind') == 'channel' for c in dumped)
+
+ # 5) Add a new video to the same directory; ensure fetch methods include it after population
+ v4 = video_factory(with_video_file=coll_dir / 'v4.mp4')
+ test_session.commit()
+
+ # Populate from directory to pick up new file group
+ coll.populate_from_directory(session=test_session)
+ test_session.commit()
+
+ items = coll.get_items(session=test_session)
+ fg_ids = {i.file_group_id for i in items}
+ assert v4.file_group_id in fg_ids
+
+ # 6) Remove one video from the collection only (do not delete the video)
+ coll.remove_file_group(v2.file_group_id, session=test_session)
+ test_session.commit()
+
+ items = coll.get_items(session=test_session)
+ fg_ids = {i.file_group_id for i in items}
+ assert v2.file_group_id not in fg_ids, 'Removed video should not be visible in collection fetch methods'
+ # Files should remain on disk
+ assert v2.file_group.primary_path.exists()
+
+ # 7) Delete another video entirely (remove files too)
+ # Use the model's delete to remove files and DB rows
+ v3.delete()
+ test_session.commit()
+
+ # File should be removed from disk
+ assert not v3.file_group.primary_path.exists()
+
+ # The collection should no longer reference it (FK cascade or absence after populate)
+ items = coll.get_items(session=test_session)
+ fg_ids = {i.file_group_id for i in items}
+ assert v3.file_group_id not in fg_ids
+
+ # 8) Finally, delete the collection
+ test_session.delete(coll)
+ test_session.commit()
+
+ # Dump config again; collection should be removed from config
+ collections_config.dump_config()
+ data = yaml.safe_load(config_file.read_text())
+ dumped = data.get('collections', [])
+ assert not any(c.get('name') == 'Lifecycle Test Collection' and c.get('directory') == str(rel_dir) for c in dumped)
+
+ # All remaining files should still exist except the one for the deleted video
+ assert v1.file_group.primary_path.exists()
+ assert v2.file_group.primary_path.exists()
+ assert v4.file_group.primary_path.exists()
diff --git a/wrolpi/collections/test/test_collection_tagging.py b/wrolpi/collections/test/test_collection_tagging.py
new file mode 100644
index 000000000..6324b2bb0
--- /dev/null
+++ b/wrolpi/collections/test/test_collection_tagging.py
@@ -0,0 +1,145 @@
+import pathlib
+import pathlib
+
+import pytest
+import yaml
+from sqlalchemy.orm import Session
+
+from wrolpi.collections import collections_config
+from wrolpi.collections.models import Collection
+from wrolpi.common import get_media_directory
+
+
+# from wrolpi.switches import await_switches
+
+
+@pytest.mark.asyncio
+async def test_tagging_directory_collection_moves_files_and_updates_config(
+ test_session: Session,
+ test_directory: pathlib.Path,
+ video_factory,
+ async_client,
+ tag_factory,
+):
+ # Start clean
+ assert test_session.query(Collection).count() == 0
+
+ # Create source directory and a few videos
+ src_dir = test_directory / 'videos' / 'collections_tagging' / 'dir_case'
+ src_dir.mkdir(parents=True, exist_ok=True)
+ v1 = video_factory(with_video_file=src_dir / 'a1.mp4')
+ v2 = video_factory(with_video_file=src_dir / 'a2.mp4')
+ test_session.commit()
+
+ # Create a directory-restricted Collection (kind=channel) via config import semantics
+ rel_dir = src_dir.relative_to(get_media_directory())
+ cfg = {
+ 'name': 'Funny Clips',
+ 'description': 'Clips that make me laugh',
+ 'directory': str(rel_dir),
+ 'kind': 'channel',
+ }
+ coll = Collection.from_config(cfg, session=test_session)
+ test_session.commit()
+
+ # Verify it populated
+ assert {i.file_group_id for i in coll.get_items(session=test_session)} == {v1.file_group_id, v2.file_group_id}
+
+ # Dump config and verify (config stores absolute paths)
+ collections_config.dump_config()
+ cfg_file = collections_config.get_file()
+ data = yaml.safe_load(cfg_file.read_text())
+ dumped = data.get('collections', [])
+ assert any(
+ c.get('name') == 'Funny Clips' and c.get('directory') == str(src_dir) and c.get('kind') == 'channel' for c in
+ dumped)
+
+ # Tag the collection -> should move to videos//
+ tag = await tag_factory('Funny')
+ coll.set_tag(tag.id)
+ test_session.commit()
+
+ # Compute expected destination and move
+ dest_dir = coll.format_directory('Funny')
+ dest_dir.mkdir(parents=True, exist_ok=True)
+
+ # Perform the move using the model helper
+ await coll.move_collection(dest_dir, session=test_session)
+ test_session.commit()
+
+ # Validate the directory changed and files moved
+ assert coll.directory == dest_dir
+ assert (dest_dir / 'a1.mp4').is_file()
+ assert (dest_dir / 'a2.mp4').is_file()
+ assert not (src_dir / 'a1.mp4').exists()
+ assert not (src_dir / 'a2.mp4').exists()
+
+ # FileGroup paths should be in dest_dir now
+ test_session.refresh(v1)
+ test_session.refresh(v2)
+ assert str(dest_dir) in str(v1.file_group.primary_path)
+ assert str(dest_dir) in str(v2.file_group.primary_path)
+
+ # Dump config again and verify updated directory and tag (config stores absolute paths)
+ collections_config.dump_config()
+ data = yaml.safe_load(cfg_file.read_text())
+ dumped = data.get('collections', [])
+ assert any(
+ c.get('name') == 'Funny Clips' and c.get('directory') == str(dest_dir) and c.get('tag_name') == 'Funny' for
+ c in dumped)
+
+
+@pytest.mark.asyncio
+async def test_tagging_unrestricted_collection_updates_config_only(
+ test_session: Session,
+ test_directory: pathlib.Path,
+ make_files_structure,
+ async_client,
+ tag_factory,
+):
+ # Create a couple files in different places using make_files_structure
+ files = [
+ 'videos/loose/x1.mp4',
+ 'videos/misc/x2.mp4',
+ ]
+ fg1, fg2 = make_files_structure(files, file_groups=True, session=test_session)
+ test_session.commit()
+
+ # Create an unrestricted collection and add both files
+ coll = Collection.from_config({'name': 'Loose Stuff', 'kind': 'channel'}, session=test_session)
+ # Unrestricted, so add manually
+ coll.add_file_groups([fg1, fg2], session=test_session)
+ test_session.commit()
+
+ # Dump config and verify directory is absent/None
+ collections_config.dump_config()
+ cfg_file = collections_config.get_file()
+ data = yaml.safe_load(cfg_file.read_text())
+ dumped = data.get('collections', [])
+ entry = next(c for c in dumped if c.get('name') == 'Loose Stuff')
+ assert 'directory' not in entry
+
+ # Tag the collection; files should not move
+ before_paths = (fg1.primary_path, fg2.primary_path)
+ # Create a Tag synchronously and assign by name
+ # from wrolpi.tags import Tag
+ # tag = Tag(name='Favorites', color='#00ff00')
+ # test_session.add(tag)
+ # test_session.commit()
+ await tag_factory('Favorites')
+ coll.set_tag('Favorites')
+ test_session.commit()
+
+ # Assert paths unchanged
+ test_session.refresh(fg1)
+ test_session.refresh(fg2)
+ after_paths = (fg1.primary_path, fg2.primary_path)
+ assert before_paths == after_paths
+
+ # Dump config and verify tag written, directory still not set
+ collections_config.dump_config()
+ data = yaml.safe_load(cfg_file.read_text())
+ dumped = data.get('collections', [])
+ entry = next(c for c in dumped if c.get('name') == 'Loose Stuff')
+ assert entry.get('tag_name') == 'Favorites'
+ assert 'directory' not in entry
diff --git a/wrolpi/collections/test/test_domain_collections.py b/wrolpi/collections/test/test_domain_collections.py
new file mode 100644
index 000000000..1b2d2f489
--- /dev/null
+++ b/wrolpi/collections/test/test_domain_collections.py
@@ -0,0 +1,230 @@
+"""Tests for domain-specific Collection functionality."""
+import pytest
+from sqlalchemy.orm import Session
+
+from wrolpi.collections.models import Collection
+
+
+class TestDomainValidation:
+ """Test domain name validation for Collections with kind='domain'."""
+
+ def test_is_valid_domain_name_valid_domains(self):
+ """Test that valid domain names are accepted."""
+ assert Collection.is_valid_domain_name('example.com') is True
+ assert Collection.is_valid_domain_name('sub.example.com') is True
+ assert Collection.is_valid_domain_name('a.b.c') is True
+ assert Collection.is_valid_domain_name('test.org') is True
+ assert Collection.is_valid_domain_name('my-site.co.uk') is True
+
+ def test_is_valid_domain_name_invalid_domains(self):
+ """Test that invalid domain names are rejected."""
+ # No dots
+ assert Collection.is_valid_domain_name('example') is False
+ assert Collection.is_valid_domain_name('localhost') is False
+
+ # Starts or ends with dot
+ assert Collection.is_valid_domain_name('.example.com') is False
+ assert Collection.is_valid_domain_name('example.com.') is False
+ assert Collection.is_valid_domain_name('.') is False
+
+ # Empty or non-string
+ assert Collection.is_valid_domain_name('') is False
+ assert Collection.is_valid_domain_name(None) is False
+ assert Collection.is_valid_domain_name(123) is False
+
+ def test_from_config_validates_domain_kind(self, test_session: Session):
+ """Test that from_config enforces domain validation when kind='domain'."""
+ # Valid domain should succeed
+ config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ }
+ collection = Collection.from_config(config, session=test_session)
+ assert collection.name == 'example.com'
+ assert collection.kind == 'domain'
+ assert collection.directory is None
+
+ # Invalid domain should raise ValueError
+ invalid_config = {
+ 'name': 'invalid-no-dot',
+ 'kind': 'domain',
+ }
+ with pytest.raises(ValueError, match='Invalid domain name'):
+ Collection.from_config(invalid_config, session=test_session)
+
+ # Domain starting with dot should fail
+ invalid_config2 = {
+ 'name': '.example.com',
+ 'kind': 'domain',
+ }
+ with pytest.raises(ValueError, match='Invalid domain name'):
+ Collection.from_config(invalid_config2, session=test_session)
+
+ # Domain ending with dot should fail
+ invalid_config3 = {
+ 'name': 'example.com.',
+ 'kind': 'domain',
+ }
+ with pytest.raises(ValueError, match='Invalid domain name'):
+ Collection.from_config(invalid_config3, session=test_session)
+
+ def test_from_config_allows_any_name_for_channel_kind(self, test_session: Session):
+ """Test that channel kind does not enforce domain validation."""
+ # Channel kind should allow any name (no domain validation)
+ config = {
+ 'name': 'My Channel Without Dots',
+ 'kind': 'channel',
+ }
+ collection = Collection.from_config(config, session=test_session)
+ assert collection.name == 'My Channel Without Dots'
+ assert collection.kind == 'channel'
+
+ def test_domain_collection_unrestricted_mode(self, test_session: Session):
+ """Test that domain collections can be created without directory (unrestricted)."""
+ config = {
+ 'name': 'example.org',
+ 'kind': 'domain',
+ 'description': 'Archives from example.org',
+ }
+ collection = Collection.from_config(config, session=test_session)
+ test_session.commit()
+
+ assert collection.name == 'example.org'
+ assert collection.kind == 'domain'
+ assert collection.directory is None
+ assert collection.is_directory_restricted is False
+ assert collection.description == 'Archives from example.org'
+
+ def test_domain_collection_with_subdomain(self, test_session: Session):
+ """Test that subdomains are valid domain names."""
+ config = {
+ 'name': 'blog.example.com',
+ 'kind': 'domain',
+ }
+ collection = Collection.from_config(config, session=test_session)
+ test_session.commit()
+
+ assert collection.name == 'blog.example.com'
+ assert collection.kind == 'domain'
+
+ def test_update_existing_domain_collection(self, test_session: Session):
+ """Test that updating a domain collection preserves validation."""
+ # Create initial domain collection
+ config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'description': 'Initial description',
+ }
+ collection = Collection.from_config(config, session=test_session)
+ test_session.commit()
+ collection_id = collection.id
+
+ # Update with valid domain
+ updated_config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'description': 'Updated description',
+ }
+ updated = Collection.from_config(updated_config, session=test_session)
+ assert updated.id == collection_id
+ assert updated.description == 'Updated description'
+
+ def test_domain_collection_to_config(self, test_session: Session):
+ """Test that domain collections export correctly to config."""
+ config = {
+ 'name': 'test.org',
+ 'kind': 'domain',
+ 'description': 'Test domain',
+ }
+ collection = Collection.from_config(config, session=test_session)
+ test_session.commit()
+
+ exported = collection.to_config()
+ assert exported['name'] == 'test.org'
+ assert exported['kind'] == 'domain'
+ assert exported['description'] == 'Test domain'
+ assert 'directory' not in exported # Should not have directory
+
+ def test_get_by_name_and_kind(self, test_session: Session):
+ """Test finding domain collections by name and kind."""
+ # Create two collections with same name but different kinds
+ domain_config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ }
+ channel_config = {
+ 'name': 'example.com',
+ 'kind': 'channel',
+ }
+
+ domain_coll = Collection.from_config(domain_config, session=test_session)
+ channel_coll = Collection.from_config(channel_config, session=test_session)
+ test_session.commit()
+
+ # Both should exist with different IDs
+ assert domain_coll.id != channel_coll.id
+ assert domain_coll.kind == 'domain'
+ assert channel_coll.kind == 'channel'
+
+ # Finding by name and kind should return the correct one
+ found_domain = test_session.query(Collection).filter_by(
+ name='example.com',
+ kind='domain'
+ ).first()
+ assert found_domain.id == domain_coll.id
+
+
+class TestDirectoryValidation:
+ """Test that collection directories must be under media directory."""
+
+ def test_from_config_rejects_absolute_path_outside_media_directory(self, test_session, test_directory):
+ """Test that from_config rejects absolute paths outside media directory."""
+ from wrolpi.errors import ValidationError
+
+ # Try to create collection with directory outside media directory
+ config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'directory': '/opt/wrolpi/archive/example.com' # Outside media directory!
+ }
+
+ # Should raise ValidationError
+ with pytest.raises(ValidationError, match="must be under media directory"):
+ Collection.from_config(config, test_session)
+
+ def test_from_config_accepts_relative_path(self, test_session, test_directory):
+ """Test that from_config converts relative paths to absolute under media directory."""
+ from wrolpi.common import get_media_directory
+
+ config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'directory': 'archive/example.com' # Relative path
+ }
+
+ collection = Collection.from_config(config, test_session)
+ test_session.flush()
+
+ # Should be converted to absolute path under media directory
+ assert collection.directory.is_absolute()
+ assert str(collection.directory).startswith(str(get_media_directory()))
+ assert collection.directory.name == 'example.com'
+
+ def test_from_config_accepts_absolute_path_under_media_directory(self, test_session, test_directory):
+ """Test that from_config accepts absolute paths that are under media directory."""
+ from wrolpi.common import get_media_directory
+
+ media_dir = get_media_directory()
+ valid_path = media_dir / 'archive' / 'example.com'
+
+ config = {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'directory': str(valid_path)
+ }
+
+ collection = Collection.from_config(config, test_session)
+ test_session.flush()
+
+ # Should accept the path
+ assert collection.directory == valid_path
diff --git a/wrolpi/collections/test/test_domains_config.py b/wrolpi/collections/test/test_domains_config.py
new file mode 100644
index 000000000..21f64c0d5
--- /dev/null
+++ b/wrolpi/collections/test/test_domains_config.py
@@ -0,0 +1,215 @@
+"""Tests for DomainsConfig functionality."""
+import pathlib
+
+import yaml
+from sqlalchemy.orm import Session
+
+from modules.archive.lib import DomainsConfig, domains_config
+from wrolpi.collections import Collection
+
+
+class TestDomainsConfig:
+ """Test domains.yaml config file operations."""
+
+ def test_domains_config_import_creates_domain_collections(self, test_session: Session, test_directory: pathlib.Path,
+ async_client):
+ """Test that importing domains.yaml creates domain collections."""
+ config_file = test_directory / 'domains.yaml'
+ config_data = {
+ 'version': 0,
+ 'collections': [
+ {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'description': 'Archives from example.com',
+ },
+ {
+ 'name': 'wikipedia.org',
+ 'kind': 'domain',
+ },
+ ]
+ }
+ config_file.write_text(yaml.dump(config_data))
+
+ # Import the config
+ config = DomainsConfig()
+ config.import_config(file=config_file)
+
+ # Verify collections were created
+ collections = test_session.query(Collection).filter_by(kind='domain').all()
+ assert len(collections) == 2
+
+ example = test_session.query(Collection).filter_by(name='example.com', kind='domain').first()
+ assert example is not None
+ assert example.description == 'Archives from example.com'
+ assert example.directory is None # Domain collections should be unrestricted
+
+ wiki = test_session.query(Collection).filter_by(name='wikipedia.org', kind='domain').first()
+ assert wiki is not None
+
+ def test_domains_config_dump_exports_only_domain_collections(self, test_session: Session,
+ test_directory: pathlib.Path, async_client):
+ """Test that dumping domains.yaml only exports domain collections."""
+ # Create some domain and channel collections
+ domain1 = Collection.from_config({'name': 'example.com', 'kind': 'domain'}, session=test_session)
+ domain2 = Collection.from_config({'name': 'test.org', 'kind': 'domain'}, session=test_session)
+ channel = Collection.from_config({'name': 'My Channel', 'kind': 'channel'}, session=test_session)
+ test_session.commit()
+
+ # Dump to config
+ config_file = test_directory / 'domains.yaml'
+ config = DomainsConfig()
+ config.dump_config(file=config_file)
+
+ # Read and verify
+ data = yaml.safe_load(config_file.read_text())
+ assert 'collections' in data
+ assert len(data['collections']) == 2 # Only domain collections
+
+ names = {c['name'] for c in data['collections']}
+ assert 'example.com' in names
+ assert 'test.org' in names
+ assert 'My Channel' not in names # Channel should not be in domains.yaml
+
+ # All should have kind='domain'
+ for coll in data['collections']:
+ assert coll['kind'] == 'domain'
+
+ def test_domains_config_enforces_domain_validation(self, test_session: Session, test_directory: pathlib.Path,
+ async_client):
+ """Test that importing invalid domain names skips them (logs error but doesn't fail import)."""
+ config_file = test_directory / 'domains.yaml'
+ config_data = {
+ 'version': 0,
+ 'collections': [
+ {
+ 'name': 'invalid-domain', # No dot
+ 'kind': 'domain',
+ },
+ ]
+ }
+ config_file.write_text(yaml.dump(config_data))
+
+ config = DomainsConfig()
+
+ # Import should succeed but skip invalid domain (error is logged)
+ config.import_config(file=config_file)
+
+ # No collections should be created
+ domains = test_session.query(Collection).filter_by(kind='domain').all()
+ assert len(domains) == 0
+
+ def test_domains_config_forces_kind_to_domain(self, test_session: Session, test_directory: pathlib.Path,
+ async_client):
+ """Test that DomainsConfig forces kind='domain' even if not specified."""
+ config_file = test_directory / 'domains.yaml'
+ config_data = {
+ 'version': 0,
+ 'collections': [
+ {
+ 'name': 'example.com',
+ # kind not specified
+ },
+ ]
+ }
+ config_file.write_text(yaml.dump(config_data))
+
+ config = DomainsConfig()
+ config.import_config(file=config_file)
+
+ # Should be created with kind='domain'
+ collection = test_session.query(Collection).filter_by(name='example.com').first()
+ assert collection is not None
+ assert collection.kind == 'domain'
+
+ def test_domains_config_removes_deleted_domains(self, test_session: Session, test_directory: pathlib.Path,
+ async_client):
+ """Test that domains removed from config are deleted from database."""
+ # Create two domain collections
+ domain1 = Collection.from_config({'name': 'example.com', 'kind': 'domain'}, session=test_session)
+ domain2 = Collection.from_config({'name': 'test.org', 'kind': 'domain'}, session=test_session)
+ test_session.commit()
+
+ # Import config with only one domain
+ config_file = test_directory / 'domains.yaml'
+ config_data = {
+ 'version': 0,
+ 'collections': [
+ {'name': 'example.com', 'kind': 'domain'},
+ ]
+ }
+ config_file.write_text(yaml.dump(config_data))
+
+ config = DomainsConfig()
+ config.import_config(file=config_file)
+
+ # test.org should be deleted
+ remaining = test_session.query(Collection).filter_by(kind='domain').all()
+ assert len(remaining) == 1
+ assert remaining[0].name == 'example.com'
+
+ def test_domains_config_updates_existing_domain(self, test_session: Session, test_directory: pathlib.Path,
+ async_client):
+ """Test that updating a domain in config updates the existing collection."""
+ # Create initial domain
+ domain = Collection.from_config({
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'description': 'Original description',
+ }, session=test_session)
+ test_session.commit()
+ domain_id = domain.id
+
+ # Import config with updated description
+ config_file = test_directory / 'domains.yaml'
+ config_data = {
+ 'version': 0,
+ 'collections': [
+ {
+ 'name': 'example.com',
+ 'kind': 'domain',
+ 'description': 'Updated description',
+ },
+ ]
+ }
+ config_file.write_text(yaml.dump(config_data))
+
+ config = DomainsConfig()
+ config.import_config(file=config_file)
+
+ # Should update existing, not create new
+ all_domains = test_session.query(Collection).filter_by(kind='domain').all()
+ assert len(all_domains) == 1
+ assert all_domains[0].id == domain_id
+ assert all_domains[0].description == 'Updated description'
+
+ def test_domains_config_skips_invalid_entries(self, test_session: Session, test_directory: pathlib.Path,
+ async_client):
+ """Test that invalid entries are skipped but valid ones are imported."""
+ config_file = test_directory / 'domains.yaml'
+ config_data = {
+ 'version': 0,
+ 'collections': [
+ {'name': 'example.com', 'kind': 'domain'}, # Valid
+ {'name': 'invalid', 'kind': 'domain'}, # Invalid - no dot
+ {'name': 'test.org', 'kind': 'domain'}, # Valid
+ ]
+ }
+ config_file.write_text(yaml.dump(config_data))
+
+ config = DomainsConfig()
+ config.import_config(file=config_file)
+
+ # Should have imported 2 valid domains, skipped 1 invalid
+ domains = test_session.query(Collection).filter_by(kind='domain').all()
+ assert len(domains) == 2
+ names = {d.name for d in domains}
+ assert 'example.com' in names
+ assert 'test.org' in names
+ assert 'invalid' not in names
+
+ def test_domains_config_global_instance(self):
+ """Test that the global domains_config instance exists."""
+ assert domains_config is not None
+ assert isinstance(domains_config, DomainsConfig)
+ assert domains_config.file_name == 'domains.yaml'
diff --git a/wrolpi/collections/types.py b/wrolpi/collections/types.py
new file mode 100644
index 000000000..7ecc12fbc
--- /dev/null
+++ b/wrolpi/collections/types.py
@@ -0,0 +1,137 @@
+"""
+Collection Type Registry
+
+This module provides a registry system for different collection types (domains, channels, etc.)
+to register their own validators and behavior rules.
+"""
+from typing import Callable, Dict, Optional
+
+from wrolpi.common import logger
+
+logger = logger.getChild(__name__)
+
+# Type alias for validator functions
+ValidatorFunc = Callable[[str], bool]
+
+
+class CollectionTypeRegistry:
+ """
+ Registry for collection type validators and rules.
+
+ Allows different collection types to register their own validation logic,
+ making the system extensible for future collection types.
+ """
+
+ def __init__(self):
+ self._validators: Dict[str, ValidatorFunc] = {}
+ self._descriptions: Dict[str, str] = {}
+
+ def register(self, kind: str, validator: ValidatorFunc, description: str = ""):
+ """
+ Register a validator for a collection type.
+
+ Args:
+ kind: The collection kind (e.g., 'domain', 'channel')
+ validator: Function that takes a name and returns True if valid
+ description: Human-readable description of validation rules
+ """
+ if kind in self._validators:
+ logger.warning(f"Overwriting existing validator for collection kind '{kind}'")
+
+ self._validators[kind] = validator
+ self._descriptions[kind] = description
+ logger.debug(f"Registered validator for collection kind '{kind}'")
+
+ def validate(self, kind: str, name: str) -> bool:
+ """
+ Validate a collection name for a given type.
+
+ Args:
+ kind: The collection kind
+ name: The collection name to validate
+
+ Returns:
+ True if valid, False otherwise
+ If no validator is registered for the kind, returns True (permissive)
+ """
+ if kind not in self._validators:
+ logger.debug(f"No validator registered for kind '{kind}', allowing any name")
+ return True
+
+ return self._validators[kind](name)
+
+ def get_description(self, kind: str) -> Optional[str]:
+ """Get the validation description for a collection type."""
+ return self._descriptions.get(kind)
+
+ def is_registered(self, kind: str) -> bool:
+ """Check if a validator is registered for a collection type."""
+ return kind in self._validators
+
+
+# Global registry instance
+collection_type_registry = CollectionTypeRegistry()
+
+
+# Domain validator
+def validate_domain_name(name: str) -> bool:
+ """
+ Validate that a name is a valid domain format.
+
+ A valid domain must:
+ - Be a string
+ - Contain at least one "." (e.g., "example.com")
+ - Not start or end with "."
+
+ Examples:
+ Valid: "example.com", "sub.example.com", "a.b.c"
+ Invalid: "example", "example.", ".example", "."
+
+ Args:
+ name: The domain name to validate
+
+ Returns:
+ True if valid domain format, False otherwise
+ """
+ if not isinstance(name, str) or not name:
+ return False
+
+ # Must contain at least one "."
+ if '.' not in name:
+ return False
+
+ # Should not start or end with "."
+ if name.startswith('.') or name.endswith('.'):
+ return False
+
+ return True
+
+
+# Channel validator (permissive - allows any non-empty string)
+def validate_channel_name(name: str) -> bool:
+ """
+ Validate that a name is valid for a channel.
+
+ Channels are permissive and allow any non-empty string.
+
+ Args:
+ name: The channel name to validate
+
+ Returns:
+ True if valid (non-empty string), False otherwise
+ """
+ return isinstance(name, str) and len(name.strip()) > 0
+
+
+# Register built-in collection types
+collection_type_registry.register(
+ 'domain',
+ validate_domain_name,
+ 'Domain must contain at least one "." and not start/end with "."'
+)
+
+collection_type_registry.register(
+ 'channel',
+ validate_channel_name,
+ 'Channel name must be a non-empty string'
+)
diff --git a/wrolpi/common.py b/wrolpi/common.py
index c598b7a53..94e196cf3 100644
--- a/wrolpi/common.py
+++ b/wrolpi/common.py
@@ -24,8 +24,9 @@
from itertools import islice, filterfalse, tee
from multiprocessing.managers import DictProxy
from pathlib import Path
-from types import GeneratorType, MappingProxyType
-from typing import Union, Callable, Tuple, Dict, List, Iterable, Optional, Generator, Any, Set
+from types import GeneratorType
+from typing import Optional, List
+from typing import Union, Callable, Tuple, Dict, Iterable, Generator, Any, Set
from urllib.parse import urlparse, urlunsplit
import aiohttp
@@ -731,6 +732,10 @@ def get_all_configs() -> Dict[str, ConfigFile]:
if download_manager_config := get_download_manager_config():
all_configs[download_manager_config.file_name] = download_manager_config
+ from modules.archive.lib import get_domains_config
+ if domains_config := get_domains_config():
+ all_configs[domains_config.file_name] = domains_config
+
return all_configs
@@ -1173,16 +1178,21 @@ def make_media_directory(path: Union[str, Path]):
def extract_domain(url: str) -> str:
"""
- Extract the domain from a URL. Remove leading www.
+ Extract the domain from a URL. Remove leading www and port.
>>> extract_domain('https://www.example.com/foo')
'example.com'
+ >>> extract_domain('https://example.com:443/foo')
+ 'example.com'
"""
parsed = urlparse(url)
domain = parsed.netloc
if not domain:
raise ValueError(f'URL does not have a domain: {url=}')
domain = domain.decode() if hasattr(domain, 'decode') else domain
+ # Remove port if present
+ if ':' in domain:
+ domain = domain.split(':')[0]
if domain.startswith('www.'):
# Remove leading www.
domain = domain[4:]
@@ -1902,7 +1912,8 @@ def extract_html_text(html: str) -> str:
for script in soup(["script", "style"]):
script.extract() # rip it out
- text = soup.body.get_text()
+ # Use body if it exists, otherwise use the entire soup
+ text = soup.body.get_text() if soup.body else soup.get_text()
# break into lines and remove leading and trailing space on each
lines = (line.strip() for line in text.splitlines())
@@ -1951,7 +1962,8 @@ async def search_other_estimates(tag_names: List[str]) -> dict:
stmt = '''
SELECT COUNT(c.id)
FROM channel c
- LEFT OUTER JOIN public.tag t on t.id = c.tag_id
+ INNER JOIN collection col ON col.id = c.collection_id
+ LEFT OUTER JOIN public.tag t on t.id = col.tag_id
WHERE t.name = %(tag_name)s \
'''
# TODO handle multiple tags
diff --git a/wrolpi/conftest.py b/wrolpi/conftest.py
index bd1ddbf92..4d15c99bc 100644
--- a/wrolpi/conftest.py
+++ b/wrolpi/conftest.py
@@ -9,7 +9,6 @@
import multiprocessing
import pathlib
import shutil
-import sys
import tempfile
import threading
import zipfile
@@ -19,14 +18,13 @@
from typing import List, Callable, Dict, Sequence, Union, Coroutine, Awaitable, Optional
from typing import Tuple, Set
from unittest import mock
-from unittest.mock import MagicMock, AsyncMock, patch
+from unittest.mock import MagicMock, AsyncMock
from uuid import uuid1, uuid4
import pytest
import sqlalchemy
import yaml
from PIL import Image
-from sanic_testing.reusable import ReusableClient
from sanic_testing.testing import SanicASGITestClient
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
@@ -143,31 +141,6 @@ def test_wrolpi_config(test_directory) -> pathlib.Path:
ROUTES_ATTACHED = False
-@pytest.fixture()
-def test_client(test_directory) -> ReusableClient:
- """Get a Reusable Sanic Test Client with all default routes attached.
-
- (A non-reusable client would turn on for each request)
- """
- attach_shared_contexts(api_app)
-
- initialize_configs_contexts(api_app)
-
- for _ in range(5):
- # Sometimes the Sanic client tries to use a port already in use, try again...
- try:
- client = ReusableClient(api_app)
- with client:
- yield client
- break
- except OSError as e:
- # Ignore errors where the port was already in use.
- if 'address already in use' not in str(e):
- raise
- else:
- raise RuntimeError('Test never got unused port')
-
-
@api_app.on_response
async def background_task_listener(request, response):
"""Wait for all background tasks to finish before returning API response while testing."""
@@ -623,12 +596,12 @@ def _(file_groups: List[Dict], assert_count: bool = True):
@pytest.fixture
-def assert_files_search(test_client):
+def assert_files_search(async_client):
from wrolpi.test.common import assert_dict_contains
- def _(search_str: str, expected: List[dict]):
+ async def _(search_str: str, expected: List[dict]):
content = json.dumps({'search_str': search_str})
- request, response = test_client.post('/api/files/search', content=content)
+ request, response = await async_client.post('/api/files/search', content=content)
for file_group, exp in zip_longest(response.json['file_groups'], expected):
assert_dict_contains(file_group, exp)
diff --git a/wrolpi/contexts.py b/wrolpi/contexts.py
index 90f1cb44b..5eb53a2ed 100644
--- a/wrolpi/contexts.py
+++ b/wrolpi/contexts.py
@@ -56,6 +56,7 @@ def attach_shared_contexts(app: Sanic):
app.shared_ctx.switches_lock = multiprocessing.Lock()
app.shared_ctx.switches_changed = multiprocessing.Event()
app.shared_ctx.archive_singlefiles = multiprocessing.Queue()
+ app.shared_ctx.archive_screenshots = multiprocessing.Queue()
# Warnings
app.shared_ctx.warn_once = manager.dict()
@@ -93,6 +94,12 @@ def reset_shared_contexts(app: Sanic):
drives_stats=list(),
processes_stats=list(),
memory_stats=dict(),
+ # Upgrade info defaults
+ update_available=False,
+ latest_commit=None,
+ current_commit=None,
+ commits_behind=0,
+ git_branch=None,
))
app.shared_ctx.map_importing.clear()
app.shared_ctx.cache.clear()
@@ -118,6 +125,12 @@ def reset_shared_contexts(app: Sanic):
app.shared_ctx.archive_singlefiles.get_nowait()
except queue.Empty:
break
+ while True:
+ # Clear out any pending screenshot generation switches.
+ try:
+ app.shared_ctx.archive_screenshots.get_nowait()
+ except queue.Empty:
+ break
# Events.
app.shared_ctx.single_tasks_started.clear()
diff --git a/wrolpi/downloader.py b/wrolpi/downloader.py
index e701ef270..73f3957f8 100644
--- a/wrolpi/downloader.py
+++ b/wrolpi/downloader.py
@@ -117,9 +117,10 @@ class Download(ModelHelper, Base): # noqa
status = Column(String, default=DownloadStatus.new) # `DownloadStatus` enum.
tag_names = Column(ARRAY(Text))
- # A Download may be associated with a Channel (downloads all Channel videos, or a playlist, etc.).
- channel_id = Column(Integer, ForeignKey('channel.id'))
- channel = relationship('Channel', primaryjoin='Download.channel_id==Channel.id', back_populates='downloads')
+ # A Download may be associated with a Collection (downloads for channels, domains, etc.).
+ collection_id = Column(Integer, ForeignKey('collection.id', ondelete='SET NULL'))
+ collection = relationship('Collection', primaryjoin='Download.collection_id==Collection.id',
+ back_populates='downloads')
def __repr__(self):
if self.next_download or self.frequency:
@@ -132,7 +133,7 @@ def __repr__(self):
def __json__(self) -> dict:
d = dict(
attempts=self.attempts,
- channel_id=self.channel_id,
+ collection_id=self.collection_id,
downloader=self.downloader,
frequency=self.frequency,
destination=self.destination,
@@ -226,12 +227,24 @@ def delete(self, add_to_skip_list: bool = True):
session = Session.object_session(self)
+ # Get collection info before deleting
+ collection = self.collection
+ collection_kind = collection.kind if collection else None
+
session.delete(self)
session.commit()
- if self.channel_id:
- # Save Channels config if this download was associated with a Channel.
+
+ # Save appropriate config based on collection kind
+ if collection_kind == 'channel':
from modules.videos.lib import save_channels_config
save_channels_config.activate_switch()
+ elif collection_kind == 'domain':
+ from modules.archive.lib import save_domains_config
+ save_domains_config.activate_switch()
+ elif collection_kind:
+ from wrolpi.collections.config import save_collections_config
+ save_collections_config.activate_switch()
+
# Save download config again because this download is now removed from the download lists.
save_downloads_config.activate_switch()
@@ -612,9 +625,9 @@ def create_downloads(self, urls: List[str], downloader_name: str, session: Sessi
download.tag_names = tag_names or None
# Preserve existing settings, unless new settings are provided.
download.settings = settings if settings is not None else download.settings
- if download.frequency and download.settings and (channel_id := download.settings.get('channel_id')):
- # Attach a recurring Channel download to it's Channel.
- download.channel_id = download.channel_id or channel_id
+ if download.frequency and download.settings and (collection_id := download.settings.get('collection_id')):
+ # Attach a recurring download to its Collection.
+ download.collection_id = download.collection_id or collection_id
downloads.append(download)
@@ -647,7 +660,7 @@ def create_download(self, url: str, downloader_name: str, session: Session = Non
def recurring_download(self, url: str, frequency: int, downloader_name: str, session: Session = None,
sub_downloader_name: str = None, reset_attempts: bool = False,
destination: str | pathlib.Path = None, tag_names: List[str] = None,
- settings: Dict = None) -> Download:
+ settings: Dict = None, collection_id: int = None) -> Download:
"""Schedule a recurring download."""
if not frequency or not isinstance(frequency, int):
raise ValueError('Recurring download must have a frequency!')
@@ -660,11 +673,13 @@ def recurring_download(self, url: str, frequency: int, downloader_name: str, ses
reset_attempts=reset_attempts, sub_downloader_name=sub_downloader_name,
destination=destination, tag_names=tag_names, settings=settings)
download.frequency = frequency
+ if collection_id:
+ download.collection_id = collection_id
- # Only recurring Downloads can be Channel Downloads.
+ # Only recurring Downloads can be Collection Downloads - look up via Channel.
from modules.videos.models import Channel
if channel := Channel.get_by_url(url=download.url, session=session):
- download.channel_id = channel.id
+ download.collection_id = channel.collection_id
session.commit()
@@ -674,10 +689,11 @@ def recurring_download(self, url: str, frequency: int, downloader_name: str, ses
def update_download(self, id_: int, url: str, downloader: str,
destination: str | pathlib.Path = None, tag_names: List[str] = None,
sub_downloader: str | None = None, frequency: int = None,
- settings: Dict = None, session: Session = None) -> Download:
+ settings: Dict = None, collection_id: int = None,
+ session: Session = None) -> Download:
download = Download.find_by_id(id_, session=session)
- if settings and settings.get('channel_id') and not frequency:
- raise InvalidDownload(f'A once-download cannot be associated with a Channel')
+ if collection_id and not frequency:
+ raise InvalidDownload(f'A once-download cannot be associated with a Collection')
download.url = url
download.downloader = downloader
download.frequency = frequency
@@ -687,8 +703,8 @@ def update_download(self, id_: int, url: str, downloader: str,
download.destination = destination or None
download.tag_names = tag_names or None
download.sub_downloader = sub_downloader or None
- # Remove Channel relationship, if necessary.
- download.channel_id = (settings or dict()).get('channel_id')
+ # Update Collection relationship
+ download.collection_id = collection_id
save_downloads_config.activate_switch()
@@ -939,7 +955,7 @@ def get_fe_downloads(self):
with get_db_curs() as curs:
stmt = f'''
SELECT
- channel_id,
+ collection_id,
destination,
downloader,
error,
@@ -1462,7 +1478,7 @@ async def do_download(self, download: Download) -> DownloadResult:
# Apply YT channel to the Download, if not already applied.
if yt_channel_id := feed.get('feed', dict()).get('yt_channelid'):
- if not (download.location or download.channel_id):
+ if not (download.location or download.collection_id):
self.apply_yt_channel(download.id, yt_channel_id)
# Filter entries using Download.settings.
@@ -1562,14 +1578,14 @@ def filter_entries(download: Download, entries: List[dict]) -> List[dict]:
@staticmethod
def apply_yt_channel(download_id: int, yt_channel_id: str):
- """Get Channel that matches this Download, apply Channel information to the Download."""
+ """Get Channel that matches this Download, apply Channel/Collection information to the Download."""
with get_db_session() as session:
from modules.videos.models import Channel
channel = Channel.get_by_source_id(session, f'UC{yt_channel_id}')
if channel:
download_ = Download.get_by_id(download_id, session=session)
- download_.channel = channel
- download_.channel_id = channel.id
+ download_.collection = channel.collection
+ download_.collection_id = channel.collection_id
download_.location = download_.location or channel.location
session.commit()
diff --git a/wrolpi/errors.py b/wrolpi/errors.py
index 735a93e51..049f908e1 100644
--- a/wrolpi/errors.py
+++ b/wrolpi/errors.py
@@ -187,3 +187,9 @@ class DownloadError(APIError):
code = 'DOWNLOAD_ERROR'
summary = 'Unable to complete download'
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+class UnknownCollection(APIError):
+ code = 'UNKNOWN_COLLECTION'
+ summary = 'Unable to find the collection'
+ status_code = HTTPStatus.NOT_FOUND
diff --git a/wrolpi/events.py b/wrolpi/events.py
index ce6043085..af3a59dd1 100644
--- a/wrolpi/events.py
+++ b/wrolpi/events.py
@@ -100,6 +100,18 @@ def send_archive_uploaded(cls, message: str = None, url: str = None):
def send_upload_archive_failed(cls, message: str = None):
send_event('upload_archive_failed', message, subject='upload')
+ @classmethod
+ def send_screenshot_generated(cls, message: str = None, url: str = None):
+ send_event('screenshot_generated', message, subject='screenshot', url=url)
+
+ @classmethod
+ def send_screenshot_generation_failed(cls, message: str = None):
+ send_event('screenshot_generation_failed', message, subject='screenshot')
+
+ @staticmethod
+ def send_upgrade_started(message: str = None):
+ send_event('upgrade_started', message, subject='upgrade')
+
def log_event(event: str, message: str = None, action: str = None, subject: str = None):
log = f'{event=}'
diff --git a/wrolpi/files/api.py b/wrolpi/files/api.py
index d83bb3229..18f6711d9 100644
--- a/wrolpi/files/api.py
+++ b/wrolpi/files/api.py
@@ -1,4 +1,20 @@
-import os.path
+import pathlib
+from http import HTTPStatus
+
+import sanic.request
+from sanic import response, Request, Blueprint
+from sanic_ext import validate
+from sanic_ext.extensions.openapi import openapi
+
+from wrolpi.common import get_media_directory, wrol_mode_check, get_relative_to_media_directory, logger, \
+ background_task, walk, timer, TRACE_LEVEL, unique_by_predicate
+from wrolpi.errors import InvalidFile, UnknownDirectory, FileUploadFailed, FileConflict
+from . import lib, schema
+from ..api_utils import json_response, api_app
+from ..events import Events
+from ..schema import JSONErrorResponse
+from ..tags import Tag
+from ..vars import PYTEST
import pathlib
from http import HTTPStatus
@@ -165,7 +181,8 @@ async def post_search_directories(_, body: schema.DirectoriesSearchRequest):
# Search Domains by name.
from modules.archive.lib import search_domains_by_name
domains = await search_domains_by_name(name=body.path)
- domain_directories = [dict(path=i.directory, domain=i.domain) for i in domains]
+ # search_domains_by_name returns dicts with 'directory' and 'domain' keys
+ domain_directories = [dict(path=i['directory'], domain=i['domain']) for i in domains]
domain_paths = [i['path'] for i in domain_directories]
if logger.isEnabledFor(TRACE_LEVEL):
logger.trace(f'post_search_directories: {domain_paths=}')
diff --git a/wrolpi/files/indexers.py b/wrolpi/files/indexers.py
index 0b58997cb..fe3922dde 100644
--- a/wrolpi/files/indexers.py
+++ b/wrolpi/files/indexers.py
@@ -7,7 +7,6 @@
import docx
-from wrolpi import cmd
from wrolpi.cmd import CATDOC_PATH, TEXTUTIL_PATH
from wrolpi.vars import PYTEST, FILE_MAX_TEXT_SIZE
diff --git a/wrolpi/files/lib.py b/wrolpi/files/lib.py
index 756eb5a3a..c95416324 100644
--- a/wrolpi/files/lib.py
+++ b/wrolpi/files/lib.py
@@ -4,7 +4,6 @@
import functools
import glob
import json
-import os
import pathlib
import re
import shutil
diff --git a/wrolpi/files/test/test_api.py b/wrolpi/files/test/test_api.py
index 6fd1078d4..69b48aa77 100644
--- a/wrolpi/files/test/test_api.py
+++ b/wrolpi/files/test/test_api.py
@@ -16,7 +16,8 @@
from wrolpi.vars import PROJECT_DIR
-def test_list_files_api(test_session, test_client, make_files_structure, test_directory):
+@pytest.mark.asyncio
+async def test_list_files_api(test_session, async_client, make_files_structure, test_directory):
files = [
'archives/bar.txt',
'archives/baz/bar.txt',
@@ -30,8 +31,8 @@ def test_list_files_api(test_session, test_client, make_files_structure, test_di
files = make_files_structure(files)
files[0].write_text('bar contents')
- def check_get_files(directories, expected_files):
- request, response = test_client.post('/api/files', content=json.dumps({'directories': directories}))
+ async def check_get_files(directories, expected_files):
+ request, response = await async_client.post('/api/files', content=json.dumps({'directories': directories}))
assert response.status_code == HTTPStatus.OK
assert not response.json.get('errors')
# The first dict is the media directory.
@@ -44,14 +45,14 @@ def check_get_files(directories, expected_files):
'empty directory/': {'path': 'empty directory/', 'is_empty': True},
'videos/': {'path': 'videos/', 'is_empty': False}
}
- check_get_files([], expected)
+ await check_get_files([], expected)
# empty directory is empty
expected = {
'archives/': {'path': 'archives/', 'is_empty': False},
'empty directory/': {'path': 'empty directory/', 'children': {}, 'is_empty': True},
'videos/': {'path': 'videos/', 'is_empty': False}
}
- check_get_files(['empty directory'], expected)
+ await check_get_files(['empty directory'], expected)
expected = {
'archives/': {
@@ -66,7 +67,7 @@ def check_get_files(directories, expected_files):
'empty directory/': {'path': 'empty directory/', 'is_empty': True},
'videos/': {'path': 'videos/', 'is_empty': False}
}
- check_get_files(['archives'], expected)
+ await check_get_files(['archives'], expected)
# Sub-directories are supported.
expected = {
@@ -89,9 +90,9 @@ def check_get_files(directories, expected_files):
'empty directory/': {'path': 'empty directory/', 'is_empty': True},
'videos/': {'path': 'videos/', 'is_empty': False},
}
- check_get_files(['archives', 'archives/baz'], expected)
+ await check_get_files(['archives', 'archives/baz'], expected)
# Requesting only a subdirectory also returns `archives` contents.
- check_get_files(['archives/baz'], expected)
+ await check_get_files(['archives/baz'], expected)
expected = {
'archives/': {
@@ -112,41 +113,43 @@ def check_get_files(directories, expected_files):
'is_empty': False
}
}
- check_get_files(['archives', 'videos'], expected)
+ await check_get_files(['archives', 'videos'], expected)
# Order does not matter.
- check_get_files(['videos', 'archives'], expected)
+ await check_get_files(['videos', 'archives'], expected)
-def test_delete_file(test_session, test_client, make_files_structure, test_directory):
+@pytest.mark.asyncio
+async def test_delete_file(test_session, async_client, make_files_structure, test_directory):
files = ['bar.txt', 'baz/', 'foo']
make_files_structure(files)
# Delete a file.
- request, response = test_client.post('/api/files/delete', content=json.dumps({'paths': ['bar.txt']}))
+ request, response = await async_client.post('/api/files/delete', content=json.dumps({'paths': ['bar.txt']}))
assert response.status_code == HTTPStatus.NO_CONTENT
assert not (test_directory / 'bar.txt').is_file()
assert (test_directory / 'baz').is_dir()
# Delete a directory.
- request, response = test_client.post('/api/files/delete', content=json.dumps({'paths': ['baz']}))
+ request, response = await async_client.post('/api/files/delete', content=json.dumps({'paths': ['baz']}))
assert response.status_code == HTTPStatus.NO_CONTENT
assert not (test_directory / 'bar.txt').is_file()
assert not (test_directory / 'baz').is_dir()
- request, response = test_client.post('/api/files/delete', content=json.dumps({'paths': ['bad file']}))
+ request, response = await async_client.post('/api/files/delete', content=json.dumps({'paths': ['bad file']}))
assert response.status_code == HTTPStatus.BAD_REQUEST
+@pytest.mark.asyncio
@pytest.mark.parametrize(
'paths', [
[],
['', ],
]
)
-def test_delete_invalid_file(test_client, paths):
+async def test_delete_invalid_file(async_client, paths):
"""Some paths must be passed."""
with mock.patch('wrolpi.files.api.lib.delete') as mock_delete_file:
- request, response = test_client.post('/api/files/delete', content=json.dumps({'paths': paths}))
+ request, response = await async_client.post('/api/files/delete', content=json.dumps({'paths': paths}))
assert response.status_code == HTTPStatus.BAD_REQUEST
mock_delete_file.assert_not_called()
@@ -197,9 +200,10 @@ async def test_files_search_recent(test_session, test_directory, async_client, v
assert [i['name'] for i in response.json['file_groups']] == ['foo.mp4']
-def test_files_search(test_session, test_client, make_files_structure, assert_files_search):
+@pytest.mark.asyncio
+async def test_files_search(test_session, async_client, make_files_structure, assert_files_search):
# You can search an empty directory.
- assert_files_search('nothing', [])
+ await assert_files_search('nothing', [])
# Create files in the temporary directory. Add some contents so the mimetype can be tested.
files = [
@@ -215,14 +219,14 @@ def test_files_search(test_session, test_client, make_files_structure, assert_fi
baz2.write_bytes((PROJECT_DIR / 'test/big_buck_bunny_720p_1mb.mp4').read_bytes())
# Refresh so files can be searched.
- request, response = test_client.post('/api/files/refresh')
+ request, response = await async_client.post('/api/files/refresh')
assert response.status_code == HTTPStatus.NO_CONTENT
- assert_files_search('foo', [dict(primary_path='foo_is_the_name.txt')])
- assert_files_search('bar', [dict(primary_path='archives/bar.txt')])
- assert_files_search('baz', [dict(primary_path='baz baz two.mp4'), dict(primary_path='baz.mp4')])
- assert_files_search('two', [dict(primary_path='baz baz two.mp4')])
- assert_files_search('nothing', [])
+ await assert_files_search('foo', [dict(primary_path='foo_is_the_name.txt')])
+ await assert_files_search('bar', [dict(primary_path='archives/bar.txt')])
+ await assert_files_search('baz', [dict(primary_path='baz baz two.mp4'), dict(primary_path='baz.mp4')])
+ await assert_files_search('two', [dict(primary_path='baz baz two.mp4')])
+ await assert_files_search('nothing', [])
@pytest.mark.asyncio
@@ -267,7 +271,8 @@ async def test_files_search_any_tag(async_client, test_session, make_files_struc
assert response.status_code == HTTPStatus.BAD_REQUEST
-def test_refresh_files_list(test_session, test_client, make_files_structure, test_directory, video_bytes):
+@pytest.mark.asyncio
+async def test_refresh_files_list(test_session, async_client, make_files_structure, test_directory, video_bytes):
"""The user can request to refresh specific files."""
make_files_structure({
'bar.txt': 'hello',
@@ -276,20 +281,21 @@ def test_refresh_files_list(test_session, test_client, make_files_structure, tes
# Only the single file that was refreshed is discovered.
content = json.dumps({'paths': ['bar.txt']})
- request, response = test_client.post('/api/files/refresh', content=content)
+ request, response = await async_client.post('/api/files/refresh', content=content)
assert response.status_code == HTTPStatus.NO_CONTENT
assert test_session.query(FileGroup).count() == 1
group: FileGroup = test_session.query(FileGroup).one()
assert len(group.files) == 1
- request, response = test_client.post('/api/files/refresh')
+ request, response = await async_client.post('/api/files/refresh')
assert response.status_code == HTTPStatus.NO_CONTENT
group: FileGroup = test_session.query(FileGroup).one()
assert len(group.files) == 2
-def test_file_statistics(test_session, test_client, test_directory, example_pdf, example_mobi, example_epub,
- video_file):
+@pytest.mark.asyncio
+async def test_file_statistics(test_session, async_client, test_directory, example_pdf, example_mobi, example_epub,
+ video_file):
"""A summary of File statistics can be fetched."""
# Give each file a unique stem.
video_file.rename(test_directory / 'video.mp4')
@@ -298,7 +304,7 @@ def test_file_statistics(test_session, test_client, test_directory, example_pdf,
example_epub.rename(test_directory / 'epub.epub')
# Statistics can be fetched while empty.
- request, response = test_client.get('/api/statistics')
+ request, response = await async_client.get('/api/statistics')
assert response.status_code == HTTPStatus.OK
assert response.json['file_statistics'] == {
'archive_count': 0,
@@ -315,9 +321,9 @@ def test_file_statistics(test_session, test_client, test_directory, example_pdf,
'zip_count': 0,
}
- test_client.post('/api/files/refresh')
+ await async_client.post('/api/files/refresh')
- request, response = test_client.get('/api/statistics')
+ request, response = await async_client.get('/api/statistics')
assert response.status_code == HTTPStatus.OK
stats = response.json['file_statistics']
stats.pop('total_size')
@@ -475,11 +481,12 @@ async def test_post_search_directories(test_session, async_client, make_files_st
assert response.status_code == HTTPStatus.NO_CONTENT
from modules.videos.models import Channel
- channel1 = Channel(directory=channel1_dir, name='Channel Name')
- channel2 = Channel(directory=channel2_dir, name='OtherChannel')
- from modules.archive.models import Domain
- domain = Domain(directory=domain_dir, domain='example.com')
- test_session.add_all([channel1, channel2, domain])
+ # Use from_config to create channels properly (creates Collection first)
+ channel1 = Channel.from_config({'directory': channel1_dir, 'name': 'Channel Name'}, session=test_session)
+ channel2 = Channel.from_config({'directory': channel2_dir, 'name': 'OtherChannel'}, session=test_session)
+ from wrolpi.collections import Collection
+ domain_collection = Collection(directory=domain_dir, name='example.com', kind='domain')
+ test_session.add(domain_collection)
test_session.commit()
# All directories contain "di". The names of the Channel and Directory do not match.
@@ -520,7 +527,9 @@ async def test_post_search_directories(test_session, async_client, make_files_st
assert response.json['domain_directories'] == [{'domain': 'example.com', 'path': 'dir3'}]
-def test_post_upload_directory(test_session, test_client, test_directory, make_files_structure, make_multipart_form):
+@pytest.mark.asyncio
+async def test_post_upload_directory(test_session, async_client, test_directory, make_files_structure,
+ make_multipart_form):
"""A file can be uploaded in a directory in the destination."""
make_files_structure(['uploads/'])
@@ -534,7 +543,7 @@ def test_post_upload_directory(test_session, test_client, test_directory, make_f
]
body = make_multipart_form(forms)
headers = {'Content-Type': 'multipart/form-data; boundary=-----------------------------sanic'}
- request, response = test_client.post('/api/files/upload', content=body, headers=headers)
+ request, response = await async_client.post('/api/files/upload', content=body, headers=headers)
assert response.status_code == HTTPStatus.CREATED
assert (test_directory / 'uploads/foo/bar.txt').is_file()
@@ -681,7 +690,8 @@ async def test_post_upload(test_session, async_client, test_directory, make_file
assert not video.channel and not video.channel_id
-def test_post_upload_text(test_session, test_client, test_directory, make_files_structure, make_multipart_form):
+@pytest.mark.asyncio
+async def test_post_upload_text(test_session, async_client, test_directory, make_files_structure, make_multipart_form):
"""A file that cannot be modeled can still be uploaded, and is indexed."""
make_files_structure(['uploads/'])
@@ -696,7 +706,7 @@ def test_post_upload_text(test_session, test_client, test_directory, make_files_
]
body = make_multipart_form(forms)
headers = {'Content-Type': 'multipart/form-data; boundary=-----------------------------sanic'}
- request, response = test_client.post('/api/files/upload', content=body, headers=headers)
+ request, response = await async_client.post('/api/files/upload', content=body, headers=headers)
assert response.status_code == HTTPStatus.CREATED, response.content.decode()
file_group: FileGroup = test_session.query(FileGroup).one()
@@ -896,7 +906,8 @@ async def test_delete_directory_recursive(test_session, test_directory, make_fil
@pytest.mark.asyncio
-async def test_get_file(test_session, async_client, test_directory, make_files_structure, await_background_tasks, await_switches):
+async def test_get_file(test_session, async_client, test_directory, make_files_structure, await_background_tasks,
+ await_switches):
"""Can get info about a single file."""
make_files_structure({'foo/bar.txt': 'foo contents'})
await lib.refresh_files()
diff --git a/wrolpi/files/test/test_ebooks.py b/wrolpi/files/test/test_ebooks.py
index 4958d7165..1ac4cc027 100644
--- a/wrolpi/files/test/test_ebooks.py
+++ b/wrolpi/files/test/test_ebooks.py
@@ -77,9 +77,10 @@ async def test_extract_cover(test_session, test_directory, example_epub, await_s
assert ebook.cover_path and ebook.cover_path.stat().st_size == 297099
-def test_search_ebooks(test_session, test_client, example_epub):
+@pytest.mark.asyncio
+async def test_search_ebooks(test_session, async_client, example_epub):
"""Ebooks are handled in File search results."""
- request, response = test_client.post('/api/files/refresh')
+ request, response = await async_client.post('/api/files/refresh')
assert response.status_code == HTTPStatus.NO_CONTENT
assert test_session.query(EBook).count() == 1
@@ -91,7 +92,7 @@ def test_search_ebooks(test_session, test_client, example_epub):
assert ebook.file_group.d_text, 'Book was not indexed'
content = dict(mimetypes=['application/epub', 'application/x-mobipocket-ebook'])
- request, response = test_client.post('/api/files/search', content=json.dumps(content))
+ request, response = await async_client.post('/api/files/search', content=json.dumps(content))
assert response.status_code == HTTPStatus.OK
assert response.json
file_group = response.json['file_groups'][0]
@@ -105,7 +106,7 @@ def test_search_ebooks(test_session, test_client, example_epub):
# No Mobi ebook.
content = dict(mimetypes=['application/x-mobipocket-ebook'])
- request, response = test_client.post('/api/files/search', content=json.dumps(content))
+ request, response = await async_client.post('/api/files/search', content=json.dumps(content))
assert response.status_code == HTTPStatus.OK
assert response.json
assert len(response.json['file_groups']) == 0
diff --git a/wrolpi/files/test/test_lib.py b/wrolpi/files/test/test_lib.py
index 32fd6a22f..df9f7860e 100644
--- a/wrolpi/files/test/test_lib.py
+++ b/wrolpi/files/test/test_lib.py
@@ -673,8 +673,9 @@ def test_get_primary_file(test_directory, video_file, srt_file3, example_epub, e
assert lib.get_primary_file([foo, bar])
-def test_get_refresh_progress(test_client, test_session):
- request, response = test_client.get('/api/files/refresh_progress')
+@pytest.mark.asyncio
+async def test_get_refresh_progress(async_client, test_session):
+ request, response = await async_client.get('/api/files/refresh_progress')
assert response.status_code == HTTPStatus.OK
assert 'progress' in response.json
progress = response.json['progress']
@@ -709,7 +710,7 @@ async def test_refresh_files_no_groups(test_session, test_directory, make_files_
@pytest.mark.asyncio
-async def test_refresh_directories(test_session, test_directory, assert_directories):
+async def test_refresh_directories(test_session, test_directory, assert_directories, await_switches):
"""
Directories are stored when they are discovered. They are removed when they can no longer be found.
"""
diff --git a/wrolpi/root_api.py b/wrolpi/root_api.py
index 22361a18e..9ccadf8ac 100644
--- a/wrolpi/root_api.py
+++ b/wrolpi/root_api.py
@@ -21,6 +21,7 @@
from wrolpi import tags
from wrolpi.admin import HotspotStatus
from wrolpi.api_utils import json_response, api_app
+from wrolpi.collections.api import collection_bp
from wrolpi.common import logger, get_wrolpi_config, wrol_mode_enabled, get_media_directory, \
wrol_mode_check, native_only, disable_wrol_mode, enable_wrol_mode, get_global_statistics, url_strip_host, \
set_global_log_level, get_relative_to_media_directory, search_other_estimates
@@ -47,6 +48,7 @@
# Blueprints order here defines what order they are displayed in OpenAPI Docs.
api_app.blueprint(api_bp)
api_app.blueprint(archive_bp)
+api_app.blueprint(collection_bp) # Unified collection endpoints
api_app.blueprint(config_bp)
api_app.blueprint(files_bp)
api_app.blueprint(inventory_bp)
@@ -280,7 +282,7 @@ async def post_download(_: Request, body: schema.DownloadRequest):
destination=body.destination, tag_names=body.tag_names,
settings=body.settings)
if body.frequency:
- download_manager.recurring_download(body.urls[0], body.frequency, **kwargs)
+ download_manager.recurring_download(body.urls[0], body.frequency, collection_id=body.collection_id, **kwargs)
else:
download_manager.create_downloads(body.urls, **kwargs)
if download_manager.disabled.is_set() or download_manager.stopped.is_set():
@@ -314,6 +316,7 @@ async def put_download(_: Request, download_id: int, body: schema.DownloadReques
tag_names=body.tag_names,
sub_downloader=body.sub_downloader,
settings=body.settings,
+ collection_id=body.collection_id,
session=session,
)
if download_manager.disabled.is_set() or download_manager.stopped.is_set():
@@ -488,7 +491,7 @@ async def get_status(request: Request):
throttle_status=admin.throttle_status().name,
version=__version__,
wrol_mode=wrol_mode_enabled(),
- # Include all stats from status worker.
+ # Include all stats from status worker (includes upgrade info and git_branch).
**api_app.shared_ctx.status,
)
return json_response(ret)
@@ -614,6 +617,54 @@ async def post_shutdown(_: Request):
return response.empty(HTTPStatus.NO_CONTENT)
+@api_bp.get('/upgrade/check')
+@openapi.definition(description='Check for available WROLPi updates')
+@native_only
+async def get_upgrade_check(request: Request):
+ """
+ Check if an update is available by comparing local git HEAD with origin/{branch}.
+
+ Query params:
+ force: If 'true', force a fresh git fetch before checking.
+ """
+ from wrolpi.upgrade import check_for_update
+
+ force = request.args.get('force', 'false').lower() == 'true'
+ result = check_for_update(fetch=force)
+
+ # Update shared_ctx.status so /api/status returns the latest info.
+ # Using shared_ctx.status (a manager.dict) ensures info is shared across all workers.
+ api_app.shared_ctx.status['update_available'] = result.get('update_available', False)
+ api_app.shared_ctx.status['latest_commit'] = result.get('latest_commit')
+ api_app.shared_ctx.status['current_commit'] = result.get('current_commit')
+ api_app.shared_ctx.status['commits_behind'] = result.get('commits_behind', 0)
+
+ return json_response(result)
+
+
+@api_bp.post('/upgrade/start')
+@openapi.definition(description='Start WROLPi upgrade')
+@native_only
+@wrol_mode_check
+async def post_upgrade_start(_: Request):
+ """
+ Trigger the WROLPi upgrade process.
+
+ This executes /opt/wrolpi/upgrade.sh which will:
+ 1. Stop the API and app services
+ 2. Fetch latest code from git
+ 3. Run upgrade scripts
+ 4. Restart services
+
+ The frontend should redirect to the maintenance page after calling this.
+ Returns 202 Accepted as the upgrade runs asynchronously.
+ """
+ from wrolpi.upgrade import start_upgrade
+
+ await start_upgrade()
+ return response.empty(HTTPStatus.ACCEPTED)
+
+
@api_bp.post('/search_suggestions')
@openapi.definition(
description='Suggest related Channels/Domains/etc. to the user.',
diff --git a/wrolpi/schema.py b/wrolpi/schema.py
index c60574adb..160e8bb9a 100644
--- a/wrolpi/schema.py
+++ b/wrolpi/schema.py
@@ -263,6 +263,7 @@ class DownloadRequest:
frequency: Optional[int] = None
sub_downloader: Optional[str] = None
settings: Optional[dict] = field(default_factory=dict)
+ collection_id: Optional[int] = None
def __post_init__(self):
urls = [j for i in self.urls if (j := i.strip())]
diff --git a/wrolpi/status.py b/wrolpi/status.py
index 78d36a3a5..1b71ee4df 100755
--- a/wrolpi/status.py
+++ b/wrolpi/status.py
@@ -18,7 +18,7 @@
from wrolpi.cmd import which, run_command
from wrolpi.common import logger, get_warn_once, unique_by_predicate, partition, TRACE_LEVEL
from wrolpi.dates import now
-from wrolpi.vars import PYTEST
+from wrolpi.vars import PYTEST, DOCKERIZED
try:
import psutil
@@ -595,6 +595,7 @@ async def status_worker(count: int = None, sleep_time: int = 5):
asyncio.create_task(get_iostat_stats(), name='get_iostat_stats'),
asyncio.create_task(get_power_stats(), name='get_power_stats'),
)
+ from wrolpi.upgrade import get_current_branch
shared_status.update({
'cpu_stats': cpu_stats.__json__(),
'load_stats': load_stats.__json__(),
@@ -604,6 +605,7 @@ async def status_worker(count: int = None, sleep_time: int = 5):
'last_status': now().isoformat(),
'iostat_stats': iostat_stats.__json__(),
'power_stats': power_stats.__json__(),
+ 'git_branch': get_current_branch() if not DOCKERIZED else None,
})
if 'disk_bandwidth_stats' not in shared_status:
diff --git a/wrolpi/tags.py b/wrolpi/tags.py
index 7e73abc4c..36c9d7c4d 100644
--- a/wrolpi/tags.py
+++ b/wrolpi/tags.py
@@ -77,8 +77,8 @@ class Tag(ModelHelper, Base):
tag_files: List[TagFile] = relationship('TagFile', back_populates='tag', cascade='all')
tag_zim_entries: List = relationship('TagZimEntry', back_populates='tag', cascade='all')
- channels: List = relationship('Channel', primaryjoin='Tag.id==Channel.tag_id', back_populates='tag',
- cascade='all')
+
+ # Note: Channel relationship removed - Channels now access Tags through Collection
def __repr__(self):
name = self.name
@@ -148,8 +148,12 @@ def find_by_id(id_: int, session: Session = None) -> 'Tag':
raise UnknownTag(f'No Tag with id={id_}')
def has_relations(self) -> bool:
- """Returns True if this Tag has been used with any FileGroups or Zim Entries."""
- return bool(any(self.tag_files) or any(self.tag_zim_entries) or any(self.channels))
+ """Returns True if this Tag has been used with any FileGroups, Zim Entries, or Collections."""
+ from wrolpi.collections import Collection
+ session = Session.object_session(self)
+ # Check if tag is used by any collection (domain or channel)
+ has_collections = session.query(Collection).filter(Collection.tag_id == self.id).first() is not None
+ return bool(any(self.tag_files) or any(self.tag_zim_entries) or has_collections)
async def update_tag(self, name: str, color: str | None, session: Session):
"""Change name/color of tag. Ensures safeness of new values, checks for conflicts."""
@@ -174,7 +178,6 @@ async def update_tag(self, name: str, color: str | None, session: Session):
if Session.object_session(self):
self.flush()
- from modules.videos.models import Channel
from modules.videos.errors import ChannelDirectoryConflict
async def _():
@@ -197,23 +200,32 @@ async def _():
session.flush(to_flush)
save_downloads_config.activate_switch()
- for channel in self.channels:
- channel: Channel
- possible_directory = channel.format_directory(old_name)
- if channel.directory == possible_directory:
- # Channel is in this Tag's old directory, move the Channel to the new directory.
- new_directory = channel.format_directory(name)
- try:
- new_directory.mkdir(parents=True)
- except FileExistsError:
- raise ChannelDirectoryConflict(f'Channel directory already exists: {new_directory}')
-
- # Move the files of the Channel.
- await channel.move_channel(new_directory, session)
- else:
- msg = f"Not moving Channel because it is not in this Tag's old directory:" \
- f" {channel} {possible_directory} {self}"
- logger.warning(msg)
+ # Only move channels if tag is being renamed (not created)
+ if old_name:
+ # Get channels through Collections
+ from modules.videos.models import Channel
+ from wrolpi.collections import Collection
+ channels = session.query(Channel).join(Collection).filter(
+ Collection.tag_id == self.id,
+ Collection.kind == 'channel'
+ ).all()
+ for channel in channels:
+ channel: Channel
+ possible_directory = channel.format_directory(old_name)
+ if channel.directory == possible_directory:
+ # Channel is in this Tag's old directory, move the Channel to the new directory.
+ new_directory = channel.format_directory(name)
+ try:
+ new_directory.mkdir(parents=True)
+ except FileExistsError:
+ raise ChannelDirectoryConflict(f'Channel directory already exists: {new_directory}')
+
+ # Move the files of the Channel.
+ await channel.move_channel(new_directory, session)
+ else:
+ msg = f"Not moving Channel because it is not in this Tag's old directory:" \
+ f" {channel} {possible_directory} {self}"
+ logger.warning(msg)
if PYTEST:
await _()
@@ -639,13 +651,15 @@ def sync_tags_directory():
def get_tags() -> List[dict]:
with get_db_curs() as curs:
curs.execute('''
- SELECT t.id, t.name, t.color,
- (SELECT COUNT(*) FROM tag_file WHERE tag_id = t.id) AS file_group_count,
- (SELECT COUNT(*) FROM tag_zim WHERE tag_id = t.id) AS zim_entry_count
- FROM tag t
- GROUP BY t.id, t.name, t.color
- ORDER BY t.name
- ''')
+ SELECT t.id,
+ t.name,
+ t.color,
+ (SELECT COUNT(*) FROM tag_file WHERE tag_id = t.id) AS file_group_count,
+ (SELECT COUNT(*) FROM tag_zim WHERE tag_id = t.id) AS zim_entry_count
+ FROM tag t
+ GROUP BY t.id, t.name, t.color
+ ORDER BY t.name
+ ''')
tags = list(map(dict, curs.fetchall()))
return tags
@@ -682,15 +696,13 @@ def tag_names_to_file_group_sub_select(tag_names: List[str], params: dict) -> Tu
# This select gets an array of FileGroup.id's which are tagged with the provided tag names.
sub_select = '''
- SELECT
- tf.file_group_id
- FROM
- tag_file tf
- LEFT JOIN tag t on t.id = tf.tag_id
- GROUP BY file_group_id
- -- Match only FileGroups that have at least all the Tag names.
- HAVING array_agg(t.name)::TEXT[] @> %(tag_names)s::TEXT[]
- '''
+ SELECT tf.file_group_id
+ FROM tag_file tf
+ LEFT JOIN tag t on t.id = tf.tag_id
+ GROUP BY file_group_id
+ -- Match only FileGroups that have at least all the Tag names.
+ HAVING array_agg(t.name)::TEXT[] @> %(tag_names)s::TEXT[] \
+ '''
params['tag_names'] = tag_names
return sub_select, params
@@ -720,27 +732,24 @@ def tag_names_to_zim_sub_select(tag_names: List[str], zim_id: int = None) -> Tup
params = dict(tag_names=tag_names)
if zim_id:
stmt = '''
- SELECT
- tz.zim_id, tz.zim_entry
- FROM
- tag_zim tz
- LEFT JOIN tag t on tz.tag_id = t.id
- WHERE
- tz.zim_id = %(zim_id)s
- GROUP BY tz.zim_id, tz.zim_entry
- -- Match only TagZimEntries that have all the Tag names.
- HAVING array_agg(t.name)::TEXT[] @> %(tag_names)s::TEXT[]
- '''
+ SELECT tz.zim_id,
+ tz.zim_entry
+ FROM tag_zim tz
+ LEFT JOIN tag t on tz.tag_id = t.id
+ WHERE tz.zim_id = %(zim_id)s
+ GROUP BY tz.zim_id, tz.zim_entry
+ -- Match only TagZimEntries that have all the Tag names.
+ HAVING array_agg(t.name)::TEXT[] @> %(tag_names)s::TEXT[] \
+ '''
params['zim_id'] = zim_id
else:
stmt = '''
- SELECT
- tz.zim_id, tz.zim_entry
- FROM
- tag_zim tz
- LEFT JOIN tag t on tz.tag_id = t.id
- GROUP BY tz.zim_id, tz.zim_entry
- -- Match only TagZimEntries that have all the Tag names.
- HAVING array_agg(t.name)::TEXT[] @> %(tag_names)s::TEXT[]
- '''
+ SELECT tz.zim_id,
+ tz.zim_entry
+ FROM tag_zim tz
+ LEFT JOIN tag t on tz.tag_id = t.id
+ GROUP BY tz.zim_id, tz.zim_entry
+ -- Match only TagZimEntries that have all the Tag names.
+ HAVING array_agg(t.name)::TEXT[] @> %(tag_names)s::TEXT[] \
+ '''
return stmt, params
diff --git a/wrolpi/test/common.py b/wrolpi/test/common.py
index 1e09cf2b5..3e4f981c3 100644
--- a/wrolpi/test/common.py
+++ b/wrolpi/test/common.py
@@ -1,5 +1,4 @@
import json
-import os
import pathlib
import tempfile
from contextlib import contextmanager
@@ -12,7 +11,7 @@
from wrolpi.api_utils import CustomJSONEncoder
from wrolpi.common import get_media_directory
-from wrolpi.conftest import test_db, test_client # noqa
+from wrolpi.conftest import test_db, async_client # noqa
from wrolpi.db import postgres_engine
from wrolpi.vars import CIRCLECI, IS_MACOS
diff --git a/wrolpi/test/test_common.py b/wrolpi/test/test_common.py
index 100ca4ca7..3a59b11b0 100644
--- a/wrolpi/test/test_common.py
+++ b/wrolpi/test/test_common.py
@@ -459,6 +459,7 @@ def test_timer():
with common.timer(name='test_timer'):
time.sleep(0.1)
+
@pytest.mark.asyncio
@skip_macos
async def test_limit_concurrent_async():
diff --git a/wrolpi/test/test_root_api.py b/wrolpi/test/test_root_api.py
index 3a9dd3d64..222b5904a 100644
--- a/wrolpi/test/test_root_api.py
+++ b/wrolpi/test/test_root_api.py
@@ -205,7 +205,8 @@ async def test_echo(async_client):
assert response.json['args'] == {}
-def test_hotspot_settings(test_session, test_client, test_wrolpi_config):
+@pytest.mark.asyncio
+async def test_hotspot_settings(test_session, async_client, test_wrolpi_config):
"""
The User can toggle the Hotspot via /settings. The Hotspot can be automatically started on startup.
"""
@@ -215,28 +216,28 @@ def test_hotspot_settings(test_session, test_client, test_wrolpi_config):
with mock.patch('wrolpi.root_api.admin') as mock_admin:
# Turning on the hotspot succeeds.
mock_admin.enable_hotspot.return_value = True
- request, response = test_client.patch('/api/settings', content=json.dumps({'hotspot_status': True}))
+ request, response = await async_client.patch('/api/settings', content=json.dumps({'hotspot_status': True}))
assert response.status_code == HTTPStatus.NO_CONTENT, response.json
mock_admin.enable_hotspot.assert_called_once()
mock_admin.reset_mock()
# Turning on the hotspot fails.
mock_admin.enable_hotspot.return_value = False
- request, response = test_client.patch('/api/settings', content=json.dumps({'hotspot_status': True}))
+ request, response = await async_client.patch('/api/settings', content=json.dumps({'hotspot_status': True}))
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR, response.json
assert response.json['code'] == 'HOTSPOT_ERROR'
mock_admin.enable_hotspot.assert_called_once()
# Turning off the hotspot succeeds.
mock_admin.disable_hotspot.return_value = True
- request, response = test_client.patch('/api/settings', content=json.dumps({'hotspot_status': False}))
+ request, response = await async_client.patch('/api/settings', content=json.dumps({'hotspot_status': False}))
assert response.status_code == HTTPStatus.NO_CONTENT, response.json
mock_admin.disable_hotspot.assert_called_once()
mock_admin.reset_mock()
# Turning off the hotspot succeeds.
mock_admin.disable_hotspot.return_value = False
- request, response = test_client.patch('/api/settings', content=json.dumps({'hotspot_status': False}))
+ request, response = await async_client.patch('/api/settings', content=json.dumps({'hotspot_status': False}))
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR, response.json
assert response.json['code'] == 'HOTSPOT_ERROR'
mock_admin.disable_hotspot.assert_called_once()
@@ -248,7 +249,7 @@ def test_hotspot_settings(test_session, test_client, test_wrolpi_config):
# Hotspot password can be changed.
mock_admin.disable_hotspot.return_value = True
content = {'hotspot_password': 'new password', 'hotspot_ssid': 'new ssid'}
- request, response = test_client.patch('/api/settings', content=json.dumps(content))
+ request, response = await async_client.patch('/api/settings', content=json.dumps(content))
assert response.status_code == HTTPStatus.NO_CONTENT
assert config.hotspot_password == 'new password'
assert config.hotspot_ssid == 'new ssid'
@@ -258,7 +259,7 @@ def test_hotspot_settings(test_session, test_client, test_wrolpi_config):
# Hotspot password must be at least 8 characters.
mock_admin.disable_hotspot.return_value = True
content = {'hotspot_password': '1234567', 'hotspot_ssid': 'new ssid'}
- request, response = test_client.patch('/api/settings', content=json.dumps(content))
+ request, response = await async_client.patch('/api/settings', content=json.dumps(content))
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.json == {'code': 'HOTSPOT_PASSWORD_TOO_SHORT',
'error': 'Bad Request',
@@ -266,9 +267,10 @@ def test_hotspot_settings(test_session, test_client, test_wrolpi_config):
}
+@pytest.mark.asyncio
@skip_macos
@skip_circleci
-def test_throttle_toggle(test_session, test_client, test_wrolpi_config):
+async def test_throttle_toggle(test_session, async_client, test_wrolpi_config):
get_wrolpi_config().ignored_directories = list()
with mock.patch('wrolpi.admin.subprocess') as mock_subprocess, \
@@ -277,7 +279,7 @@ def test_throttle_toggle(test_session, test_client, test_wrolpi_config):
b'wlan0: unavailable',
b'The governor "ondemand" may decide ',
]
- request, response = test_client.get('/api/settings')
+ request, response = await async_client.get('/api/settings')
# Throttle is off by default.
assert response.status_code == HTTPStatus.OK
@@ -490,9 +492,10 @@ async def dispatch_downloads(*a, **kw):
assert download_config['tag_names'] == [tag1.name, tag2.name]
-def test_get_downloaders(test_client):
+@pytest.mark.asyncio
+async def test_get_downloaders(async_client):
"""A list of Downloaders the user can use can be gotten."""
- request, response = test_client.get('/api/downloaders')
+ request, response = await async_client.get('/api/downloaders')
assert response.status_code == HTTPStatus.OK
assert 'downloaders' in response.json, 'Downloaders not returned'
assert isinstance(response.json['downloaders'], list) and len(response.json['downloaders']), \
@@ -536,8 +539,9 @@ async def test_restart_download(test_session, async_client, test_download_manage
assert download.is_deferred, download.status
-def test_get_global_statistics(test_session, test_client):
- request, response = test_client.get('/api/statistics')
+@pytest.mark.asyncio
+async def test_get_global_statistics(test_session, async_client):
+ request, response = await async_client.get('/api/statistics')
assert response.json['global_statistics']['db_size'] > 1
@@ -594,7 +598,7 @@ async def assert_results(body: dict, expected_channels=None, expected_domains=No
{'directory': 'videos/Fool', 'id': 2, 'name': 'Fool', 'url': 'https://example.com/Fool', 'downloads': []},
{'directory': 'videos/Foo', 'id': 1, 'name': 'Foo', 'url': 'https://example.com/Foo', 'downloads': []},
],
- [{'directory': 'archive/foo.com', 'domain': 'foo.com', 'id': 1}],
+ [{'directory': 'archive/foo.com', 'domain': 'foo.com', 'id': 4}],
)
# Channel name "Fool" is matched because spaces are stripped in addition to only
@@ -610,7 +614,7 @@ async def assert_results(body: dict, expected_channels=None, expected_domains=No
[
{'directory': 'videos/Bar', 'id': 3, 'name': 'Bar', 'url': 'https://example.com/Bar', 'downloads': []}
],
- [{'directory': 'archive/bar.com', 'domain': 'bar.com', 'id': 2}],
+ [{'directory': 'archive/bar.com', 'domain': 'bar.com', 'id': 5}],
)
diff --git a/wrolpi/test/test_upgrade.py b/wrolpi/test/test_upgrade.py
new file mode 100644
index 000000000..759462e60
--- /dev/null
+++ b/wrolpi/test/test_upgrade.py
@@ -0,0 +1,310 @@
+"""
+Tests for the WROLPi upgrade system.
+"""
+import pytest
+from unittest.mock import patch, MagicMock
+
+from wrolpi.upgrade import (
+ check_for_update,
+ get_current_branch,
+ get_local_commit,
+ get_remote_commit,
+ get_commits_behind,
+ git_fetch,
+ start_upgrade,
+)
+
+
+class TestGitFunctions:
+ """Test git helper functions."""
+
+ def test_git_fetch_directory_not_exists(self):
+ """git_fetch returns False when directory doesn't exist."""
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = False
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path):
+ assert git_fetch() is False
+
+ def test_git_fetch_success(self):
+ """git_fetch returns True on success."""
+ mock_result = MagicMock()
+ mock_result.returncode = 0
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('subprocess.run', return_value=mock_result):
+ assert git_fetch() is True
+
+ def test_git_fetch_failure(self):
+ """git_fetch returns False when command fails."""
+ mock_result = MagicMock()
+ mock_result.returncode = 1
+ mock_result.stderr = b'error message'
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('subprocess.run', return_value=mock_result):
+ assert git_fetch() is False
+
+ def test_get_current_branch(self):
+ """get_current_branch returns the branch name."""
+ mock_result = MagicMock()
+ mock_result.returncode = 0
+ mock_result.stdout = b'release\n'
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('subprocess.run', return_value=mock_result):
+ assert get_current_branch() == 'release'
+
+ def test_get_current_branch_directory_not_exists(self):
+ """get_current_branch returns None when directory doesn't exist."""
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = False
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path):
+ assert get_current_branch() is None
+
+ def test_get_local_commit(self):
+ """get_local_commit returns the short commit hash."""
+ mock_result = MagicMock()
+ mock_result.returncode = 0
+ mock_result.stdout = b'abc1234\n'
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('subprocess.run', return_value=mock_result):
+ assert get_local_commit() == 'abc1234'
+
+ def test_get_remote_commit(self):
+ """get_remote_commit returns the short commit hash for origin/branch."""
+ mock_result = MagicMock()
+ mock_result.returncode = 0
+ mock_result.stdout = b'def5678\n'
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('subprocess.run', return_value=mock_result):
+ assert get_remote_commit('release') == 'def5678'
+
+ def test_get_commits_behind(self):
+ """get_commits_behind returns the number of commits behind."""
+ mock_result = MagicMock()
+ mock_result.returncode = 0
+ mock_result.stdout = b'5\n'
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+
+ with patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('subprocess.run', return_value=mock_result):
+ assert get_commits_behind('release') == 5
+
+
+class TestCheckForUpdate:
+ """Test check_for_update function."""
+
+ def test_check_for_update_dockerized(self):
+ """check_for_update returns no update in Docker."""
+ with patch('wrolpi.upgrade.DOCKERIZED', True):
+ result = check_for_update(fetch=False)
+ assert result['update_available'] is False
+
+ def test_check_for_update_directory_not_exists(self):
+ """check_for_update returns no update when directory doesn't exist."""
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = False
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.PROJECT_DIR', mock_path):
+ result = check_for_update(fetch=False)
+ assert result['update_available'] is False
+
+ def test_check_for_update_no_update(self):
+ """check_for_update returns no update when commits match."""
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('wrolpi.upgrade.get_current_branch', return_value='release'), \
+ patch('wrolpi.upgrade.git_fetch', return_value=True), \
+ patch('wrolpi.upgrade.get_local_commit', return_value='abc1234'), \
+ patch('wrolpi.upgrade.get_remote_commit', return_value='abc1234'):
+ result = check_for_update(fetch=True)
+ assert result['update_available'] is False
+ assert result['branch'] == 'release'
+ assert result['current_commit'] == 'abc1234'
+ assert result['latest_commit'] == 'abc1234'
+
+ def test_check_for_update_update_available(self):
+ """check_for_update returns update available when commits differ."""
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('wrolpi.upgrade.get_current_branch', return_value='master'), \
+ patch('wrolpi.upgrade.git_fetch', return_value=True), \
+ patch('wrolpi.upgrade.get_local_commit', return_value='abc1234'), \
+ patch('wrolpi.upgrade.get_remote_commit', return_value='def5678'), \
+ patch('wrolpi.upgrade.get_commits_behind', return_value=3):
+ result = check_for_update(fetch=True)
+ assert result['update_available'] is True
+ assert result['branch'] == 'master'
+ assert result['current_commit'] == 'abc1234'
+ assert result['latest_commit'] == 'def5678'
+ assert result['commits_behind'] == 3
+
+ def test_check_for_update_skip_fetch(self):
+ """check_for_update can skip fetch."""
+ mock_path = MagicMock()
+ mock_path.is_dir.return_value = True
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.PROJECT_DIR', mock_path), \
+ patch('wrolpi.upgrade.get_current_branch', return_value='release'), \
+ patch('wrolpi.upgrade.git_fetch') as mock_fetch, \
+ patch('wrolpi.upgrade.get_local_commit', return_value='abc1234'), \
+ patch('wrolpi.upgrade.get_remote_commit', return_value='abc1234'):
+ result = check_for_update(fetch=False)
+ mock_fetch.assert_not_called()
+ assert result['update_available'] is False
+
+
+class TestStartUpgrade:
+ """Test start_upgrade function."""
+
+ @pytest.mark.asyncio
+ async def test_start_upgrade_uses_current_branch(self, tmp_path):
+ """start_upgrade writes branch to env file and starts systemd service."""
+ mock_script_path = MagicMock()
+ mock_script_path.is_file.return_value = True
+
+ env_file = tmp_path / 'wrolpi-upgrade.env'
+
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.UPGRADE_SCRIPT', mock_script_path), \
+ patch('wrolpi.upgrade.get_current_branch', return_value='master'), \
+ patch('wrolpi.upgrade.subprocess.Popen') as mock_popen, \
+ patch('wrolpi.upgrade.pathlib.Path', return_value=env_file), \
+ patch('wrolpi.events.Events.send_upgrade_started'):
+ await start_upgrade()
+
+ # Verify env file was written with correct branch
+ assert env_file.read_text() == 'BRANCH=master\n'
+
+ # Verify systemctl start was called
+ mock_popen.assert_called_once()
+ call_args = mock_popen.call_args[0][0]
+ assert call_args == ['sudo', 'systemctl', 'start', 'wrolpi-upgrade.service']
+
+ @pytest.mark.asyncio
+ async def test_start_upgrade_uses_release_branch(self, tmp_path):
+ """start_upgrade writes release branch to env file."""
+ mock_script_path = MagicMock()
+ mock_script_path.is_file.return_value = True
+
+ env_file = tmp_path / 'wrolpi-upgrade.env'
+
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.UPGRADE_SCRIPT', mock_script_path), \
+ patch('wrolpi.upgrade.get_current_branch', return_value='release'), \
+ patch('wrolpi.upgrade.subprocess.Popen') as mock_popen, \
+ patch('wrolpi.upgrade.pathlib.Path', return_value=env_file), \
+ patch('wrolpi.events.Events.send_upgrade_started'):
+ await start_upgrade()
+
+ # Verify env file was written with correct branch
+ assert env_file.read_text() == 'BRANCH=release\n'
+
+ # Verify systemctl start was called
+ mock_popen.assert_called_once()
+ call_args = mock_popen.call_args[0][0]
+ assert call_args == ['sudo', 'systemctl', 'start', 'wrolpi-upgrade.service']
+
+ @pytest.mark.asyncio
+ async def test_start_upgrade_defaults_to_release_on_error(self, tmp_path):
+ """start_upgrade defaults to release branch if current branch cannot be determined."""
+ mock_script_path = MagicMock()
+ mock_script_path.is_file.return_value = True
+
+ env_file = tmp_path / 'wrolpi-upgrade.env'
+
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.UPGRADE_SCRIPT', mock_script_path), \
+ patch('wrolpi.upgrade.get_current_branch', return_value=None), \
+ patch('wrolpi.upgrade.subprocess.Popen') as mock_popen, \
+ patch('wrolpi.upgrade.pathlib.Path', return_value=env_file), \
+ patch('wrolpi.events.Events.send_upgrade_started'):
+ await start_upgrade()
+
+ # Verify env file was written with default release branch
+ assert env_file.read_text() == 'BRANCH=release\n'
+
+ # Verify systemctl start was called
+ mock_popen.assert_called_once()
+ call_args = mock_popen.call_args[0][0]
+ assert call_args == ['sudo', 'systemctl', 'start', 'wrolpi-upgrade.service']
+
+ @pytest.mark.asyncio
+ async def test_start_upgrade_skipped_in_docker(self):
+ """start_upgrade does nothing in Docker environment."""
+ with patch('wrolpi.upgrade.DOCKERIZED', True), \
+ patch('wrolpi.upgrade.subprocess.Popen') as mock_popen:
+ await start_upgrade()
+
+ # Verify Popen was not called
+ mock_popen.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_start_upgrade_script_not_found(self):
+ """start_upgrade does nothing if upgrade script doesn't exist."""
+ mock_script_path = MagicMock()
+ mock_script_path.is_file.return_value = False
+
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.UPGRADE_SCRIPT', mock_script_path), \
+ patch('wrolpi.upgrade.subprocess.Popen') as mock_popen:
+ await start_upgrade()
+
+ # Verify Popen was not called
+ mock_popen.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_upgrade_check_api_endpoint(async_client):
+ """Test the /api/upgrade/check endpoint."""
+ # The endpoint is native_only, so in Docker it should return 403
+ # In tests, we're not in Docker, so patch the upgrade functions
+ with patch('wrolpi.upgrade.DOCKERIZED', False), \
+ patch('wrolpi.upgrade.check_for_update') as mock_check:
+ mock_check.return_value = {
+ 'update_available': True,
+ 'current_commit': 'abc1234',
+ 'latest_commit': 'def5678',
+ 'branch': 'release',
+ 'commits_behind': 2,
+ }
+
+ _, response = await async_client.get('/api/upgrade/check')
+ # Note: This may return 403 NATIVE_ONLY in test environment
+ # The actual test depends on whether DOCKERIZED is True in tests
+
+
+@pytest.mark.asyncio
+async def test_status_endpoint_includes_git_branch(async_client):
+ """Test that /api/status includes git_branch from status worker without conflict."""
+ from wrolpi.api_utils import api_app
+
+ # Simulate status_worker having populated git_branch in shared_ctx.status
+ api_app.shared_ctx.status['git_branch'] = 'release'
+
+ try:
+ _, response = await async_client.get('/api/status')
+ assert response.status == 200
+ data = response.json
+ # git_branch should be present and correct
+ assert data.get('git_branch') == 'release'
+ finally:
+ # Clean up
+ api_app.shared_ctx.status.pop('git_branch', None)
diff --git a/wrolpi/upgrade.py b/wrolpi/upgrade.py
new file mode 100644
index 000000000..6100f123b
--- /dev/null
+++ b/wrolpi/upgrade.py
@@ -0,0 +1,255 @@
+"""
+WROLPi upgrade system.
+
+This module provides functionality to check for available updates by comparing the local git HEAD
+with the remote origin branch, and to trigger the upgrade process.
+"""
+import pathlib
+import subprocess
+
+from wrolpi.common import logger
+from wrolpi.vars import DOCKERIZED, PROJECT_DIR
+
+logger = logger.getChild(__name__)
+
+# Path to the upgrade script.
+UPGRADE_SCRIPT = PROJECT_DIR / 'upgrade.sh'
+
+
+def git_fetch() -> bool:
+ """
+ Run `git fetch` in the WROLPi directory.
+
+ Returns True if successful, False otherwise.
+ """
+ if not PROJECT_DIR.is_dir():
+ logger.warning(f'WROLPi directory does not exist: {PROJECT_DIR}')
+ return False
+
+ try:
+ result = subprocess.run(
+ ['git', 'fetch'],
+ cwd=str(PROJECT_DIR),
+ capture_output=True,
+ timeout=60,
+ )
+ if result.returncode != 0:
+ logger.warning(f'git fetch failed: {result.stderr.decode()}')
+ return False
+ return True
+ except subprocess.TimeoutExpired:
+ logger.warning('git fetch timed out')
+ return False
+ except Exception as e:
+ logger.error('git fetch failed', exc_info=e)
+ return False
+
+
+def get_current_branch() -> str | None:
+ """
+ Get the current git branch name (e.g., 'release', 'master').
+
+ Returns None if unable to determine.
+ """
+ if not PROJECT_DIR.is_dir():
+ return None
+
+ try:
+ result = subprocess.run(
+ ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
+ cwd=str(PROJECT_DIR),
+ capture_output=True,
+ timeout=10,
+ )
+ if result.returncode == 0:
+ return result.stdout.decode().strip()
+ except Exception as e:
+ logger.error('Failed to get current branch', exc_info=e)
+
+ return None
+
+
+def get_local_commit() -> str | None:
+ """
+ Get the current local HEAD commit hash (short form).
+
+ Returns None if unable to determine.
+ """
+ if not PROJECT_DIR.is_dir():
+ return None
+
+ try:
+ result = subprocess.run(
+ ['git', 'rev-parse', '--short', 'HEAD'],
+ cwd=str(PROJECT_DIR),
+ capture_output=True,
+ timeout=10,
+ )
+ if result.returncode == 0:
+ return result.stdout.decode().strip()
+ except Exception as e:
+ logger.error('Failed to get local commit', exc_info=e)
+
+ return None
+
+
+def get_remote_commit(branch: str) -> str | None:
+ """
+ Get the latest commit hash on origin/{branch} (short form).
+
+ Note: Requires `git fetch` to be run first to have up-to-date remote refs.
+
+ Returns None if unable to determine.
+ """
+ if not PROJECT_DIR.is_dir():
+ return None
+
+ try:
+ result = subprocess.run(
+ ['git', 'rev-parse', '--short', f'origin/{branch}'],
+ cwd=str(PROJECT_DIR),
+ capture_output=True,
+ timeout=10,
+ )
+ if result.returncode == 0:
+ return result.stdout.decode().strip()
+ except Exception as e:
+ logger.error(f'Failed to get remote commit for origin/{branch}', exc_info=e)
+
+ return None
+
+
+def get_commits_behind(branch: str) -> int:
+ """
+ Get the number of commits the local branch is behind origin/{branch}.
+
+ Returns 0 if unable to determine or if up-to-date.
+ """
+ if not PROJECT_DIR.is_dir():
+ return 0
+
+ try:
+ # Count commits that are in origin/branch but not in HEAD
+ result = subprocess.run(
+ ['git', 'rev-list', '--count', f'HEAD..origin/{branch}'],
+ cwd=str(PROJECT_DIR),
+ capture_output=True,
+ timeout=10,
+ )
+ if result.returncode == 0:
+ return int(result.stdout.decode().strip())
+ except Exception as e:
+ logger.error(f'Failed to get commits behind for origin/{branch}', exc_info=e)
+
+ return 0
+
+
+def check_for_update(fetch: bool = True) -> dict:
+ """
+ Check if an update is available by comparing local HEAD with origin/{branch}.
+
+ Args:
+ fetch: If True, run `git fetch` first to get latest remote refs.
+
+ Returns a dict with:
+ - update_available: bool
+ - current_commit: str (short hash) or None
+ - latest_commit: str (short hash) or None
+ - branch: str or None
+ - commits_behind: int
+ """
+ result = {
+ 'update_available': False,
+ 'current_commit': None,
+ 'latest_commit': None,
+ 'branch': None,
+ 'commits_behind': 0,
+ }
+
+ # Don't check for updates in Docker environments
+ if DOCKERIZED:
+ return result
+
+ # Check if WROLPi directory exists
+ if not PROJECT_DIR.is_dir():
+ logger.debug(f'WROLPi directory does not exist: {PROJECT_DIR}')
+ return result
+
+ # Get current branch
+ branch = get_current_branch()
+ if not branch:
+ logger.warning('Could not determine current branch')
+ return result
+
+ result['branch'] = branch
+
+ # Fetch latest from remote if requested
+ if fetch:
+ if not git_fetch():
+ logger.warning('git fetch failed, using cached remote refs')
+
+ # Get local and remote commits
+ local_commit = get_local_commit()
+ remote_commit = get_remote_commit(branch)
+
+ result['current_commit'] = local_commit
+ result['latest_commit'] = remote_commit
+
+ if not local_commit or not remote_commit:
+ return result
+
+ # Check if update is available
+ if local_commit != remote_commit:
+ commits_behind = get_commits_behind(branch)
+ result['commits_behind'] = commits_behind
+ result['update_available'] = commits_behind > 0
+
+ return result
+
+
+async def start_upgrade():
+ """
+ Start the WROLPi upgrade process.
+
+ This executes /opt/wrolpi/upgrade.sh in a detached subprocess so it survives
+ the API shutdown (upgrade.sh stops the API service).
+
+ The upgrade will use the currently checked out branch (e.g., 'release' or 'master').
+
+ The frontend should redirect to the maintenance page after calling this.
+ """
+ from wrolpi.events import Events
+
+ if DOCKERIZED:
+ logger.warning('Cannot start upgrade in Docker environment')
+ return
+
+ if not UPGRADE_SCRIPT.is_file():
+ logger.error(f'Upgrade script not found: {UPGRADE_SCRIPT}')
+ return
+
+ # Get current branch to upgrade from the same branch
+ branch = get_current_branch()
+ if not branch:
+ logger.error('Could not determine current branch, defaulting to release')
+ branch = 'release'
+
+ logger.warning(f'Starting WROLPi upgrade on branch: {branch}')
+
+ # Send event to notify frontend
+ Events.send_upgrade_started(f'WROLPi upgrade is starting on branch {branch}. Please wait...')
+
+ # Write branch to environment file for the systemd service to read.
+ env_file = pathlib.Path('/tmp/wrolpi-upgrade.env')
+ env_file.write_text(f'BRANCH={branch}\n')
+
+ # Use systemd to run the upgrade service. This ensures the upgrade process
+ # survives when the API is stopped, as systemd manages it independently.
+ subprocess.Popen(
+ ['sudo', 'systemctl', 'start', 'wrolpi-upgrade.service'],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ stdin=subprocess.DEVNULL,
+ )
+
+ logger.warning(f'Upgrade process started on branch {branch}, API will be stopped shortly...')