Skip to content

Commit b52ba0e

Browse files
committed
refactor(grpc-server): move test auth logic to a single function
1 parent a186c57 commit b52ba0e

File tree

3 files changed

+34
-32
lines changed

3 files changed

+34
-32
lines changed

crates/grpc_server/src/constants.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ pub const ENV_LOGGING: &str = "GRPC_SERVER_LOGGING";
1212
pub const DEFAULT_PORT: &str = "8080";
1313

1414
pub const SIMILARITY_PROTOBUFF_MAP: [Similarity; 4] = [Euclidean, Manhattan, Hamming, Cosine];
15+
pub const AUTHORIZATION_HEADER_KEY: &str = "authorization";

crates/grpc_server/src/interceptors.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use tonic::{Status, service::Interceptor};
22
use tracing::{Level, event};
33

4+
use crate::constants::AUTHORIZATION_HEADER_KEY;
5+
46
#[derive(Clone)]
57
pub struct AuthInterceptor {
68
root_password: String,
79
}
810

911
impl Interceptor for AuthInterceptor {
1012
fn call(&mut self, req: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
11-
let auth_token = match req.metadata().get("authorization") {
13+
let auth_token = match req.metadata().get(AUTHORIZATION_HEADER_KEY) {
1214
Some(t) => t,
1315
None => return Err(Status::unauthenticated("Invalid credentials")),
1416
};

crates/grpc_server/src/tests.rs

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::config::GRPCServerConfig;
2+
use crate::constants::AUTHORIZATION_HEADER_KEY;
23
use crate::service::vectordb::vector_db_client::VectorDbClient;
34
use crate::service::vectordb::{DenseVector, InsertVectorRequest, PointId, SearchRequest};
45
use crate::service::{VectorDBService, run_server};
@@ -15,6 +16,15 @@ use tonic::transport::Channel;
1516

1617
// Inspired from https://github.com/hyperium/tonic/discussions/924#discussioncomment-9854088
1718

19+
const TEST_AUTH_BEARER_TOKEN: &str = "123";
20+
21+
fn append_test_auth_header<T>(request: &mut tonic::Request<T>, token: &str) {
22+
let auth_value = format!("Bearer {}", token);
23+
request
24+
.metadata_mut()
25+
.insert(AUTHORIZATION_HEADER_KEY, auth_value.parse().unwrap());
26+
}
27+
1828
async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error>> {
1929
// using a temporary directory for db datapath
2030
let temp_dir = tempdir().unwrap();
@@ -28,7 +38,7 @@ async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error>> {
2838

2939
let config = GRPCServerConfig {
3040
addr: "127.0.0.1:0".parse()?,
31-
root_password: "123".to_string(),
41+
root_password: TEST_AUTH_BEARER_TOKEN.to_string(),
3242
logging: false,
3343
db_config,
3444
};
@@ -80,9 +90,8 @@ async fn test_grpc_server_start() {
8090
}),
8191
payload: Some(Struct::default()),
8292
});
83-
request
84-
.metadata_mut()
85-
.insert("authorization", "Bearer 123".parse().unwrap());
93+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
94+
8695
let status = client.insert_vector(request).await.is_ok();
8796
assert!(status);
8897
}
@@ -101,9 +110,8 @@ async fn test_insert_vector_rpc() {
101110
}),
102111
payload: Some(Struct::default()),
103112
});
104-
request
105-
.metadata_mut()
106-
.insert("authorization", "Bearer 123".parse().unwrap());
113+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
114+
107115
let resp = client.insert_vector(request).await;
108116

109117
// check if request is successful
@@ -113,9 +121,7 @@ async fn test_insert_vector_rpc() {
113121
let mut request = tonic::Request::new(PointId {
114122
id: resp.unwrap().into_inner().id,
115123
});
116-
request
117-
.metadata_mut()
118-
.insert("authorization", "Bearer 123".parse().unwrap());
124+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
119125
let resp = client.get_point(request).await;
120126

121127
// check if request is successful
@@ -130,9 +136,8 @@ async fn test_insert_vector_rpc() {
130136
}),
131137
payload: Some(Struct::default()),
132138
});
133-
request
134-
.metadata_mut()
135-
.insert("authorization", "Bearer 123".parse().unwrap());
139+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
140+
136141
let resp = client.insert_vector(request).await;
137142

138143
// request must fail
@@ -152,9 +157,8 @@ async fn test_delete_vector_rpc() {
152157
}),
153158
payload: Some(Struct::default()),
154159
});
155-
request
156-
.metadata_mut()
157-
.insert("authorization", "Bearer 123".parse().unwrap());
160+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
161+
158162
let resp = client.insert_vector(request).await;
159163

160164
// check if request is successful
@@ -163,19 +167,17 @@ async fn test_delete_vector_rpc() {
163167

164168
// delete the vector
165169
let mut request = tonic::Request::new(PointId { id: point.id });
166-
request
167-
.metadata_mut()
168-
.insert("authorization", "Bearer 123".parse().unwrap());
170+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
171+
169172
let resp = client.delete_point(request).await;
170173

171174
// check if request is successful
172175
assert!(resp.is_ok());
173176

174177
// verify that the vector is deleted
175178
let mut request = tonic::Request::new(PointId { id: point.id });
176-
request
177-
.metadata_mut()
178-
.insert("authorization", "Bearer 123".parse().unwrap());
179+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
180+
179181
let resp = client.get_point(request).await;
180182

181183
// request must fail since the vector is deleted
@@ -195,9 +197,8 @@ async fn test_search_vector_rpc() {
195197
}),
196198
payload: Some(Struct::default()),
197199
});
198-
request
199-
.metadata_mut()
200-
.insert("authorization", "Bearer 123".parse().unwrap());
200+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
201+
201202
let resp = client.insert_vector(request).await;
202203

203204
// check if request is successful
@@ -214,9 +215,8 @@ async fn test_search_vector_rpc() {
214215
similarity: 0, // euclidean distance
215216
limit: 1,
216217
});
217-
request
218-
.metadata_mut()
219-
.insert("authorization", "Bearer 123".parse().unwrap());
218+
append_test_auth_header(&mut request, TEST_AUTH_BEARER_TOKEN);
219+
220220
let resp = client.search_points(request).await;
221221

222222
// check if request is successful
@@ -243,9 +243,8 @@ async fn test_unauthorized_rpc() {
243243
}),
244244
payload: Some(Struct::default()),
245245
});
246-
request
247-
.metadata_mut()
248-
.insert("authorization", "Bearer 43121".parse().unwrap()); // wrong auth token
246+
247+
append_test_auth_header(&mut request, "43121");
249248
let resp = client.insert_vector(request).await;
250249

251250
// request must fail

0 commit comments

Comments
 (0)