import os import json import tomllib from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from app import ssl_heidi from app.routers import ( proposal, puck, spreadsheet, logistics, auth, sample, ) from app.database import Base, engine, SessionLocal from app.routers.protected_router import protected_router os.makedirs("images", exist_ok=True) # Utility function to fetch metadata from pyproject.toml def get_project_metadata(): script_dir = Path(__file__).resolve().parent pyproject_path = script_dir / "pyproject.toml" # Check current directory first if pyproject_path.exists(): with open(pyproject_path, "rb") as f: pyproject = tomllib.load(f) name = pyproject["project"]["name"] version = pyproject["project"]["version"] return name, version # Search in parent directories for parent in script_dir.parents: pyproject_path = parent / "pyproject.toml" if pyproject_path.exists(): with open(pyproject_path, "rb") as f: pyproject = tomllib.load(f) name = pyproject["project"]["name"] version = pyproject["project"]["version"] return name, version raise FileNotFoundError( f"pyproject.toml not found in any parent directory of {script_dir}" ) def run_server(): import uvicorn print(f"[INFO] Starting server in {environment} environment...") print(f"[INFO] SSL Certificate Path: {cert_path}") print(f"[INFO] SSL Key Path: {key_path}") port = config.get("PORT", os.getenv("PORT")) if not port: print("[ERROR] No port defined in config or environment variables. Aborting!") sys.exit(1) # Exit if no port is defined port = int(port) print(f"[INFO] Running on port {port}") uvicorn.run( app, host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0", port=port, log_level="debug", ssl_keyfile=key_path, ssl_certfile=cert_path, ) # Get project metadata from pyproject.toml project_name, project_version = get_project_metadata() # Determine environment and configuration file path environment = os.getenv("ENVIRONMENT", "dev") config_file = Path(__file__).resolve().parent / f"config_{environment}.json" if not config_file.exists(): raise FileNotFoundError(f"Config file '{config_file}' does not exist.") # Load configuration with open(config_file) as f: config = json.load(f) # Set SSL paths based on environment if environment in ["test", "dev"]: cert_path = config.get("ssl_cert_path", "ssl/cert.pem") key_path = config.get("ssl_key_path", "ssl/key.pem") elif environment == "prod": cert_path = config.get("SSL_CERT_PATH") key_path = config.get("SSL_KEY_PATH") # Validate production SSL paths if not cert_path or not key_path: raise ValueError( "SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json" " for production." ) if not Path(cert_path).exists() or not Path(key_path).exists(): raise FileNotFoundError( f"Missing SSL files in production. Ensure the following files exist:\n" f"SSL Certificate: {cert_path}\nSSL Key: {key_path}" ) else: raise ValueError(f"Unknown environment: {environment}") # Generate SSL Key and Certificate if not exist (only for development) if environment == "dev": Path("ssl").mkdir(parents=True, exist_ok=True) if not Path(cert_path).exists() or not Path(key_path).exists(): ssl_heidi.generate_self_signed_cert(cert_path, key_path) @asynccontextmanager async def lifespan(app: FastAPI): print("[INFO] Running application startup tasks...") db = SessionLocal() try: if environment == "prod": from sqlalchemy.engine import reflection inspector = reflection.Inspector.from_engine(engine) tables_exist = inspector.get_table_names() # from app.models import ExperimentParameters # adjust the import as needed # # inspector = reflection.Inspector.from_engine(engine) # tables_exist = inspector.get_table_names() # # if ExperimentParameters.__tablename__ not in tables_exist: # print("Creating missing table: ExperimentParameters") # ExperimentParameters.__table__.create(bind=engine) # # Ensure the production database is initialized if not tables_exist: print("Production database is empty. Initializing...") Base.metadata.create_all(bind=engine) # Seed the database (slots + proposals) from app.database import load_slots_data load_slots_data(db) else: # dev or test environments print(f"{environment.capitalize()} environment: Regenerating database.") # Base.metadata.drop_all(bind=engine) # Base.metadata.create_all(bind=engine) # from sqlalchemy.engine import reflection # from app.models import ExperimentParameters # adjust the import as needed # inspector = reflection.Inspector.from_engine(engine) # tables_exist = inspector.get_table_names() # # if ExperimentParameters.__tablename__ not in tables_exist: # print("Creating missing table: ExperimentParameters") # ExperimentParameters.__table__.create(bind=engine) # if environment == "dev": from app.database import load_sample_data load_sample_data(db) elif environment == "test": from app.database import load_slots_data load_slots_data(db) yield finally: db.close() app = FastAPI( lifespan=lifespan, title=project_name, description="Backend for next-gen sample management system", version=project_version, servers=[ {"url": "https://mx-aare-test.psi.ch:1492", "description": "Default server"} ], ) # Apply CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Include routers with correct configuration app.include_router(protected_router, prefix="/protected") app.include_router(auth.router, prefix="/auth", tags=["auth"]) app.include_router(proposal.router, prefix="/proposals", tags=["proposals"]) app.include_router(puck.router, prefix="/pucks", tags=["pucks"]) app.include_router(spreadsheet.router, tags=["spreadsheet"]) app.include_router(logistics.router, prefix="/logistics", tags=["logistics"]) app.include_router(sample.router, prefix="/samples", tags=["samples"]) app.mount("/images", StaticFiles(directory="images"), name="images") if __name__ == "__main__": import sys from dotenv import load_dotenv from multiprocessing import Process from time import sleep # Load environment variables from .env file load_dotenv() # Check if `generate-openapi` option is passed if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi": from fastapi.openapi.utils import get_openapi # Generate and save OpenAPI JSON file openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, ) with open("openapi.json", "w") as f: json.dump(openapi_schema, f, indent=4) print("openapi.json generated successfully.") sys.exit(0) # Exit after generating the file # Default behavior: Run the server based on the environment environment = os.getenv("ENVIRONMENT", "dev") port = int(os.getenv("PORT", 8000)) is_ci = os.getenv("CI", "false").lower() == "true" if is_ci or environment == "test": # Test or CI Mode: Run server process temporarily for test validation ssl_dir = Path(cert_path).parent ssl_dir.mkdir(parents=True, exist_ok=True) # Generate self-signed certs if missing if not Path(cert_path).exists() or not Path(key_path).exists(): print(f"[INFO] Generating self-signed SSL certificates at {ssl_dir}") ssl_heidi.generate_self_signed_cert(cert_path, key_path) # Start the server as a subprocess, wait, then terminate server_process = Process(target=run_server) server_process.start() sleep(5) # Wait for 5 seconds to verify the server is running server_process.terminate() # Terminate the server process (for CI) server_process.join() # Ensure proper cleanup print("CI: Server started and terminated successfully for test validation.") else: # Dev or Prod: Start the server as usual run_server()