Source code for ai.backend.client.config

import enum
import os
from pathlib import Path
import random
import re
from typing import (

import appdirs
from yarl import URL

__all__ = [

class Undefined(enum.Enum):
    token = object()

_config = None
_undefined = Undefined.token

API_VERSION = (6, '20210815')
MIN_API_VERSION = (5, '20191215')

DEFAULT_CHUNK_SIZE = 16 * (2**20)  # 16 MiB

local_state_path = Path(appdirs.user_state_dir('', 'Lablup'))
local_cache_path = Path(appdirs.user_cache_dir('', 'Lablup'))

def parse_api_version(value: str) -> Tuple[int, str]:
    match ='^v(?P<major>\d+)\.(?P<date>\d{8})$', value)
    if match is not None:
        return int(,
    raise ValueError('Could not parse the given API version string', value)

T = TypeVar('T')

def default_clean(v: str) -> T:
    return cast(T, v)

[docs]def get_env( key: str, default: Union[str, Undefined] = _undefined, *, clean: Callable[[str], T] = default_clean, ) -> T: """ Retrieves a configuration value from the environment variables. The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then ``"SORNA_"`` if the former does not exist. :param key: The key name. :param default: The default value returned when there is no corresponding environment variable. :param clean: A single-argument function that is applied to the result of lookup (in both successes and the default value for failures). The default is returning the value as-is. :returns: The value processed by the *clean* function. """ key = key.upper() raw = os.environ.get('BACKEND_' + key) if raw is None: raw = os.environ.get('SORNA_' + key) if raw is None: if default is _undefined: raise KeyError(key) raw = default return clean(raw)
def bool_env(v: str) -> bool: v = v.lower() if v in ('y', 'yes', 't', 'true', '1'): return True if v in ('n', 'no', 'f', 'false', '0'): return False raise ValueError('Unrecognized value of boolean environment variable', v) def _clean_urls(v: Union[URL, str]) -> List[URL]: if isinstance(v, URL): return [v] urls = [] if isinstance(v, str): for entry in v.split(','): url = URL(entry) if not url.is_absolute(): raise ValueError('URL {} is not absolute.'.format(url)) urls.append(url) return urls def _clean_tokens(v: str) -> Tuple[str, ...]: if not v: return tuple() return tuple(v.split(','))
[docs]class APIConfig: """ Represents a set of API client configurations. The access key and secret key are mandatory -- they must be set in either environment variables or as the explicit arguments. :param endpoint: The URL prefix to make API requests via HTTP/HTTPS. If this is given as ``str`` and contains multiple URLs separated by comma, the underlying HTTP request-response facility will perform client-side load balancing and automatic fail-over using them, assuming that all those URLs indicates a single, same cluster. The users of the API and CLI will get network connection errors only when all of the given endpoints fail -- intermittent failures of a subset of endpoints will be hidden with a little increased latency. :param endpoint_type: Either ``"api"`` or ``"session"``. If the endpoint type is ``"api"`` (the default if unspecified), it uses the access key and secret key in the configuration to access the manager API server directly. If the endpoint type is ``"session"``, it assumes the endpoint is a Backend.AI console server which provides cookie-based authentication with username and password. In the latter, users need to use `` login`` and `` logout`` to manage their sign-in status, or the API equivalent in :meth:`~ai.backend.client.auth.Auth.login` and :meth:`~ai.backend.client.auth.Auth.logout` methods. :param version: The API protocol version. :param user_agent: A custom user-agent string which is sent to the API server as a ``User-Agent`` HTTP header. :param access_key: The API access key. If deliberately set to an empty string, the API requests will be made without signatures (anonymously). :param secret_key: The API secret key. :param hash_type: The hash type to generate per-request authentication signatures. :param vfolder_mounts: A list of vfolder names (that must belong to the given access key) to be automatically mounted upon any :func:`Kernel.get_or_create() <ai.backend.client.kernel.Kernel.get_or_create>` calls. """ DEFAULTS: Mapping[str, str] = { 'endpoint': '', 'endpoint_type': 'api', 'version': f'v{API_VERSION[0]}.{API_VERSION[1]}', 'hash_type': 'sha256', 'domain': 'default', 'group': 'default', 'connection_timeout': '10.0', 'read_timeout': '0', } """ The default values for config parameterse settable via environment variables xcept the access and secret keys. """ _endpoints: List[URL] _group: str _hash_type: str def __init__( self, *, endpoint: Union[URL, str] = None, endpoint_type: str = None, domain: str = None, group: str = None, version: str = None, user_agent: str = None, access_key: str = None, secret_key: str = None, hash_type: str = None, vfolder_mounts: Iterable[str] = None, skip_sslcert_validation: bool = None, connection_timeout: float = None, read_timeout: float = None, announcement_handler: Callable[[str], None] = None, ) -> None: from . import get_user_agent self._endpoints = ( _clean_urls(endpoint) if endpoint else get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls) ) random.shuffle(self._endpoints) self._endpoint_type = endpoint_type if endpoint_type is not None else \ get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type'], clean=str) self._domain = domain if domain is not None else \ get_env('DOMAIN', self.DEFAULTS['domain'], clean=str) self._group = group if group is not None else \ get_env('GROUP', self.DEFAULTS['group'], clean=str) self._version = version if version is not None else \ self.DEFAULTS['version'] self._user_agent = user_agent if user_agent is not None else get_user_agent() if self._endpoint_type == 'api': self._access_key = access_key if access_key is not None else \ get_env('ACCESS_KEY', '') self._secret_key = secret_key if secret_key is not None else \ get_env('SECRET_KEY', '') else: self._access_key = 'dummy' self._secret_key = 'dummy' self._hash_type = hash_type.lower() if hash_type is not None else \ cast(str, self.DEFAULTS['hash_type']) arg_vfolders = set(vfolder_mounts) if vfolder_mounts else set() env_vfolders = set(get_env('VFOLDER_MOUNTS', '', clean=_clean_tokens)) self._vfolder_mounts = [*(arg_vfolders | env_vfolders)] # prefer the argument flag and fallback to env if the flag is not set. self._skip_sslcert_validation = (skip_sslcert_validation if skip_sslcert_validation else get_env('SKIP_SSLCERT_VALIDATION', 'no', clean=bool_env)) self._connection_timeout = connection_timeout if connection_timeout else \ get_env('CONNECTION_TIMEOUT', self.DEFAULTS['connection_timeout'], clean=float) self._read_timeout = read_timeout if read_timeout else \ get_env('READ_TIMEOUT', self.DEFAULTS['read_timeout'], clean=float) self._announcement_handler = announcement_handler @property def is_anonymous(self) -> bool: return self._access_key == '' @property def endpoint(self) -> URL: """ The currently active endpoint URL. This may change if there are multiple configured endpoints and the current one is not accessible. """ return self._endpoints[0] @property def endpoints(self) -> Sequence[URL]: """All configured endpoint URLs.""" return self._endpoints def rotate_endpoints(self): if len(self._endpoints) > 1: item = self._endpoints.pop(0) self._endpoints.append(item) def load_balance_endpoints(self): pass @property def endpoint_type(self) -> str: """ The configured endpoint type. """ return self._endpoint_type @property def domain(self) -> str: """The configured domain.""" return self._domain @property def group(self) -> str: """The configured group.""" return self._group @property def user_agent(self) -> str: """The configured user agent string.""" return self._user_agent @property def access_key(self) -> str: """The configured API access key.""" return self._access_key @property def secret_key(self) -> str: """The configured API secret key.""" return self._secret_key @property def version(self) -> str: """The configured API protocol version.""" return self._version @property def hash_type(self) -> str: """The configured hash algorithm for API authentication signatures.""" return self._hash_type @property def vfolder_mounts(self) -> Sequence[str]: """The configured auto-mounted vfolder list.""" return self._vfolder_mounts @property def skip_sslcert_validation(self) -> bool: """Whether to skip SSL certificate validation for the API gateway.""" return self._skip_sslcert_validation @property def connection_timeout(self) -> float: """The maximum allowed duration for making TCP connections to the server.""" return self._connection_timeout @property def read_timeout(self) -> float: """The maximum allowed waiting time for the first byte of the response from the server.""" return self._read_timeout @property def announcement_handler(self) -> Optional[Callable[[str], None]]: '''The announcement handler to display server-set announcements.''' return self._announcement_handler
[docs]def get_config(): """ Returns the configuration for the current process. If there is no explicitly set :class:`APIConfig` instance, it will generate a new one from the current environment variables and defaults. """ global _config if _config is None: _config = APIConfig() return _config
[docs]def set_config(conf: APIConfig): """ Sets the configuration used throughout the current process. """ global _config _config = conf