@@ -50,7 +50,7 @@ pub struct Cli {
5050 pub driver : Option < Driver > ,
5151}
5252
53- #[ derive( Parser , Debug , Clone , Copy , Deserialize ) ]
53+ #[ derive( Parser , Debug , Clone , Copy , Deserialize , PartialEq , Eq ) ]
5454pub enum Driver {
5555 #[ serde( alias = "postgres" , alias = "POSTGRES" ) ]
5656 Postgres ,
@@ -83,10 +83,14 @@ impl FromStr for Driver {
8383
8484pub fn extract_driver_from_url ( url : & str ) -> Result < Driver > {
8585 let url = url. trim ( ) ;
86- if let Some ( pos) = url. find ( "://" ) {
86+ if url. starts_with ( "jdbc:" ) {
87+ if let Some ( driver_part) = url. split ( ':' ) . nth ( 1 ) {
88+ driver_part. to_lowercase ( ) . parse ( )
89+ } else {
90+ Err ( eyre:: Report :: msg ( "Invalid connection URL format" ) )
91+ }
92+ } else if let Some ( pos) = url. find ( "://" ) {
8793 url[ ..pos] . to_lowercase ( ) . parse ( )
88- } else if url. starts_with ( "jdbc:oracle:thin" ) {
89- Ok ( Driver :: Oracle )
9094 } else if url. ends_with ( ".duckdb" ) || url. ends_with ( ".ddb" ) {
9195 #[ cfg( not( feature = "musl" ) ) ]
9296 {
@@ -137,3 +141,108 @@ pub fn prompt_for_database_selection(config: &Config) -> Result<Option<(Database
137141 } ,
138142 }
139143}
144+
145+ #[ cfg( test) ]
146+ mod tests {
147+ use super :: * ;
148+ #[ test]
149+ fn extracts_driver_from_standard_urls ( ) {
150+ let cases = [
151+ ( "postgres://username:password@localhost:5432/dbname" , Driver :: Postgres ) ,
152+ ( "postgresql://[email protected] /reporting?sslmode=require" , Driver :: Postgres ) , 153+ ( "postgres://user:pass@[2001:db8::1]:5432/app" , Driver :: Postgres ) ,
154+ ( "postgresql://user@/analytics?host=/var/run/postgresql" , Driver :: Postgres ) ,
155+ ( "POSTGRES://localhost/dbname" , Driver :: Postgres ) ,
156+ ( "mysql://localhost/dbname" , Driver :: MySql ) ,
157+ ( "mysql://app:[email protected] :3307/metrics?useSSL=false" , Driver :: MySql ) , 158+ ( "mysql://reader:[email protected] /app?charset=utf8mb4" , Driver :: MySql ) , 159+ ( "sqlite:///tmp/data.sqlite" , Driver :: Sqlite ) ,
160+ ( "sqlite:///var/lib/sqlite/app.sqlite3" , Driver :: Sqlite ) ,
161+ ( "sqlite://localhost/var/data.sqlite?mode=ro" , Driver :: Sqlite ) ,
162+ ( "oracle://scott:tiger@//prod-db.example.com:1521/ORCLPDB1" , Driver :: Oracle ) ,
163+ ( "oracle://user:pass@db-host/service_name" , Driver :: Oracle ) ,
164+ #[ cfg( not( feature = "musl" ) ) ]
165+ ( "duckdb:///var/tmp/cache.duckdb" , Driver :: DuckDb ) ,
166+ ] ;
167+
168+ for ( url, expected) in cases {
169+ let actual = extract_driver_from_url ( url) . unwrap_or_else ( |err| panic ! ( "url: {url}, err: {err}" ) ) ;
170+ assert_eq ! ( actual, expected, "url: {url}" ) ;
171+ }
172+ }
173+
174+ #[ test]
175+ fn extracts_driver_from_jdbc_urls ( ) {
176+ let cases = [
177+ ( "jdbc:postgresql://localhost:5432/dbname" , Driver :: Postgres ) ,
178+ ( "jdbc:postgresql://[email protected] :5432/reporting?sslmode=require" , Driver :: Postgres ) , 179+ ( "jdbc:mysql://localhost:3306/dbname" , Driver :: MySql ) ,
180+ ( "jdbc:mysql:loadbalance://db1.example.com:3306,db2.example.com:3306/app" , Driver :: MySql ) ,
181+ ( "jdbc:sqlite://localhost/path" , Driver :: Sqlite ) ,
182+ ( "jdbc:sqlite:/var/lib/sqlite/cache.sqlite3" , Driver :: Sqlite ) ,
183+ ( "jdbc:oracle:thin:@localhost:1521/dbname" , Driver :: Oracle ) ,
184+ ( "jdbc:oracle:oci:@//prod-host:1521/ORCLCDB.localdomain" , Driver :: Oracle ) ,
185+ #[ cfg( not( feature = "musl" ) ) ]
186+ ( "jdbc:duckdb:/var/lib/duckdb/cache.duckdb" , Driver :: DuckDb ) ,
187+ ] ;
188+
189+ for ( url, expected) in cases {
190+ let actual = extract_driver_from_url ( url) . unwrap_or_else ( |err| panic ! ( "url: {url}, err: {err}" ) ) ;
191+ assert_eq ! ( actual, expected, "url: {url}" ) ;
192+ }
193+ }
194+
195+ #[ test]
196+ fn extracts_driver_from_file_extensions ( ) {
197+ let sqlite_paths = [ "/tmp/app.sqlite" , "/tmp/app.sqlite3" , "./relative/state.sqlite" , r"C:\data\inventory.sqlite3" ] ;
198+ for path in sqlite_paths {
199+ assert_eq ! (
200+ extract_driver_from_url( path) . unwrap_or_else( |err| panic!( "url: {path}, err: {err}" ) ) ,
201+ Driver :: Sqlite ,
202+ "url: {path}"
203+ ) ;
204+ }
205+
206+ #[ cfg( not( feature = "musl" ) ) ]
207+ {
208+ let duckdb_paths = [ "/tmp/data.duckdb" , "/tmp/data.ddb" , "./var/cache/session.duckdb" ] ;
209+ for path in duckdb_paths {
210+ assert_eq ! (
211+ extract_driver_from_url( path) . unwrap_or_else( |err| panic!( "url: {path}, err: {err}" ) ) ,
212+ Driver :: DuckDb ,
213+ "url: {path}"
214+ ) ;
215+ }
216+ }
217+
218+ #[ cfg( feature = "musl" ) ]
219+ {
220+ assert ! ( extract_driver_from_url( "/tmp/data.duckdb" ) . is_err( ) ) ;
221+ }
222+
223+ let err = extract_driver_from_url ( "/tmp/unknown.db" ) . unwrap_err ( ) ;
224+ assert ! ( err. to_string( ) . contains( "ambiguous" ) ) ;
225+ }
226+
227+ #[ test]
228+ fn trims_whitespace_before_parsing ( ) {
229+ let cases = [
230+ ( " mysql://user@localhost/db " , Driver :: MySql ) ,
231+ ( "\t postgres://readonly@reports/db\n " , Driver :: Postgres ) ,
232+ ( " \n sqlite:///tmp/cache.sqlite3\t " , Driver :: Sqlite ) ,
233+ ] ;
234+
235+ for ( url, expected) in cases {
236+ let actual = extract_driver_from_url ( url) . unwrap_or_else ( |err| panic ! ( "url: {url:?}, err: {err}" ) ) ;
237+ assert_eq ! ( actual, expected, "url: {url:?}" ) ;
238+ }
239+ }
240+
241+ #[ test]
242+ fn errors_on_invalid_format ( ) {
243+ for url in [ "localhost:5432/db" , "postgresql:/localhost/db" , "oracle//prod-host:1521/service" ] {
244+ let err = extract_driver_from_url ( url) . unwrap_err ( ) ;
245+ assert ! ( err. to_string( ) . contains( "Invalid connection URL format" ) , "Unexpected error for {url}: {err}" ) ;
246+ }
247+ }
248+ }
0 commit comments