diff --git a/vmtest/vm.py b/vmtest/vm.py index e612fa93f..ec8e0e659 100644 --- a/vmtest/vm.py +++ b/vmtest/vm.py @@ -70,6 +70,14 @@ def _get_ssh_transport(self, user, password="", keyfile=None): continue if pkey is None: raise RuntimeError(f"cannot load {keyfile}, tried {paramiko.key_classes}") + try: + self._connect(client, user, password, pkey) + except: + # do one retry on connect failure + time.sleep(0.5) + self._connect(client, user, password, pkey) + + def _connect(self, client, user, password, pkey): client.connect( self._address, self._ssh_port, user, password, pkey=pkey, @@ -96,7 +104,8 @@ def run(self, cmd, user, password="", keyfile=None): return exit_status, output.getvalue() def scp(self, src, dst, user, password="", keyfile=None): - with SCPClient(self._get_ssh_transport(user, password, keyfile)) as scp: + trans = self._get_ssh_transport(user, password, keyfile) + with SCPClient(trans) as scp: scp.put(src, dst) @abc.abstractmethod