import os
import json
import tomllib
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles

from app import ssl_heidi
from app.routers import (
    proposal,
    puck,
    spreadsheet,
    logistics,
    auth,
    sample,
)
from app.database import Base, engine, SessionLocal
from app.routers.protected_router import protected_router

os.makedirs("images", exist_ok=True)


# Utility function to fetch metadata from pyproject.toml
def get_project_metadata():
    script_dir = Path(__file__).resolve().parent
    pyproject_path = script_dir / "pyproject.toml"  # Check current directory first

    if pyproject_path.exists():
        with open(pyproject_path, "rb") as f:
            pyproject = tomllib.load(f)
            name = pyproject["project"]["name"]
            version = pyproject["project"]["version"]
            return name, version

    # Search in parent directories
    for parent in script_dir.parents:
        pyproject_path = parent / "pyproject.toml"
        if pyproject_path.exists():
            with open(pyproject_path, "rb") as f:
                pyproject = tomllib.load(f)
                name = pyproject["project"]["name"]
                version = pyproject["project"]["version"]
                return name, version

    raise FileNotFoundError(
        f"pyproject.toml not found in any parent directory of {script_dir}"
    )


def run_server():
    import uvicorn

    print(f"[INFO] Starting server in {environment} environment...")
    print(f"[INFO] SSL Certificate Path: {cert_path}")
    print(f"[INFO] SSL Key Path: {key_path}")
    port = config.get("PORT", os.getenv("PORT"))
    if not port:
        print("[ERROR] No port defined in config or environment variables. Aborting!")
        sys.exit(1)  # Exit if no port is defined
    port = int(port)
    print(f"[INFO] Running on port {port}")
    uvicorn.run(
        app,
        host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0",
        port=port,
        log_level="debug",
        ssl_keyfile=key_path,
        ssl_certfile=cert_path,
    )


# Get project metadata from pyproject.toml
project_name, project_version = get_project_metadata()

# Determine environment and configuration file path
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent / f"config_{environment}.json"

if not config_file.exists():
    raise FileNotFoundError(f"Config file '{config_file}' does not exist.")

# Load configuration
with open(config_file) as f:
    config = json.load(f)

# Set SSL paths based on environment
if environment in ["test", "dev"]:
    cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
    key_path = config.get("ssl_key_path", "ssl/key.pem")
elif environment == "prod":
    cert_path = config.get("SSL_CERT_PATH")
    key_path = config.get("SSL_KEY_PATH")
    # Validate production SSL paths
    if not cert_path or not key_path:
        raise ValueError(
            "SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
            " for production."
        )
    if not Path(cert_path).exists() or not Path(key_path).exists():
        raise FileNotFoundError(
            f"Missing SSL files in production. Ensure the following files exist:\n"
            f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
        )
else:
    raise ValueError(f"Unknown environment: {environment}")

# Generate SSL Key and Certificate if not exist (only for development)
if environment == "dev":
    Path("ssl").mkdir(parents=True, exist_ok=True)
    if not Path(cert_path).exists() or not Path(key_path).exists():
        ssl_heidi.generate_self_signed_cert(cert_path, key_path)


@asynccontextmanager
async def lifespan(app: FastAPI):
    print("[INFO] Running application startup tasks...")
    db = SessionLocal()
    try:
        if environment == "prod":
            from sqlalchemy.engine import reflection

            inspector = reflection.Inspector.from_engine(engine)
            tables_exist = inspector.get_table_names()
            # from app.models import ExperimentParameters  # adjust the import as needed
            #
            # inspector = reflection.Inspector.from_engine(engine)
            # tables_exist = inspector.get_table_names()
            #
            # if ExperimentParameters.__tablename__ not in tables_exist:
            #     print("Creating missing table: ExperimentParameters")
            #     ExperimentParameters.__table__.create(bind=engine)
            #
            # Ensure the production database is initialized
            if not tables_exist:
                print("Production database is empty. Initializing...")
                Base.metadata.create_all(bind=engine)

            # Seed the database (slots + proposals)
            from app.database import load_slots_data

            load_slots_data(db)
        else:  # dev or test environments
            print(f"{environment.capitalize()} environment: Regenerating database.")
            # Base.metadata.drop_all(bind=engine)
            # Base.metadata.create_all(bind=engine)
            # from sqlalchemy.engine import reflection
            # from app.models import ExperimentParameters  # adjust the import as needed
            # inspector = reflection.Inspector.from_engine(engine)
            # tables_exist = inspector.get_table_names()
            #
            # if ExperimentParameters.__tablename__ not in tables_exist:
            #    print("Creating missing table: ExperimentParameters")
            #    ExperimentParameters.__table__.create(bind=engine)
            #
            if environment == "dev":
                from app.database import load_sample_data

                load_sample_data(db)
            elif environment == "test":
                from app.database import load_slots_data

                load_slots_data(db)

        yield
    finally:
        db.close()


app = FastAPI(
    lifespan=lifespan,
    title=project_name,
    description="Backend for next-gen sample management system",
    version=project_version,
    servers=[
        {"url": "https://mx-aare-test.psi.ch:1492", "description": "Default server"}
    ],
)
# Apply CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Include routers with correct configuration
app.include_router(protected_router, prefix="/protected")
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
app.include_router(spreadsheet.router, tags=["spreadsheet"])
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
app.include_router(sample.router, prefix="/samples", tags=["samples"])

app.mount("/images", StaticFiles(directory="images"), name="images")


if __name__ == "__main__":
    import sys
    from dotenv import load_dotenv
    from multiprocessing import Process
    from time import sleep

    # Load environment variables from .env file
    load_dotenv()

    # Check if `generate-openapi` option is passed
    if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
        from fastapi.openapi.utils import get_openapi

        # Generate and save OpenAPI JSON file
        openapi_schema = get_openapi(
            title=app.title,
            version=app.version,
            description=app.description,
            routes=app.routes,
        )
        with open("openapi.json", "w") as f:
            json.dump(openapi_schema, f, indent=4)
        print("openapi.json generated successfully.")
        sys.exit(0)  # Exit after generating the file

    # Default behavior: Run the server based on the environment
    environment = os.getenv("ENVIRONMENT", "dev")
    port = int(os.getenv("PORT", 8000))
    is_ci = os.getenv("CI", "false").lower() == "true"

    if is_ci or environment == "test":
        # Test or CI Mode: Run server process temporarily for test validation
        ssl_dir = Path(cert_path).parent
        ssl_dir.mkdir(parents=True, exist_ok=True)

        # Generate self-signed certs if missing
        if not Path(cert_path).exists() or not Path(key_path).exists():
            print(f"[INFO] Generating self-signed SSL certificates at {ssl_dir}")
            ssl_heidi.generate_self_signed_cert(cert_path, key_path)

        # Start the server as a subprocess, wait, then terminate
        server_process = Process(target=run_server)
        server_process.start()
        sleep(5)  # Wait for 5 seconds to verify the server is running
        server_process.terminate()  # Terminate the server process (for CI)
        server_process.join()  # Ensure proper cleanup
        print("CI: Server started and terminated successfully for test validation.")
    else:
        # Dev or Prod: Start the server as usual
        run_server()