Skip to content
Snippets Groups Projects
setup_logging.py 5.23 KiB
Newer Older
import logging
import sys
import traceback
import warnings
from pythonjsonlogger import jsonlogger

from cal_tools.warnings import CalibrationWarning

NOTEBOOK_NAME = os.getenv('CAL_NOTEBOOK_NAME', 'Unknown notebook')
JOB_ID = os.getenv('SLURM_JOB_ID', 'local')


class ContextFilter(logging.Filter):
    def filter(self, record):
        # Only allow records that come from exception handlers
        if getattr(record, 'from_exception_handler', False):
            record.notebook = NOTEBOOK_NAME
            record.job_id = JOB_ID
            return True
        return False


def get_class_hierarchy(cls):
    """Helper function to get the full class hierarchy"""
    class_hierarchy = []
    current_class = cls
    while current_class and current_class != object:
        class_hierarchy.append(current_class.__name__)
        current_class = current_class.__base__
    return '.'.join(reversed(class_hierarchy))
class CustomJsonFormatter(jsonlogger.JsonFormatter):
    def add_fields(self, log_record, record, message_dict):
        super(CustomJsonFormatter, self).add_fields(
            log_record, record, message_dict)
        log_record['timestamp'] = self.formatTime(record, self.datefmt)
        log_record['level'] = record.levelname
        log_record['filename'] = record.filename
        log_record['lineno'] = record.lineno

        # Get log_class from extra parameters (set in our warning/error handlers)
        if hasattr(record, 'log_class'):
            log_record['log_class'] = record.log_class
        if record.exc_info:
            log_record['exc_info'] = self.formatException(record.exc_info)
            exc_class = record.exc_info[0]
            log_record['class'] = get_class_hierarchy(exc_class)
# Create a logger
logger = logging.getLogger()
# Define a custom JSON format
formatter = CustomJsonFormatter(
    '%(timestamp)s %(level)s %(filename)s %(lineno)d '
    '%(notebook)s %(job_id)s %(log_class)s %(message)s')
# Function to create a file handler with job-specific JSON log file
def create_job_specific_handler(log_level, file_suffix):
    log_file = f'{file_suffix}_{JOB_ID}.log'
    handler = logging.FileHandler(log_file, delay=True)
    handler.setLevel(log_level)
    handler.setFormatter(formatter)
    return handler

# Create job-specific file handlers
error_handler = create_job_specific_handler(logging.ERROR, 'errors')
warning_handler = create_job_specific_handler(logging.WARNING, 'warnings')

# Keep console handler for notebook and slurm.out stdout
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter('%(levelname)s: %(message)s')
console_handler.setFormatter(console_formatter)

# Add the custom filter to handlers
context_filter = ContextFilter()
error_handler.addFilter(context_filter)
warning_handler.addFilter(context_filter)
warning_handler.addFilter(lambda record: record.levelno < logging.ERROR)

# Add handlers to logger
logger.addHandler(error_handler)
logger.addHandler(warning_handler)

handling_error = False


def safe_handle_error(exc_type, exc_value, exc_traceback):
    global handling_error

    # Added this block to skip sys.exit()
    if exc_type in (SystemExit, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    if handling_error:  # Avoid infinite loop of errors.
        sys.stderr.write("Recursive error detected!\n")
        traceback.print_exception(
            exc_type, exc_value, exc_traceback, file=sys.stderr)
        return
    handling_error = True
    try:
        # Log the error with the notebook name, job ID, and additional metadata
        logger.error(str(exc_value),
                'log_class': get_class_hierarchy(exc_type),
                'from_exception_handler': True
            exc_info=(exc_type, exc_value, exc_traceback)
        )
    except Exception as log_error:
        sys.stderr.write(f"Logging failed: {log_error}\n")
        traceback.print_exception(
            exc_type, exc_value, exc_traceback, file=sys.stderr)
    finally:
        handling_error = False


def handle_warning(message, category, filename, lineno, file=None, line=None):
    try:
        logger.warning(
            extra={
                'notebook': NOTEBOOK_NAME,
                'job_id': JOB_ID,
                'log_class': get_class_hierarchy(category),
                'from_exception_handler': True
        )
    except Exception as log_error:
        sys.stderr.write(f"Logging failed: {log_error}\n")


# Replace the handlers with our custom ones
sys.excepthook = safe_handle_error
warnings.showwarning = handle_warning

# Set up warnings filter
warnings.filterwarnings("ignore")  # Ignore all warnings
# Except CalibrationWarning and subclasses
warnings.simplefilter("default", CalibrationWarning)

# Override IPython's exception handling
def custom_showtraceback(self, *args, **kwargs):
    return safe_handle_error(*sys.exc_info())

IPython.core.interactiveshell.InteractiveShell.showtraceback = custom_showtraceback  # noqa