Skip to content

Commit 6b66da0

Browse files
author
Andres D. Molins
committed
Feature: To be able to clean existing issues on firewall rules, create an authenticated endpoint that allows to clean completely the network setting and re-create it from the existing running VMs.
1 parent 5e94bdb commit 6b66da0

File tree

3 files changed

+256
-0
lines changed

3 files changed

+256
-0
lines changed

src/aleph/vm/network/firewall.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,127 @@ def check_nftables_redirections(port: int) -> bool:
762762
except Exception as e:
763763
logger.warning(f"Error checking NAT redirections: {e}")
764764
return False
765+
766+
767+
def get_all_aleph_chains() -> list[str]:
768+
"""Query nftables ruleset and return all chains created by aleph software.
769+
770+
This function scans the entire nftables ruleset and identifies all chains
771+
whose names start with the configured NFTABLES_CHAIN_PREFIX. This includes
772+
both supervisor chains (e.g., aleph-supervisor-nat, aleph-supervisor-filter,
773+
aleph-supervisor-prerouting) and VM-specific chains (e.g., aleph-vm-nat-123,
774+
aleph-vm-filter-123).
775+
776+
Returns:
777+
A list of chain names that belong to aleph software
778+
779+
Raises:
780+
Exception: If the nftables query fails
781+
"""
782+
logger.debug("Querying nftables for all aleph-related chains")
783+
nft_ruleset = get_existing_nftables_ruleset()
784+
aleph_chains = []
785+
786+
for entry in nft_ruleset:
787+
if isinstance(entry, dict) and "chain" in entry:
788+
chain_name = entry["chain"].get("name", "")
789+
# Find all chains created by aleph software
790+
if chain_name.startswith(settings.NFTABLES_CHAIN_PREFIX):
791+
aleph_chains.append(chain_name)
792+
logger.debug(f"Found aleph chain: {chain_name}")
793+
794+
logger.info(f"Found {len(aleph_chains)} aleph-related chains")
795+
return aleph_chains
796+
797+
798+
def remove_all_aleph_chains() -> tuple[list[str], list[tuple[str, str]]]:
799+
"""Remove all chains created by aleph software from the nftables ruleset.
800+
801+
This function queries the nftables ruleset to find all chains that start with
802+
the configured NFTABLES_CHAIN_PREFIX, then attempts to remove each one. This
803+
ensures a clean slate by removing both tracked and untracked chains that may
804+
have been left behind due to software crashes or inconsistent state.
805+
806+
The function uses the remove_chain() helper which handles:
807+
- Removing all rules that jump to the chain
808+
- Removing the chain itself
809+
810+
Returns:
811+
A tuple containing:
812+
- List of successfully removed chain names
813+
- List of tuples (chain_name, error_message) for failed removals
814+
815+
Example:
816+
removed, failed = remove_all_aleph_chains()
817+
if failed:
818+
logger.warning(f"Failed to remove {len(failed)} chains")
819+
"""
820+
logger.info("Removing all aleph-related chains from nftables")
821+
aleph_chains = get_all_aleph_chains()
822+
823+
removed_chains = []
824+
failed_chains = []
825+
826+
for chain_name in aleph_chains:
827+
try:
828+
remove_chain(chain_name)
829+
removed_chains.append(chain_name)
830+
logger.debug(f"Successfully removed chain: {chain_name}")
831+
except Exception as e:
832+
error_msg = str(e)
833+
failed_chains.append((chain_name, error_msg))
834+
logger.warning(f"Failed to remove chain {chain_name}: {error_msg}")
835+
836+
logger.info(f"Chain removal complete. Removed: {len(removed_chains)}, Failed: {len(failed_chains)}")
837+
return removed_chains, failed_chains
838+
839+
840+
def recreate_network_for_vms(vm_configurations: list[dict]) -> tuple[list[str], list[dict]]:
841+
"""Recreate network rules for a list of VMs.
842+
843+
This function sets up nftables chains and rules for each VM in the provided list.
844+
For each VM, it creates:
845+
- NAT chain and masquerading rules for outbound traffic
846+
- Filter chain and forwarding rules for traffic control
847+
- Port forwarding rules if the VM is an instance (handled by caller)
848+
849+
Args:
850+
vm_configurations: List of dictionaries, each containing:
851+
- vm_id: Integer ID of the VM
852+
- tap_interface: TapInterface object for the VM
853+
- vm_hash: ItemHash of the VM (for logging)
854+
855+
Returns:
856+
A tuple containing:
857+
- List of successfully recreated VM hashes (as strings)
858+
- List of dictionaries with failed VMs:
859+
[{"vm_hash": str, "error": str}, ...]
860+
861+
Example:
862+
vms = [
863+
{"vm_id": 1, "tap_interface": tap1, "vm_hash": hash1},
864+
{"vm_id": 2, "tap_interface": tap2, "vm_hash": hash2},
865+
]
866+
recreated, failed = recreate_network_for_vms(vms)
867+
"""
868+
logger.info(f"Recreating network rules for {len(vm_configurations)} VMs")
869+
recreated_vms = []
870+
failed_vms = []
871+
872+
for vm_config in vm_configurations:
873+
vm_id = vm_config["vm_id"]
874+
tap_interface = vm_config["tap_interface"]
875+
vm_hash = vm_config["vm_hash"]
876+
877+
try:
878+
# Recreate the basic VM network chains and rules
879+
setup_nftables_for_vm(vm_id, tap_interface)
880+
recreated_vms.append(str(vm_hash))
881+
logger.debug(f"Recreated nftables for VM {vm_hash} (vm_id={vm_id})")
882+
except Exception as e:
883+
error_msg = str(e)
884+
failed_vms.append({"vm_hash": str(vm_hash), "error": error_msg})
885+
logger.error(f"Failed to recreate network for VM {vm_hash}: {error_msg}")
886+
887+
logger.info(f"VM network recreation complete. Success: {len(recreated_vms)}, Failed: {len(failed_vms)}")
888+
return recreated_vms, failed_vms

