1use std::{env, fs::File, io::BufReader, path::PathBuf, time::Duration};
39
40use tonic::transport::{Channel, Endpoint};
41#[cfg(feature = "mtls")]
42use rustls::ClientConfig;
43#[cfg(feature = "mtls")]
44use rustls::RootCertStore;
45
46use crate::dev_log;
47
48pub const DEFAULT_MOUNTAIN_ADDRESS:&str = "[::1]:50051";
55
56pub const DEFAULT_CONNECTION_TIMEOUT_SECS:u64 = 5;
58
59pub const DEFAULT_REQUEST_TIMEOUT_SECS:u64 = 30;
61
62#[cfg(feature = "mtls")]
67#[derive(Debug, Clone)]
68pub struct TlsConfig {
69 pub ca_cert_path:Option<PathBuf>,
72
73 pub client_cert_path:Option<PathBuf>,
75
76 pub client_key_path:Option<PathBuf>,
78
79 pub server_name:Option<String>,
81
82 pub verify_certs:bool,
84}
85
86#[cfg(feature = "mtls")]
87impl Default for TlsConfig {
88 fn default() -> Self {
89 Self {
90 ca_cert_path:None,
91
92 client_cert_path:None,
93
94 client_key_path:None,
95
96 server_name:None,
97
98 verify_certs:true,
99 }
100 }
101}
102
103#[cfg(feature = "mtls")]
104impl TlsConfig {
105 pub fn server_auth(ca_cert_path:PathBuf) -> Self {
113 Self {
114 ca_cert_path:Some(ca_cert_path),
115
116 client_cert_path:None,
117
118 client_key_path:None,
119
120 server_name:Some("localhost".to_string()),
121
122 verify_certs:true,
123 }
124 }
125
126 pub fn mtls(ca_cert_path:PathBuf, client_cert_path:PathBuf, client_key_path:PathBuf) -> Self {
136 Self {
137 ca_cert_path:Some(ca_cert_path),
138
139 client_cert_path:Some(client_cert_path),
140
141 client_key_path:Some(client_key_path),
142
143 server_name:Some("localhost".to_string()),
144
145 verify_certs:true,
146 }
147 }
148}
149
150#[cfg(feature = "mtls")]
161pub fn create_tls_client_config(tls_config:&TlsConfig) -> Result<ClientConfig, Box<dyn std::error::Error>> {
162 dev_log!("grpc", "Creating TLS client configuration");
163
164 let mut root_store = RootCertStore::empty();
166
167 if let Some(ca_path) = &tls_config.ca_cert_path {
168 dev_log!("grpc", "Loading CA certificate from {:?}", ca_path);
170
171 let ca_file = File::open(ca_path).map_err(|e| format!("Failed to open CA certificate file: {}", e))?;
172
173 let mut reader = BufReader::new(ca_file);
174
175 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut reader).collect();
176
177 let certs = certs.map_err(|e| format!("Failed to parse CA certificate: {}", e))?;
178
179 if certs.is_empty() {
180 return Err("No CA certificates found in file".into());
181 }
182
183 for cert in certs {
184 root_store
185 .add(cert)
186 .map_err(|e| format!("Failed to add CA certificate to root store: {}", e))?;
187 }
188
189 dev_log!("grpc", "Loaded CA certificate from {:?}", ca_path);
190 } else {
191 dev_log!("grpc", "Loading system root certificates");
193
194 let cert_result = rustls_native_certs::load_native_certs();
195
196 if !cert_result.errors.is_empty() {
198 dev_log!(
199 "grpc",
200 "warn: Encountered errors loading system certificates: {:?}",
201 cert_result.errors
202 );
203 }
204
205 let native_certs = cert_result.certs;
206
207 if native_certs.is_empty() {
208 dev_log!("grpc", "warn: No system root certificates found");
209 }
210
211 for cert in native_certs {
212 root_store
213 .add(cert)
214 .map_err(|e| format!("Failed to add system certificate to root store: {}", e))?;
215 }
216
217 dev_log!("grpc", "Loaded {} system root certificates", root_store.len());
218 }
219
220 let client_certs = if tls_config.client_cert_path.is_some() && tls_config.client_key_path.is_some() {
222 let cert_path = tls_config.client_cert_path.as_ref().unwrap();
223
224 let key_path = tls_config.client_key_path.as_ref().unwrap();
225
226 dev_log!("grpc", "Loading client certificate from {:?}", cert_path);
227
228 let cert_file = File::open(cert_path).map_err(|e| format!("Failed to open client certificate file: {}", e))?;
229
230 let mut cert_reader = BufReader::new(cert_file);
231
232 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert_reader).collect();
233
234 let certs = certs.map_err(|e| format!("Failed to parse client certificate: {}", e))?;
235
236 if certs.is_empty() {
237 return Err("No client certificates found in file".into());
238 }
239
240 dev_log!("grpc", "Loading client private key from {:?}", key_path);
241
242 let key_file = File::open(key_path).map_err(|e| format!("Failed to open private key file: {}", e))?;
243
244 let mut key_reader = BufReader::new(key_file);
245
246 let key = rustls_pemfile::private_key(&mut key_reader)
247 .map_err(|e| format!("Failed to parse private key: {}", e))?
248 .ok_or("No private key found in file")?;
249
250 Some((certs, key))
251 } else {
252 None
253 };
254
255 let mut config = match client_certs {
257 Some((certs, key)) => {
258 let client_config = ClientConfig::builder()
260 .with_root_certificates(root_store)
261 .with_client_auth_cert(certs, key)
262 .map_err(|e| format!("Failed to configure client authentication: {}", e))?;
263
264 dev_log!("grpc", "Configured mTLS with client certificate");
265
266 client_config
267 },
268
269 None => {
270 let client_config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
273
274 dev_log!("grpc", "Configured TLS with server authentication only");
275
276 client_config
277 },
278 };
279
280 config.alpn_protocols = vec![b"h2".to_vec()];
282
283 if !tls_config.verify_certs {
288 dev_log!(
289 "grpc",
290 "warn: Certificate verification disabled - this is NOT secure for production!"
291 ); }
294
295 dev_log!("grpc", "TLS client configuration created successfully");
296
297 Ok(config)
298}
299
300#[derive(Debug, Clone)]
302pub struct MountainClientConfig {
303 pub address:String,
305
306 pub connection_timeout_secs:u64,
308
309 pub request_timeout_secs:u64,
311
312 #[cfg(feature = "mtls")]
314 pub tls_config:Option<TlsConfig>,
315}
316
317impl Default for MountainClientConfig {
318 fn default() -> Self {
319 Self {
320 address:DEFAULT_MOUNTAIN_ADDRESS.to_string(),
321
322 connection_timeout_secs:DEFAULT_CONNECTION_TIMEOUT_SECS,
323
324 request_timeout_secs:DEFAULT_REQUEST_TIMEOUT_SECS,
325
326 #[cfg(feature = "mtls")]
327 tls_config:None,
328 }
329 }
330}
331
332impl MountainClientConfig {
333 pub fn new(address:impl Into<String>) -> Self { Self { address:address.into(), ..Default::default() } }
341
342 pub fn from_env() -> Self {
362 let address = env::var("MOUNTAIN_ADDRESS").unwrap_or_else(|_| DEFAULT_MOUNTAIN_ADDRESS.to_string());
363
364 let connection_timeout_secs = env::var("MOUNTAIN_CONNECTION_TIMEOUT_SECS")
365 .ok()
366 .and_then(|s| s.parse().ok())
367 .unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS);
368
369 let request_timeout_secs = env::var("MOUNTAIN_REQUEST_TIMEOUT_SECS")
370 .ok()
371 .and_then(|s| s.parse().ok())
372 .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS);
373
374 #[cfg(feature = "mtls")]
375 let tls_config = if env::var("MOUNTAIN_TLS_ENABLED")
376 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
377 .unwrap_or(false)
378 {
379 Some(TlsConfig {
380 ca_cert_path:env::var("MOUNTAIN_CA_CERT").ok().map(PathBuf::from),
381 client_cert_path:env::var("MOUNTAIN_CLIENT_CERT").ok().map(PathBuf::from),
382 client_key_path:env::var("MOUNTAIN_CLIENT_KEY").ok().map(PathBuf::from),
383 server_name:env::var("MOUNTAIN_SERVER_NAME").ok(),
384 verify_certs:env::var("MOUNTAIN_VERIFY_CERTS")
385 .map(|v| v != "0" && !v.eq_ignore_ascii_case("false"))
386 .unwrap_or(true),
387 })
388 } else {
389 None
390 };
391
392 #[cfg(not(feature = "mtls"))]
393 let tls_config = None;
394
395 Self {
396 address,
397
398 connection_timeout_secs,
399
400 request_timeout_secs,
401
402 #[cfg(feature = "mtls")]
403 tls_config,
404 }
405 }
406
407 pub fn with_connection_timeout(mut self, timeout_secs:u64) -> Self {
415 self.connection_timeout_secs = timeout_secs;
416
417 self
418 }
419
420 pub fn with_request_timeout(mut self, timeout_secs:u64) -> Self {
428 self.request_timeout_secs = timeout_secs;
429
430 self
431 }
432
433 #[cfg(feature = "mtls")]
441 pub fn with_tls(mut self, tls_config:TlsConfig) -> Self {
442 self.tls_config = Some(tls_config);
443
444 self
445 }
446}
447
448#[derive(Debug, Clone)]
454pub struct MountainClient {
455 channel:Channel,
457
458 config:MountainClientConfig,
460}
461
462impl MountainClient {
463 pub async fn connect(config:MountainClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
474 dev_log!("grpc", "Connecting to Mountain at {}", config.address);
475
476 let endpoint = Endpoint::from_shared(config.address.clone())?
477 .connect_timeout(Duration::from_secs(config.connection_timeout_secs));
478
479 #[cfg(feature = "mtls")]
481 if let Some(tls_config) = &config.tls_config {
482 dev_log!("grpc", "TLS configuration provided, configuring secure connection");
483
484 let _client_config = create_tls_client_config(tls_config).map_err(|e| {
485 dev_log!("grpc", "error: Failed to create TLS client configuration: {}", e);
486 format!("TLS configuration error: {}", e)
487 })?;
488
489 let domain_name = tls_config.server_name.clone().unwrap_or_else(|| "localhost".to_string());
491
492 dev_log!("grpc", "Setting server name for SNI: {}", domain_name);
493
494 let tls = tonic::transport::ClientTlsConfig::new().domain_name(domain_name.clone());
496
497 let channel = endpoint
498 .tcp_keepalive(Some(Duration::from_secs(60)))
499 .tls_config(tls)?
500 .connect()
501 .await
502 .map_err(|e| format!("Failed to connect with TLS: {}", e))?;
503
504 dev_log!("grpc", "Successfully connected to Mountain at {} with TLS", config.address);
505
506 return Ok(Self { channel, config });
507 }
508
509 dev_log!("grpc", "Using unencrypted connection");
511
512 let channel = endpoint.connect().await?;
513
514 dev_log!("grpc", "Successfully connected to Mountain at {}", config.address);
515
516 Ok(Self { channel, config })
517 }
518
519 pub fn channel(&self) -> &Channel { &self.channel }
524
525 pub fn config(&self) -> &MountainClientConfig { &self.config }
530
531 pub async fn health_check(&self) -> Result<bool, Box<dyn std::error::Error>> {
538 dev_log!("grpc", "Checking Mountain health");
539
540 match tokio::time::timeout(Duration::from_secs(self.config.request_timeout_secs), async {
542 Ok::<(), Box<dyn std::error::Error>>(())
545 })
546 .await
547 {
548 Ok(Ok(())) => {
549 dev_log!("grpc", "Mountain health check: healthy");
550
551 Ok(true)
552 },
553
554 Ok(Err(e)) => {
555 dev_log!("grpc", "warn: Mountain health check: disconnected - {}", e);
556
557 Ok(false)
558 },
559
560 Err(_) => {
561 dev_log!("grpc", "warn: Mountain health check: timeout");
562
563 Ok(false)
564 },
565 }
566 }
567
568 pub async fn get_status(&self) -> Result<String, Box<dyn std::error::Error>> {
576 dev_log!("grpc", "Getting Mountain status");
577
578 Ok("connected".to_string())
581 }
582
583 pub async fn get_config(&self, key:&str) -> Result<Option<String>, Box<dyn std::error::Error>> {
594 dev_log!("grpc", "Getting Mountain config: {}", key);
595
596 Ok(None)
599 }
600
601 pub async fn set_config(&self, key:&str, value:&str) -> Result<(), Box<dyn std::error::Error>> {
613 dev_log!("grpc", "Setting Mountain config: {} = {}", key, value);
614
615 Ok(())
618 }
619}
620
621pub async fn connect_to_mountain() -> Result<MountainClient, Box<dyn std::error::Error>> {
626 MountainClient::connect(MountainClientConfig::default()).await
627}
628
629pub async fn connect_to_mountain_at(address:impl Into<String>) -> Result<MountainClient, Box<dyn std::error::Error>> {
637 MountainClient::connect(MountainClientConfig::new(address)).await
638}
639
640#[cfg(test)]
641mod tests {
642
643 use super::*;
644
645 #[test]
646 fn test_default_config() {
647 let config = MountainClientConfig::default();
648
649 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
650
651 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
652
653 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
654 }
655
656 #[test]
657 fn test_config_builder() {
658 let config = MountainClientConfig::new("[::1]:50060")
659 .with_connection_timeout(10)
660 .with_request_timeout(60);
661
662 assert_eq!(config.address, "[::1]:50060");
663
664 assert_eq!(config.connection_timeout_secs, 10);
665
666 assert_eq!(config.request_timeout_secs, 60);
667 }
668
669 #[cfg(feature = "mtls")]
670 #[test]
671 fn test_tls_config_server_auth() {
672 let tls = TlsConfig::server_auth(std::path::PathBuf::from("/path/to/ca.pem"));
673
674 assert_eq!(tls.server_name, Some("localhost".to_string()));
675
676 assert!(tls.client_cert_path.is_none());
677
678 assert!(tls.client_key_path.is_none());
679
680 assert!(tls.ca_cert_path.is_some());
681
682 assert!(tls.verify_certs);
683 }
684
685 #[cfg(feature = "mtls")]
686 #[test]
687 fn test_tls_config_mtls() {
688 let tls = TlsConfig::mtls(
689 std::path::PathBuf::from("/path/to/ca.pem"),
690 std::path::PathBuf::from("/path/to/cert.pem"),
691 std::path::PathBuf::from("/path/to/key.pem"),
692 );
693
694 assert!(tls.client_cert_path.is_some());
695
696 assert!(tls.client_key_path.is_some());
697
698 assert!(tls.ca_cert_path.is_some());
699
700 assert!(tls.verify_certs);
701
702 assert_eq!(tls.server_name, Some("localhost".to_string()));
703 }
704
705 #[cfg(feature = "mtls")]
706 #[test]
707 fn test_tls_config_default() {
708 let tls = TlsConfig::default();
709
710 assert!(tls.ca_cert_path.is_none());
711
712 assert!(tls.client_cert_path.is_none());
713
714 assert!(tls.client_key_path.is_none());
715
716 assert!(tls.server_name.is_none());
717
718 assert!(tls.verify_certs);
719 }
720
721 #[test]
722 fn test_from_env_default() {
723 unsafe {
725 env::remove_var("MOUNTAIN_ADDRESS");
726 }
727
728 unsafe {
729 env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS");
730 }
731
732 unsafe {
733 env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS");
734 }
735
736 unsafe {
737 env::remove_var("MOUNTAIN_TLS_ENABLED");
738 }
739
740 let config = MountainClientConfig::from_env();
741
742 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
743
744 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
745
746 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
747 }
748
749 #[test]
750 fn test_from_env_custom() {
751 unsafe {
752 env::set_var("MOUNTAIN_ADDRESS", "[::1]:50060");
753 }
754
755 unsafe {
756 env::set_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS", "10");
757 }
758
759 unsafe {
760 env::set_var("MOUNTAIN_REQUEST_TIMEOUT_SECS", "60");
761 }
762
763 let config = MountainClientConfig::from_env();
764
765 assert_eq!(config.address, "[::1]:50060");
766
767 assert_eq!(config.connection_timeout_secs, 10);
768
769 assert_eq!(config.request_timeout_secs, 60);
770
771 unsafe {
773 env::remove_var("MOUNTAIN_ADDRESS");
774 }
775
776 unsafe {
777 env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS");
778 }
779
780 unsafe {
781 env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS");
782 }
783 }
784
785 #[cfg(feature = "mtls")]
786 #[test]
787 fn test_from_env_tls() {
788 unsafe {
789 env::set_var("MOUNTAIN_TLS_ENABLED", "1");
790 }
791
792 unsafe {
793 env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem");
794 }
795
796 unsafe {
797 env::set_var("MOUNTAIN_SERVER_NAME", "mymountain.com");
798 }
799
800 let config = MountainClientConfig::from_env();
801
802 assert!(config.tls_config.is_some());
803
804 let tls = config.tls_config.unwrap();
805
806 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
807
808 assert_eq!(tls.server_name, Some("mymountain.com".to_string()));
809
810 assert!(tls.verify_certs);
811
812 unsafe {
814 env::remove_var("MOUNTAIN_TLS_ENABLED");
815 }
816
817 unsafe {
818 env::remove_var("MOUNTAIN_CA_CERT");
819 }
820
821 unsafe {
822 env::remove_var("MOUNTAIN_SERVER_NAME");
823 }
824 }
825
826 #[cfg(feature = "mtls")]
827 #[test]
828 fn test_from_env_mtls() {
829 unsafe {
830 env::set_var("MOUNTAIN_TLS_ENABLED", "true");
831 }
832
833 unsafe {
834 env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem");
835 }
836
837 unsafe {
838 env::set_var("MOUNTAIN_CLIENT_CERT", "/path/to/cert.pem");
839 }
840
841 unsafe {
842 env::set_var("MOUNTAIN_CLIENT_KEY", "/path/to/key.pem");
843 }
844
845 let config = MountainClientConfig::from_env();
846
847 assert!(config.tls_config.is_some());
848
849 let tls = config.tls_config.unwrap();
850
851 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
852
853 assert_eq!(tls.client_cert_path, Some(std::path::PathBuf::from("/path/to/cert.pem")));
854
855 assert_eq!(tls.client_key_path, Some(std::path::PathBuf::from("/path/to/key.pem")));
856
857 assert!(tls.verify_certs);
858
859 unsafe {
861 env::remove_var("MOUNTAIN_TLS_ENABLED");
862 }
863
864 unsafe {
865 env::remove_var("MOUNTAIN_CA_CERT");
866 }
867
868 unsafe {
869 env::remove_var("MOUNTAIN_CLIENT_CERT");
870 }
871
872 unsafe {
873 env::remove_var("MOUNTAIN_CLIENT_KEY");
874 }
875 }
876}