Skip to content

Commit b365264

Browse files
authored
Merge pull request #145 from olehermanse/linting
Error handling improvements and fixing pyright errors
2 parents 888d2cc + 590a47e commit b365264

File tree

11 files changed

+247
-122
lines changed

11 files changed

+247
-122
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
flake8 . --ignore=E203,W503,E722,E731 --max-complexity=100 --max-line-length=160
3737
- name: Lint with pyright (type checking)
3838
run: |
39-
echo TODO - fix pyright errors # pyright cf_remote
39+
pyright cf_remote
4040
- name: Lint with pyflakes
4141
run: |
4242
pyflakes cf_remote

cf_remote/commands.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
write_json,
3232
whoami,
3333
get_package_name,
34-
user_error,
34+
CFRExitError,
3535
is_package_url,
3636
print_progress_dot,
37-
ChecksumError,
37+
CFRChecksumError,
38+
CFRUserError,
3839
)
3940
from cf_remote.spawn import VM, VMRequest, Providers, AWSCredentials, GCPCredentials
4041
from cf_remote.spawn import spawn_vms, destroy_vms, dump_vms_info, get_cloud_driver
@@ -71,6 +72,7 @@ def run(hosts, command, users=None, sudo=False, raw=False):
7172
continue
7273
cmd = command
7374
lines = lines.replace("\r", "")
75+
fill = ""
7476
for line in lines.split("\n"):
7577
if raw:
7678
print(line)
@@ -79,6 +81,7 @@ def run(hosts, command, users=None, sudo=False, raw=False):
7981
fill = " " * (len(cmd) + 7)
8082
cmd = None
8183
else:
84+
assert fill, "First iteration of loop should have set fill variable"
8285
print("{}{}'{}'".format(host_colon, fill, line))
8386
return errors
8487

@@ -123,7 +126,9 @@ def _download_urls(urls):
123126
paths.append(path)
124127

125128
if path in downloaded_paths and url not in downloaded_urls:
126-
user_error("2 packages with the same name '%s' from different URLs" % name)
129+
raise CFRExitError(
130+
"2 packages with the same name '%s' from different URLs" % name
131+
)
127132

128133
download_package(url, path)
129134
downloaded_urls.append(url)
@@ -143,7 +148,7 @@ def _verify_package_urls(urls):
143148
if is_package_url(package_url):
144149
verified_urls.append(package_url)
145150
else:
146-
user_error("Wrong package URL: {}".format(package_url))
151+
raise CFRExitError("Wrong package URL: {}".format(package_url))
147152

148153
return verified_urls
149154