src/aleph/vm/orchestrator/supervisor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
notify_allocation,
3939
operate_reserve_resources,
4040
operate_update,
41+
recreate_network,
4142
run_code_from_hostname,
4243
run_code_from_path,
4344
status_check_fastapi,
@@ -164,6 +165,7 @@ def setup_webapp(pool: VmPool | None):
164165
other_routes = [
165166
# /control APIs are used to control the VMs and access their logs
166167
web.post("/control/allocations", update_allocations),
168+
web.post("/control/network/recreate", recreate_network),
167169
# Raise an HTTP Error 404 if attempting to access an unknown URL within these paths.
168170
web.get("/about/{suffix:.*}", http_not_found),
169171
web.get("/control/{suffix:.*}", http_not_found),

src/aleph/vm/orchestrator/views/__init__.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
from aleph.vm.controllers.firecracker.program import FileTooLargeError
2828
from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError
2929
from aleph.vm.models import VmExecution
30+
from aleph.vm.network.firewall import (
31+
initialize_nftables,
32+
recreate_network_for_vms,
33+
remove_all_aleph_chains,
34+
)
3035
from aleph.vm.orchestrator import payment, status
3136
from aleph.vm.orchestrator.chain import STREAM_CHAINS
3237
from aleph.vm.orchestrator.custom_logs import set_vm_for_logging
@@ -429,6 +434,7 @@ def authenticate_api_request(request: web.Request) -> bool:
429434

430435

431436
allocation_lock = None
437+
network_recreation_lock = None
432438

433439

434440
async def update_allocations(request: web.Request):
@@ -547,6 +553,130 @@ async def update_allocations(request: web.Request):
547553
)
548554

549555

