# Copyright Spack Project Developers. See COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

"""All the logic for OCI fetching and authentication"""

import base64
import json
import re
import socket
import time
import urllib.error
import urllib.parse
import urllib.request
from enum import Enum, auto
from http.client import HTTPResponse
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple
from urllib.request import Request

import spack.config
import spack.llnl.util.lang
import spack.mirrors.mirror
import spack.tokenize
import spack.util.web

from .image import ImageReference


def _urlopen():
    opener = create_opener()

    def dispatch_open(fullurl, data=None, timeout=None):
        timeout = timeout or spack.config.get("config:connect_timeout", 10)
        return opener.open(fullurl, data, timeout)

    return dispatch_open


OpenType = Callable[..., HTTPResponse]
MaybeOpen = Optional[OpenType]

#: Opener that automatically uses OCI authentication based on mirror config
urlopen: OpenType = spack.llnl.util.lang.Singleton(_urlopen)


SP = r" "
OWS = r"[ \t]*"
BWS = OWS
HTAB = r"\t"
VCHAR = r"\x21-\x7E"
tchar = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]"
token = rf"{tchar}+"
obs_text = r"\x80-\xFF"
qdtext = rf"[{HTAB}{SP}\x21\x23-\x5B\x5D-\x7E{obs_text}]"
quoted_pair = rf"\\([{HTAB}{SP}{VCHAR}{obs_text}])"
quoted_string = rf'"(?:({qdtext}*)|{quoted_pair})*"'


class WwwAuthenticateTokens(spack.tokenize.TokenBase):
    AUTH_PARAM = rf"({token}){BWS}={BWS}({token}|{quoted_string})"
    # TOKEN68 = r"([A-Za-z0-9\-._~+/]+=*)"  # todo... support this?
    TOKEN = rf"{tchar}+"
    EQUALS = rf"{BWS}={BWS}"
    COMMA = rf"{OWS},{OWS}"
    SPACE = r" +"
    EOF = r"$"
    ANY = r"."


WWW_AUTHENTICATE_TOKENIZER = spack.tokenize.Tokenizer(WwwAuthenticateTokens)


class State(Enum):
    CHALLENGE = auto()
    AUTH_PARAM_LIST_START = auto()
    AUTH_PARAM = auto()
    NEXT_IN_LIST = auto()
    AUTH_PARAM_OR_SCHEME = auto()


class Challenge:
    __slots__ = ["scheme", "params"]

    def __init__(
        self, scheme: Optional[str] = None, params: Optional[List[Tuple[str, str]]] = None
    ) -> None:
        self.scheme = scheme or ""
        self.params = params or []

    def __repr__(self) -> str:
        return f"Challenge({self.scheme}, {self.params})"

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, Challenge)
            and self.scheme == other.scheme
            and self.params == other.params
        )


def parse_www_authenticate(input: str):
    """Very basic parsing of www-authenticate parsing (RFC7235 section 4.1)
    Notice: this omits token68 support."""

    # auth-scheme      = token
    # auth-param       = token BWS "=" BWS ( token / quoted-string )
    # challenge        = auth-scheme [ 1*SP ( token68 / #auth-param ) ]
    # WWW-Authenticate = 1#challenge

    challenges: List[Challenge] = []

    _unquote = re.compile(quoted_pair).sub
    unquote = lambda s: _unquote(r"\1", s[1:-1])

    mode: State = State.CHALLENGE
    tokens = WWW_AUTHENTICATE_TOKENIZER.tokenize(input)

    current_challenge = Challenge()

    def extract_auth_param(input: str) -> Tuple[str, str]:
        key, value = input.split("=", 1)
        key = key.rstrip()
        value = value.lstrip()
        if value.startswith('"'):
            value = unquote(value)
        return key, value

    while True:
        token: spack.tokenize.Token = next(tokens)

        if mode == State.CHALLENGE:
            if token.kind == WwwAuthenticateTokens.EOF:
                raise ValueError(token)
            elif token.kind == WwwAuthenticateTokens.TOKEN:
                current_challenge.scheme = token.value
                mode = State.AUTH_PARAM_LIST_START
            else:
                raise ValueError(token)

        elif mode == State.AUTH_PARAM_LIST_START:
            if token.kind == WwwAuthenticateTokens.EOF:
                challenges.append(current_challenge)
                break
            elif token.kind == WwwAuthenticateTokens.COMMA:
                # Challenge without param list, followed by another challenge.
                challenges.append(current_challenge)
                current_challenge = Challenge()
                mode = State.CHALLENGE
            elif token.kind == WwwAuthenticateTokens.SPACE:
                # A space means it must be followed by param list
                mode = State.AUTH_PARAM
            else:
                raise ValueError(token)

        elif mode == State.AUTH_PARAM:
            if token.kind == WwwAuthenticateTokens.EOF:
                raise ValueError(token)
            elif token.kind == WwwAuthenticateTokens.AUTH_PARAM:
                key, value = extract_auth_param(token.value)
                current_challenge.params.append((key, value))
                mode = State.NEXT_IN_LIST
            else:
                raise ValueError(token)

        elif mode == State.NEXT_IN_LIST:
            if token.kind == WwwAuthenticateTokens.EOF:
                challenges.append(current_challenge)
                break
            elif token.kind == WwwAuthenticateTokens.COMMA:
                mode = State.AUTH_PARAM_OR_SCHEME
            else:
                raise ValueError(token)

        elif mode == State.AUTH_PARAM_OR_SCHEME:
            if token.kind == WwwAuthenticateTokens.EOF:
                raise ValueError(token)
            elif token.kind == WwwAuthenticateTokens.TOKEN:
                challenges.append(current_challenge)
                current_challenge = Challenge(token.value)
                mode = State.AUTH_PARAM_LIST_START
            elif token.kind == WwwAuthenticateTokens.AUTH_PARAM:
                key, value = extract_auth_param(token.value)
                current_challenge.params.append((key, value))
                mode = State.NEXT_IN_LIST

    return challenges


