"""
Gunicorn config extension and hooks. This config file adds some extra settings and memory management.
Gunicorn configuration should be managed by .ini files entries of RhodeCode or VCSServer
"""

import gc
import os
import sys
import math
import time
import threading
import traceback
import random
import socket
import dataclasses
import json
from gunicorn.glogging import Logger


def get_workers():
    import multiprocessing
    return multiprocessing.cpu_count() * 2 + 1


bind = "127.0.0.1:10010"


# Error logging output for gunicorn (-) is stdout
errorlog = '-'

# Access logging output for gunicorn (-) is stdout
accesslog = '-'


# SERVER MECHANICS
# None == system temp dir
# worker_tmp_dir is recommended to be set to some tmpfs
worker_tmp_dir = None
tmp_upload_dir = None

# use re-use port logic to let linux internals load-balance the requests better.
reuse_port = True

# Custom log format
#access_log_format = (
#    '%(t)s %(p)s INFO  [GNCRN] %(h)-15s rqt:%(L)s %(s)s %(b)-6s "%(m)s:%(U)s %(q)s" usr:%(u)s "%(f)s" "%(a)s"')

# loki format for easier parsing in grafana
loki_access_log_format = (
    'time="%(t)s" pid=%(p)s level="INFO" type="[GNCRN]" ip="%(h)-15s" rqt="%(L)s" response_code="%(s)s" response_bytes="%(b)-6s" uri="%(m)s:%(U)s %(q)s" user=":%(u)s" user_agent="%(a)s"')

# JSON format
json_access_log_format = json.dumps({
    'time': r'%(t)s',
    'pid': r'%(p)s',
    'level': 'INFO',
    'ip': r'%(h)s',
    'request_time': r'%(L)s',
    'remote_address': r'%(h)s',
    'user_name': r'%(u)s',
    'status': r'%(s)s',
    'method': r'%(m)s',
    'url_path': r'%(U)s',
    'query_string': r'%(q)s',
    'protocol': r'%(H)s',
    'response_length': r'%(B)s',
    'referer': r'%(f)s',
    'user_agent': r'%(a)s',

})

access_log_format = loki_access_log_format
if os.environ.get('RC_LOGGING_FORMATTER') == 'json':
    access_log_format = json_access_log_format

# self adjust workers based on CPU count, to use maximum of CPU and not overquota the resources
# workers = get_workers()

# Gunicorn access log level
loglevel = 'info'

# Process name visible in a process list
proc_name = "rhodecode_vcsserver"

# Type of worker class, one of `sync`, `gevent` or `gthread`
# currently `sync` is the only option allowed for vcsserver and for rhodecode all of 3 are allowed
# gevent:
# In this case, the maximum number of concurrent requests is (N workers * X worker_connections)
# e.g. workers =3 worker_connections=10 = 3*10, 30 concurrent requests can be handled
# gthread:
# In this case, the maximum number of concurrent requests is (N workers * X threads)
# e.g. workers = 3 threads=3 = 3*3, 9 concurrent requests can be handled
worker_class = 'sync'

# Sets the number of process workers. More workers means more concurrent connections
# RhodeCode can handle at the same time. Each additional worker also it increases
# memory usage as each has its own set of caches.
# The Recommended value is (2 * NUMBER_OF_CPUS + 1), eg 2CPU = 5 workers, but no more
# than 8-10 unless for huge deployments .e.g 700-1000 users.
# `instance_id = *` must be set in the [app:main] section below (which is the default)
# when using more than 1 worker.
workers = 2

# Threads numbers for worker class gthread
threads = 1

# The maximum number of simultaneous clients. Valid only for gevent
# In this case, the maximum number of concurrent requests is (N workers * X worker_connections)
# e.g workers =3 worker_connections=10 = 3*10, 30 concurrent requests can be handled
worker_connections = 10

# Max number of requests that worker will handle before being gracefully restarted.
# Prevents memory leaks, jitter adds variability so not all workers are restarted at once.
max_requests = 2000
max_requests_jitter = int(max_requests * 0.2)  # 20% of max_requests

# The maximum number of pending connections.
# Exceeding this number results in the client getting an error when attempting to connect.
backlog = 64

# The Amount of time a worker can spend with handling a request before it
# gets killed and restarted. By default, set to 21600 (6hrs)
# Examples: 1800 (30min), 3600 (1hr), 7200 (2hr), 43200 (12h)
timeout = 21600

# The maximum size of HTTP request line in bytes.
# 0 for unlimited
limit_request_line = 0

# Limit the number of HTTP headers fields in a request.
# By default this value is 100 and can't be larger than 32768.
limit_request_fields = 32768

