Source code for fafbseg.flywire.l2

#    A collection of tools to interface with manually traced and autosegmented
#    data in FAFB.
#
#    Copyright (C) 2019 Philipp Schlegel
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#    GNU General Public License for more details.

"""Functions to extract skeletons from L2 graphs.

Heavily borrows from code from Casey Schneider-Mizell's "pcg_skel"
(https://github.com/AllenInstitute/pcg_skel).

"""

import navis
import fastremap

import networkx as nx
import numpy as np
import pandas as pd
import skeletor as sk
import trimesh as tm

from concurrent.futures import ThreadPoolExecutor
from functools import partial

from .annotations import parse_neuroncriteria
from .utils import get_cloudvolume, get_cave_client, retry, inject_dataset

__all__ = ['get_l2_skeleton', 'get_l2_dotprops', 'get_l2_graph', 'get_l2_info',
           'find_anchor_loc']


[docs] @parse_neuroncriteria() @inject_dataset() def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None): """Fetch basic info for given neuron(s) using the L2 cache. Parameters ---------- root_ids : int | list of ints | NeuronCriteria FlyWire root ID(s) for which to fetch L2 infos. progress : bool Whether to show a progress bar. max_threads : int Number of parallel requests to make. dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). Returns ------- pandas.DataFrame DataFrame with basic info (also see Examples): - `length_um` is the sum of the max diameter across all L2 chunks; note that this severely underestimates the actual length (factor >10) but is still useful for relative comparisons - `bounds_nm` is a very rough bounding box based on the representative coordinates of the L2 chunks - `chunks_missing` is the number of L2 chunks not present in the L2 cache Examples -------- >>> from fafbseg import flywire >>> info = flywire.get_l2_info(720575940614131061) >>> info # doctest: +SKIP root_id l2_chunks chunks_missing area_um2 size_um3 length_um ... 0 720575940614131061 286 0 2378.16384 163.876526 60.666 ... """ if navis.utils.is_iterable(root_ids): root_ids = np.unique(root_ids) info = [] with ThreadPoolExecutor(max_workers=max_threads) as pool: func = partial(get_l2_info, dataset=dataset) futures = pool.map(func, root_ids) info = [ f for f in navis.config.tqdm( futures, desc="Fetching L2 info", total=len(root_ids), disable=not progress or len(root_ids) == 1, leave=False, ) ] return pd.concat(info, axis=0).reset_index(drop=True) # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) get_l2_ids = partial(retry(client.chunkedgraph.get_leaves), stop_layer=2) l2_ids = get_l2_ids(root_ids) attributes = ["area_nm2", "size_nm3", "max_dt_nm", "rep_coord_nm"] get_l2data = retry(client.l2cache.get_l2data) info = get_l2data(l2_ids.tolist(), attributes=attributes) n_miss = len([v for v in info.values() if not v]) row = [root_ids, len(l2_ids), n_miss] info_df = pd.DataFrame([row], columns=["root_id", "l2_chunks", "chunks_missing"]) # Collect L2 attributes for at in attributes: if at in ("rep_coord_nm",): continue summed = sum([v.get(at, 0) for v in info.values()]) if at.endswith("3"): summed /= 1000**3 elif at.endswith("2"): summed /= 1000**2 else: summed /= 1000 info_df[at.replace("_nm", "_um")] = [summed] # Check bounding box pts = np.array([v["rep_coord_nm"] for v in info.values() if v]) if len(pts) > 1: bounds = [v for l in zip(pts.min(axis=0), pts.max(axis=0)) for v in l] elif len(pts) == 1: pt = pts[0] rad = [v["max_dt_nm"] for v in info.values() if v][0] / 2 bounds = [ pt[0] - rad, pt[0] + rad, pt[1] - rad, pt[1] + rad, pt[2] - rad, pt[2] + rad, ] bounds = [int(co) for co in bounds] else: bounds = None info_df["bounds_nm"] = [bounds] info_df.rename({"max_dt_um": "length_um"}, axis=1, inplace=True) return info_df
@parse_neuroncriteria() @inject_dataset() def get_l2_chunk_info(l2_ids, progress=True, chunk_size=2000, *, dataset=None): """Fetch info for given L2 chunks. Parameters ---------- l2_ids : int | list of ints | NeuronCriteria FlyWire root ID(s) for which to fetch L2 infos. progress : bool Whether to show a progress bar. chunksize : int Number of L2 IDs per query. dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). Returns ------- pandas.DataFrame """ # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) # Get the L2 representative coordinates, vectors and (if required) volume attributes = ['rep_coord_nm', 'pca', 'size_nm3'] l2_info = {} with navis.config.tqdm(desc='Fetching L2 info', disable=not progress, total=len(l2_ids), leave=False) as pbar: func = retry(client.l2cache.get_l2data) for chunk_ix in np.arange(0, len(l2_ids), chunk_size): chunk = l2_ids[chunk_ix: chunk_ix + chunk_size] l2_info.update(func(chunk.tolist(), attributes=attributes)) pbar.update(len(chunk)) # L2 chunks without info will show as empty dictionaries # Let's drop them to make our life easier (speeds up indexing too) l2_info = {k: v for k, v in l2_info.items() if v} if l2_info: pts = np.vstack([i['rep_coord_nm'] for i in l2_info.values()]) vec = np.vstack([i.get('pca', [[None, None, None]])[0] for i in l2_info.values()]) sizes = np.array([i['size_nm3'] for i in l2_info.values()]) info_df = pd.DataFrame() info_df['id'] = list(l2_info.keys()) info_df['x'] = (pts[:, 0] / 4).astype(int) info_df['y'] = (pts[:, 1] / 4).astype(int) info_df['z'] = (pts[:, 2] / 40).astype(int) info_df['vec_x'] = vec[:, 0] info_df['vec_y'] = vec[:, 1] info_df['vec_z'] = vec[:, 2] info_df['size_nm3'] = sizes else: info_df = pd.DataFrame([], columns=['id', 'x', 'y', 'z', 'vec_x', 'vec_y', 'vec_z', 'size_nm3']) return info_df
[docs] @parse_neuroncriteria() @inject_dataset() def find_anchor_loc(root_ids, validate=False, max_threads=4, progress=True, *, dataset=None): """Find a representative coordinate. This works by querying the L2 cache and using the representative coordinate for the largest L2 chunk. Parameters ---------- root_ids : int | list thereof | NeuronCriteria Root ID(s) to get coordinate for. validate : bool If True, will validate the x/y/z position. I have yet to encounter a representative coordinate that wasn't mapping to the correct L2 chunk - therefore this parameter is False by default. max_threads : int Number of parallel threads to use. dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). Returns ------- pandas.DataFrame """ if navis.utils.is_iterable(root_ids): root_ids = np.asarray(root_ids).astype(np.int64) root_ids_unique = np.unique(root_ids) info = [] with ThreadPoolExecutor(max_workers=max_threads) as pool: func = partial(find_anchor_loc, dataset=dataset, validate=False, progress=False) futures = pool.map(func, root_ids_unique) info = [f for f in navis.config.tqdm(futures, desc='Fetching locations', total=len(root_ids_unique), disable=not progress or len(root_ids_unique) == 1, leave=False)] df = pd.concat(info, axis=0, ignore_index=True) # Validate if validate: has_loc = ~df.x.isnull() if any(has_loc): from .segmentation import locs_to_supervoxels sv = locs_to_supervoxels(df.loc[has_loc, ['x', 'y', 'z']].values) df['supervoxel'] = None df.loc[has_loc, 'supervoxel'] = sv.astype(str) # do not change str # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) # Get root timestamps ts = client.chunkedgraph.get_root_timestamps(df.root_id.values.tolist()) df['valid'] = False for i in navis.config.trange(len(df), desc='Validating', disable=not progress or len(df) == 1, leave=False): if df.supervoxel.values[i]: sv = np.int64(df.supervoxel.values[i]) r = client.chunkedgraph.get_root_id(sv, timestamp=ts[i]) df.loc[i, 'valid'] = r == df.root_id.values[i] # Make sure the original order is retained df = df.set_index('root_id').loc[root_ids].reset_index(drop=False) return df root_ids = np.int64(root_ids) # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) get_l2_ids = partial(retry(client.chunkedgraph.get_leaves), stop_layer=2) l2_ids = get_l2_ids(root_ids) get_l2data = retry(get_l2_chunk_info) info = get_l2data(l2_ids, progress=progress) if info.empty: loc = [None, None, None] else: info.sort_values('size_nm3', ascending=False, inplace=True) loc = info[['x', 'y', 'z']].values[0].tolist() df = pd.DataFrame([[root_ids] + loc], columns=['root_id', 'x', 'y', 'z']) if validate: if not isinstance(loc[0], type(None)): from .segmentation import locs_to_supervoxels sv = locs_to_supervoxels([loc])[0] df['supervoxel'] = sv if sv: ts = client.chunkedgraph.get_root_timestamps(root_ids)[0] r = client.chunkedgraph.get_root_id(sv, timestamp=ts) df['valid'] = r == root_ids else: df['valid'] = False return df
[docs] @parse_neuroncriteria() @inject_dataset() def get_l2_graph(root_ids, progress=True, *, dataset=None): """Fetch L2 graph(s). Parameters ---------- root_ids : int | list of ints | NeuronCriteria FlyWire root ID(s) for which to fetch the L2 graphs. progress : bool Whether to show a progress bar. dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). Returns ------- networkx.Graph The L2 graph or list thereof. Examples -------- >>> from fafbseg import flywire >>> G = flywire.get_l2_graph(720575940614131061) """ if navis.utils.is_iterable(root_ids): graphs = [] for id in navis.config.tqdm(root_ids, desc='Fetching', disable=not progress or len(root_ids) == 1, leave=False): n = get_l2_graph(id, dataset=dataset) graphs.append(n) return graphs # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) # Load the L2 graph for given root ID # This is a (N,2) array of edges l2_eg = np.array(client.chunkedgraph.level2_chunk_graph(root_ids)) # Generate graph G = nx.Graph() if not len(l2_eg): # If no edges, this neuron consists of a single chunk # Get the single chunk's ID chunks = client.chunkedgraph.get_leaves(root_ids, stop_layer=2) G.add_nodes_from(chunks) else: # Drop duplicate edges l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0) G.add_edges_from(l2_eg) return G
[docs] @parse_neuroncriteria() @inject_dataset() def get_l2_skeleton(root_id, refine=True, drop_missing=True, l2_node_ids=False, omit_failures=None, progress=True, max_threads=4, *, dataset=None, **kwargs): """Generate skeleton from L2 graph. Parameters ---------- root_id : int | list of ints | NeuronCriteria Root ID(s) of the FlyWire neuron(s) you want to skeletonize. refine : bool If True, will refine skeleton nodes by moving them in the center of their corresponding chunk meshes. This uses the L2 cache (see :func:`fafbseg.flywire.get_l2_info`). drop_missing : bool Only relevant if ``refine=True``: If True, will drop chunks that don't exist in the L2 cache. These are typically chunks that are either very small or new. If False, chunks missing from L2 cache will be kept but with their unrefined, approximate position. l2_node_ids : bool If True, will use the L2 IDs as node IDs (instead of just enumerating the nodes). omit_failures : bool, optional Determine behaviour when skeleton generation fails (e.g. if the neuron has only a single chunk): - ``None`` (default) will raise an exception - ``True`` will skip the offending neuron (might result in an empty ``NeuronList``) - ``False`` will return an empty ``TreeNeuron`` progress : bool Whether to show a progress bar. max_threads : int Number of parallel requests to make when fetching the L2 skeletons. dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). **kwargs Keyword arguments are passed through to the `TreeNeuron` initialization. Use to e.g. set extra properties. Returns ------- skeleton(s) : navis.TreeNeuron | navis.NeuronList The extracted L2 skeleton. See Also -------- :func:`fafbseg.flywire.get_l2_dotprops` Create dotprops instead of skeletons (faster and possibly more accurate). :func:`~fafbseg.flywire.get_skeletons` Fetch precomputed full resolution skeletons. Only available for proofread neurons and for certain materialization versions. :func:`fafbseg.flywire.skeletonize_neuron` Skeletonize the full resolution mesh. Examples -------- >>> from fafbseg import flywire >>> n = flywire.get_l2_skeleton(720575940614131061) """ if omit_failures not in (None, True, False): raise ValueError('`omit_failures` must be either None, True or False. ' f'Got "{omit_failures}".') if navis.utils.is_iterable(root_id): root_id = np.asarray(root_id, dtype=np.int64) get_l2_skels = partial(get_l2_skeleton, refine=refine, drop_missing=drop_missing, omit_failures=omit_failures, dataset=dataset, **kwargs) if (max_threads > 1) and (len(root_id) > 1): with ThreadPoolExecutor(max_workers=max_threads) as pool: futures = pool.map(get_l2_skels, root_id) nl = [f for f in navis.config.tqdm(futures, desc='Fetching L2 skeletons', total=len(root_id), disable=not progress or len(root_id) == 1, leave=False)] else: nl = [get_l2_skels(r) for r in navis.config.tqdm(root_id, desc='Fetching L2 skeletons', total=len(root_id), disable=not progress or len(root_id) == 1, leave=False)] # Turn into neuron list nl = navis.NeuronList(nl) # Bring in original order if len(nl): root_id = root_id[np.isin(root_id, nl.id)] nl = nl.idx[root_id] return nl # Turn into integer root_id = np.int64(root_id) # Get the cloudvolume vol = get_cloudvolume(dataset) # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) # Load the L2 graph for given root ID (this is a (N, 2) array of edges) get_l2_edges = retry(client.chunkedgraph.level2_chunk_graph) l2_eg = get_l2_edges(root_id) # If no edges, we can't create a skeleton if not len(l2_eg): msg = (f'Unable to create L2 skeleton: root ID {root_id} ' 'consists of only a single L2 chunk.') if omit_failures is None: raise ValueError(msg) navis.config.logger.warning(msg) if omit_failures: # If omission simply return an empty NeuronList return navis.NeuronList([]) # If no omission, return empty TreeNeuron else: return navis.TreeNeuron(None, id=root_id, units='1 nm', **kwargs) # Drop duplicate edges l2_eg = np.unique(np.sort(l2_eg, axis=1), axis=0) # Unique L2 IDs l2_ids = np.unique(l2_eg) # ID to index l2dict = {l2: ii for ii, l2 in enumerate(l2_ids)} # Remap edge graph to indices eg_arr_rm = fastremap.remap(l2_eg, l2dict) coords = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids] coords = np.vstack(coords) # This turns the graph into a hierarchal tree by removing cycles and # ensuring all edges point towards a root if sk.__version_vector__[0] < 1: G = sk.skeletonizers.edges_to_graph(eg_arr_rm) swc = sk.skeletonizers.make_swc(G, coords=coords) else: G = sk.skeletonize.utils.edges_to_graph(eg_arr_rm) swc = sk.skeletonize.utils.make_swc(G, coords=coords, reindex=False) # Set radius to 0 swc['radius'] = 0 # Convert to Euclidian space # Dimension of a single chunk ch_dims = chunks_to_nm([1, 1, 1], vol) - chunks_to_nm([0, 0, 0], vol) ch_dims = np.squeeze(ch_dims) xyz = swc[['x', 'y', 'z']].values swc[['x', 'y', 'z']] = chunks_to_nm(xyz, vol) + ch_dims / 2 if refine: # Get the L2 representative coordinates get_l2data = retry(client.l2cache.get_l2data) l2_info = get_l2data(l2_ids.tolist(), attributes=['rep_coord_nm', 'max_dt_nm']) # Missing L2 chunks will be {'id': {}} new_co = {l2dict[np.int64(k)]: v['rep_coord_nm'] for k, v in l2_info.items() if v} new_r = {l2dict[np.int64(k)]: v.get('max_dt_nm', 0) for k, v in l2_info.items() if v} # Map refined coordinates onto the SWC has_new = swc.node_id.isin(new_co) # Only apply if we actually have new coordinates - otherwise there # the datatype is changed to object for some reason... if any(has_new): swc.loc[has_new, 'x'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][0]) swc.loc[has_new, 'y'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][1]) swc.loc[has_new, 'z'] = swc.loc[has_new, 'node_id'].map(lambda x: new_co[x][2]) swc['radius'] = swc.node_id.map(new_r) # Turn into a proper neuron tn = navis.TreeNeuron(swc, id=root_id, units='1 nm', **kwargs) # Drop nodes that are still at their unrefined chunk position if drop_missing: frac_refined = has_new.sum() / len(has_new) if not any(has_new): msg = (f'Unable to refine: no L2 info for root ID {root_id} ' 'available. Set `drop_missing=False` to use unrefined ' 'positions.') if omit_failures is None: raise ValueError(msg) elif omit_failures: return navis.NeuronList([]) # If no omission, return empty TreeNeuron else: return navis.TreeNeuron(None, id=root_id, units='1 nm', **kwargs) elif frac_refined < .5: msg = (f'Root ID {root_id} has only {frac_refined:.1%} of their ' 'L2 IDs in the cache. Set `drop_missing=False` to use ' 'unrefined positions.') navis.config.logger.warning(msg) tn = navis.remove_nodes(tn, swc.loc[~has_new, 'node_id'].values) tn._l2_chunks_missing = (~has_new).sum() else: tn = navis.TreeNeuron(swc, id=root_id, units='1 nm', **kwargs) if l2_node_ids: ixdict = {ii: l2 for ii, l2 in enumerate(l2_ids)} tn.nodes['node_id'] = tn.nodes.node_id.map(ixdict) tn.nodes['parent_id'] = tn.nodes.parent_id.map(lambda x: ixdict.get(x, -1)) return tn
[docs] @parse_neuroncriteria() @inject_dataset() def get_l2_dotprops(root_ids, min_size=None, sample=False, omit_failures=None, progress=True, max_threads=10, *, dataset=None, **kwargs): """Generate dotprops from L2 chunks. L2 chunks not present in the L2 cache or without a `pca` attribute (happens for very small chunks) are silently ignored. Parameters ---------- root_ids : int | list of ints | NeuronCriteria Root ID(s) of the FlyWire neuron(s) you want to dotprops for. min_size : int, optional Minimum size (in nm^3) for the L2 chunks. Smaller chunks will be ignored. This is useful to de-emphasise the finer terminal neurites which typically break into more, smaller chunks and are hence overrepresented. A good value appears to be around 1_000_000. sample : float [0 > 1], optional If float, will create Dotprops based on a fractional sample of the L2 chunks. The sampling is random but deterministic. omit_failures : bool, optional Determine behaviour when dotprops generation fails (i.e. if the neuron has no L2 info): - ``None`` (default) will raise an exception - ``True`` will skip the offending neuron (might result in an empty ``NeuronList``) - ``False`` will return an empty ``Dotprops`` progress : bool Whether to show a progress bar. max_threads : int Number of parallel requests to make when fetching the L2 IDs (but not the L2 info). dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). **kwargs Keyword arguments are passed through to the `Dotprops` initialization. Use to e.g. set extra properties. Returns ------- dps : navis.NeuronList List of Dotprops. See Also -------- :func:`fafbseg.flywire.get_l2_skeleton` Fetch skeletons instead of dotprops using the L2 edges to infer connectivity. :func:`fafbseg.flywire.skeletonize_neuron` Skeletonize the full resolution mesh. Examples -------- >>> from fafbseg import flywire >>> n = flywire.get_l2_dotprops(720575940614131061) """ if omit_failures not in (None, True, False): raise ValueError('`omit_failures` must be either None, True or False. ' f'Got "{omit_failures}".') if not navis.utils.is_iterable(root_ids): root_ids = [root_ids] root_ids = np.asarray(root_ids, dtype=np.int64) if '0' in root_ids or 0 in root_ids: raise ValueError('Unable to produce dotprops for root ID 0.') # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) # Load the L2 IDs with ThreadPoolExecutor(max_workers=max_threads) as pool: get_l2_ids = partial(retry(client.chunkedgraph.get_leaves), stop_layer=2) futures = pool.map(get_l2_ids, root_ids) l2_ids = [f for f in navis.config.tqdm(futures, desc='Fetching L2 IDs', total=len(root_ids), disable=not progress or len(root_ids) == 1, leave=False)] # Turn IDs into strings l2_ids = [i.astype(str) for i in l2_ids] if sample: if sample <= 0 or sample >= 1: raise ValueError(f'`sample` must be between 0 and 1, got {sample}') for i in range(len(l2_ids)): # Make the sampling deterministic np.random.seed(1985) l2_ids[i] = np.random.choice(l2_ids[i], size=max(1, int(len(l2_ids[i]) * sample)), replace=False) # Flatten into a list of all L2 IDs l2_ids_all = np.unique([i for l in l2_ids for i in l]) # Get the L2 representative coordinates, vectors and (if required) volume chunk_size = 2000 # no. of L2 IDs per query (doesn't seem have big impact) attributes = ['rep_coord_nm', 'pca'] if min_size: attributes.append('size_nm3') l2_info = {} with navis.config.tqdm(desc='Fetching L2 vectors', disable=not progress, total=len(l2_ids_all), leave=False) as pbar: get_l2data = retry(client.l2cache.get_l2data) for chunk_ix in np.arange(0, len(l2_ids_all), chunk_size): chunk = l2_ids_all[chunk_ix: chunk_ix + chunk_size] l2_info.update(get_l2data(chunk.tolist(), attributes=attributes)) pbar.update(len(chunk)) # L2 chunks without info will show as empty dictionaries # Let's drop them to make our life easier (speeds up indexing too) # Note that small L2 chunks won't have a `pca` entry l2_info = {k: v for k, v in l2_info.items() if 'pca' in v} # Generate dotprops dps = [] for root, ids in navis.config.tqdm(zip(root_ids, l2_ids), desc='Creating dotprops', total=len(root_ids), disable=not progress or len(root_ids) <= 1, leave=False): # Get xyz points and the first component of the PCA as vector # Note that first subsetting IDs to what's actually available in # `l2_info` is actually slower than doing it like this this_info = [l2_info[i] for i in ids if i in l2_info] if not len(this_info): msg = ('Unable to create L2 dotprops: none of the L2 chunks for ' f'root ID {root} are present in the L2 cache.') if omit_failures is None: raise ValueError(msg) if not omit_failures: # If no omission, add empty Dotprops dps.append(navis.Dotprops(None, k=None, id=root, units='1 nm', **kwargs)) dps[-1]._l2_chunks_missing = len(ids) continue pts = np.vstack([i['rep_coord_nm'] for i in this_info]) vec = np.vstack([i['pca'][0] for i in this_info]) # Apply min size filter if requested if min_size: sizes = np.array([i['size_nm3'] for i in this_info]) pts = pts[sizes >= min_size] vec = vec[sizes >= min_size] # Generate the actual dotprops dps.append(navis.Dotprops(points=pts, vect=vec, id=root, k=None, units='1 nm', **kwargs)) dps[-1]._l2_chunks_missing = len(ids) - len(this_info) return navis.NeuronList(dps)
@inject_dataset() def get_l2_meshes(x, threads=10, progress=True, *, dataset=None): """Fetch L2 meshes for a given neuron. Parameters ---------- x : int | str Root ID. threads : int progress : bool dataset : "public" | "production" | "sandbox" | "flat_630", optional Against which FlyWire dataset to query. If ``None`` will fall back to the default dataset (see :func:`~fafbseg.flywire.set_default_dataset`). Returns ------- navis.NeuronList """ try: x = np.int64(x) except ValueError: raise ValueError(f'Unable to convert root ID {x} to integer') # Get/Initialize the CAVE client client = get_cave_client(dataset=dataset) # Get the cloudvolume vol = get_cloudvolume(dataset) # Load the L2 IDs l2_ids = client.chunkedgraph.get_leaves(x, stop_layer=2) with ThreadPoolExecutor(max_workers=threads) as pool: mesh_get = retry(vol.mesh.get) futures = [pool.submit(mesh_get, i, allow_missing=True, deduplicate_chunk_boundaries=False) for i in l2_ids] res = [f.result() for f in navis.config.tqdm(futures, disable=not progress, leave=False, desc='Loading meshes')] # Unpack results meshes = {k: v for d in res for k, v in d.items()} return navis.NeuronList([navis.MeshNeuron(v, id=k) for k, v in meshes.items()]) def _get_l2_centroids(l2_ids, vol, threads=10, progress=True): """Fetch L2 meshes and compute centroid.""" with ThreadPoolExecutor(max_workers=threads) as pool: futures = [pool.submit(vol.mesh.get, i, allow_missing=True, deduplicate_chunk_boundaries=False) for i in l2_ids] res = [f.result() for f in navis.config.tqdm(futures, disable=not progress, leave=False, desc='Loading meshes')] # Unpack results meshes = {k: v for d in res for k, v in d.items()} # For each mesh find the center of mass and move the corresponding point centroids = {} for k, m in meshes.items(): m = tm.Trimesh(m.vertices, m.faces) # Do NOT use center_mass here -> garbage if not non-watertight centroids[k] = m.centroid return centroids def chunks_to_nm(xyz_ch, vol, voxel_resolution=[4, 4, 40]): """Map a chunk location to Euclidean space. Parameters ---------- xyz_ch : array-like (N, 3) array of chunk indices. vol : cloudvolume.CloudVolume CloudVolume object associated with the chunked space. voxel_resolution : list, optional Voxel resolution. Returns ------- np.array (N, 3) array of spatial points. """ mip_scaling = vol.mip_resolution(0) // np.array(voxel_resolution, dtype=int) x_vox = np.atleast_2d(xyz_ch) * vol.mesh.meta.meta.graph_chunk_size return ( (x_vox + np.array(vol.mesh.meta.meta.voxel_offset(0))) * voxel_resolution * mip_scaling )