diff --git a/Cargo.toml b/Cargo.toml index 06cec64..eceadab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.105" async-std = { version = "1.12.0", features = ["attributes", "tokio1"] } async-recursion = "1.0.4" +os_info = "3.7.0" [dependencies.reqwest] version = "0.11.20" diff --git a/README.md b/README.md index f2491d8..7af1b25 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ async fn main() -> Result<(), Box> { None, ); - let mut client = Client::new(base_url.clone(), Some(static_provider)); + let mut client = Client::new(base_url.clone(), Some(&static_provider), None, None).unwrap(); let bucket_name = "asiatrip"; diff --git a/src/s3/client.rs b/src/s3/client.rs index 8ef5ac4..0544c7c 100644 --- a/src/s3/client.rs +++ b/src/s3/client.rs @@ -32,6 +32,7 @@ use async_recursion::async_recursion; use bytes::{Buf, Bytes}; use dashmap::DashMap; use hyper::http::Method; +use os_info; use reqwest::header::HeaderMap; use std::collections::{HashMap, VecDeque}; use std::fs::File; @@ -201,24 +202,48 @@ fn parse_list_objects_common_prefixes( #[derive(Clone, Debug, Default)] pub struct Client<'a> { + client: reqwest::Client, base_url: BaseUrl, provider: Option<&'a (dyn Provider + Send + Sync)>, - pub ssl_cert_file: String, - pub ignore_cert_check: bool, - pub user_agent: String, region_map: DashMap, } impl<'a> Client<'a> { - pub fn new(base_url: BaseUrl, provider: Option<&(dyn Provider + Send + Sync)>) -> Client { - Client { + pub fn new( + base_url: BaseUrl, + provider: Option<&(dyn Provider + Send + Sync)>, + ssl_cert_file: Option, + ignore_cert_check: Option, + ) -> Result { + let info = os_info::get(); + let user_agent = String::from("MinIO (") + + &info.os_type().to_string() + + "; " + + info.architecture().unwrap_or("unknown") + + ") minio-rs/" + + env!("CARGO_PKG_VERSION"); + + let mut builder = reqwest::Client::builder() + .no_gzip() + .user_agent(user_agent.to_string()); + if let Some(v) = ignore_cert_check { + builder = builder.danger_accept_invalid_certs(v); + } + if let Some(v) = ssl_cert_file { + let mut buf = Vec::new(); + File::open(v.to_string())?.read_to_end(&mut buf)?; + let cert = reqwest::Certificate::from_pem(&buf)?; + builder = builder.add_root_certificate(cert); + } + + let client = builder.build()?; + + Ok(Client { + client, base_url, provider, - ssl_cert_file: String::new(), - ignore_cert_check: false, - user_agent: String::new(), region_map: DashMap::new(), - } + }) } fn build_headers( @@ -452,20 +477,7 @@ impl<'a> Client<'a> { .build_url(&method, region, query_params, bucket_name, object_name)?; self.build_headers(headers, query_params, region, &url, &method, body); - let mut builder = reqwest::Client::builder().no_gzip(); - if self.ignore_cert_check { - builder = builder.danger_accept_invalid_certs(self.ignore_cert_check); - } - if !self.ssl_cert_file.is_empty() { - let mut buf = Vec::new(); - File::open(&self.ssl_cert_file)?.read_to_end(&mut buf)?; - let cert = reqwest::Certificate::from_pem(&buf)?; - builder = builder.add_root_certificate(cert); - } - - let client = builder.build()?; - - let mut req = client.request(method.clone(), url.to_string()); + let mut req = self.client.request(method.clone(), url.to_string()); for (key, values) in headers.iter_all() { for value in values { diff --git a/tests/tests.rs b/tests/tests.rs index 05f210d..9166dd4 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -77,8 +77,8 @@ struct ClientTest<'a> { base_url: BaseUrl, access_key: String, secret_key: String, - ignore_cert_check: bool, - ssl_cert_file: String, + ignore_cert_check: Option, + ssl_cert_file: Option, client: Client<'a>, test_bucket: String, } @@ -91,12 +91,16 @@ impl<'a> ClientTest<'_> { access_key: String, secret_key: String, static_provider: &'a StaticProvider, - ignore_cert_check: bool, - ssl_cert_file: String, + ignore_cert_check: Option, + ssl_cert_file: Option, ) -> ClientTest<'a> { - let mut client = Client::new(base_url.clone(), Some(static_provider)); - client.ignore_cert_check = ignore_cert_check; - client.ssl_cert_file = ssl_cert_file.to_string(); + let client = Client::new( + base_url.clone(), + Some(static_provider), + ssl_cert_file.as_ref().cloned(), + ignore_cert_check, + ) + .unwrap(); ClientTest { base_url, @@ -533,9 +537,13 @@ impl<'a> ClientTest<'_> { let listen_task = move || async move { let static_provider = StaticProvider::new(&access_key, &secret_key, None); - let mut client = Client::new(base_url, Some(&static_provider)); - client.ignore_cert_check = ignore_cert_check; - client.ssl_cert_file = ssl_cert_file; + let client = Client::new( + base_url, + Some(&static_provider), + ssl_cert_file, + ignore_cert_check, + ) + .unwrap(); let event_fn = |event: NotificationRecords| { for record in event.records.iter() { @@ -1135,7 +1143,11 @@ async fn s3_tests() -> Result<(), Box> { let access_key = std::env::var("ACCESS_KEY")?; let secret_key = std::env::var("SECRET_KEY")?; let secure = std::env::var("ENABLE_HTTPS").is_ok(); - let ssl_cert_file = std::env::var("SSL_CERT_FILE")?; + let value = std::env::var("SSL_CERT_FILE")?; + let mut ssl_cert_file = None; + if !value.is_empty() { + ssl_cert_file = Some(value); + } let ignore_cert_check = std::env::var("IGNORE_CERT_CHECK").is_ok(); let region = std::env::var("SERVER_REGION").ok(); @@ -1151,7 +1163,7 @@ async fn s3_tests() -> Result<(), Box> { access_key, secret_key, &static_provider, - ignore_cert_check, + Some(ignore_cert_check), ssl_cert_file, ); ctest.init().await;