Skip to content

Commit 6c2c3ec

Browse files
parse jdbc urls (#229)
* parse jdbc, add tests * handle jdbc url connections * clippy * fmt
1 parent 290cc41 commit 6c2c3ec

File tree

6 files changed

+118
-9
lines changed

6 files changed

+118
-9
lines changed

src/cli.rs

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub struct Cli {
5050
pub driver: Option<Driver>,
5151
}
5252

53-
#[derive(Parser, Debug, Clone, Copy, Deserialize)]
53+
#[derive(Parser, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
5454
pub enum Driver {
5555
#[serde(alias = "postgres", alias = "POSTGRES")]
5656
Postgres,
@@ -83,10 +83,14 @@ impl FromStr for Driver {
8383

8484
pub fn extract_driver_from_url(url: &str) -> Result<Driver> {
8585
let url = url.trim();
86-
if let Some(pos) = url.find("://") {
86+
if url.starts_with("jdbc:") {
87+
if let Some(driver_part) = url.split(':').nth(1) {
88+
driver_part.to_lowercase().parse()
89+
} else {
90+
Err(eyre::Report::msg("Invalid connection URL format"))
91+
}
92+
} else if let Some(pos) = url.find("://") {
8793
url[..pos].to_lowercase().parse()
88-
} else if url.starts_with("jdbc:oracle:thin") {
89-
Ok(Driver::Oracle)
9094
} else if url.ends_with(".duckdb") || url.ends_with(".ddb") {
9195
#[cfg(not(feature = "musl"))]
9296
{
@@ -137,3 +141,108 @@ pub fn prompt_for_database_selection(config: &Config) -> Result<Option<(Database
137141
},
138142
}
139143
}
144+
145+
#[cfg(test)]
146+
mod tests {
147+
use super::*;
148+
#[test]
149+
fn extracts_driver_from_standard_urls() {
150+
let cases = [
151+
("postgres://username:password@localhost:5432/dbname", Driver::Postgres),
152+
("postgresql://[email protected]/reporting?sslmode=require", Driver::Postgres),
153+
("postgres://user:pass@[2001:db8::1]:5432/app", Driver::Postgres),
154+
("postgresql://user@/analytics?host=/var/run/postgresql", Driver::Postgres),
155+
("POSTGRES://localhost/dbname", Driver::Postgres),
156+
("mysql://localhost/dbname", Driver::MySql),
157+
("mysql://app:[email protected]:3307/metrics?useSSL=false", Driver::MySql),
158+
("mysql://reader:[email protected]/app?charset=utf8mb4", Driver::MySql),
159+
("sqlite:///tmp/data.sqlite", Driver::Sqlite),
160+
("sqlite:///var/lib/sqlite/app.sqlite3", Driver::Sqlite),
161+
("sqlite://localhost/var/data.sqlite?mode=ro", Driver::Sqlite),
162+
("oracle://scott:tiger@//prod-db.example.com:1521/ORCLPDB1", Driver::Oracle),
163+
("oracle://user:pass@db-host/service_name", Driver::Oracle),
164+
#[cfg(not(feature = "musl"))]
165+
("duckdb:///var/tmp/cache.duckdb", Driver::DuckDb),
166+
];
167+
168+
for (url, expected) in cases {
169+
let actual = extract_driver_from_url(url).unwrap_or_else(|err| panic!("url: {url}, err: {err}"));
170+
assert_eq!(actual, expected, "url: {url}");
171+
}
172+
}
173+
174+
#[test]
175+
fn extracts_driver_from_jdbc_urls() {
176+
let cases = [
177+
("jdbc:postgresql://localhost:5432/dbname", Driver::Postgres),
178+
("jdbc:postgresql://[email protected]:5432/reporting?sslmode=require", Driver::Postgres),
179+
("jdbc:mysql://localhost:3306/dbname", Driver::MySql),
180+
("jdbc:mysql:loadbalance://db1.example.com:3306,db2.example.com:3306/app", Driver::MySql),
181+
("jdbc:sqlite://localhost/path", Driver::Sqlite),
182+
("jdbc:sqlite:/var/lib/sqlite/cache.sqlite3", Driver::Sqlite),
183+
("jdbc:oracle:thin:@localhost:1521/dbname", Driver::Oracle),
184+
("jdbc:oracle:oci:@//prod-host:1521/ORCLCDB.localdomain", Driver::Oracle),
185+
#[cfg(not(feature = "musl"))]
186+
("jdbc:duckdb:/var/lib/duckdb/cache.duckdb", Driver::DuckDb),
187+
];
188+
189+
for (url, expected) in cases {
190+
let actual = extract_driver_from_url(url).unwrap_or_else(|err| panic!("url: {url}, err: {err}"));
191+
assert_eq!(actual, expected, "url: {url}");
192+
}
193+
}
194+
195+
#[test]
196+
fn extracts_driver_from_file_extensions() {
197+
let sqlite_paths = ["/tmp/app.sqlite", "/tmp/app.sqlite3", "./relative/state.sqlite", r"C:\data\inventory.sqlite3"];
198+
for path in sqlite_paths {
199+
assert_eq!(
200+
extract_driver_from_url(path).unwrap_or_else(|err| panic!("url: {path}, err: {err}")),
201+
Driver::Sqlite,
202+
"url: {path}"
203+
);
204+
}
205+
206+
#[cfg(not(feature = "musl"))]
207+
{
208+
let duckdb_paths = ["/tmp/data.duckdb", "/tmp/data.ddb", "./var/cache/session.duckdb"];
209+
for path in duckdb_paths {
210+
assert_eq!(
211+
extract_driver_from_url(path).unwrap_or_else(|err| panic!("url: {path}, err: {err}")),
212+
Driver::DuckDb,
213+
"url: {path}"
214+
);
215+
}
216+
}
217+
218+
#[cfg(feature = "musl")]
219+
{
220+
assert!(extract_driver_from_url("/tmp/data.duckdb").is_err());
221+
}
222+
223+
let err = extract_driver_from_url("/tmp/unknown.db").unwrap_err();
224+
assert!(err.to_string().contains("ambiguous"));
225+
}
226+
227+
#[test]
228+
fn trims_whitespace_before_parsing() {
229+
let cases = [
230+
(" mysql://user@localhost/db ", Driver::MySql),
231+
("\tpostgres://readonly@reports/db\n", Driver::Postgres),
232+
(" \nsqlite:///tmp/cache.sqlite3\t", Driver::Sqlite),
233+
];
234+
235+
for (url, expected) in cases {
236+
let actual = extract_driver_from_url(url).unwrap_or_else(|err| panic!("url: {url:?}, err: {err}"));
237+
assert_eq!(actual, expected, "url: {url:?}");
238+
}
239+
}
240+
241+
#[test]
242+
fn errors_on_invalid_format() {
243+
for url in ["localhost:5432/db", "postgresql:/localhost/db", "oracle//prod-host:1521/service"] {
244+
let err = extract_driver_from_url(url).unwrap_err();
245+
assert!(err.to_string().contains("Invalid connection URL format"), "Unexpected error for {url}: {err}");
246+
}
247+
}
248+
}

src/database/duckdb.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ impl DuckDbDriver {
286286

287287
fn build_connection_opts(args: crate::cli::Cli) -> Result<(String, Config)> {
288288
match args.connection_url {
289-
Some(url) => Ok((url, Config::default())),
289+
Some(url) => Ok((url.trim().trim_start_matches("jdbc:").to_string(), Config::default())),
290290
None => {
291291
if let Some(database) = args.database {
292292
Ok((database, Config::default()))

src/database/mysql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ impl MySqlDriver<'_> {
263263
args: crate::cli::Cli,
264264
) -> Result<<<sqlx::MySql as sqlx::Database>::Connection as sqlx::Connection>::Options> {
265265
match args.connection_url {
266-
Some(url) => Ok(MySqlConnectOptions::from_str(&url)?),
266+
Some(url) => Ok(MySqlConnectOptions::from_str(url.trim().trim_start_matches("jdbc:"))?),
267267
None => {
268268
let mut opts = MySqlConnectOptions::new();
269269

src/database/oracle/connect_options.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ impl FromStr for OracleConnectOptions {
146146
type Err = String;
147147

148148
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
149-
let s = s.trim().trim_start_matches("jdbc:oracle:thin:");
149+
let s = s.trim().trim_start_matches("jdbc:oracle:thin:").trim_start_matches("jdbc:");
150150
let (is_easy_connect, (auth_part, host_part)) = if s.contains("@//") {
151151
(true, s.split_once("@//").ok_or("Invalid Oracle Easy Connect connection string format".to_string())?)
152152
} else if s.contains("@") {

src/database/postgresql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ impl PostgresDriver<'_> {
267267
args: crate::cli::Cli,
268268
) -> Result<<<sqlx::Postgres as sqlx::Database>::Connection as sqlx::Connection>::Options> {
269269
match args.connection_url {
270-
Some(url) => Ok(PgConnectOptions::from_str(&url)?),
270+
Some(url) => Ok(PgConnectOptions::from_str(url.trim().trim_start_matches("jdbc:"))?),
271271
None => {
272272
let mut opts = PgConnectOptions::new();
273273

src/database/sqlite.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ impl SqliteDriver<'_> {
213213
args: crate::cli::Cli,
214214
) -> Result<<<sqlx::Sqlite as sqlx::Database>::Connection as sqlx::Connection>::Options> {
215215
match args.connection_url {
216-
Some(url) => Ok(SqliteConnectOptions::from_str(&url)?),
216+
Some(url) => Ok(SqliteConnectOptions::from_str(url.trim().trim_start_matches("jdbc:"))?),
217217
None => {
218218
let filename = if let Some(database) = args.database {
219219
database

0 commit comments

Comments
 (0)