@@ -184,7 +189,7 @@ def install(
184189
else:
185190
try:
186191
package, hub_package, client_package = _download_urls(packages)
187-
except ChecksumError as ce:
192+
except CFRChecksumError as ce:
188193
log.error(ce)
189194
return 1
190195

@@ -288,14 +293,15 @@ def install(
288293
def _iterate_over_packages(
289294
tags=None, version=None, edition=None, download=False, output_dir=None
290295
):
296+
assert edition in ["enterprise", "community", None]
291297
releases = Releases(edition)
292298
print("Available releases: {}".format(releases))
293299

294300
release_versions = [rel.version for rel in releases.releases]
295301
if version and version not in release_versions:
296-
user_error("CFEngine version '%s' doesn't exist (yet)." % version)
302+
raise CFRExitError("CFEngine version '%s' doesn't exist (yet)." % version)
297303

298-
if not version:
304+
if tags and not version:
299305
for tag in tags:
300306
if tag in release_versions:
301307
version = tag
@@ -305,10 +311,11 @@ def _iterate_over_packages(
305311
release = releases.default
306312
if version:
307313
release = releases.pick_version(version)
314+
if not release:
315+
raise CFRExitError("Failed to find a release for version '%s'" % version)
308316
print("Using {}:".format(release))
309317
log.debug("Looking for a release based on host tags: {}".format(tags))
310318
artifacts = release.find(tags)
311-
312319
if len(artifacts) == 0:
313320
print("No suitable packages found")
314321
else:
@@ -318,14 +325,14 @@ def _iterate_over_packages(
318325
package_path = download_package(
319326
artifact.url, checksum=artifact.checksum
320327
)
321-
except ChecksumError as ce:
328+
except CFRChecksumError as ce:
322329
log.error(ce)
323330
return 1
324331
if output_dir:
325332
output_dir = os.path.abspath(os.path.expanduser(output_dir))
326333
parent = os.path.dirname(output_dir)
327334
if not os.path.exists(parent):
328-
user_error(
335+
raise CFRExitError(
329336
"'{}' doesn't exist. Make sure this path is correct and exists.".format(
330337
parent
331338
)
@@ -373,16 +380,17 @@ def spawn(
373380
public_ip=True,
374381
extend_group=False,
375382
):
383+
creds_data = None
376384
if os.path.exists(CLOUD_CONFIG_FPATH):
377385
creds_data = read_json(CLOUD_CONFIG_FPATH)
378-
else:
379-
print("Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH)
380-
return 1
386+
if not creds_data:
387+
raise CFRUserError("Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH)
381388

389+
vms_info = None
382390
if os.path.exists(CLOUD_STATE_FPATH):
383391
vms_info = read_json(CLOUD_STATE_FPATH)
384-
else:
385-
vms_info = dict()
392+
if not vms_info:
393+
vms_info = {}
386394

387395
group_key = "@%s" % group_name
388396
group_exists = group_key in vms_info
@@ -523,6 +531,8 @@ def destroy(group_name=None):
523531
return 1
524532

525533
vms_info = read_json(CLOUD_STATE_FPATH)
534+
if not vms_info:
535+
raise CFRUserError("No saved VMs found in '{}'".format(CLOUD_STATE_FPATH))
526536

527537
to_destroy = []
528538
if group_name:
@@ -541,16 +551,23 @@ def destroy(group_name=None):
541551

542552
region = vms_info[group_name]["meta"]["region"]
543553
provider = vms_info[group_name]["meta"]["provider"]
554+
if provider not in ["aws", "gcp"]:
555+
raise CFRUserError(
556+
"Unsupported provider '{}' encountered in '{}', only aws / gcp is supported".format(
557+
provider, CLOUD_STATE_FPATH
558+
)
559+
)
560+
561+
driver = None
544562
if provider == "aws":
545563
if aws_creds is None:
546-
user_error("Missing/incomplete AWS credentials")
547-
return 1
564+
raise CFRExitError("Missing/incomplete AWS credentials")
548565
driver = get_cloud_driver(Providers.AWS, aws_creds, region)
549566
if provider == "gcp":
550567
if gcp_creds is None:
551-
user_error("Missing/incomplete GCP credentials")
552-
return 1
568+
raise CFRExitError("Missing/incomplete GCP credentials")
553569
driver = get_cloud_driver(Providers.GCP, gcp_creds, region)
570+
assert driver is not None
554571

555572
nodes = driver.list_nodes()
556573
for name, vm_info in vms_info[group_name].items():
@@ -572,16 +589,23 @@ def destroy(group_name=None):
572589

573590
region = vms_info[group_name]["meta"]["region"]
574591
provider = vms_info[group_name]["meta"]["provider"]
592+
if provider not in ["aws", "gcp"]:
593+
raise CFRUserError(
594+
"Unsupported provider '{}' encountered in '{}', only aws / gcp is supported".format(
595+
provider, CLOUD_STATE_FPATH
596+
)
597+
)
598+
599+
driver = None
575600
if provider == "aws":
576601
if aws_creds is None:
577-
user_error("Missing/incomplete AWS credentials")
578-
return 1
602+
raise CFRExitError("Missing/incomplete AWS credentials")
579603
driver = get_cloud_driver(Providers.AWS, aws_creds, region)
580604
if provider == "gcp":
581605
if gcp_creds is None:
582-
user_error("Missing/incomplete GCP credentials")
583-
return 1
606+
raise CFRExitError("Missing/incomplete GCP credentials")
584607
driver = get_cloud_driver(Providers.GCP, gcp_creds, region)
608+
assert driver is not None
585609

586610
nodes = driver.list_nodes()
587611
for name, vm_info in vms_info[group_name].items():
@@ -673,11 +697,15 @@ def save(name, hosts, role):
673697

674698

675699
def _ansible_inventory():
676-
if not os.path.exists(CLOUD_STATE_FPATH):
677-
print("No saved cloud state info")
678-
return 1
679700

680-
vms_info = read_json(CLOUD_STATE_FPATH)
701+
vms_info = None
702+
if os.path.exists(CLOUD_STATE_FPATH):
703+
vms_info = read_json(CLOUD_STATE_FPATH)
704+
705+
if not vms_info:
706+
raise CFRUserError(
707+
"No saved cloud state info in '{}'".format(CLOUD_STATE_FPATH)
708+
)
681709
all_lines = []
682710
hub_lines = []
683711
client_lines = []
@@ -851,7 +879,7 @@ def deploy(hubs, masterfiles):
851879
print("Found saved/spawned hubs: " + ", ".join(hubs))
852880

853881
if not hubs:
854-
user_error(
882+
raise CFRExitError(
855883
"No hub to deploy to (Specify with --hub or use spawn/save commands to add to cf-remote)"
856884
)
857885

@@ -866,7 +894,7 @@ def deploy(hubs, masterfiles):
866894
urls = [masterfiles]
867895
try:
868896
paths = _download_urls(urls)
869-
except ChecksumError as ce:
897+
except CFRChecksumError as ce:
870898
log.error(ce)
871899
return 1
872900
assert len(paths) == 1
@@ -876,7 +904,7 @@ def deploy(hubs, masterfiles):
876904
if not masterfiles:
877905
masterfiles = "."
878906
if not (os.path.isfile("promises.cf") or os.path.isfile("promises.cf.in")):
879-
user_error("No cfbs or masterfiles policy set found")
907+
raise CFRExitError("No cfbs or masterfiles policy set found")
880908

881909
masterfiles = os.path.abspath(os.path.expanduser(masterfiles))
882910
print("Found masterfiles policy set: '{}'".format(masterfiles))
@@ -937,23 +965,23 @@ def deploy(hubs, masterfiles):
937965

938966

939967
def agent(hosts, bootstrap=None):
940-
941-
if len(bootstrap) > 1:
942-
user_error(
968+
if bootstrap and len(bootstrap) > 1:
969+
raise CFRExitError(
943970
"Cannot boostrap {} to {}. Cannot bootstrap to more than one host.".format(
944971
hosts, bootstrap
945972
)
946973
)
947974

948-
hub_host = bootstrap[0]
949-
950975
for host in hosts:
951976
data = get_info(host)
952977

953978
if not data["agent_location"]:
954-
user_error("CFEngine not installed on {}".format(host))
979+
raise CFRExitError("CFEngine not installed on {}".format(host))
980+
981+
command = "{}".format(data["agent_location"])
982+
if bootstrap:
983+
command += "--bootstrap {}".format(bootstrap[0])
955984

956-
command = "{} --bootstrap {}".format(data["agent_location"], hub_host)
957985
output = run_command(host, command, sudo=True)
958986
if output:
959987
print(output)
@@ -965,7 +993,7 @@ def connect_cmd(hosts):
965993
assert hosts and len(hosts) >= 1 # Ensured by argument parser
966994

967995
if len(hosts) > 1:
968-
user_error("You can only connect to one host at a time")
996+
raise CFRExitError("You can only connect to one host at a time")
969997

970998
print("Opening a SSH command shell...")
971999
r = subprocess.run(["ssh", hosts[0]])

0 commit comments

Comments
 (0)