Source code for fafbseg.flywire.skeletonize

#    A collection of tools to interface with manually traced and autosegmented
#    data in FAFB.
#
#    Copyright (C) 2019 Philipp Schlegel
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.

import navis
import numbers
import os
import requests
import inspect
import pathlib

import cloudvolume as cv
import multiprocessing as mp
import networkx as nx
import pandas as pd
import numpy as np
import skeletor as sk
import trimesh as tm

from functools import partial
from concurrent.futures import ThreadPoolExecutor

from .segmentation import snap_to_id, is_latest_root
from .utils import get_cloudvolume, silence_find_mat_version, inject_dataset
from .annotations import get_somas, parse_neuroncriteria

SKELETON_BASE_URL = {'630': "https://flyem.mrc-lmb.cam.ac.uk/flyconnectome/flywire_skeletons_630",
                     '783': "https://flyem.mrc-lmb.cam.ac.uk/flyconnectome/flywire_skeletons_783",}
SKELETON_INFO = {"@type": "neuroglancer_skeletons", "transform": [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], "vertex_attributes": [{"id": "radius", "data_type": "float32", "num_components": 1}]}


__all__ = ['skeletonize_neuron', 'skeletonize_neuron_parallel', 'get_skeletons']


[docs] @parse_neuroncriteria() @inject_dataset() def skeletonize_neuron(x, shave_skeleton=True, remove_soma_hairball=False, assert_id_match=False, threads=2, save_to=None, progress=True, *, dataset=None, **kwargs): """Skeletonize FlyWire neuron. Note that this is optimized to be primarily fast which comes at the cost of (some) quality. Also note that soma detection is using the nucleus segmentation and falls back to a radius-based heuristic if no nucleus is found. Parameters ---------- x : int | trimesh.TriMesh | list thereof | NeuronCriteria ID(s) or trimesh of the FlyWire neuron(s) you want to skeletonize. shave_skeleton : bool If True, we will attempt to remove any "bristles" on the on the backbone which typically occur if the neurites are very big (or badly segmented). remove_soma_hairball : bool If True, we will try to drop the hairball that is typically created inside the soma. Note that while this should work just fine for 99% of neurons, it's not very smart and there is always a small chance that we remove stuff that should not have been removed. Also only works if the neuron has its nucleus annotated (see :func:`fafbseg.flywire.get_somas`). assert_id_match : bool If True, will check if skeleton nodes map to the correct segment ID and if not will move them back into the segment. This is potentially very slow! threads : int Number of parallel threads to use for downloading the meshes. save_to : str, optional If provided will save skeleton as SWC at `{save_to}/{id}.swc`. progress : bool Whether to show a progress bar or not. dataset : str | CloudVolume Against which FlyWire dataset to query:: - "production" (current production dataset, fly_v31) - "sandbox" (i.e. fly_v26) - "public" - "flat_630" or "flat_571" will use the flat segmentations matching the respective materialization versions. By default these use `lod=2`, you can change that behaviour by passing `lod` as keyword argument. Return ------ skeleton : navis.TreeNeuron The extracted skeleton. See Also -------- :func:`fafbseg.flywire.skeletonize_neuron_parallel` Use this if you want to skeletonize many neurons in parallel. :func:`fafbseg.flywire.get_l2_skeleton` Generate a skeleton using the L2 cache. Much faster than skeletonization from scratch but the skeleton will be coarser. :func:`~fafbseg.flywire.get_skeletons` Use this function to fetch precomputed skeletons. Only available for proofread neurons and for specific materialization versions. Examples -------- >>> from fafbseg import flywire >>> n = flywire.skeletonize_neuron(720575940603231916) """ if save_to is not None: save_to = pathlib.Path(save_to) if not save_to.exists(): raise ValueError('`save_to` must be an existing directory') if not save_to.is_dir(): raise ValueError('`save_to` must be a directory') # TODOs: # - drop single disconnected nodes? # - heal fragmented neurons? # - fix 0-radius nodes: these will be on 99.9% of the cases be leaf nodes # - shave only high-strahler twigs if int(sk.__version__.split('.')[0]) < 1: raise ImportError('Please update skeletor to version >= 1.0.0: ' 'pip3 install skeletor -U') if navis.utils.is_iterable(x): # Make sure these are root IDs x = np.asarray(x).astype(np.int64) # Fetch all somas in one go (note that this will only find somas for # roots that actually existed at the time) # For neurons without a soma we'll be doing more sophisticated checks # when we skeletonize with silence_find_mat_version(): kwargs['_nuclei'] = get_somas(x, raise_missing=False, dataset=dataset, materialization='latest') return navis.NeuronList([skeletonize_neuron(n, progress=False, shave_skeleton=shave_skeleton, remove_soma_hairball=remove_soma_hairball, assert_id_match=assert_id_match, dataset=dataset, threads=threads, save_to=save_to, **kwargs) for n in navis.config.tqdm(x, desc='Skeletonizing', disable=not progress, leave=False)]) if not navis.utils.is_mesh(x): vol = get_cloudvolume(dataset) # Make sure this is a valid integer id = np.int64(x) # Download the mesh try: old_parallel = vol.parallel vol.parallel = threads if vol.path.startswith('graphene'): mesh = vol.mesh.get(id, deduplicate_chunk_boundaries=False)[id] elif vol.path.startswith('precomputed'): lod_ = kwargs.pop('lod', 2) while lod_ >= 0: try: mesh = vol.mesh.get(id, lod=lod_)[id] break except cv.exceptions.MeshDecodeError: lod_ -= 1 except BaseException: raise if lod_ < 0: raise ValueError(f'Root ID {id} does not appear to exist ' f'in "{dataset}"') except BaseException: raise finally: vol.parallel = old_parallel else: mesh = x id = getattr(mesh, 'segid', getattr(mesh, 'id', 0)) # Pop nuclei from kwargs before passing them to skeletonization nuc = kwargs.pop('_nuclei', pd.DataFrame()) mesh = sk.utilities.make_trimesh(mesh, validate=True) # Fix things before we skeletonize # Drop disconnected pieces that represent less than 0.05% of total size to_remove = int(0.0001 * mesh.vertices.shape[0]) to_remove = None if to_remove == 0 else to_remove mesh = sk.pre.fix_mesh(mesh, inplace=True, remove_disconnected=to_remove) # Skeletonize defaults = dict(waves=1, step_size=1) defaults.update(kwargs) s = sk.skeletonize.by_wavefront(mesh, progress=progress, **defaults) # Skeletor indexes node IDs at zero but to avoid potential issues we want # node IDs to start at 1 s.swc['node_id'] += 1 s.swc.loc[s.swc.parent_id >= 0, 'parent_id'] += 1 # We will also round the radius and make it an integer to save some # memory. We could do the same with x/y/z coordinates but that could # potentially move nodes outside the mesh s.swc['radius'] = s.swc.radius.round().astype(int) # Turn into a neuron tn = navis.TreeNeuron(s.swc, units='1 nm', id=id, soma=None) if shave_skeleton: # Get child -> parent distances d = navis.morpho.mmetrics.parent_dist(tn, root_dist=0) # Find all nodes whose parent is more than a micron away (suspicious) long = tn.nodes[d >= 1000].node_id.values # Now start shaving while True: # Find segments containing leafs leaf_segs = [seg for seg in tn.small_segments if seg[0] in tn.leafs.node_id.values] # Among the leaf segments find those that are either only 1 node # or have any of the suspicously long (> micron) connections to_remove = [seg for seg in leaf_segs if any(np.isin(seg, long)) or (len(seg) <= 2)] # Make sure we don't drop very long segments to_remove = [seg for seg in to_remove if len(seg) < 10] # Turn list of lists into list of node IDs to_remove = [n for l in to_remove for n in l[:-1]] # If nothing more to remove, we can stop here if not len(to_remove): break navis.subset_neuron(tn, ~tn.nodes.node_id.isin(to_remove), inplace=True) # Get branch points bp = tn.nodes.loc[tn.nodes.type == 'branch', 'node_id'].values # Get single-node twigs is_end = tn.nodes.type == 'end' parent_is_bp = tn.nodes.parent_id.isin(bp) twigs = tn.nodes.loc[is_end & parent_is_bp, 'node_id'].values # Drop terminal twigs tn._nodes = tn.nodes.loc[~tn.nodes.node_id.isin(twigs)].copy() tn._clear_temp_attr() # If nuclei have already been fetched for all neurons if not nuc.empty: soma = nuc[nuc.pt_root_id == id] else: soma = pd.DataFrame() if soma.empty: # See if we can find a soma based on the nucleus segmentation try: with silence_find_mat_version(): soma = get_somas(id, raise_missing=False, dataset=dataset, materialization='auto') except KeyboardInterrupt: raise except requests.HTTPError: navis.config.logger.warning(f'Failed to fetch soma for {id} from ' 'nucleus table.') soma = pd.DataFrame() if not soma.empty: soma = tn.snap(soma.iloc[0].pt_position)[0] else: # If no nucleus, try to detect the soma like this soma = detect_soma_skeleton(tn, min_rad=800, N=3) if soma: tn.soma = soma # Reroot to soma tn.reroot(tn.soma, inplace=True) if remove_soma_hairball: soma = tn.nodes.set_index('node_id').loc[soma] soma_loc = soma[['x', 'y', 'z']].values # Find all nodes within 2x the soma radius tree = navis.neuron2KDTree(tn) ix = tree.query_ball_point(soma_loc, max(4000, soma.radius * 2)) # Translate indices into node IDs ids = tn.nodes.iloc[ix].node_id.values # Find segments that contain these nodes segs = [s for s in tn.segments if any(np.isin(ids, s))] # Sort segs by length segs = sorted(segs, key=lambda x: len(x)) # Keep only the longest segment in that initial list to_drop = np.array([n for s in segs[:-1] for n in s]) to_drop = to_drop[~np.isin(to_drop, segs[-1] + [soma.name])] navis.remove_nodes(tn, to_drop, inplace=True) if assert_id_match: if id == 0: raise ValueError('Segmentation ID must not be 0') new_locs = snap_to_id(tn.nodes[['x', 'y', 'z']].values, id=id, snap_zero=False, dataset=dataset, search_radius=160, coordinates='nm', max_workers=4, verbose=True) tn.nodes[['x', 'y', 'z']] = new_locs if save_to is not None: navis.write_swc(tn, save_to / f'{tn.id}.swc') return tn
def detect_soma_skeleton(s, min_rad=800, N=3): """Try detecting the soma based on radii. Parameters ---------- s : navis.TreeNeuron min_rad : float Minimum radius for a node to be considered a soma candidate. N : int Number of consecutive nodes with radius > `min_rad` we need in order to consider them soma candidates. Returns ------- node ID """ assert isinstance(s, navis.TreeNeuron) # For each segment get the radius radii = s.nodes.set_index('node_id').radius.to_dict() candidates = [] for seg in s.segments: rad = np.array([radii[s] for s in seg]) is_big = np.where(rad > min_rad)[0] # Skip if no above-threshold radii in this segment if not any(is_big): continue # Find stretches of consectutive above-threshold radii for stretch in np.split(is_big, np.where(np.diff(is_big) != 1)[0]+1): if len(stretch) < N: continue candidates += [seg[i] for i in stretch] if not candidates: return None # Return largest candidate return sorted(candidates, key=lambda x: radii[x])[-1] def __detect_soma_mesh(mesh): """Try detecting the soma based on vertex clusters. Parameters ---------- mesh : trimesh.Trimesh | navis.MeshNeuron Coordinates are assumed to be in nanometers. Mesh must not be downsampled. Returns ------- vertex indices """ # Build a KD tree from scipy.spatial import cKDTree tree = cKDTree(mesh.vertices) # Find out how many neighbours each vertex has within a 4 micron radius n_neighbors = tree.query_ball_point(mesh.vertices, r=4000, return_length=True, n_jobs=3) # Seed for soma is the node with the most neighbors seed = np.argmax(n_neighbors) # We need to find a sensible threshold for neurons without an actual soma res = np.mean(mesh.area_faces) if n_neighbors.max() < (20e4 / res): return np.array([]) # Find nodes within 10 microns of the seed dist, ix = tree.query(mesh.vertices[[seed]], k=mesh.vertices.shape[0], distance_upper_bound=10000) soma_verts = ix[dist < float('inf')] """ TODO: - use along-the-mesh distances instead to avoid pulling in close-by neurites - combine this with looking for a fall-off in N neighbors, i.e. when we hit the primary neurite track """ return soma_verts def divide_local_neighbourhood(mesh, radius): """Divide the mesh into locally connected patches of a given size. All nodes will be assigned to a patches but patches will be overlapping. Parameters ---------- mesh : trimesh.Trimesh radius : float Returns ------- list of sets """ assert isinstance(mesh, tm.Trimesh) assert isinstance(radius, numbers.Number) # Generate a graph for mesh G = mesh.vertex_adjacency_graph # Use Eucledian distance for edge weights edges = np.array(G.edges) e1 = mesh.vertices[edges[:, 0]] e2 = mesh.vertices[edges[:, 1]] dist = np.sqrt(np.sum((e1 - e2) ** 2, axis=1)) nx.set_edge_attributes(G, dict(zip(G.edges, dist)), name='weight') not_seen = set(G.nodes) patches = [] while not_seen: center = not_seen.pop() sg = nx.ego_graph(G, center, distance='weight', radius=radius) nodes = set(sg.nodes) patches.append(nodes) not_seen -= nodes
[docs] def skeletonize_neuron_parallel(ids, n_cores=os.cpu_count() // 2, progress=True, **kwargs): """Skeletonization on parallel cores. Parameters ---------- ids : iterable Root IDs of neurons you want to skeletonize. n_cores : int Number of cores to use. Don't go too crazy on this as the downloading of meshes becomes a bottle neck if you try to do too many at the same time. Keep your internet speed in mind. For reference: with 100Mbps internet, I can comfortably run on 8 cores with some room to spare. **kwargs Keyword arguments are passed on to `skeletonize_neuron`. Returns ------- navis.NeuronList See Also -------- :func:`fafbseg.flywire.skeletonize_neuron` The function called for individual neurons. """ if n_cores < 2 or n_cores > os.cpu_count(): raise ValueError('`n_cores` must be between 2 and max number of cores.') sig = inspect.signature(skeletonize_neuron) for k in kwargs: if k not in sig.parameters and k not in ('lod', ): raise ValueError('unexpected keyword argument for ' f'`skeletonize_neuron`: {k}') # Make sure IDs are all integers ids = np.asarray(ids, dtype=np.int64) # Prepare the calls and parameters kwargs['progress'] = False kwargs['threads'] = 1 kwargs['_nuclei'] = get_somas(ids, raise_missing=False, dataset=kwargs.get('dataset', 'production')) funcs = [skeletonize_neuron] * len(ids) parsed_kwargs = [kwargs] * len(ids) combinations = list(zip(funcs, [[i] for i in ids], parsed_kwargs)) # Run the actual skeletonization with mp.Pool(n_cores) as pool: chunksize = 1 res = list(navis.config.tqdm(pool.imap(_worker_wrapper, combinations, chunksize=chunksize), total=len(combinations), desc='Skeletonizing', disable=not progress, leave=True)) # Check if any skeletonizations failed failed = np.array([r for r in res if not isinstance(r, navis.TreeNeuron)]).astype(str) if any(failed): print(f'{len(failed)} neurons failed to skeletonize: ' f'{". ".join(failed)}') return navis.NeuronList([r for r in res if isinstance(r, navis.TreeNeuron)])
def _worker_wrapper(x): f, args, kwargs = x try: return f(*args, **kwargs) except KeyboardInterrupt: raise # We implement a single retry in case of HTTP errors except requests.HTTPError: try: return f(*args, **kwargs) except BaseException: # In case of failure return the root ID return args[0] except BaseException: # In case of failure return the root ID return args[0]
[docs] @parse_neuroncriteria() def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6, dataset=783, progress=True): """Fetch precomputed skeletons. Currently this only works for proofread (!) 630 and 783 root IDs (i.e. the first two public releases of FlyWire). Parameters ---------- root_id : int | list of ints | NeuronCriteria Root ID(s) of the FlyWire neuron(s) you want to skeletonize. Must be root IDs that existed at materialization 630 or 783 (see `dataset` parameter). omit_failures : bool, optional Determine behaviour when skeleton generation fails (e.g. if the neuron has only a single chunk): - ``None`` (default) will raise an exception - ``True`` will skip the offending neuron (might result in an empty ``NeuronList``) - ``False`` will return an empty ``TreeNeuron`` dataset : 630 | 783 Which dataset to query. progress : bool Whether to show a progress bar. max_threads : int Number of parallel requests to make when fetching the skeletons. Returns ------- skeletons : navis.NeuronList | navis.TreeNeurons Either a single neuron or a list thereof. See Also -------- :func:`~fafbseg.flywire.skeletonize_neuron` Use this function to skeletonize neurons from scratch, e.g. if there aren't any precomputed skeletons available. Examples -------- >>> from fafbseg import flywire >>> n = flywire.get_skeletons(720575940603231916) >>> n #doctest: +SKIP type navis.TreeNeuron name skeleton id 720575940603231916 n_nodes 3588 n_connectors None n_branches 586 n_leafs 645 cable_length 2050971.75 soma None units 1 nanometer dtype: object """ if str(dataset) not in SKELETON_BASE_URL: raise ValueError( "Currently we only provide precomputed skeletons for the " "630 and 783 data releases." ) if omit_failures not in (None, True, False): raise ValueError( "`omit_failures` must be either None, True or False. " f'Got "{omit_failures}".' ) if navis.utils.is_iterable(root_id): root_id = np.asarray(root_id, dtype=np.int64) il = is_latest_root(root_id, timestamp=f"mat_{dataset}") if np.any(~il): msg = ( f"{(~il).sum()} root ID(s) did not exists at materialization {dataset}" ) if omit_failures is None: raise ValueError(msg) navis.config.logger.warning(msg) get_skels = partial(get_skeletons, omit_failures=omit_failures, dataset=dataset) if (max_threads > 1) and (len(root_id) > 1): with ThreadPoolExecutor(max_workers=max_threads) as pool: futures = pool.map(get_skels, root_id) nl = [ f for f in navis.config.tqdm( futures, desc="Fetching skeletons", total=len(root_id), disable=not progress or len(root_id) == 1, leave=False, ) ] else: nl = [ get_skels(r) for r in navis.config.tqdm( root_id, desc="Fetching skeletons", total=len(root_id), disable=not progress or len(root_id) == 1, leave=False, ) ] # Turn into neuron list nl = navis.NeuronList(nl) # Bring in original order if len(nl): root_id = root_id[np.isin(root_id, nl.id)] nl = nl.idx[root_id] return nl # Turn into integer root_id = np.int64(root_id) try: tn = navis.read_precomputed(f'{SKELETON_BASE_URL[str(dataset)]}/{root_id}', datatype='skeleton', info=SKELETON_INFO) # Force integer (navis.read_precomputed will turn Id into string) tn.id = root_id tn.units = '1nm' return tn except BaseException: if omit_failures is None: raise elif omit_failures: return navis.NeuronList([]) else: return navis.TreeNeuron(None, id=root_id, units='1 nm')