diff --git a/colabcode/code.py b/colabcode/code.py index 6432b83..ff9a8e3 100644 --- a/colabcode/code.py +++ b/colabcode/code.py @@ -4,7 +4,6 @@ try: from google.colab import drive - colab_env = True except ImportError: colab_env = False @@ -14,11 +13,17 @@ class ColabCode: - def __init__(self, port=10000, password=None, mount_drive=False): + def __init__(self, port=10000, password=None, mount_drive=False, add_extensions=None): self.port = port self.password = password self._mount = mount_drive self._install_code() + self.extensions = EXTENSIONS + if add_extensions is not None and add_extensions != []: + if isinstance(add_extensions, list) and isinstance(add_extensions[0], str): + self.extensions += add_extensions + else: + raise TypeError("You need to pass a list of string(s) e.g. ['ms-python.python']") self._install_extensions() self._start_server() self._run_code() @@ -30,7 +35,7 @@ def _install_code(self): subprocess.run(["sh", "install.sh"], stdout=subprocess.PIPE) def _install_extensions(self): - for ext in EXTENSIONS: + for ext in self.extensions: subprocess.run(["code-server", "--install-extension", f"{ext}"]) def _start_server(self):