# A collection of tools to interface with manually traced and autosegmented
# data in FAFB.
#
# Copyright (C) 2019 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
import functools
import json
import navis
import os
import pytz
import time
import requests
import warnings
from caveclient import CAVEclient
from pathlib import Path
from importlib import reload
from zipfile import ZipFile
from io import BytesIO
import cloudvolume as cv
import datetime as dt
import trimesh as tm
import numpy as np
import pandas as pd
from .. import utils
__all__ = ['set_chunkedgraph_secret', 'get_chunkedgraph_secret',
'get_cave_client', 'get_neuropil_volumes', 'get_lr_position',
'set_default_dataset', 'find_mat_version']
FLYWIRE_DATASETS = {'production': 'fly_v31',
'sandbox': 'fly_v26',
'public': 'flywire_public'}
FLYWIRE_URLS = {'production': 'graphene://https://prod.flywire-daf.com/segmentation/1.0/fly_v31',
'sandbox': 'graphene://https://prod.flywire-daf.com/segmentation/1.0/fly_v26',
'public': 'graphene://https://prodv1.flywire-daf.com/segmentation/1.0/flywire_public',
'flat_630': 'precomputed://gs://flywire_v141_m630',
'flat_571': 'precomputed://gs://flywire_v141_m526',
'flat_783': 'precomputed://gs://flywire_v141_m783'}
CAVE_DATASETS = {'production': 'flywire_fafb_production',
'flat_783': 'flywire_fafb_production',
'flat_630': 'flywire_fafb_public',
'flat_571': 'flywire_fafb_production',
'sandbox': 'flywire_fafb_sandbox',
'public': 'flywire_fafb_public'}
SILENCE_FIND_MAT_VERSION = False
# Initialize without a volume
cloud_volumes = {}
cave_clients = {}
# Data stuff
fp = Path(__file__).parent
data_path = fp.parent / 'data'
area_ids = None
vol_names = None
# The default dataset
DEFAULT_DATASET = os.environ.get('FLYWIRE_DEFAULT_DATASET', 'public')
# Some useful data types
INT_DTYPES = (np.int32, np.int64, int, np.uint32, np.uint64)
FLOAT_DTYPES = (np.float32, np.float64, float)
STR_DTYPES = (str, np.str_)
def match_dtype(x, target_dt):
"""Make sure that input has same dtype as target.
This function only maches the broad data type (float, integer, string), not
e.g. the exact precision.
Parameters
----------
x
Input to be converted.
target_dt
The target data type.
Returns
-------
x :
Input with matching dtype. Lists and tuples will be converted
to numpy arrays.
"""
if isinstance(x, (list, tuple, np.ndarray)):
x = np.asarray(x)
if target_dt in INT_DTYPES:
x = x.astype(np.int64)
elif target_dt in FLOAT_DTYPES:
x = x.astype(np.float64)
elif target_dt in STR_DTYPES:
x = x.astype(str)
else:
if target_dt in INT_DTYPES:
x = np.int64(x)
elif target_dt in FLOAT_DTYPES:
x = float(x)
elif target_dt in STR_DTYPES:
x = str(x)
return x
[docs]
def set_default_dataset(dataset):
"""Set the default FlyWire dataset for this session.
Alternatively, you can also use a FLYWIRE_DEFAULT_DATASET environment
variable (must be set before starting Python).
Parameters
----------
dataset : "production" | "public" | "sandbox" | "flat_630"
Dataset to be used by default.
Examples
--------
>>> from fafbseg import flywire
>>> flywire.set_default_dataset('public')
Default dataset set to "public".
"""
if dataset not in FLYWIRE_URLS and dataset not in get_cave_datastacks():
datasets = np.unique(list(FLYWIRE_URLS) + get_cave_datastacks())
raise ValueError(f'`dataset` must be one of: {", ".join(datasets)}.')
global DEFAULT_DATASET
DEFAULT_DATASET = dataset
print(f'Default dataset set to "{dataset}".')
def inject_dataset(allowed=None, disallowed=None):
"""Inject current default dataset."""
if isinstance(allowed, str):
allowed = [allowed]
if isinstance(disallowed, str):
disallowed = [disallowed]
def outer(func):
@functools.wraps(func)
def inner(*args, **kwargs):
if kwargs.get('dataset', None) is None:
kwargs['dataset'] = DEFAULT_DATASET
ds = kwargs['dataset']
if allowed and ds not in allowed:
raise ValueError(f'Dataset "{ds}" not allowed for function {func}. '
f'Accepted datasets: {allowed}')
if disallowed and ds in disallowed:
raise ValueError(f'Dataset "{ds}" not allowed for function {func}.')
return func(*args, **kwargs)
return inner
return outer
[docs]
def get_neuropil_volumes(neuropils):
"""Load FlyWire neuropil volumes.
These meshes were originally created by for the JFRC2 brain template
(for citation and details see 10.5281/zenodo.10567). Here, we transformed
them to FlyWire (FAFB14.1) space.
Parameters
----------
neuropils : str | list thereof | None
Neuropil name(s) - e.g. 'LH_R' or ['LH_R', 'LH_L']. Use
``None`` to get an array of available neuropils.
Returns
-------
meshes : single navis.Volume or list thereof
Examples
--------
Load a single volume:
>>> from fafbseg import flywire
>>> al_r = flywire.get_neuropil_volumes('AL_R')
>>> al_r
<navis.Volume(name=AL_R, color=(0.85, 0.85, 0.85, 0.2), vertices.shape=(622, 3), faces.shape=(1240, 3))>
Load multiple volumes:
>>> from fafbseg import flywire
>>> al_lr = flywire.get_neuropil_volumes(['AL_R', 'AL_L'])
>>> al_lr
[<navis.Volume(name=AL_R, color=(0.85, 0.85, 0.85, 0.2), vertices.shape=(622, 3), faces.shape=(1240, 3))>,
<navis.Volume(name=AL_L, color=(0.85, 0.85, 0.85, 0.2), vertices.shape=(612, 3), faces.shape=(1228, 3))>]
Get a list of available volumes:
>>> from fafbseg import flywire
>>> available = flywire.get_neuropil_volumes(None)
"""
if navis.utils.is_iterable(neuropils):
return [get_neuropil_volumes(n) for n in neuropils]
with ZipFile(data_path / 'JFRC2NP.surf.fw.zip', 'r') as zip:
try:
f = zip.read(f'{neuropils}.stl')
except KeyError:
available = []
for file in zip.filelist:
fname = file.filename.split('/')[-1]
if not fname.endswith('.stl') or fname.startswith('.'):
continue
available.append(fname.replace('.stl', ''))
available = sorted(available)
if neuropils:
raise ValueError(f'No mesh for neuropil "{neuropils}". Available '
f'neuropils: {", ".join(available)}')
else:
return np.array(available)
f = zip.read(f'{neuropils}.stl')
m = tm.load_mesh(BytesIO(f), file_type='stl')
return navis.Volume(m, name=neuropils)
def get_synapse_areas(ind):
"""Lazy-load synapse areas (neuropils).
Parameters
----------
ind : (N, ) iterable
Synapse indices (shows up as `id` in synapse table).
Returns
-------
areas : (N, ) array
Array with neuropil name for each synapse. Unassigned synapses
come back as "NA".
"""
global area_ids, vol_names
if isinstance(area_ids, type(None)):
area_ids = np.load(data_path / 'global_area_ids.npy.zip')['global_area_ids']
with open(data_path / 'volume_name_dict.json') as f:
vol_names = json.load(f)
vol_names = {int(k): v for k, v in vol_names.items()}
vol_names[-1] = 'NA'
return np.array([vol_names[i] for i in area_ids[ind]])
@functools.lru_cache
def get_cave_datastacks():
"""Get available CAVE datastacks."""
return CAVEclient().info.get_datastacks()
@functools.lru_cache
def get_datastack_segmentation_source(datastack):
"""Get segmentation source for given CAVE datastack."""
return CAVEclient().info.get_datastack_info(datastack_name=datastack)['segmentation_source']
@inject_dataset()
def get_cave_client(*, dataset=None, token=None, check_stale=True,
force_new=False):
"""Get CAVE client.
Currently, the CAVE client pulls the available materialization versions
ONCE on initialization. This means that if the same client is used for over
24h it will be unaware of any new materialization versions which will slow
down live queries substantially. We try to detect whether the client may
have gone stale but this may not always work perfectly.
Parameters
----------
dataset : str
Data set to create client for.
token : str, optional
Your chunked graph secret (i.e. "CAVE secret"). If not
provided will try reading via cloud-volume.
check_stale : bool
Check if any existing client has gone stale. Currently, we
check if the cached materialization meta data needs
refreshing and we automatically refresh the client every
hour.
force_new : bool
If True, we force a re-initialization.
Returns
-------
CAVEclient
"""
if not token:
token = get_chunkedgraph_secret()
datastack = CAVE_DATASETS.get(dataset, dataset)
if datastack in cave_clients and not force_new and check_stale:
# Get the existing client
client = cave_clients[datastack]
# Get the (likely cached) materialization meta data
mds = client.materialize.get_versions_metadata()
# Check if any of the versions are expired
now = pytz.UTC.localize(dt.datetime.utcnow())
for v in mds:
if v['expires_on'] <= now:
force_new = True
break
# Over the weekend no new versions are materialized. The last version
# from Friday will persist into middle of the next week - i.e. not
# expire on Monday. Therefore, on Mondays only, we will also
# force an update if the client is older than 30 minutes
if now.weekday() in (0, ) and not force_new:
if (dt.datetime.now() - client._created_at) > dt.timedelta(minutes=30):
force_new = True
if datastack not in cave_clients or force_new:
cave_clients[datastack] = CAVEclient(datastack, auth_token=token)
cave_clients[datastack]._created_at = dt.datetime.now()
# The public datastack configuration currently does not set the .synapse_table
# That's intentional to avoid people using it - they are supposed to use the filtered view
# However, we want to enable our users to do that if they want, so we will add it back
client = cave_clients[datastack]
if client.materialize.synapse_table is None:
if "synapses_nt_v1" in client.materialize.get_tables():
client.materialize.synapse_table = "synapses_nt_v1"
return client
def get_chunkedgraph_secret(domain=('global.daf-apis.com', 'prod.flywire-daf.com')):
"""Get local FlyWire chunkedgraph/CAVE secret.
Parameters
----------
domain : str | list thereof
Domain to get the secret for.
Returns
-------
token : str
"""
if isinstance(domain, str):
domain = [domain]
token = None
for dom in domain:
token = cv.secrets.cave_credentials(dom).get('token', None)
if token:
break
if not token:
raise ValueError(f'No chunkedgraph/CAVE secret for domain(s) {domain} '
'found. Please see fafbseg.flywire.set_chunkedgraph_secret '
'to store your API token.')
return token
def set_chunkedgraph_secret(token, overwrite=False, **kwargs):
"""Set FlyWire chunkedgraph/CAVE secret.
This is just a thin wrapper around ``caveclient.CAVEclient.auth.save_token()``.
Parameters
----------
token : str
Get your token from
https://global.daf-apis.com/auth/api/v1/user/token. If that URL
returns an empty list ``[]`` you should visit
https://global.daf-apis.com/auth/api/v1/create_token instead.
overwrite : bool
Whether to overwrite any existing secret.
**kwargs
Keyword arguments are passed through to
``caveclient.CAVEclient.save_token()``.
"""
assert isinstance(token, str), f'Token must be string, got "{type(token)}"'
# Save token
CAVEclient().auth.save_token(token, overwrite=overwrite, **kwargs)
# We need to reload cloudvolume for changes to take effect
reload(cv.secrets)
reload(cv)
# Should also reset the volume after setting the secret
global fw_vol
fw_vol = None
print("Token succesfully stored.")
def parse_root_ids(x):
"""Parse root IDs.
Always returns an array of integers.
"""
if isinstance(x, navis.BaseNeuron):
ids = [x.id]
elif isinstance(x, navis.NeuronList):
ids = x.id
elif isinstance(x, (int, np.int64)):
ids = [x]
else:
ids = utils.make_iterable(x, force_type=np.int64)
# Make sure we are working with proper numerical IDs
try:
return np.asarray(ids, dtype=np.int64)
except ValueError:
raise ValueError(f'Unable to convert given root IDs to integer: {ids}')
except BaseException:
raise
def get_cloudvolume(dataset, **kwargs):
"""Get CloudVolume for given dataset."""
# If this already is a CloudVolume just pass it through
if "CloudVolume" in str(type(dataset)):
return dataset
else:
if not isinstance(dataset, str):
raise ValueError(f'Unable to initialize CloudVolume from "{type(dataset)}"')
# Translate into a URL
if not utils.is_url(dataset):
# Map "production" and "sandbox" to their URLs
if dataset in FLYWIRE_URLS:
dataset = FLYWIRE_URLS[dataset]
# Failing that, see if CAVE knows about them
elif dataset in get_cave_datastacks():
dataset = get_datastack_segmentation_source(dataset)
# Otherwise we will assume that this already is a segmentation URL
# Add this volume if it does not already exists
if dataset not in cloud_volumes:
# Set and update defaults from kwargs
defaults = dict(mip=0,
fill_missing=True,
cache=False,
use_https=True, # this way google secret is not needed
progress=False)
defaults.update(kwargs)
# Check if chunkedgraph secret exists
# This probably needs yanking!
#secret = os.path.expanduser('~/.cloudvolume/secrets/chunkedgraph-secret.json')
#if not os.path.isfile(secret):
# # If not secrets but environment variable use this
# if 'CHUNKEDGRAPH_SECRET' in os.environ and 'secrets' not in defaults:
# defaults['secrets'] = {'token': os.environ['CHUNKEDGRAPH_SECRET']}
cloud_volumes[dataset] = cv.CloudVolume(dataset, **defaults)
cloud_volumes[dataset].path = dataset
return cloud_volumes[dataset]
def retry(func, retries=5, cooldown=2):
"""Retry function on HTTPError.
This also suppresses UserWarnings (because we typically use this for stuff
like the l2 Cache).
Parameters
----------
cooldown : int | float
Cooldown period in seconds between attempts.
retries : int
Number of retries before we give up. Every subsequent retry
will delay by an additional `retry`.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
for i in range(1, retries + 1):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
return func(*args, **kwargs)
except KeyboardInterrupt:
raise
except requests.RequestException:
if i >= retries:
raise
except BaseException:
raise
time.sleep(cooldown * i)
return wrapper
def parse_bounds(x):
"""Parse bounds.
Parameters
----------
x : (3, 2) array | (2, 3) array | None
Returns
-------
bounds : (3, 2) np.array
"""
if isinstance(x, type(None)):
return x
x = np.asarray(x)
if not x.ndim == 2 or x.shape not in [(3, 2), (2, 3)]:
raise ValueError('Must provide bounding box as (3, 2) or (2, 3) array, '
f'got {x.shape}')
if x.shape == (2, 3):
x = x.T
return np.vstack((x.min(axis=1), x.max(axis=1))).T
[docs]
def get_lr_position(x, coordinates='nm'):
"""Find out if given xyz positions are on the fly's left or right.
This works by:
1. Mirror positions from one side to the other (requires `flybrains`)
2. Substracting original from the mirrored x-coordinate
Parameters
----------
x : (N, 3) array | TreeNeuron | MeshNeuron | Dotprops
Array of xyz coordinates or a neuron. If a navis neuron,
will use nodes, vertex or point coordinates for TreeNeurons,
MeshNeurons and Dotprops, respectively.
coordinates : "nm" | "voxel"
Whether coordinates are in nm or voxel space.
Returns
-------
xm : (N, ) array
A vector of point displacements in nanometers where 0 is
at the midline and positive values are to the fly's right.
Examples
--------
>>> from fafbseg import flywire
>>> # Three example points: right, left, ~center
>>> flywire.get_lr_position([[104904, 47464, 5461],
... [140648, 49064, 2262],
... [131256, 29984, 2358]],
... coordinates='voxel')
array([110501.5, -39480. , 306.5])
"""
try:
import flybrains
except ImportError:
raise ImportError('This function requires `flybrains` to be '
'installed:\n pip3 install flybrains')
# The FlyWire mirror registration is only part of the most recent version
try:
_ = navis.transforms.registry.find_template('FLYWIRE')
except ValueError:
raise ImportError('Looks like your version of `flybrains` is outdated. '
'Please update:\n pip3 install flybrains -U')
navis.utils.eval_param(coordinates, name='coordinates',
allowed_values=('nm', 'nanometer', 'nanometers',
'voxel', 'voxels'))
if navis.utils.is_iterable(x):
x = np.asarray(x)
elif isinstance(x, pd.DataFrame):
if x.shape[1] == 3:
x = x.values
elif all([c in x.columns for c in ['x', 'y', 'z']]):
x = x[['x', 'y', 'z']].values
elif isinstance(x, navis.TreeNeuron):
x = x.nodes[['x', 'y', 'z']].values
elif isinstance(x, navis.MeshNeuron):
x = x.vertices
elif isinstance(x, navis.Dotprops):
x = x.points
if not isinstance(x, np.ndarray):
raise TypeError(f'Expected numpy array or neuron, got "{type(x)}"')
elif x.ndim != 2 or x.shape[1] != 3:
raise TypeError(f'Expected (N, 3) numpy array, got {x.shape}')
# Scale if required
if coordinates in ('voxel', 'voxels'):
x = x * [4, 4, 40]
# Mirror -> this should be using the landmark-based transform in flybrains
m = navis.mirror_brain(x, template='FLYWIRE')
return (m[:, 0] - x[:, 0]) / 2
[docs]
@inject_dataset()
def find_mat_version(ids,
verbose=True,
allow_multiple=False,
raise_missing=True,
dataset=None):
"""Find a materialization version (or live) for given IDs.
Parameters
----------
ids : iterable
Root IDs to check.
verbose : bool
Whether to print results of search. See also the
`flywire.utils.silence_find_mat_version` context manager to
silence output.
allow_multiple : bool
If True, will track if IDs can be found spread across multiple
materialization versions if there is no single one containing
all.
raise_missing : bool
Only relevant if `allow_multiple=True`. If False, will return
versions even if some IDs could not be found.
Returns
-------
version : int | "live"
A single version (including "live") that contains all given
root IDs.
versions : np.ndarray
If no single version was found and `allow_multiple=True` will
return a vector of `len(ids)` with the latest version at which
the respective ID can be found.
Important: "live" version will be return as -1!
If `raise_missing=False` and one or more root IDs could not
be found in any of the available materialization versions
these IDs will be return as version 0.
"""
# If dataset is the flat segmentation we can take a shortcut
if dataset == 'flat_630':
return 630
elif dataset == 'flat_571':
return 571
ids = np.asarray(ids)
client = get_cave_client(dataset=dataset)
# For each ID track the most recent valid version
latest_valid = np.zeros(len(ids), dtype=np.int32)
# Go over each version (start with the most recent)
for i, version in enumerate(sorted(client.materialize.get_versions(), reverse=True)):
ts_m = client.materialize.get_timestamp(version)
# Check which root IDs were valid at the time
is_valid = client.chunkedgraph.is_latest_roots(ids, timestamp=ts_m)
# Update latest valid versions
latest_valid[(latest_valid == 0) & is_valid] = version
if all(is_valid):
if verbose and not SILENCE_FIND_MAT_VERSION:
print(f'Using materialization version {version}.')
return version
# If no single materialized version can be found, see if we can get
# by with the live materialization
is_latest = client.chunkedgraph.is_latest_roots(ids, timestamp=None)
latest_valid[(latest_valid == 0) & is_latest] = -1 # track "live" as -1
if all(is_latest) and dataset != 'public': # public does not have live
if verbose and not SILENCE_FIND_MAT_VERSION:
print('Using live materialization')
return 'live'
if allow_multiple and any(latest_valid != 0):
if all(latest_valid != 0):
if verbose and not SILENCE_FIND_MAT_VERSION:
print(f"Found root IDs spread across {len(np.unique(latest_valid))} "
"materialization versions.")
return latest_valid
msg = (f"Found root IDs spread across {len(np.unique(latest_valid)) - 1} "
f"materialization versions but {(latest_valid == 0).sum()} IDs "
"do not exist in any of the materialized tables.")
if not raise_missing:
if verbose and not SILENCE_FIND_MAT_VERSION:
print(msg)
return latest_valid
else:
raise MaterializationMatchError(msg)
if dataset not in ('public, '):
raise MaterializationMatchError(
'Given root IDs do not (co-)exist in any of the available '
'materialization versions (including live). Try updating '
'root IDs and rerun your query.')
else:
raise MaterializationMatchError(
'Given root IDs do not (co-)exist in any of the available '
'public materialization versions. Please make sure that '
'the root IDs do exist and rerun your query.')
def _is_valid_version(ids, version, dataset):
"""Test if materialization version is valid for given root IDs."""
client = get_cave_client(dataset=dataset)
# If this is not even a valid version (for this dataset) return False
if version not in client.materialize.get_versions():
return False
ts_m = client.materialize.get_timestamp(version)
# Check which root IDs were valid at the time
is_valid = client.chunkedgraph.is_latest_roots(ids, timestamp=ts_m)
if all(is_valid):
return True
return False
def package_timestamp(timestamp, name="timestamp"):
# Copied from caveclient
if timestamp is None:
query_d = {}
else:
if timestamp.tzinfo is None:
timestamp = pytz.UTC.localize(timestamp)
else:
timestamp = timestamp.astimezone(dt.timezone.utc)
query_d = {name: timestamp.timestamp()}
return query_d
class silence_find_mat_version:
def __enter__(self):
global SILENCE_FIND_MAT_VERSION
SILENCE_FIND_MAT_VERSION = True
def __exit__(self, exc_type, exc_value, exc_tb):
global SILENCE_FIND_MAT_VERSION
SILENCE_FIND_MAT_VERSION = False
class MaterializationMatchError(Exception):
pass