# Limit the allowed size of an HTTP request header field.
# Value is a positive number or 0.
# Setting it to 0 will allow unlimited header field sizes.
limit_request_field_size = 0

# Timeout for graceful workers restart.
# After receiving a restart signal, workers have this much time to finish
# serving requests. Workers still alive after the timeout (starting from the
# receipt of the restart signal) are force killed.
# Examples: 1800 (30min), 3600 (1hr), 7200 (2hr), 43200 (12h)
graceful_timeout = 21600

# The number of seconds to wait for requests on a Keep-Alive connection.
# Generally set in the 1-5 seconds range.
keepalive = 2

# Maximum memory usage that each worker can use before it will receive a
# graceful restart signal 0 = memory monitoring is disabled
# Examples: 268435456 (256MB), 536870912 (512MB)
# 1073741824 (1GB), 2147483648 (2GB), 4294967296 (4GB)
# Dynamic formula 1024 * 1024 * 256 == 256MBs
memory_max_usage = 0

# How often in seconds to check for memory usage for each gunicorn worker
memory_usage_check_interval = 60

# Threshold value for which we don't recycle worker if GarbageCollection
# frees up enough resources. Before each restart, we try to run GC on worker
# in case we get enough free memory after that; restart will not happen.
memory_usage_recovery_threshold = 0.8


@dataclasses.dataclass
class MemoryCheckConfig:
    max_usage: int
    check_interval: int
    recovery_threshold: float


def _get_process_rss(pid=None):
    try:
        import psutil
        if pid:
            proc = psutil.Process(pid)
        else:
            proc = psutil.Process()
        return proc.memory_info().rss
    except Exception:
        return None


def _get_config(ini_path):
    import configparser

    try:
        config = configparser.RawConfigParser()
        config.read(ini_path)
        return config
    except Exception:
        return None


def get_memory_usage_params(config=None):
    # memory spec defaults
    _memory_max_usage = memory_max_usage
    _memory_usage_check_interval = memory_usage_check_interval
    _memory_usage_recovery_threshold = memory_usage_recovery_threshold

    if config:
        ini_path = os.path.abspath(config)
        conf = _get_config(ini_path)

        section = 'server:main'
        if conf and conf.has_section(section):

            if conf.has_option(section, 'memory_max_usage'):
                _memory_max_usage = conf.getint(section, 'memory_max_usage')

            if conf.has_option(section, 'memory_usage_check_interval'):
                _memory_usage_check_interval = conf.getint(section, 'memory_usage_check_interval')

            if conf.has_option(section, 'memory_usage_recovery_threshold'):
                _memory_usage_recovery_threshold = conf.getfloat(section, 'memory_usage_recovery_threshold')

    _memory_max_usage = int(os.environ.get('RC_GUNICORN_MEMORY_MAX_USAGE', '')
                            or _memory_max_usage)
    _memory_usage_check_interval = int(os.environ.get('RC_GUNICORN_MEMORY_USAGE_CHECK_INTERVAL', '')
                                       or _memory_usage_check_interval)
    _memory_usage_recovery_threshold = float(os.environ.get('RC_GUNICORN_MEMORY_USAGE_RECOVERY_THRESHOLD', '')
                                             or _memory_usage_recovery_threshold)

    return MemoryCheckConfig(_memory_max_usage, _memory_usage_check_interval, _memory_usage_recovery_threshold)


def _time_with_offset(check_interval):
    return time.time() - random.randint(0, check_interval/2.0)


def pre_fork(server, worker):
    pass


def post_fork(server, worker):

    memory_conf = get_memory_usage_params()
    _memory_max_usage = memory_conf.max_usage
    _memory_usage_check_interval = memory_conf.check_interval
    _memory_usage_recovery_threshold = memory_conf.recovery_threshold

    worker._memory_max_usage = int(os.environ.get('RC_GUNICORN_MEMORY_MAX_USAGE', '')
                                   or _memory_max_usage)
    worker._memory_usage_check_interval = int(os.environ.get('RC_GUNICORN_MEMORY_USAGE_CHECK_INTERVAL', '')
                                              or _memory_usage_check_interval)
    worker._memory_usage_recovery_threshold = float(os.environ.get('RC_GUNICORN_MEMORY_USAGE_RECOVERY_THRESHOLD', '')
                                                    or _memory_usage_recovery_threshold)

    # register memory last check time, with some random offset so we don't recycle all
    # at once
    worker._last_memory_check_time = _time_with_offset(_memory_usage_check_interval)

    if _memory_max_usage:
        server.log.info("pid=[%-10s] WORKER spawned with max memory set at %s", worker.pid,
                        _format_data_size(_memory_max_usage))
    else:
        server.log.info("pid=[%-10s] WORKER spawned", worker.pid)


