diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index e7a2ba17..e28f5edd 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -645,11 +645,14 @@ def create_connect_args(self, url): The given URL follows the style: `spanner:///projects/{project-id}/instances/{instance-id}/databases/{database-id}` + or `spanner:///projects/{project-id}/instances/{instance-id}`. For the latter, + database operatopns will be not be possible and if required a new engine with + database-id set will need to be created. """ match = re.match( ( r"^projects/(?P.+?)/instances/" - "(?P.+?)/databases/(?P.+?)$" + "(?P.+?)(/databases/(?P.+)|$)" ), url.database, ) @@ -1346,17 +1349,29 @@ def do_rollback(self, dbapi_connection): ): pass else: - trace_attributes = {"db.instance": dbapi_connection.database.name} + trace_attributes = { + "db.instance": dbapi_connection.database.name + if dbapi_connection.database + else "" + } with trace_call("SpannerSqlAlchemy.Rollback", trace_attributes): dbapi_connection.rollback() def do_commit(self, dbapi_connection): - trace_attributes = {"db.instance": dbapi_connection.database.name} + trace_attributes = { + "db.instance": dbapi_connection.database.name + if dbapi_connection.database + else "" + } with trace_call("SpannerSqlAlchemy.Commit", trace_attributes): dbapi_connection.commit() def do_close(self, dbapi_connection): - trace_attributes = {"db.instance": dbapi_connection.database.name} + trace_attributes = { + "db.instance": dbapi_connection.database.name + if dbapi_connection.database + else "" + } with trace_call("SpannerSqlAlchemy.Close", trace_attributes): dbapi_connection.close() diff --git a/test/test_suite_13.py b/test/test_suite_13.py index a5b2e3bb..0561de5d 100644 --- a/test/test_suite_13.py +++ b/test/test_suite_13.py @@ -2045,3 +2045,17 @@ def test_create_engine_w_invalid_client_object(self): with pytest.raises(ValueError): engine.connect() + + +class CreateEngineWithoutDatabaseTest(fixtures.TestBase): + def test_create_engine_wo_database(self): + """ + SPANNER TEST: + + Check that we can connect to SqlAlchemy + without passing database id in the + connection URL. + """ + engine = create_engine(get_db_url().split("/database")[0]) + with engine.connect() as connection: + assert connection.connection.database is None diff --git a/test/test_suite_14.py b/test/test_suite_14.py index a13477b3..3ff069b2 100644 --- a/test/test_suite_14.py +++ b/test/test_suite_14.py @@ -2378,3 +2378,17 @@ def test_create_engine_w_invalid_client_object(self): with pytest.raises(ValueError): engine.connect() + + +class CreateEngineWithoutDatabaseTest(fixtures.TestBase): + def test_create_engine_wo_database(self): + """ + SPANNER TEST: + + Check that we can connect to SqlAlchemy + without passing database id in the + connection URL. + """ + engine = create_engine(get_db_url().split("/database")[0]) + with engine.connect() as connection: + assert connection.connection.database is None diff --git a/test/test_suite_20.py b/test/test_suite_20.py index fb59b725..b4bf26fa 100644 --- a/test/test_suite_20.py +++ b/test/test_suite_20.py @@ -24,7 +24,7 @@ import time from unittest import mock -from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1 import RequestOptions, Client import sqlalchemy from sqlalchemy import create_engine from sqlalchemy.engine import Inspector @@ -144,7 +144,7 @@ UnicodeTextTest as _UnicodeTextTest, _UnicodeFixture as __UnicodeFixture, ) # noqa: F401, F403 -from test._helpers import get_db_url +from test._helpers import get_db_url, get_project config.test_schema = "" @@ -3000,3 +3000,44 @@ def test_request_priority(self): engine = create_engine("sqlite:///database") with engine.connect() as connection: pass + + +class CreateEngineWithClientObjectTest(fixtures.TestBase): + def test_create_engine_w_valid_client_object(self): + """ + SPANNER TEST: + + Check that we can connect to SqlAlchemy + by passing custom Client object. + """ + client = Client(project=get_project()) + engine = create_engine(get_db_url(), connect_args={"client": client}) + with engine.connect() as connection: + assert connection.connection.instance._client == client + + def test_create_engine_w_invalid_client_object(self): + """ + SPANNER TEST: + + Check that if project id in url and custom Client + Object passed to enginer mismatch, error is thrown. + """ + client = Client(project="project_id") + engine = create_engine(get_db_url(), connect_args={"client": client}) + + with pytest.raises(ValueError): + engine.connect() + + +class CreateEngineWithoutDatabaseTest(fixtures.TestBase): + def test_create_engine_wo_database(self): + """ + SPANNER TEST: + + Check that we can connect to SqlAlchemy + without passing database id in the + connection URL. + """ + engine = create_engine(get_db_url().split("/database")[0]) + with engine.connect() as connection: + assert connection.connection.database is None