diff --git a/src/lib.rs b/src/lib.rs index faa7ce2..486b59a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,11 @@ use serde::{Deserialize, Serialize}; use std::error::Error; use std::fmt; -const URL_COMPLETION: &str = "https://api.openai.com/v1/chat/completions"; +const DEFAULT_URL_COMPLETION: &str = "https://api.openai.com/v1/chat/completions"; + +fn default_url_completion() -> String { + DEFAULT_URL_COMPLETION.to_string() +} #[derive(Debug, Deserialize, Serialize)] struct Config { @@ -13,6 +17,8 @@ struct Config { struct OpenAI { model: String, access_key: String, + #[serde(default = "default_url_completion")] + url_completion: String, } #[derive(Debug, Deserialize, Serialize)] @@ -117,7 +123,7 @@ impl Gptcli { let client = reqwest::blocking::Client::new(); let res = client - .post(URL_COMPLETION) + .post(&self.config.openai.url_completion) .header( "Authorization", format!("Bearer {}", self.config.openai.access_key), @@ -138,3 +144,31 @@ impl Gptcli { Ok(res) } } + +#[cfg(test)] +mod tests { + use super::{Config, default_url_completion}; + + #[test] + fn parse_config_with_custom_url() { + let toml_str = r#" + [openai] + model = "gpt-3.5-turbo" + access_key = "key" + url_completion = "http://example.com" + "#; + let config: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(config.openai.url_completion, "http://example.com"); + } + + #[test] + fn parse_config_without_url_uses_default() { + let toml_str = r#" + [openai] + model = "gpt-3.5-turbo" + access_key = "key" + "#; + let config: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(config.openai.url_completion, default_url_completion()); + } +} diff --git a/src/main.rs b/src/main.rs index 05a2cdc..9ab2060 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,11 @@ use std::io::Write; fn get_config_path() -> String { let homepath = home_dir().unwrap(); - format!("{}/.gptcli.toml", homepath.as_path().to_str().unwrap()) + homepath + .join(".gptcli.toml") + .to_str() + .unwrap() + .to_string() } fn main() { @@ -35,3 +39,26 @@ fn main() { println!("\n{}{}", message, choice.message.content); } } + +#[cfg(test)] +mod tests { + use super::get_config_path; + use std::env; + use std::path::PathBuf; + + #[test] + fn config_path_uses_home_directory() { + let original = env::var("HOME").ok(); + let temp_home = env::temp_dir().join("gptcli_test_home"); + env::set_var("HOME", &temp_home); + + let path = get_config_path(); + assert_eq!(PathBuf::from(path), temp_home.join(".gptcli.toml")); + + if let Some(val) = original { + env::set_var("HOME", val); + } else { + env::remove_var("HOME"); + } + } +}