@@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend<NoTls>;
6464/// A postgres backend with TLS connections to the database
6565pub type PostgresTlsBackend = PostgresBackend < MakeTlsConnector > ;
6666
67- async fn make_postgres_db_connection < T > ( postgres_endpoint : & str , tls : T ) -> Result < Client , Error >
67+ async fn make_db_connection < T > (
68+ postgres_endpoint : & str , db_name : & str , tls : T ,
69+ ) -> Result < Client , Error >
6870where
6971 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
7072 T :: Stream : Send + Sync ,
7173 T :: TlsConnect : Send ,
7274 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
7375{
74- let dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
76+ let dsn = format ! ( "{}/{}" , postgres_endpoint, db_name ) ;
7577 let ( client, connection) = tokio_postgres:: connect ( & dsn, tls)
7678 . await
7779 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
@@ -84,16 +86,16 @@ where
8486 Ok ( client)
8587}
8688
87- async fn initialize_vss_database < T > (
88- postgres_endpoint : & str , db_name : & str , tls : T ,
89+ async fn create_database < T > (
90+ postgres_endpoint : & str , default_db : & str , db_name : & str , tls : T ,
8991) -> Result < ( ) , Error >
9092where
9193 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
9294 T :: Stream : Send + Sync ,
9395 T :: TlsConnect : Send ,
9496 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
9597{
96- let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
98+ let client = make_db_connection ( postgres_endpoint, default_db , tls) . await ?;
9799
98100 let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
99101 Error :: new (
@@ -113,14 +115,16 @@ where
113115}
114116
115117#[ cfg( test) ]
116- async fn drop_database < T > ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < ( ) , Error >
118+ async fn drop_database < T > (
119+ postgres_endpoint : & str , default_db : & str , db_name : & str , tls : T ,
120+ ) -> Result < ( ) , Error >
117121where
118122 T : MakeTlsConnect < Socket > + Clone + Send + Sync + ' static ,
119123 T :: Stream : Send + Sync ,
120124 T :: TlsConnect : Send ,
121125 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
122126{
123- let client = make_postgres_db_connection ( & postgres_endpoint, tls) . await ?;
127+ let client = make_db_connection ( postgres_endpoint, default_db , tls) . await ?;
124128
125129 let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
126130 let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
@@ -133,25 +137,38 @@ where
133137
134138impl PostgresPlaintextBackend {
135139 /// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information.
136- pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
137- PostgresBackend :: new_internal ( postgres_endpoint, db_name, NoTls ) . await
140+ pub async fn new (
141+ postgres_endpoint : & str , default_db : & str , vss_db : & str ,
142+ ) -> Result < Self , Error > {
143+ PostgresBackend :: new_internal ( postgres_endpoint, default_db, vss_db, NoTls ) . await
138144 }
139145}
140146
141147impl PostgresTlsBackend {
142148 /// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information.
143149 pub async fn new (
144- postgres_endpoint : & str , db_name : & str , additional_certificate : Option < Certificate > ,
150+ postgres_endpoint : & str , default_db : & str , vss_db : & str , crt_pem : Option < & str > ,
145151 ) -> Result < Self , Error > {
146152 let mut builder = TlsConnector :: builder ( ) ;
147- if let Some ( cert) = additional_certificate {
148- builder. add_root_certificate ( cert) ;
153+ if let Some ( pem) = crt_pem {
154+ let crt = Certificate :: from_pem ( pem. as_bytes ( ) ) . map_err ( |e| {
155+ Error :: new (
156+ ErrorKind :: Other ,
157+ format ! ( "Failed to parse the PEM formatted certificate: {}" , e) ,
158+ )
159+ } ) ?;
160+ builder. add_root_certificate ( crt) ;
149161 }
150162 let connector = builder. build ( ) . map_err ( |e| {
151163 Error :: new ( ErrorKind :: Other , format ! ( "Error building tls connector: {}" , e) )
152164 } ) ?;
153- PostgresBackend :: new_internal ( postgres_endpoint, db_name, MakeTlsConnector :: new ( connector) )
154- . await
165+ PostgresBackend :: new_internal (
166+ postgres_endpoint,
167+ default_db,
168+ vss_db,
169+ MakeTlsConnector :: new ( connector) ,
170+ )
171+ . await
155172 }
156173}
157174
@@ -162,9 +179,11 @@ where
162179 T :: TlsConnect : Send ,
163180 <<T as MakeTlsConnect < Socket > >:: TlsConnect as TlsConnect < Socket > >:: Future : Send ,
164181{
165- async fn new_internal ( postgres_endpoint : & str , db_name : & str , tls : T ) -> Result < Self , Error > {
166- initialize_vss_database ( postgres_endpoint, db_name, tls. clone ( ) ) . await ?;
167- let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
182+ async fn new_internal (
183+ postgres_endpoint : & str , default_db : & str , vss_db : & str , tls : T ,
184+ ) -> Result < Self , Error > {
185+ create_database ( postgres_endpoint, default_db, vss_db, tls. clone ( ) ) . await ?;
186+ let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, vss_db) ;
168187 let manager =
169188 PostgresConnectionManager :: new_from_stringlike ( vss_dsn, tls) . map_err ( |e| {
170189 Error :: new (
@@ -649,24 +668,27 @@ mod tests {
649668 use tokio_postgres:: NoTls ;
650669
651670 const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
671+ const DEFAULT_DB : & str = "postgres" ;
652672 const MIGRATIONS_START : usize = 0 ;
653673 const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
654674
655675 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
656676
657677 define_kv_store_tests ! ( PostgresKvStoreTest , PostgresPlaintextBackend , {
658- let db_name = "postgres_kv_store_tests" ;
678+ let vss_db = "postgres_kv_store_tests" ;
659679 START
660680 . get_or_init( || async {
661- let _ = drop_database( POSTGRES_ENDPOINT , db_name, NoTls ) . await ;
662- let store =
663- PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
681+ let _ = drop_database( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db, NoTls ) . await ;
682+ let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db)
683+ . await
684+ . unwrap( ) ;
664685 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
665686 assert_eq!( start, MIGRATIONS_START ) ;
666687 assert_eq!( end, MIGRATIONS_END ) ;
667688 } )
668689 . await ;
669- let store = PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
690+ let store =
691+ PostgresPlaintextBackend :: new( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap( ) ;
670692 let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
671693 assert_eq!( start, MIGRATIONS_END ) ;
672694 assert_eq!( end, MIGRATIONS_END ) ;
@@ -678,36 +700,40 @@ mod tests {
678700 #[ tokio:: test]
679701 #[ should_panic( expected = "We do not allow downgrades" ) ]
680702 async fn panic_on_downgrade ( ) {
681- let db_name = "panic_on_downgrade_test" ;
682- let _ = drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await ;
703+ let vss_db = "panic_on_downgrade_test" ;
704+ let _ = drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await ;
683705 {
684706 let mut migrations = MIGRATIONS . to_vec ( ) ;
685707 migrations. push ( DUMMY_MIGRATION ) ;
686- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
708+ let store =
709+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
687710 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
688711 assert_eq ! ( start, MIGRATIONS_START ) ;
689712 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
690713 } ;
691714 {
692- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
715+ let store =
716+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
693717 let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
694718 } ;
695719 }
696720
697721 #[ tokio:: test]
698722 async fn new_migrations_increments_upgrades ( ) {
699- let db_name = "new_migrations_increments_upgrades_test" ;
700- let _ = drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await ;
723+ let vss_db = "new_migrations_increments_upgrades_test" ;
724+ let _ = drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await ;
701725 {
702- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
726+ let store =
727+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
703728 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
704729 assert_eq ! ( start, MIGRATIONS_START ) ;
705730 assert_eq ! ( end, MIGRATIONS_END ) ;
706731 assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
707732 assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
708733 } ;
709734 {
710- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
735+ let store =
736+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
711737 let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
712738 assert_eq ! ( start, MIGRATIONS_END ) ;
713739 assert_eq ! ( end, MIGRATIONS_END ) ;
@@ -718,7 +744,8 @@ mod tests {
718744 let mut migrations = MIGRATIONS . to_vec ( ) ;
719745 migrations. push ( DUMMY_MIGRATION ) ;
720746 {
721- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
747+ let store =
748+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
722749 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
723750 assert_eq ! ( start, MIGRATIONS_END ) ;
724751 assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
@@ -729,7 +756,8 @@ mod tests {
729756 migrations. push ( DUMMY_MIGRATION ) ;
730757 migrations. push ( DUMMY_MIGRATION ) ;
731758 {
732- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
759+ let store =
760+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
733761 let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
734762 assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
735763 assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
@@ -741,13 +769,14 @@ mod tests {
741769 } ;
742770
743771 {
744- let store = PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
772+ let store =
773+ PostgresPlaintextBackend :: new ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db) . await . unwrap ( ) ;
745774 let list = store. get_upgrades_list ( ) . await ;
746775 assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
747776 let version = store. get_schema_version ( ) . await ;
748777 assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
749778 }
750779
751- drop_database ( POSTGRES_ENDPOINT , db_name , NoTls ) . await . unwrap ( ) ;
780+ drop_database ( POSTGRES_ENDPOINT , DEFAULT_DB , vss_db , NoTls ) . await . unwrap ( ) ;
752781 }
753782}
0 commit comments