Code indexing in gitaly is broken and leads to code not being visible to the user. We work on the issue with highest priority.

Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 8.85 KiB
import os
import json
import tomllib
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles

from app import ssl_heidi
from app.routers import (
    proposal,
    puck,
    spreadsheet,
    logistics,
    auth,
    sample,
)
from app.database import Base, engine, SessionLocal
from app.routers.protected_router import protected_router

os.makedirs("images", exist_ok=True)


# Utility function to fetch metadata from pyproject.toml
def get_project_metadata():
    script_dir = Path(__file__).resolve().parent
    pyproject_path = script_dir / "pyproject.toml"  # Check current directory first

    if pyproject_path.exists():
        with open(pyproject_path, "rb") as f:
            pyproject = tomllib.load(f)
            name = pyproject["project"]["name"]
            version = pyproject["project"]["version"]
            return name, version

    # Search in parent directories
    for parent in script_dir.parents:
        pyproject_path = parent / "pyproject.toml"
        if pyproject_path.exists():
            with open(pyproject_path, "rb") as f:
                pyproject = tomllib.load(f)
                name = pyproject["project"]["name"]
                version = pyproject["project"]["version"]
                return name, version

    raise FileNotFoundError(
        f"pyproject.toml not found in any parent directory of {script_dir}"
    )


def run_server():
    import uvicorn

    print(f"[INFO] Starting server in {environment} environment...")
    print(f"[INFO] SSL Certificate Path: {cert_path}")
    print(f"[INFO] SSL Key Path: {key_path}")
    port = config.get("PORT", os.getenv("PORT"))
    if not port:
        print("[ERROR] No port defined in config or environment variables. Aborting!")
        sys.exit(1)  # Exit if no port is defined
    port = int(port)
    print(f"[INFO] Running on port {port}")
    uvicorn.run(
        app,
        host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0",
        port=port,
        log_level="debug",
        ssl_keyfile=key_path,
        ssl_certfile=cert_path,
    )


# Get project metadata from pyproject.toml
project_name, project_version = get_project_metadata()

# Determine environment and configuration file path
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent / f"config_{environment}.json"

if not config_file.exists():
    raise FileNotFoundError(f"Config file '{config_file}' does not exist.")

# Load configuration
with open(config_file) as f:
    config = json.load(f)

# Set SSL paths based on environment
if environment in ["test", "dev"]:
    cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
    key_path = config.get("ssl_key_path", "ssl/key.pem")
elif environment == "prod":
    cert_path = config.get("SSL_CERT_PATH")
    key_path = config.get("SSL_KEY_PATH")
    # Validate production SSL paths
    if not cert_path or not key_path:
        raise ValueError(
            "SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
            " for production."
        )
    if not Path(cert_path).exists() or not Path(key_path).exists():
        raise FileNotFoundError(
            f"Missing SSL files in production. Ensure the following files exist:\n"
            f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
        )
else:
    raise ValueError(f"Unknown environment: {environment}")

# Generate SSL Key and Certificate if not exist (only for development)
if environment == "dev":
    Path("ssl").mkdir(parents=True, exist_ok=True)
    if not Path(cert_path).exists() or not Path(key_path).exists():
        ssl_heidi.generate_self_signed_cert(cert_path, key_path)


@asynccontextmanager
async def lifespan(app: FastAPI):
    print("[INFO] Running application startup tasks...")
    db = SessionLocal()
    try:
        if environment == "prod":
            from sqlalchemy.engine import reflection

            inspector = reflection.Inspector.from_engine(engine)
            tables_exist = inspector.get_table_names()
            # from app.models import ExperimentParameters  # adjust the import as needed
            #
            # inspector = reflection.Inspector.from_engine(engine)
            # tables_exist = inspector.get_table_names()
            #
            # if ExperimentParameters.__tablename__ not in tables_exist:
            #     print("Creating missing table: ExperimentParameters")
            #     ExperimentParameters.__table__.create(bind=engine)
            #
            # Ensure the production database is initialized
            if not tables_exist:
                print("Production database is empty. Initializing...")
                Base.metadata.create_all(bind=engine)

            # Seed the database (slots + proposals)
            from app.database import load_slots_data

            load_slots_data(db)
        else:  # dev or test environments
            print(f"{environment.capitalize()} environment: Regenerating database.")
            # Base.metadata.drop_all(bind=engine)
            # Base.metadata.create_all(bind=engine)
            # from sqlalchemy.engine import reflection
            # from app.models import ExperimentParameters  # adjust the import as needed
            # inspector = reflection.Inspector.from_engine(engine)
            # tables_exist = inspector.get_table_names()
            #
            # if ExperimentParameters.__tablename__ not in tables_exist:
            #    print("Creating missing table: ExperimentParameters")
            #    ExperimentParameters.__table__.create(bind=engine)
            #
            if environment == "dev":
                from app.database import load_sample_data

                load_sample_data(db)
            elif environment == "test":
                from app.database import load_slots_data

                load_slots_data(db)

        yield
    finally:
        db.close()


app = FastAPI(
    lifespan=lifespan,
    title=project_name,
    description="Backend for next-gen sample management system",
    version=project_version,
    servers=[
        {"url": "https://mx-aare-test.psi.ch:1492", "description": "Default server"}
    ],
)
# Apply CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Include routers with correct configuration
app.include_router(protected_router, prefix="/protected")
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
app.include_router(spreadsheet.router, tags=["spreadsheet"])
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
app.include_router(sample.router, prefix="/samples", tags=["samples"])

app.mount("/images", StaticFiles(directory="images"), name="images")


if __name__ == "__main__":
    import sys
    from dotenv import load_dotenv
    from multiprocessing import Process
    from time import sleep

    # Load environment variables from .env file
    load_dotenv()

    # Check if `generate-openapi` option is passed
    if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
        from fastapi.openapi.utils import get_openapi

        # Generate and save OpenAPI JSON file
        openapi_schema = get_openapi(
            title=app.title,
            version=app.version,
            description=app.description,
            routes=app.routes,
        )
        with open("openapi.json", "w") as f:
            json.dump(openapi_schema, f, indent=4)
        print("openapi.json generated successfully.")
        sys.exit(0)  # Exit after generating the file

    # Default behavior: Run the server based on the environment
    environment = os.getenv("ENVIRONMENT", "dev")
    port = int(os.getenv("PORT", 8000))
    is_ci = os.getenv("CI", "false").lower() == "true"

    if is_ci or environment == "test":
        # Test or CI Mode: Run server process temporarily for test validation
        ssl_dir = Path(cert_path).parent
        ssl_dir.mkdir(parents=True, exist_ok=True)

        # Generate self-signed certs if missing
        if not Path(cert_path).exists() or not Path(key_path).exists():
            print(f"[INFO] Generating self-signed SSL certificates at {ssl_dir}")
            ssl_heidi.generate_self_signed_cert(cert_path, key_path)

        # Start the server as a subprocess, wait, then terminate
        server_process = Process(target=run_server)
        server_process.start()
        sleep(5)  # Wait for 5 seconds to verify the server is running
        server_process.terminate()  # Terminate the server process (for CI)
        server_process.join()  # Ensure proper cleanup
        print("CI: Server started and terminated successfully for test validation.")
    else:
        # Dev or Prod: Start the server as usual
        run_server()