Source code for ai.backend.client.config

import os
from pathlib import Path
import random
import re
from typing import (
    Any, Callable, Iterable, Union,
    List, Tuple, Sequence,
)

import appdirs
from yarl import URL

__all__ = [
    'parse_api_version',
    'get_config',
    'set_config',
    'APIConfig',
    'API_VERSION',
    'DEFAULT_CHUNK_SIZE',
]

_config = None
_undefined = object()

API_VERSION = (5, '20191215')

DEFAULT_CHUNK_SIZE = 256 * 1024  # 256 KiB

local_state_path = Path(appdirs.user_state_dir('backend.ai', 'Lablup'))
local_cache_path = Path(appdirs.user_cache_dir('backend.ai', 'Lablup'))


def parse_api_version(value: str) -> Tuple[int, str]:
    match = re.search(r'^v(?P<major>\d+)\.(?P<date>\d{8})$', value)
    if match is not None:
        return int(match.group(1)), match.group(2)
    raise ValueError('Could not parse the given API version string', value)


[docs]def get_env(key: str, default: Any = _undefined, *, clean: Callable[[str], Any] = lambda v: v): ''' Retrieves a configuration value from the environment variables. The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then ``"SORNA_"`` if the former does not exist. :param key: The key name. :param default: The default value returned when there is no corresponding environment variable. :param clean: A single-argument function that is applied to the result of lookup (in both successes and the default value for failures). The default is returning the value as-is. :returns: The value processed by the *clean* function. ''' key = key.upper() v = os.environ.get('BACKEND_' + key) if v is None: v = os.environ.get('SORNA_' + key) if v is None: if default is _undefined: raise KeyError(key) v = default return clean(v)
def bool_env(v: str) -> bool: v = v.lower() if v in ('y', 'yes', 't', 'true', '1'): return True if v in ('n', 'no', 'f', 'false', '0'): return False raise ValueError('Unrecognized value of boolean environment variable', v) def _clean_urls(v: str) -> List[URL]: if isinstance(v, URL): return [v] if isinstance(v, str): urls = [] for entry in v.split(','): url = URL(entry) if not url.is_absolute(): raise ValueError('URL {} is not absolute.'.format(url)) urls.append(url) return urls def _clean_tokens(v): if isinstance(v, str): if not v: return tuple() return tuple(v.split(',')) return tuple(iter(v))
[docs]class APIConfig: ''' Represents a set of API client configurations. The access key and secret key are mandatory -- they must be set in either environment variables or as the explicit arguments. :param endpoint: The URL prefix to make API requests via HTTP/HTTPS. If this is given as ``str`` and contains multiple URLs separated by comma, the underlying HTTP request-response facility will perform client-side load balancing and automatic fail-over using them, assuming that all those URLs indicates a single, same cluster. The users of the API and CLI will get network connection errors only when all of the given endpoints fail -- intermittent failures of a subset of endpoints will be hidden with a little increased latency. :param endpoint_type: Either ``"api"`` or ``"session"``. If the endpoint type is ``"api"`` (the default if unspecified), it uses the access key and secret key in the configuration to access the manager API server directly. If the endpoint type is ``"session"``, it assumes the endpoint is a Backend.AI console server which provides cookie-based authentication with username and password. In the latter, users need to use ``backend.ai login`` and ``backend.ai logout`` to manage their sign-in status, or the API equivalent in :meth:`~ai.backend.client.auth.Auth.login` and :meth:`~ai.backend.client.auth.Auth.logout` methods. :param version: The API protocol version. :param user_agent: A custom user-agent string which is sent to the API server as a ``User-Agent`` HTTP header. :param access_key: The API access key. If deliberately set to an empty string, the API requests will be made without signatures (anonymously). :param secret_key: The API secret key. :param hash_type: The hash type to generate per-request authentication signatures. :param vfolder_mounts: A list of vfolder names (that must belong to the given access key) to be automatically mounted upon any :func:`Kernel.get_or_create() <ai.backend.client.kernel.Kernel.get_or_create>` calls. ''' DEFAULTS = { 'endpoint': 'https://api.backend.ai', 'endpoint_type': 'api', 'version': f'v{API_VERSION[0]}.{API_VERSION[1]}', 'hash_type': 'sha256', 'domain': 'default', 'group': 'default', 'connection_timeout': 10.0, 'read_timeout': None, } ''' The default values except the access and secret keys. ''' def __init__(self, *, endpoint: Union[URL, str] = None, endpoint_type: str = None, domain: str = None, group: str = None, version: str = None, user_agent: str = None, access_key: str = None, secret_key: str = None, hash_type: str = None, vfolder_mounts: Iterable[str] = None, skip_sslcert_validation: bool = None, connection_timeout: float = None, read_timeout: float = None) -> None: from . import get_user_agent # noqa; to avoid circular imports self._endpoints = ( _clean_urls(endpoint) if endpoint else get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls)) random.shuffle(self._endpoints) self._endpoint_type = endpoint_type if endpoint_type \ else get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type']) self._domain = domain if domain else get_env('DOMAIN', self.DEFAULTS['domain']) self._group = group if group else get_env('GROUP', self.DEFAULTS['group']) self._version = version if version else self.DEFAULTS['version'] self._user_agent = user_agent if user_agent else get_user_agent() if self._endpoint_type == 'api': self._access_key = access_key if access_key is not None \ else get_env('ACCESS_KEY', '') self._secret_key = secret_key if secret_key is not None \ else get_env('SECRET_KEY', '') else: self._access_key = 'dummy' self._secret_key = 'dummy' self._hash_type = hash_type.lower() if hash_type else \ self.DEFAULTS['hash_type'] arg_vfolders = set(vfolder_mounts) if vfolder_mounts else set() env_vfolders = set(get_env('VFOLDER_MOUNTS', [], clean=_clean_tokens)) self._vfolder_mounts = [*(arg_vfolders | env_vfolders)] # prefer the argument flag and fallback to env if the flag is not set. self._skip_sslcert_validation = (skip_sslcert_validation if skip_sslcert_validation else get_env('SKIP_SSLCERT_VALIDATION', 'no', clean=bool_env)) self._connection_timeout = connection_timeout if connection_timeout else \ get_env('CONNECTION_TIMEOUT', self.DEFAULTS['connection_timeout']) self._read_timeout = read_timeout if read_timeout else \ get_env('READ_TIMEOUT', self.DEFAULTS['read_timeout']) @property def is_anonymous(self) -> bool: return self._access_key == '' @property def endpoint(self) -> URL: ''' The currently active endpoint URL. This may change if there are multiple configured endpoints and the current one is not accessible. ''' return self._endpoints[0] @property def endpoints(self) -> Sequence[URL]: '''All configured endpoint URLs.''' return self._endpoints def rotate_endpoints(self): if len(self._endpoints) > 1: item = self._endpoints.pop(0) self._endpoints.append(item) @property def endpoint_type(self) -> str: ''' The configured endpoint type. ''' return self._endpoint_type @property def domain(self) -> str: '''The configured domain.''' return self._domain @property def group(self) -> str: '''The configured group.''' return self._group @property def user_agent(self) -> str: '''The configured user agent string.''' return self._user_agent @property def access_key(self) -> str: '''The configured API access key.''' return self._access_key @property def secret_key(self) -> str: '''The configured API secret key.''' return self._secret_key @property def version(self) -> str: '''The configured API protocol version.''' return self._version @property def hash_type(self) -> str: '''The configured hash algorithm for API authentication signatures.''' return self._hash_type @property def vfolder_mounts(self) -> Tuple[str, ...]: '''The configured auto-mounted vfolder list.''' return self._vfolder_mounts @property def skip_sslcert_validation(self) -> bool: '''Whether to skip SSL certificate validation for the API gateway.''' return self._skip_sslcert_validation @property def connection_timeout(self) -> float: '''The maximum allowed duration for making TCP connections to the server.''' return self._connection_timeout @property def read_timeout(self) -> float: '''The maximum allowed waiting time for the first byte of the response from the server.''' return self._read_timeout
[docs]def get_config(): ''' Returns the configuration for the current process. If there is no explicitly set :class:`APIConfig` instance, it will generate a new one from the current environment variables and defaults. ''' global _config if _config is None: _config = APIConfig() return _config
[docs]def set_config(conf: APIConfig): ''' Sets the configuration used throughout the current process. ''' global _config _config = conf