def pre_exec(server):
    server.log.info("Forked child, re-executing.")


def on_starting(server):
    server_lbl = '{} {}'.format(server.proc_name, server.address)
    server.log.info("Server %s is starting.", server_lbl)
    server.log.info('Config:')
    server.log.info(f"\n{server.cfg}")
    server.log.info(get_memory_usage_params())


def when_ready(server):
    server.log.info("Server %s is ready. Spawning workers", server)


def on_reload(server):
    pass


def _format_data_size(size, unit="B", precision=1, binary=True):
    """Format a number using SI units (kilo, mega, etc.).

    ``size``: The number as a float or int.

    ``unit``: The unit name in plural form. Examples: "bytes", "B".

    ``precision``: How many digits to the right of the decimal point. Default
    is 1.  0 suppresses the decimal point.

    ``binary``: If false, use base-10 decimal prefixes (kilo = K = 1000).
    If true, use base-2 binary prefixes (kibi = Ki = 1024).

    ``full_name``: If false (default), use the prefix abbreviation ("k" or
    "Ki").  If true, use the full prefix ("kilo" or "kibi"). If false,
    use abbreviation ("k" or "Ki").

    """

    if not binary:
        base = 1000
        multiples = ('', 'k', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y')
    else:
        base = 1024
        multiples = ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi')

    sign = ""
    if size > 0:
        m = int(math.log(size, base))
    elif size < 0:
        sign = "-"
        size = -size
        m = int(math.log(size, base))
    else:
        m = 0
    if m > 8:
        m = 8

    if m == 0:
        precision = '%.0f'
    else:
        precision = '%%.%df' % precision

    size = precision % (size / math.pow(base, m))

    return '%s%s %s%s' % (sign, size.strip(), multiples[m], unit)


def _check_memory_usage(worker):
    _memory_max_usage = worker._memory_max_usage
    if not _memory_max_usage:
        return

    _memory_usage_check_interval = worker._memory_usage_check_interval
    _memory_usage_recovery_threshold = memory_max_usage * worker._memory_usage_recovery_threshold

    elapsed = time.time() - worker._last_memory_check_time
    if elapsed > _memory_usage_check_interval:
        mem_usage = _get_process_rss()
        if mem_usage and mem_usage > _memory_max_usage:
            worker.log.info(
                "memory usage %s > %s, forcing gc",
                _format_data_size(mem_usage), _format_data_size(_memory_max_usage))
            # Try to clean it up by forcing a full collection.
            gc.collect()
            mem_usage = _get_process_rss()
            if mem_usage > _memory_usage_recovery_threshold:
                # Didn't clean up enough, we'll have to terminate.
                worker.log.warning(
                    "memory usage %s > %s after gc, quitting",
                    _format_data_size(mem_usage), _format_data_size(_memory_max_usage))
                # This will cause worker to auto-restart itself
                worker.alive = False
        worker._last_memory_check_time = time.time()


def worker_int(worker):
    worker.log.info("pid=[%-10s] worker received INT or QUIT signal", worker.pid)

    # get traceback info, when a worker crashes
    def get_thread_id(t_id):
        id2name = dict([(th.ident, th.name) for th in threading.enumerate()])
        return id2name.get(t_id, "unknown_thread_id")

    code = []
    for thread_id, stack in sys._current_frames().items():  # noqa
        code.append(
            "\n# Thread: %s(%d)" % (get_thread_id(thread_id), thread_id))
        for fname, lineno, name, line in traceback.extract_stack(stack):
            code.append('File: "%s", line %d, in %s' % (fname, lineno, name))
            if line:
                code.append("  %s" % (line.strip()))
    worker.log.debug("\n".join(code))


def worker_abort(worker):
    worker.log.info("pid=[%-10s] worker received SIGABRT signal", worker.pid)


def worker_exit(server, worker):
    worker.log.info("pid=[%-10s] worker exit", worker.pid)


def child_exit(server, worker):
    worker.log.info("pid=[%-10s] worker child exit", worker.pid)


def pre_request(worker, req):
    worker.start_time = time.time()
    worker.log.debug(
        "GNCRN PRE  WORKER [cnt:%s]: %s %s", worker.nr, req.method, req.path)


def post_request(worker, req, environ, resp):
    total_time = time.time() - worker.start_time
    # Gunicorn sometimes has problems with reading the status_code
    status_code = getattr(resp, 'status_code', '')
    worker.log.debug(
        "GNCRN POST WORKER [cnt:%s]: %s %s resp: %s, Load Time: %.4fs",
        worker.nr, req.method, req.path, status_code, total_time)
    _check_memory_usage(worker)


def _filter_proxy(ip):
    """
    Passed in IP addresses in HEADERS can be in a special format of multiple
    ips. Those comma separated IPs are passed from various proxies in the
    chain of request processing. The left-most being the original client.
    We only care about the first IP which came from the org. client.

    :param ip: ip string from headers
    """
    if ',' in ip:
        _ips = ip.split(',')
        _first_ip = _ips[0].strip()
        return _first_ip
    return ip


def _filter_port(ip):
    """
    Removes a port from ip, there are 4 main cases to handle here.
    - ipv4 eg. 127.0.0.1
    - ipv6 eg. ::1
    - ipv4+port eg. 127.0.0.1:8080
    - ipv6+port eg. [::1]:8080

    :param ip:
    """
    def is_ipv6(ip_addr):
        if hasattr(socket, 'inet_pton'):
            try:
                socket.inet_pton(socket.AF_INET6, ip_addr)
            except socket.error:
                return False
        else:
            return False
        return True

    if ':' not in ip:  # must be ipv4 pure ip
        return ip

    if '[' in ip and ']' in ip:  # ipv6 with port
        return ip.split(']')[0][1:].lower()

    # must be ipv6 or ipv4 with port
    if is_ipv6(ip):
        return ip
    else:
        ip, _port = ip.split(':')[:2]  # means ipv4+port
        return ip


def get_ip_addr(environ):
    proxy_key = 'HTTP_X_REAL_IP'
    proxy_key2 = 'HTTP_X_FORWARDED_FOR'
    def_key = 'REMOTE_ADDR'

    def _filters(x):
        return _filter_port(_filter_proxy(x))

    ip = environ.get(proxy_key)
    if ip:
        return _filters(ip)

    ip = environ.get(proxy_key2)
    if ip:
        return _filters(ip)

    ip = environ.get(def_key, '0.0.0.0')
    return _filters(ip)


class RhodeCodeLogger(Logger):
    """
    Custom Logger that allows some customization that gunicorn doesn't allow
    """

    datefmt = r"%Y-%m-%d %H:%M:%S"

    def __init__(self, cfg):
        Logger.__init__(self, cfg)

    def now(self):
        """ return date in RhodeCode Log format """
        now = time.time()
        msecs = int((now - int(now)) * 1000)
        return time.strftime(self.datefmt, time.localtime(now)) + '.{0:03d}'.format(msecs)

    def atoms(self, resp, req, environ, request_time):
        """ Gets atoms for log formatting.
        """
        status = resp.status
        if isinstance(status, str):
            status = status.split(None, 1)[0]
        atoms = {
            'h': get_ip_addr(environ),
            'l': '-',
            'u': self._get_user(environ) or '-',
            't': self.now(),
            'r': "%s %s %s" % (environ['REQUEST_METHOD'],
                               environ['RAW_URI'],
                               environ["SERVER_PROTOCOL"]),
            's': status,
            'm': environ.get('REQUEST_METHOD'),
            'U': environ.get('PATH_INFO'),
            'q': environ.get('QUERY_STRING'),
            'H': environ.get('SERVER_PROTOCOL'),
            'b': getattr(resp, 'sent', None) is not None and str(resp.sent) or '-',
            'B': getattr(resp, 'sent', None),
            'f': environ.get('HTTP_REFERER', '-'),
            'a': environ.get('HTTP_USER_AGENT', '-'),
            'T': request_time.seconds,
            'D': (request_time.seconds * 1000000) + request_time.microseconds,
            'M': (request_time.seconds * 1000) + int(request_time.microseconds/1000),
            'L': "%d.%06d" % (request_time.seconds, request_time.microseconds),
            'p': "<%s>" % os.getpid()
        }

        # add request headers
        if hasattr(req, 'headers'):
            req_headers = req.headers
        else:
            req_headers = req

        if hasattr(req_headers, "items"):
            req_headers = req_headers.items()

        atoms.update({"{%s}i" % k.lower(): v for k, v in req_headers})

        resp_headers = resp.headers
        if hasattr(resp_headers, "items"):
            resp_headers = resp_headers.items()

        # add response headers
        atoms.update({"{%s}o" % k.lower(): v for k, v in resp_headers})

        # add environ variables
        environ_variables = environ.items()
        atoms.update({"{%s}e" % k.lower(): v for k, v in environ_variables})

        return atoms


logger_class = RhodeCodeLogger
