diff --git a/Cargo.toml b/Cargo.toml index 3135f61..40b675e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,12 +25,12 @@ base64 = { version = "0.21.0", default-features = false } rand_core = { version = "0.6", default-features = true } log = { version = "0.4", optional = true } defmt = { version = "0.3", optional = true } -embedded-tls = { version = "0.17", default-features = false, optional = true } +embedded-tls = { git = "https://github.com/Frostie314159/embedded-tls.git", default-features = false, optional = true } rand_chacha = { version = "0.3", default-features = false } -nourl = "0.1.2" esp-mbedtls = { version = "0.1", git = "https://github.com/esp-rs/esp-mbedtls.git", features = [ "async", ], optional = true } +nourl = "0.1.4" [dev-dependencies] hyper = { version = "0.14.23", features = ["full"] } diff --git a/src/client.rs b/src/client.rs index 4361406..047e3fc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -113,25 +113,25 @@ where &'conn mut self, url: &Url<'_>, ) -> Result>, Error> { - let host = url.host(); - let port = url.port_or_default(); - - let remote = self - .dns - .get_host_by_name(host, embedded_nal_async::AddrType::Either) - .await - .map_err(|_| Error::Dns)?; + // If the host is an IP, we skip the DNS lookup. + let socket_address = if let Some(socket_address) = url.host_socket_address() { + socket_address + } else { + SocketAddr::new( + self.dns + .get_host_by_name(url.host(), embedded_nal_async::AddrType::Either) + .await + .map_err(|_| Error::Dns)?, + url.port_or_default(), + ) + }; - let conn = self - .client - .connect(SocketAddr::new(remote, port)) - .await - .map_err(|e| e.kind())?; + let conn = self.client.connect(socket_address).await.map_err(|e| e.kind())?; if url.scheme() == UrlScheme::HTTPS { #[cfg(feature = "esp-mbedtls")] if let Some(tls) = self.tls.as_mut() { - let mut servername = host.as_bytes().to_vec(); + let mut servername = url.host().as_bytes().to_vec(); servername.push(0); let mut session = esp_mbedtls::asynch::Session::new( conn, @@ -162,7 +162,7 @@ where } let mut conn: embedded_tls::TlsConnection<'conn, T::Connection<'conn>, embedded_tls::Aes128GcmSha256> = embedded_tls::TlsConnection::new(conn, tls.read_buffer, tls.write_buffer); - conn.open::<_, embedded_tls::NoVerify>(TlsContext::new(&config, &mut rng)) + conn.open(TlsContext::new(&config, embedded_tls::UnsecureProvider::new(rng))) .await?; Ok(HttpConnection::Tls(conn)) } else { @@ -724,7 +724,7 @@ mod tests { let mut buffer = VecBuffer::default(); let mut conn = HttpConnection::Plain(&mut buffer); - static CHUNKS: [&'static [u8]; 2] = [b"PART1", b"PART2"]; + static CHUNKS: [&[u8]; 2] = [b"PART1", b"PART2"]; let request = Request::new(Method::POST, "/").body(ChunkedBody(&CHUNKS)).build(); conn.write_request(&request).await.unwrap(); @@ -740,7 +740,7 @@ mod tests { let mut tx_buf = [0; 1024]; let mut conn = HttpConnection::Plain(&mut buffer).into_buffered(&mut tx_buf); - static CHUNKS: [&'static [u8]; 2] = [b"PART1", b"PART2"]; + static CHUNKS: [&[u8]; 2] = [b"PART1", b"PART2"]; let request = Request::new(Method::POST, "/").body(ChunkedBody(&CHUNKS)).build(); conn.write_request(&request).await.unwrap(); diff --git a/src/fmt.rs b/src/fmt.rs index 0669708..3b8e8d5 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -201,6 +201,7 @@ pub struct NoneError; pub trait Try { type Ok; type Error; + #[allow(dead_code)] fn into_result(self) -> Result; } diff --git a/tests/client.rs b/tests/client.rs index e8f6c74..01801d4 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -8,7 +8,7 @@ use reqwless::client::HttpClient; use reqwless::headers::ContentType; use reqwless::request::{Method, RequestBody, RequestBuilder}; use reqwless::response::Status; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::Once; use tokio::net::TcpListener; use tokio::sync::oneshot; @@ -31,11 +31,7 @@ static TCP: TokioTcp = TokioTcp; static LOOPBACK_DNS: LoopbackDns = LoopbackDns; static PUBLIC_DNS: StdDns = StdDns; -#[tokio::test] -async fn test_request_response_notls() { - setup(); - let addr = ([127, 0, 0, 1], 0).into(); - +async fn request_response_notls(addr: SocketAddr) { let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) }); let server = Server::bind(&addr).serve(service); @@ -49,7 +45,7 @@ async fn test_request_response_notls() { } }); - let url = format!("http://127.0.0.1:{}", addr.port()); + let url = format!("http://{addr}"); let mut client = HttpClient::new(&TCP, &LOOPBACK_DNS); let mut rx_buf = [0; 4096]; for _ in 0..2 { @@ -68,6 +64,13 @@ async fn test_request_response_notls() { t.await.unwrap(); } +#[tokio::test] +async fn test_request_response_notls() { + setup(); + request_response_notls(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))).await; + request_response_notls(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 0, 0, 0))).await; +} + #[tokio::test] async fn test_resource_notls() { setup();