Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def main(
max=1.0,
help="Limits highest probable tokens (words).",
),
use_litellm: str = typer.Option(
cfg.get("USE_LITELLM"),
"--use-litellm",
help="Use LiteLLM for completions.",
),
md: bool = typer.Option(
cfg.get("PRETTIFY_MARKDOWN") == "true",
help="Prettify markdown output.",
Expand Down Expand Up @@ -158,6 +163,8 @@ def main(
) -> None:
stdin_passed = not sys.stdin.isatty()

cfg["USE_LITELLM"] = use_litellm

if stdin_passed:
stdin = ""
# TODO: This is very hacky.
Expand Down
14 changes: 11 additions & 3 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@
class Config(dict): # type: ignore
def __init__(self, config_path: Path, **defaults: Any):
self.config_path = config_path

if self._exists:
self._read()
# If OPENAI_API_KEY is missing and LiteLLM is not used, prompt for it.
if "OPENAI_API_KEY" not in self and defaults["USE_LITELLM"] == "false":
__api_key = getpass(prompt="Please enter your OpenAI API key: ")
self["OPENAI_API_KEY"] = __api_key
self._write()
has_new_config = False
for key, value in defaults.items():
if key not in self:
Expand All @@ -56,8 +60,12 @@ def __init__(self, config_path: Path, **defaults: Any):
self._write()
else:
config_path.parent.mkdir(parents=True, exist_ok=True)
# Don't write API key to config file if it is in the environment.
if not defaults.get("OPENAI_API_KEY") and not os.getenv("OPENAI_API_KEY"):
# Don't write API key to config file if it is in the environment or if LiteLLM is true.
if (
not defaults.get("OPENAI_API_KEY")
and not os.getenv("OPENAI_API_KEY")
and not defaults["USE_LITELLM"] == "true"
):
__api_key = getpass(prompt="Please enter your OpenAI API key: ")
defaults["OPENAI_API_KEY"] = __api_key
super().__init__(**defaults)
Expand Down
19 changes: 12 additions & 7 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@

base_url = cfg.get("API_BASE_URL")
use_litellm = cfg.get("USE_LITELLM") == "true"
additional_kwargs = {
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
"api_key": cfg.get("OPENAI_API_KEY"),
"base_url": None if base_url == "default" else base_url,
}


if use_litellm:
import litellm # type: ignore
additional_kwargs = {
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
"api_key": "",
"base_url": None if base_url == "default" else base_url,
}
import litellm # type: ignore

completion = litellm.completion
litellm.suppress_debug_info = True
additional_kwargs.pop("api_key")
else:
additional_kwargs = {
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
"api_key": cfg.get("OPENAI_API_KEY"),
"base_url": None if base_url == "default" else base_url,
}
from openai import OpenAI

client = OpenAI(**additional_kwargs) # type: ignore
Expand Down