From 1e137a4f3a1343e4abbec07d43509b6d67ffcbc6 Mon Sep 17 00:00:00 2001 From: Jacob Windsor Date: Wed, 19 Feb 2025 12:49:42 +0100 Subject: [PATCH] Add db provider --- backend/app/providers/db_provider.py | 77 ++++------------------------ 1 file changed, 11 insertions(+), 66 deletions(-) diff --git a/backend/app/providers/db_provider.py b/backend/app/providers/db_provider.py index 5ac3d1c..6722302 100644 --- a/backend/app/providers/db_provider.py +++ b/backend/app/providers/db_provider.py @@ -1,13 +1,6 @@ -import os -import subprocess from collections.abc import AsyncGenerator from contextlib import asynccontextmanager - -import asyncpg -import pg8000 -import pg8000.dbapi -from google.cloud.sql.connector import Connector, IPTypes, create_async_connector from sqlalchemy import Engine from sqlalchemy.ext.asyncio import ( AsyncConnection, @@ -16,35 +9,19 @@ from sqlalchemy.ext.asyncio import ( async_sessionmaker, create_async_engine, ) -from sqlmodel import create_engine +from sqlmodel import create_engine, SQLModel -from ..config import get_settings -from ..db.seed import seed -from ..models.base_db_model import BaseDBModel + +from ..settings import get_settings + +settings = get_settings() async def _get_async_engine() -> AsyncEngine: - settings = get_settings() - if settings.app.environment == "development": - engine = create_async_engine( + return create_async_engine( f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}", future=True, ) - else: - connector = await create_async_connector() - - async def getconn() -> asyncpg.Connection: - return await connector.connect_async( - settings.db.connection_name, - "asyncpg", - user=settings.db.username, - password=settings.db.password, - db=settings.db.db_name, - ip_type=IPTypes.PUBLIC, - ) - - engine = create_async_engine("postgresql+asyncpg://", async_creator=getconn, future=True) - return engine async def get_session() -> AsyncGenerator[AsyncSession, None]: @@ -74,45 +51,13 @@ async def get_connection() -> AsyncGenerator[AsyncConnection, None]: def _get_engine() -> Engine: - settings = get_settings() - if settings.app.environment == "development": - engine = create_engine( - f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}" - ) - else: - connector = Connector() - - def getconn() -> pg8000.dbapi.Connection: - conn: pg8000.dbapi.Connection = connector.connect( - settings.db.connection_name, - "pg8000", - user=settings.db.username, - password=settings.db.password, - db=settings.db.db_name, - ip_type=IPTypes.PUBLIC, - ) - return conn - - engine = create_engine("postgresql+pg8000://", creator=getconn) - return engine + return create_engine( + f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}" + ) def create_db_and_tables(): - # TODO Move this to use asyncpg engine = _get_engine() - BaseDBModel.metadata.create_all(engine) + SQLModel.metadata.create_all(engine) - if get_settings().app.environment == "development": - seed(engine) - - -def startup_migrations(): - """Run Alembic migrations""" - print("Running Alembic migrations...") - - api_path = os.path.dirname(os.path.abspath(__file__)) + "/../.." - try: - subprocess.run(["alembic", "upgrade", "head"], check=True, cwd=api_path) - print("Migrations applied successfully!") - except subprocess.CalledProcessError as e: - print(f"Error applying migrations: {e}") + # TODO: add seeding