556+
async def recreate_network(request: web.Request):
557+
"""Recreate network settings for the CRN and all running VMs.
558+
559+
This endpoint performs a complete network reconfiguration by:
560+
1. Querying the nftables ruleset to find all aleph-related chains
561+
2. Removing ALL chains created by aleph software (both tracked and untracked)
562+
including VM-specific chains and supervisor chains
563+
3. Re-initializing the base network setup with nftables (creating fresh
564+
supervisor chains: aleph-supervisor-nat, aleph-supervisor-filter,
565+
aleph-supervisor-prerouting)
566+
4. Recreating VM-specific chains and rules for each currently running VM
567+
5. Restoring port forwarding rules for all running instances
568+
569+
This method is designed to handle cases where:
570+
- Network rules have become duplicated or inconsistent
571+
- Chains exist on the host that are no longer tracked by the software
572+
- The firewall state needs to be reset to match the current VM pool
573+
574+
The operation is atomic and uses a lock to prevent concurrent modifications.
575+
576+
Returns:
577+
JSON response with:
578+
- success: Boolean indicating if all VMs were successfully recreated
579+
- removed_chains_count: Number of chains that were removed
580+
- removed_chains: List of chain names that were removed
581+
- recreated_count: Number of VMs that were successfully recreated
582+
- failed_count: Number of VMs that failed to recreate
583+
- recreated_vms: List of VM hashes that were recreated
584+
- failed_vms: List of VM hashes and errors for failed recreations
585+
"""
586+
if not authenticate_api_request(request):
587+
return web.HTTPUnauthorized(text="Authentication token received is invalid")
588+
589+
global network_recreation_lock
590+
if network_recreation_lock is None:
591+
network_recreation_lock = asyncio.Lock()
592+
593+
pool: VmPool = request.app["vm_pool"]
594+
595+
async with network_recreation_lock:
596+
logger.info("Starting network recreation process")
597+
598+
# Step 1: Collect all running VMs and their network configuration
599+
running_vms = []
600+
for vm_hash, execution in pool.executions.items():
601+
if execution.is_running and execution.vm and execution.vm.tap_interface:
602+
running_vms.append(
603+
{
604+
"vm_hash": vm_hash,
605+
"vm_id": execution.vm.vm_id,
606+
"tap_interface": execution.vm.tap_interface,
607+
"execution": execution,
608+
}
609+
)
610+
logger.debug(f"Found running VM {vm_hash} with vm_id={execution.vm.vm_id}")
611+
612+
logger.info(f"Found {len(running_vms)} running VMs to recreate network rules for")
613+
614+
# Step 2: Remove all aleph-related chains (VM-specific and supervisor chains)
615+
try:
616+
removed_chains, failed_removals = remove_all_aleph_chains()
617+
if failed_removals:
618+
logger.warning(f"Failed to remove {len(failed_removals)} chains")
619+
for chain_name, error in failed_removals:
620+
logger.warning(f" - {chain_name}: {error}")
621+
except Exception as e:
622+
logger.error(f"Error removing aleph chains: {e}")
623+
return web.json_response(
624+
{"success": False, "error": f"Failed to remove existing chains: {str(e)}"},
625+
status=500,
626+
)
627+
628+
# Step 3: Re-initialize the base network setup
629+
logger.info("Re-initializing nftables")
630+
try:
631+
initialize_nftables()
632+
except Exception as e:
633+
logger.error(f"Error initializing nftables: {e}")
634+
return web.json_response(
635+
{"success": False, "error": f"Failed to initialize network: {str(e)}"},
636+
status=500,
637+
)
638+
639+
# Step 4: Recreate VM-specific chains and rules
640+
try:
641+
recreated_vms, failed_vms = recreate_network_for_vms(running_vms)
642+
except Exception as e:
643+
logger.error(f"Error recreating VM networks: {e}")
644+
return web.json_response(
645+
{"success": False, "error": f"Failed to recreate VM networks: {str(e)}"},
646+
status=500,
647+
)
648+
649+
# Step 5: Recreate port forwarding rules for instances
650+
logger.info("Recreating port forwarding rules for instances")
651+
for vm_info in running_vms:
652+
execution = vm_info["execution"]
653+
if execution.is_instance and str(vm_info["vm_hash"]) in recreated_vms:
654+
try:
655+
await execution.fetch_port_redirect_config_and_setup()
656+
logger.debug(f"Recreated port redirects for instance {vm_info['vm_hash']}")
657+
except Exception as e:
658+
logger.error(f"Error recreating port redirects for VM {vm_info['vm_hash']}: {e}")
659+
# Don't add to failed_vms as the VM network itself was created successfully
660+
661+
logger.info(
662+
f"Network recreation complete. Removed chains: {len(removed_chains)}, "
663+
f"Recreated VMs: {len(recreated_vms)}, Failed: {len(failed_vms)}"
664+
)
665+
666+
return web.json_response(
667+
{
668+
"success": len(failed_vms) == 0,
669+
"removed_chains_count": len(removed_chains),
670+
"removed_chains": removed_chains,
671+
"recreated_count": len(recreated_vms),
672+
"failed_count": len(failed_vms),
673+
"recreated_vms": recreated_vms,
674+
"failed_vms": failed_vms,
675+
},
676+
status=200 if len(failed_vms) == 0 else 207,
677+
)
678+
679+
550680
@cors_allow_all
551681
async def notify_allocation(request: web.Request):
552682
"""Notify instance allocation, only used for Pay as you Go feature"""

0 commit comments

Comments
 (0)