from __future__ import annotations
import json
import os
from pathlib import Path
import secrets
import tarfile
import tempfile
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Union,
cast,
)
from typing_extensions import Literal # for Python 3.7
from uuid import UUID
import aiohttp
from aiohttp import hdrs
from tqdm import tqdm
from .base import api_function, BaseFunction
from ..compat import current_loop
from ..config import DEFAULT_CHUNK_SIZE
from ..exceptions import BackendClientError
from ..pagination import generate_paginated_results
from ..request import (
Request, AttachedFile,
WebSocketResponse,
SSEContextManager,
WebSocketContextManager,
)
from ..session import api_session
from ..utils import ProgressReportingReader
from ..types import Undefined, undefined
from ..versioning import get_naming, get_id_or_name
__all__ = (
'ComputeSession',
)
def drop(d: Mapping[str, Any], value_to_drop: Any) -> Mapping[str, Any]:
modified: Dict[str, Any] = {}
for k, v in d.items():
if isinstance(v, Mapping) or isinstance(v, dict):
modified[k] = drop(v, value_to_drop)
elif v != value_to_drop:
modified[k] = v
return modified
[docs]class ComputeSession(BaseFunction):
"""
Provides various interactions with compute sessions in Backend.AI.
The term 'kernel' is now deprecated and we prefer 'compute sessions'.
However, for historical reasons and to avoid confusion with client sessions, we
keep the backward compatibility with the naming of this API function class.
For multi-container sessions, all methods take effects to the master container
only, except :func:`~ComputeSession.destroy` and :func:`~ComputeSession.restart` methods.
So it is the user's responsibility to distribute uploaded files to multiple
containers using explicit copies or virtual folders which are commonly mounted to
all containers belonging to the same compute session.
"""
id: Optional[UUID]
name: Optional[str]
owner_access_key: Optional[str]
created: bool
status: str
service_ports: List[str]
domain: str
group: str
@api_function
@classmethod
async def paginated_list(
cls,
status: str = None,
access_key: str = None,
*,
fields: Sequence[str] = None,
page_size: int = 20,
filter: str = None,
order: str = None,
) -> AsyncIterator[dict]:
"""
Fetches the list of users. Domain admins can only get domain users.
:param is_active: Fetches active or inactive users only if not None.
:param fields: Additional per-user query fields to fetch.
"""
if fields is None:
fields = [
'task_id',
'image',
'type',
'status',
'status_info',
'status_changed',
'result',
]
async for item in generate_paginated_results(
'compute_session_list',
{
'status': (status, 'String'),
'access_key': (access_key, 'String'),
'filter': (filter, 'String'),
'order': (order, 'String'),
},
fields,
page_size=page_size,
):
yield item
@api_function
@classmethod
async def hello(cls) -> str:
rqst = Request('GET', '/')
async with rqst.fetch() as resp:
return await resp.json()
@api_function
@classmethod
async def get_task_logs(
cls, task_id: str, *,
chunk_size: int = 8192
) -> AsyncIterator[bytes]:
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request('GET', f'/{prefix}/_/logs', params={
'taskId': task_id,
})
async with rqst.fetch() as resp:
while True:
chunk = await resp.raw_response.content.read(chunk_size)
if not chunk:
break
yield chunk
@api_function
@classmethod
async def get_or_create(
cls,
image: str, *,
name: str = None,
type_: str = 'interactive',
starts_at: str = None,
enqueue_only: bool = False,
max_wait: int = 0,
no_reuse: bool = False,
mounts: List[str] = None,
mount_map: Mapping[str, str] = None,
envs: Mapping[str, str] = None,
startup_command: str = None,
resources: Mapping[str, int] = None,
resource_opts: Mapping[str, int] = None,
cluster_size: int = 1,
cluster_mode: Literal['single-node', 'multi-node'] = 'single-node',
domain_name: str = None,
group_name: str = None,
bootstrap_script: str = None,
tag: str = None,
scaling_group: str = None,
owner_access_key: str = None,
preopen_ports: List[int] = None,
) -> ComputeSession:
"""
Get-or-creates a compute session.
If *name* is ``None``, it creates a new compute session as long as
the server has enough resources and your API key has remaining quota.
If *name* is a valid string and there is an existing compute session
with the same token and the same *image*, then it returns the :class:`ComputeSession`
instance representing the existing session.
:param image: The image name and tag for the compute session.
Example: ``python:3.6-ubuntu``.
Check out the full list of available images in your server using (TODO:
new API).
:param name: A client-side (user-defined) identifier to distinguish the session among currently
running sessions.
It may be used to seamlessly reuse the session already created.
.. versionchanged:: 19.12.0
Renamed from ``clientSessionToken``.
:param type_: Either ``"interactive"`` (default) or ``"batch"``.
.. versionadded:: 19.09.0
:param enqueue_only: Just enqueue the session creation request and return immediately,
without waiting for its startup. (default: ``false`` to preserve the legacy
behavior)
.. versionadded:: 19.09.0
:param max_wait: The time to wait for session startup. If the cluster resource
is being fully utilized, this waiting time can be arbitrarily long due to
job queueing. If the timeout reaches, the returned *status* field becomes
``"TIMEOUT"``. Still in this case, the session may start in the future.
.. versionadded:: 19.09.0
:param no_reuse: Raises an explicit error if a session with the same *image* and
the same *name* already exists instead of returning the information
of it.
.. versionadded:: 19.09.0
:param mounts: The list of vfolder names that belongs to the currrent API
access key.
:param mount_map: Mapping which contains custom path to mount vfolder.
Key and value of this map should be vfolder name and custom path.
All custom mounts should be under /home/work.
vFolders which has a dot(.) prefix in its name are not affected.
:param envs: The environment variables which always bypasses the jail policy.
:param resources: The resource specification. (TODO: details)
:param cluster_size: The number of containers in this compute session.
Must be at least 1.
.. versionadded:: 19.09.0
.. versionchanged:: 20.09.0
:param cluster_mode: Set the clustering mode whether to use distributed
nodes or a single node to spawn multiple containers for the new session.
.. versionadded:: 20.09.0
:param tag: An optional string to annotate extra information.
:param owner: An optional access key that owns the created session. (Only
available to administrators)
:returns: The :class:`ComputeSession` instance.
"""
if name is not None:
assert 4 <= len(name) <= 64, \
'Client session token should be 4 to 64 characters long.'
else:
name = f'pysdk-{secrets.token_hex(5)}'
if mounts is None:
mounts = []
if mount_map is None:
mount_map = {}
if resources is None:
resources = {}
if resource_opts is None:
resource_opts = {}
if domain_name is None:
# Even if config.domain is None, it can be guessed in the manager by user information.
domain_name = api_session.get().config.domain
if group_name is None:
group_name = api_session.get().config.group
mounts.extend(api_session.get().config.vfolder_mounts)
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request('POST', f'/{prefix}')
params: Dict[str, Any] = {
'tag': tag,
get_naming(api_session.get().api_version, 'name_arg'): name,
'config': {
'mounts': mounts,
'environ': envs,
'resources': resources,
'resource_opts': resource_opts,
'scalingGroup': scaling_group,
},
}
if api_session.get().api_version >= (6, '20200815'):
params['clusterSize'] = cluster_size
params['clusterMode'] = cluster_mode
else:
params['config']['clusterSize'] = cluster_size
if api_session.get().api_version >= (5, '20191215'):
params['config'].update({
'mount_map': mount_map,
'preopen_ports': preopen_ports,
})
params.update({
'starts_at': starts_at,
'bootstrap_script': bootstrap_script,
})
if api_session.get().api_version >= (4, '20190615'):
params.update({
'owner_access_key': owner_access_key,
'domain': domain_name,
'group': group_name,
'type': type_,
'enqueueOnly': enqueue_only,
'maxWaitSeconds': max_wait,
'reuseIfExists': not no_reuse,
'startupCommand': startup_command,
})
if api_session.get().api_version > (4, '20181215'):
params['image'] = image
else:
params['lang'] = image
rqst.set_json(params)
async with rqst.fetch() as resp:
data = await resp.json()
o = cls(name, owner_access_key) # type: ignore
if api_session.get().api_version[0] >= 5:
o.id = UUID(data['sessionId'])
o.created = data.get('created', True) # True is for legacy
o.status = data.get('status', 'RUNNING')
o.service_ports = data.get('servicePorts', [])
o.domain = domain_name
o.group = group_name
return o
@api_function
@classmethod
async def create_from_template(
cls,
template_id: str, *,
name: Union[str, Undefined] = undefined,
type_: Union[str, Undefined] = undefined,
starts_at: str = None,
enqueue_only: Union[bool, Undefined] = undefined,
max_wait: Union[int, Undefined] = undefined,
no_reuse: Union[bool, Undefined] = undefined,
image: Union[str, Undefined] = undefined,
mounts: Union[List[str], Undefined] = undefined,
mount_map: Union[Mapping[str, str], Undefined] = undefined,
envs: Union[Mapping[str, str], Undefined] = undefined,
startup_command: Union[str, Undefined] = undefined,
resources: Union[Mapping[str, int], Undefined] = undefined,
resource_opts: Union[Mapping[str, int], Undefined] = undefined,
cluster_size: Union[int, Undefined] = undefined,
cluster_mode: Union[Literal['single-node', 'multi-node'], Undefined] = undefined,
domain_name: Union[str, Undefined] = undefined,
group_name: Union[str, Undefined] = undefined,
bootstrap_script: Union[str, Undefined] = undefined,
tag: Union[str, Undefined] = undefined,
scaling_group: Union[str, Undefined] = undefined,
owner_access_key: Union[str, Undefined] = undefined,
) -> ComputeSession:
"""
Get-or-creates a compute session from template.
All other parameters provided will be overwritten to template, including
vfolder mounts (not appended!).
If *name* is ``None``, it creates a new compute session as long as
the server has enough resources and your API key has remaining quota.
If *name* is a valid string and there is an existing compute session
with the same token and the same *image*, then it returns the :class:`ComputeSession`
instance representing the existing session.
:param template_id: Task template to apply to compute session.
:param image: The image name and tag for the compute session.
Example: ``python:3.6-ubuntu``.
Check out the full list of available images in your server using (TODO:
new API).
:param name: A client-side (user-defined) identifier to distinguish the session among currently
running sessions.
It may be used to seamlessly reuse the session already created.
.. versionchanged:: 19.12.0
Renamed from ``clientSessionToken``.
:param type_: Either ``"interactive"`` (default) or ``"batch"``.
.. versionadded:: 19.09.0
:param enqueue_only: Just enqueue the session creation request and return immediately,
without waiting for its startup. (default: ``false`` to preserve the legacy
behavior)
.. versionadded:: 19.09.0
:param max_wait: The time to wait for session startup. If the cluster resource
is being fully utilized, this waiting time can be arbitrarily long due to
job queueing. If the timeout reaches, the returned *status* field becomes
``"TIMEOUT"``. Still in this case, the session may start in the future.
.. versionadded:: 19.09.0
:param no_reuse: Raises an explicit error if a session with the same *image* and
the same *name* already exists instead of returning the information
of it.
.. versionadded:: 19.09.0
:param mounts: The list of vfolder names that belongs to the currrent API
access key.
:param mount_map: Mapping which contains custom path to mount vfolder.
Key and value of this map should be vfolder name and custom path.
All custom mounts should be under /home/work.
vFolders which has a dot(.) prefix in its name are not affected.
:param envs: The environment variables which always bypasses the jail policy.
:param resources: The resource specification. (TODO: details)
:param cluster_size: The number of containers in this compute session.
Must be at least 1.
.. versionadded:: 19.09.0
.. versionchanged:: 20.09.0
:param cluster_mode: Set the clustering mode whether to use distributed
nodes or a single node to spawn multiple containers for the new session.
.. versionadded:: 20.09.0
:param tag: An optional string to annotate extra information.
:param owner: An optional access key that owns the created session. (Only
available to administrators)
:returns: The :class:`ComputeSession` instance.
"""
if name is not undefined:
assert 4 <= len(name) <= 64, \
'Client session token should be 4 to 64 characters long.'
else:
name = f'pysdk-{secrets.token_urlsafe(8)}'
if domain_name is undefined:
# Even if config.domain is None, it can be guessed in the manager by user information.
domain_name = api_session.get().config.domain
if group_name is undefined:
group_name = api_session.get().config.group
if mounts is undefined:
mounts = []
if api_session.get().config.vfolder_mounts:
mounts.extend(api_session.get().config.vfolder_mounts)
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request('POST', f'/{prefix}/_/create-from-template')
params: Dict[str, Any]
params = {
'template_id': template_id,
'tag': tag,
'image': image,
'domain': domain_name,
'group': group_name,
get_naming(api_session.get().api_version, 'name_arg'): name,
'bootstrap_script': bootstrap_script,
'enqueueOnly': enqueue_only,
'maxWaitSeconds': max_wait,
'reuseIfExists': not no_reuse,
'startupCommand': startup_command,
'owner_access_key': owner_access_key,
'type': type_,
'starts_at': starts_at,
'config': {
'mounts': mounts,
'mount_map': mount_map,
'environ': envs,
'resources': resources,
'resource_opts': resource_opts,
'scalingGroup': scaling_group,
},
}
if api_session.get().api_version >= (6, '20200815'):
params['clusterSize'] = cluster_size
params['clusterMode'] = cluster_mode
else:
params['config']['clusterSize'] = cluster_size
params = cast(Dict[str, Any], drop(params, undefined))
rqst.set_json(params)
async with rqst.fetch() as resp:
data = await resp.json()
o = cls(name, owner_access_key if owner_access_key is not undefined else None)
if api_session.get().api_version[0] >= 5:
o.id = UUID(data['sessionId'])
o.created = data.get('created', True) # True is for legacy
o.status = data.get('status', 'RUNNING')
o.service_ports = data.get('servicePorts', [])
o.domain = domain_name
o.group = group_name
return o
def __init__(self, name: str, owner_access_key: str = None) -> None:
self.id = None
self.name = name
self.owner_access_key = owner_access_key
@classmethod
def from_session_id(cls, session_id: UUID) -> ComputeSession:
o = cls(None, None) # type: ignore
o.id = session_id
return o
def get_session_identity_params(self) -> Mapping[str, str]:
if self.id:
identity_params = {
'sessionId': str(self.id),
}
else:
assert self.name is not None
identity_params = {
'sessionName': self.name,
}
if self.owner_access_key:
identity_params['owner_access_key'] = self.owner_access_key
return identity_params
[docs] @api_function
async def destroy(self, *, forced: bool = False):
"""
Destroys the compute session.
Since the server literally kills the container(s), all ongoing executions are
forcibly interrupted.
"""
params = {}
if self.owner_access_key is not None:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
if forced:
params['forced'] = 'true'
rqst = Request(
'DELETE', f'/{prefix}/{self.name}',
params=params,
)
async with rqst.fetch() as resp:
if resp.status == 200:
return await resp.json()
[docs] @api_function
async def restart(self):
"""
Restarts the compute session.
The server force-destroys the current running container(s), but keeps their
temporary scratch directories intact.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'PATCH', f'/{prefix}/{self.name}',
params=params,
)
async with rqst.fetch():
pass
[docs] @api_function
async def interrupt(self):
"""
Tries to interrupt the current ongoing code execution.
This may fail without any explicit errors depending on the code being
executed.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'POST', f'/{prefix}/{self.name}/interrupt',
params=params,
)
async with rqst.fetch():
pass
[docs] @api_function
async def complete(self, code: str, opts: dict = None) -> Iterable[str]:
"""
Gets the auto-completion candidates from the given code string,
as if a user has pressed the tab key just after the code in
IDEs.
Depending on the language of the compute session, this feature
may not be supported. Unsupported sessions returns an empty list.
:param code: An (incomplete) code text.
:param opts: Additional information about the current cursor position,
such as row, col, line and the remainder text.
:returns: An ordered list of strings.
"""
opts = {} if opts is None else opts
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'POST', f'/{prefix}/{self.name}/complete',
params=params,
)
rqst.set_json({
'code': code,
'options': {
'row': int(opts.get('row', 0)),
'col': int(opts.get('col', 0)),
'line': opts.get('line', ''),
'post': opts.get('post', ''),
},
})
async with rqst.fetch() as resp:
return await resp.json()
[docs] @api_function
async def get_info(self):
"""
Retrieves a brief information about the compute session.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'GET', f'/{prefix}/{self.name}',
params=params,
)
async with rqst.fetch() as resp:
return await resp.json()
[docs] @api_function
async def get_logs(self):
"""
Retrieves the console log of the compute session container.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'GET', f'/{prefix}/{self.name}/logs',
params=params,
)
async with rqst.fetch() as resp:
return await resp.json()
[docs] @api_function
async def execute(self, run_id: str = None,
code: str = None,
mode: str = 'query',
opts: dict = None):
"""
Executes a code snippet directly in the compute session or sends a set of
build/clean/execute commands to the compute session.
For more details about using this API, please refer :doc:`the official API
documentation <user-api/intro>`.
:param run_id: A unique identifier for a particular run loop. In the
first call, it may be ``None`` so that the server auto-assigns one.
Subsequent calls must use the returned ``runId`` value to request
continuation or to send user inputs.
:param code: A code snippet as string. In the continuation requests, it
must be an empty string. When sending user inputs, this is where the
user input string is stored.
:param mode: A constant string which is one of ``"query"``, ``"batch"``,
``"continue"``, and ``"user-input"``.
:param opts: A dict for specifying additional options. Mainly used in the
batch mode to specify build/clean/execution commands.
See :ref:`the API object reference <batch-execution-query-object>`
for details.
:returns: :ref:`An execution result object <execution-result-object>`
"""
opts = opts if opts is not None else {}
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
if mode in {'query', 'continue', 'input'}:
assert code is not None, \
'The code argument must be a valid string even when empty.'
rqst = Request(
'POST', f'/{prefix}/{self.name}',
params=params,
)
rqst.set_json({
'mode': mode,
'code': code,
'runId': run_id,
})
elif mode == 'batch':
rqst = Request(
'POST', f'/{prefix}/{self.name}',
params=params,
)
rqst.set_json({
'mode': mode,
'code': code,
'runId': run_id,
'options': {
'clean': opts.get('clean', None),
'build': opts.get('build', None),
'buildLog': bool(opts.get('buildLog', False)),
'exec': opts.get('exec', None),
},
})
elif mode == 'complete':
rqst = Request(
'POST', f'/{prefix}/{self.name}',
params=params,
)
rqst.set_json({
'code': code,
'options': {
'row': int(opts.get('row', 0)),
'col': int(opts.get('col', 0)),
'line': opts.get('line', ''),
'post': opts.get('post', ''),
},
})
else:
raise BackendClientError('Invalid execution mode: {0}'.format(mode))
async with rqst.fetch() as resp:
return (await resp.json())['result']
[docs] @api_function
async def upload(self, files: Sequence[Union[str, Path]],
basedir: Union[str, Path] = None,
show_progress: bool = False):
"""
Uploads the given list of files to the compute session.
You may refer them in the batch-mode execution or from the code
executed in the server afterwards.
:param files: The list of file paths in the client-side.
If the paths include directories, the location of them in the compute
session is calculated from the relative path to *basedir* and all
intermediate parent directories are automatically created if not exists.
For example, if a file path is ``/home/user/test/data.txt`` (or
``test/data.txt``) where *basedir* is ``/home/user`` (or the current
working directory is ``/home/user``), the uploaded file is located at
``/home/work/test/data.txt`` in the compute session container.
:param basedir: The directory prefix where the files reside.
The default value is the current working directory.
:param show_progress: Displays a progress bar during uploads.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
base_path = (
Path.cwd() if basedir is None
else Path(basedir).resolve()
)
files = [Path(file).resolve() for file in files]
total_size = 0
for file_path in files:
total_size += Path(file_path).stat().st_size
tqdm_obj = tqdm(desc='Uploading files',
unit='bytes', unit_scale=True,
total=total_size,
disable=not show_progress)
with tqdm_obj:
attachments = []
for file_path in files:
try:
attachments.append(AttachedFile(
str(Path(file_path).relative_to(base_path)),
ProgressReportingReader(str(file_path),
tqdm_instance=tqdm_obj),
'application/octet-stream',
))
except ValueError:
msg = 'File "{0}" is outside of the base directory "{1}".' \
.format(file_path, base_path)
raise ValueError(msg) from None
rqst = Request(
'POST', f'/{prefix}/{self.name}/upload',
params=params,
)
rqst.attach_files(attachments)
async with rqst.fetch() as resp:
return resp
[docs] @api_function
async def download(self, files: Sequence[Union[str, Path]],
dest: Union[str, Path] = '.',
show_progress: bool = False):
"""
Downloads the given list of files from the compute session.
:param files: The list of file paths in the compute session.
If they are relative paths, the path is calculated from
``/home/work`` in the compute session container.
:param dest: The destination directory in the client-side.
:param show_progress: Displays a progress bar during downloads.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'GET', f'/{prefix}/{self.name}/download',
params=params,
)
rqst.set_json({
'files': [*map(str, files)],
})
file_names = []
async with rqst.fetch() as resp:
loop = current_loop()
tqdm_obj = tqdm(desc='Downloading files',
unit='bytes', unit_scale=True,
total=resp.content.total_bytes,
disable=not show_progress)
reader = aiohttp.MultipartReader.from_response(resp.raw_response)
with tqdm_obj as pbar:
while True:
part = cast(aiohttp.BodyPartReader, await reader.next())
if part is None:
break
assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() == 'identity'
assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in (
'binary', '8bit', '7bit',
)
fp = tempfile.NamedTemporaryFile(suffix='.tar',
delete=False)
while True:
chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE)
if not chunk:
break
await loop.run_in_executor(None, lambda: fp.write(chunk))
pbar.update(len(chunk))
fp.close()
with tarfile.open(fp.name) as tarf:
tarf.extractall(path=dest)
file_names.extend(tarf.getnames())
os.unlink(fp.name)
return {'file_names': file_names}
[docs] @api_function
async def list_files(self, path: Union[str, Path] = '.'):
"""
Gets the list of files in the given path inside the compute session
container.
:param path: The directory path in the compute session.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
rqst = Request(
'GET', f'/{prefix}/{self.name}/files',
params=params,
)
rqst.set_json({
'path': path,
})
async with rqst.fetch() as resp:
return await resp.json()
@api_function
async def stream_app_info(self):
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
id_or_name = get_id_or_name(api_session.get().api_version, self)
api_rqst = Request(
'GET', f'/stream/{prefix}/{id_or_name}/apps',
params=params,
)
async with api_rqst.fetch() as resp:
return await resp.json()
# only supported in AsyncAPISession
[docs] def listen_events(self, scope: Literal['*', 'session', 'kernel'] = '*') -> SSEContextManager:
"""
Opens the stream of the kernel lifecycle events.
Only the master kernel of each session is monitored.
:returns: a :class:`StreamEvents` object.
"""
if api_session.get().api_version[0] >= 6:
request = Request(
'GET', '/events/session',
params={
**self.get_session_identity_params(),
'scope': scope,
}
)
else:
assert self.name is not None
params = {
get_naming(api_session.get().api_version, 'event_name_arg'): self.name,
}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
path = get_naming(api_session.get().api_version, 'session_events_path')
request = Request(
'GET', path,
params=params,
)
return request.connect_events()
stream_events = listen_events # legacy alias
# only supported in AsyncAPISession
[docs] def stream_pty(self) -> WebSocketContextManager:
"""
Opens a pseudo-terminal of the kernel (if supported) streamed via
websockets.
:returns: a :class:`StreamPty` object.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
id_or_name = get_id_or_name(api_session.get().api_version, self)
request = Request(
'GET', f'/stream/{prefix}/{id_or_name}/pty',
params=params,
)
return request.connect_websocket(response_cls=StreamPty)
# only supported in AsyncAPISession
[docs] def stream_execute(self, code: str = '', *,
mode: str = 'query',
opts: dict = None) -> WebSocketContextManager:
"""
Executes a code snippet in the streaming mode.
Since the returned websocket represents a run loop, there is no need to
specify *run_id* explicitly.
"""
params = {}
if self.owner_access_key:
params['owner_access_key'] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, 'path')
id_or_name = get_id_or_name(api_session.get().api_version, self)
opts = {} if opts is None else opts
if mode == 'query':
opts = {}
elif mode == 'batch':
opts = {
'clean': opts.get('clean', None),
'build': opts.get('build', None),
'buildLog': bool(opts.get('buildLog', False)),
'exec': opts.get('exec', None),
}
else:
msg = 'Invalid stream-execution mode: {0}'.format(mode)
raise BackendClientError(msg)
request = Request(
'GET', f'/stream/{prefix}/{id_or_name}/execute',
params=params,
)
async def send_code(ws):
await ws.send_json({
'code': code,
'mode': mode,
'options': opts,
})
return request.connect_websocket(on_enter=send_code)
[docs]class StreamPty(WebSocketResponse):
"""
A derivative class of :class:`~ai.backend.client.request.WebSocketResponse` which
provides additional functions to control the terminal.
"""
__slots__ = ('ws', )
async def resize(self, rows, cols):
await self.ws.send_str(json.dumps({
'type': 'resize',
'rows': rows,
'cols': cols,
}))
async def restart(self):
await self.ws.send_str(json.dumps({
'type': 'restart',
}))