class RealmServiceScope(NamedTuple):
    realm: str
    service: str
    scope: str


class UsernamePassword(NamedTuple):
    username: str
    password: str

    @property
    def basic_auth_header(self) -> str:
        encoded = base64.b64encode(f"{self.username}:{self.password}".encode("utf-8")).decode(
            "utf-8"
        )
        return f"Basic {encoded}"


def _get_bearer_challenge(challenges: List[Challenge]) -> Optional[RealmServiceScope]:
    """Return the realm/service/scope for a Bearer auth challenge, or None if not found."""
    challenge = next((c for c in challenges if c.scheme == "Bearer"), None)

    if challenge is None:
        return None

    # Get realm / service / scope from challenge
    realm = next((v for k, v in challenge.params if k == "realm"), None)
    service = next((v for k, v in challenge.params if k == "service"), None)
    scope = next((v for k, v in challenge.params if k == "scope"), None)

    if realm is None or service is None or scope is None:
        return None

    return RealmServiceScope(realm, service, scope)


def _get_basic_challenge(challenges: List[Challenge]) -> Optional[str]:
    """Return the realm for a Basic auth challenge, or None if not found."""
    challenge = next((c for c in challenges if c.scheme == "Basic"), None)

    if challenge is None:
        return None

    return next((v for k, v in challenge.params if k == "realm"), None)


class OCIAuthHandler(urllib.request.BaseHandler):
    def __init__(self, credentials_provider: Callable[[str], Optional[UsernamePassword]]):
        """
        Args:
            credentials_provider: A function that takes a domain and may return a UsernamePassword.
        """
        self.credentials_provider = credentials_provider

        # Cached authorization headers for a given domain.
        self.cached_auth_headers: Dict[str, str] = {}

    def https_request(self, req: Request):
        # Eagerly add the bearer token to the request if no
        # auth header is set yet, to avoid 401s in multiple
        # requests to the same registry.

        # Use has_header, not .headers, since there are two
        # types of headers (redirected and unredirected)
        if req.has_header("Authorization"):
            return req

        parsed = urllib.parse.urlparse(req.full_url)
        auth_header = self.cached_auth_headers.get(parsed.netloc)

        if not auth_header:
            return req

        req.add_unredirected_header("Authorization", auth_header)
        return req

    def _try_bearer_challenge(
        self,
        challenges: List[Challenge],
        credentials: Optional[UsernamePassword],
        timeout: Optional[float],
    ) -> Optional[str]:
        # Check whether a Bearer challenge is present in the WWW-Authenticate header
        challenge = _get_bearer_challenge(challenges)
        if not challenge:
            return None

        # Get the token from the auth handler
        query = urllib.parse.urlencode(
            {"service": challenge.service, "scope": challenge.scope, "client_id": "spack"}
        )
        parsed = urllib.parse.urlparse(challenge.realm)._replace(
            query=query, fragment="", params=""
        )

        # Don't send credentials over insecure transport.
        if parsed.scheme != "https":
            raise ValueError(f"Cannot login over insecure {parsed.scheme} connection")

        request = Request(urllib.parse.urlunparse(parsed), method="GET")

        if credentials is not None:
            request.add_unredirected_header("Authorization", credentials.basic_auth_header)

        # Do a GET request.
        response = self.parent.open(request, timeout=timeout)
        try:
            response_json = json.load(response)
            token = response_json.get("token")
            if token is None:
                token = response_json.get("access_token")
            assert type(token) is str
        except Exception as e:
            raise ValueError(f"Malformed token response from {challenge.realm}") from e
        return f"Bearer {token}"

    def _try_basic_challenge(
        self, challenges: List[Challenge], credentials: UsernamePassword
    ) -> Optional[str]:
        # Check whether a Basic challenge is present in the WWW-Authenticate header
        # A realm is required for Basic auth, although we don't use it here. Leave this as a
        # validation step.
        realm = _get_basic_challenge(challenges)
        if not realm:
            return None
        return credentials.basic_auth_header

    def http_error_401(self, req: Request, fp, code, msg, headers):
        # Login failed, avoid infinite recursion where we go back and
        # forth between auth server and registry
        if hasattr(req, "login_attempted"):
            raise spack.util.web.DetailedHTTPError(
                req, code, f"Failed to login: {msg}", headers, fp
            )

        # On 401 Unauthorized, parse the WWW-Authenticate header
        # to determine what authentication is required
        if "WWW-Authenticate" not in headers:
            raise spack.util.web.DetailedHTTPError(
                req, code, "Cannot login to registry, missing WWW-Authenticate header", headers, fp
            )

        www_auth_str = headers["WWW-Authenticate"]

        try:
            challenges = parse_www_authenticate(www_auth_str)
        except ValueError as e:
            raise spack.util.web.DetailedHTTPError(
                req,
                code,
                f"Cannot login to registry, malformed WWW-Authenticate header: {www_auth_str}",
                headers,
                fp,
            ) from e

        registry = urllib.parse.urlparse(req.get_full_url()).netloc

        credentials = self.credentials_provider(registry)

        # First try Bearer, then Basic
        try:
            auth_header = self._try_bearer_challenge(challenges, credentials, req.timeout)
            if not auth_header and credentials:
                auth_header = self._try_basic_challenge(challenges, credentials)
        except Exception as e:
            raise spack.util.web.DetailedHTTPError(
                req, code, f"Cannot login to registry: {e}", headers, fp
            ) from e

        if not auth_header:
            raise spack.util.web.DetailedHTTPError(
                req,
                code,
                f"Cannot login to registry, unsupported authentication scheme: {www_auth_str}",
                headers,
                fp,
            )

        self.cached_auth_headers[registry] = auth_header

        # Add the authorization header to the request
        req.add_unredirected_header("Authorization", auth_header)
        setattr(req, "login_attempted", True)

        return self.parent.open(req, timeout=req.timeout)


def credentials_from_mirrors(
    domain: str, *, mirrors: Optional[Iterable[spack.mirrors.mirror.Mirror]] = None
) -> Optional[UsernamePassword]:
    """Filter out OCI registry credentials from a list of mirrors."""

    mirrors = mirrors or spack.mirrors.mirror.MirrorCollection().values()

    for mirror in mirrors:
        # Prefer push credentials over fetch. Unlikely that those are different
        # but our config format allows it.
        for direction in ("push", "fetch"):
            pair = mirror.get_credentials(direction).get("access_pair")
            if not pair:
                continue

            url = mirror.get_url(direction)
            try:
                parsed = ImageReference.from_url(url)
            except ValueError:
                continue
            if parsed.domain == domain:
                return UsernamePassword(*pair)
    return None


def create_opener():
    """Create an opener that can handle OCI authentication."""
    opener = urllib.request.OpenerDirector()
    for handler in [
        urllib.request.ProxyHandler(),
        urllib.request.UnknownHandler(),
        urllib.request.HTTPHandler(),
        spack.util.web.SpackHTTPSHandler(context=spack.util.web.ssl_create_default_context()),
        spack.util.web.SpackHTTPDefaultErrorHandler(),
        urllib.request.HTTPRedirectHandler(),
        urllib.request.HTTPErrorProcessor(),
        OCIAuthHandler(credentials_from_mirrors),
    ]:
        opener.add_handler(handler)
    return opener


def ensure_status(request: urllib.request.Request, response: HTTPResponse, status: int):
    """Raise an error if the response status is not the expected one."""
    if response.status == status:
        return

    raise spack.util.web.DetailedHTTPError(
        request, response.status, response.reason, response.info(), None
    )


def default_retry(f, retries: int = 5, sleep=None):
    sleep = sleep or time.sleep

    def wrapper(*args, **kwargs):
        for i in range(retries):
            try:
                return f(*args, **kwargs)
            except OSError as e:
                # Retry on internal server errors, and rate limit errors
                # Potentially this could take into account the Retry-After header
                # if registries support it
                if i + 1 != retries and (
                    (
                        isinstance(e, urllib.error.HTTPError)
                        and (500 <= e.code < 600 or e.code == 429)
                    )
                    or (
                        isinstance(e, urllib.error.URLError)
                        and isinstance(e.reason, socket.timeout)
                    )
                    or isinstance(e, socket.timeout)
                ):
                    # Exponential backoff
                    sleep(2**i)
                    continue
                raise

    return wrapper
