# A collection of tools to interface with manually traced and autosegmented
# data in FAFB.
#
# Copyright (C) 2019 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
import pymaid
import navis
import requests
import textwrap
import time
import copy
import cloudvolume as cv
import datetime as dt
import numpy as np
import pandas as pd
import networkx as nx
from concurrent import futures
from diskcache import Cache
from requests_futures.sessions import FuturesSession
from scipy import ndimage
from tqdm.auto import tqdm
from .. import spine
from .. import xform
from ..utils import make_iterable, GSPointLoader
from .utils import (
get_cloudvolume,
FLYWIRE_DATASETS,
get_chunkedgraph_secret,
retry,
get_cave_client,
parse_bounds,
package_timestamp,
inject_dataset,
)
from .annotations import parse_neuroncriteria
__all__ = [
"get_edit_history",
"get_leaderboard",
"locs_to_segments",
"locs_to_supervoxels",
"skid_to_id",
"update_ids",
"roots_to_supervoxels",
"supervoxels_to_roots",
"neuron_to_segments",
"is_latest_root",
"is_valid_root",
"is_valid_supervoxel",
"get_voxels",
"get_lineage_graph",
"find_common_time",
"get_segmentation_cutout",
]
[docs]
@inject_dataset()
def get_lineage_graph(
x,
size=False,
user=False,
synapses=False,
proofreading_status=False,
progress=True,
*,
dataset=None,
):
"""Get lineage graph for given neuron.
This piggy-backs on the CAVEclient but importantly we remap users and
operation IDs such that each node's labels refer to the operation that
created them.
Parameters
----------
x : int
A single root ID.
size : bool
If True, will add `size` and `survivals` node attributes. The
former indicates the number of supervoxels, the latter how many
of these supervoxels made it into `x`.
synapses : bool
If True, will add `pre|post|synapses` node attributes which
indicate how many of the synapses in `x` came from this fragment.
Note that this doesn't tell you e.g. how many false-positive
synapses were removed via a split. This works only if the root
ID is up-to-date.
user : bool
If True, will add user `user` node attribute.
proofreading_status : bool
If True, will add a `proofread_by` node attribute indicating if
a user has set a given root ID to proofread.
Returns
-------
networkx.DiGraph
"""
x = np.int64(x)
client = get_cave_client(dataset=dataset)
G = client.chunkedgraph.get_lineage_graph(x, as_nx_graph=True)
# Remap operation ID
op_remapped = {}
for n in G:
pred = list(G.predecessors(n))
if pred:
op_remapped[n] = G.nodes[pred[0]]["operation_id"]
# Remove existing operation IDs
for n in G.nodes:
G.nodes[n].pop("operation_id", None)
# Apply new IDs
nx.set_node_attributes(G, op_remapped, name="operation_id")
if user:
op_ids = nx.get_node_attributes(G, "operation_id")
details = client.chunkedgraph.get_operation_details(list(op_ids.values()))
users = {n: details[str(o)]["user"] for n, o in op_ids.items()}
nx.set_node_attributes(G, users, name="user")
if size:
sv = roots_to_supervoxels(list(G.nodes), dataset=dataset, progress=progress)
sizes = {n: len(sv[n]) for n in G.nodes}
nx.set_node_attributes(G, sizes, name="size")
survivors = {n: int(np.isin(sv[n], sv[x]).sum()) for n in G.nodes}
nx.set_node_attributes(G, survivors, name="survivors")
else:
sv = None
if synapses:
pre = client.materialize.live_query(
table=client.materialize.synapse_table,
filter_equal_dict=dict(pre_pt_root_id=x),
timestamp=dt.datetime.now(),
select_columns=["pre_pt_supervoxel_id", "post_pt_supervoxel_id"],
)
post = client.materialize.live_query(
table=client.materialize.synapse_table,
filter_equal_dict=dict(post_pt_root_id=x),
timestamp=dt.datetime.now(),
select_columns=["pre_pt_supervoxel_id", "post_pt_supervoxel_id"],
)
if isinstance(sv, type(None)):
sv = roots_to_supervoxels(list(G.nodes), dataset=dataset, progress=progress)
n_pre = {n: int(pre.pre_pt_supervoxel_id.isin(sv[n]).sum()) for n in G.nodes}
n_post = {n: int(post.post_pt_supervoxel_id.isin(sv[n]).sum()) for n in G.nodes}
n_syn = {n: n_pre[n] + n_post[n] for n in G.nodes}
nx.set_node_attributes(G, n_pre, name="presynapses")
nx.set_node_attributes(G, n_post, name="postsynapses")
nx.set_node_attributes(G, n_syn, name="synapses")
if proofreading_status:
from .annotations import get_cave_table
nodes = np.array(list(G.nodes), dtype=np.int64)
pr = get_cave_table(
"proofreading_status_public_v1", filter_in_dict=dict(valid_id=nodes)
)
if len(pr):
user = pr.groupby("valid_id").user_id.apply(list).to_dict()
nx.set_node_attributes(
G, {n: user[n] for n in pr.valid_id}, name="proofread_by"
)
return G
[docs]
def get_leaderboard(days=7, by_day=False, progress=True, max_threads=4):
"""Fetch leader board (# of edits).
Parameters
----------
day : int
Number of days to go back.
by_day : bool
If True, will provide a day-by-day breakdown of # edits.
progress : bool
If True, show progress bar.
max_threads : int
Max number of parallel requests to server.
Returns
-------
pandas.DataFrame
Examples
--------
>>> from fafbseg import flywire
>>> # Fetch leaderboard with edits per day
>>> hist = flywire.get_leaderboard(by_day=True) #doctest: +SKIP
>>> # Plot user actions over time
>>> hist.T.plot() #doctest: +SKIP
"""
assert isinstance(days, (int, np.integer))
assert days >= 0
session = requests.Session()
if not by_day:
url = f"https://pyrdev.eyewire.org/flywire-leaderboard?days={days-1}"
resp = session.get(url, params=None)
resp.raise_for_status()
return pd.DataFrame.from_records(resp.json()["entries"]).set_index("name")
future_session = FuturesSession(session=session, max_workers=max_threads)
futures = []
for i in range(0, days):
url = f"https://pyrdev.eyewire.org/flywire-leaderboard?days={i}"
futures.append(future_session.get(url, params=None))
# Get the responses
resp = [
f.result()
for f in navis.config.tqdm(
futures,
desc="Fetching",
disable=not progress or len(futures) == 1,
leave=False,
)
]
df = None
for i, r in enumerate(resp):
date = dt.date.today() - dt.timedelta(days=i)
r.raise_for_status()
this_df = pd.DataFrame.from_records(r.json()["entries"]).set_index("name")
this_df.columns = [date]
if isinstance(df, type(None)):
df = this_df
else:
df = pd.merge(df, this_df, how="outer", left_index=True, right_index=True)
# Make sure we don't have NAs
df = df.fillna(0).astype(int)
# This breaks it down into days
if df.shape[1] > 1:
df.iloc[:, 1:] = df.iloc[:, 1:].values - df.iloc[:, :-1].values
# Reverse such that the right-most entry is the current date
df = df.iloc[:, ::-1]
return df.loc[df.sum(axis=1).sort_values(ascending=False).index]
[docs]
@parse_neuroncriteria()
@inject_dataset()
def get_edit_history(x, progress=True, errors="raise", max_threads=4, *, dataset=None):
"""Fetch edit history for given neuron(s).
Note that neurons that haven't seen any edits will simply not show up in
returned table.
Parameters
----------
x : int | list of int | NeuronCriteria
Segmentation (root) ID(s).
progress : bool
If True, show progress bar.
max_threads : int
Max number of parallel requests to server.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
pandas.DataFrame
Examples
--------
>>> from fafbseg import flywire
>>> # Fetch edits
>>> edits = flywire.get_edit_history(720575940621039145)
>>> # Group by user
>>> edits.groupby('user_name').size()
user_name
Claire McKellar 47
Jay Gager 4
Sandeep Kumar 1
Sarah Morejohn 6
dtype: int64
"""
if not isinstance(x, (list, set, np.ndarray)):
x = [x]
session = requests.Session()
future_session = FuturesSession(session=session, max_workers=max_threads)
token = get_chunkedgraph_secret()
session.headers["Authorization"] = f"Bearer {token}"
futures = []
for id in x:
dataset = FLYWIRE_DATASETS.get(dataset, dataset)
url = f"https://prod.flywire-daf.com/segmentation/api/v1/table/{dataset}/root/{id}/tabular_change_log"
f = future_session.get(url, params=None)
futures.append(f)
# Get the responses
resp = [
f.result()
for f in navis.config.tqdm(
futures,
desc="Fetching",
disable=not progress or len(futures) == 1,
leave=False,
)
]
df = []
for r, i in zip(resp, x):
# Code 500 means server error
if r.status_code == 500:
# If server responds a time-out, it means that the root ID has not
# seen any edits from base segmentation.
if "Read timed out" in r.json().get("message", ""):
continue
try:
r.raise_for_status()
except BaseException:
if errors == "raise":
raise
else:
print(f"Error fetching logs for {i}")
continue
this_df = pd.DataFrame(r.json())
this_df["segment"] = i
df.append(this_df)
# Concat if any edits at all
if any([not f.empty for f in df]):
# Drop neurons without edits
df = [f for f in df if not f.empty]
df = pd.concat(df, axis=0, sort=True)
df["timestamp"] = pd.to_datetime(df.timestamp, unit="ms")
else:
# Return the first empty data frame
df = df[0]
return df
@parse_neuroncriteria()
@inject_dataset(disallowed=["flat_630", "flat_571"])
def roots_to_supervoxels(x, use_cache=True, progress=True, *, dataset=None):
"""Get supervoxels making up given neurons.
Parameters
----------
x : int | list of int | NeuronCriteria
Segmentation (root) ID(s).
use_cache : bool
Whether to use disk cache to avoid repeated queries for the
same root. Cache is stored in `~/.fafbseg/`.
progress : bool
If True, show progress bar.
dataset : "public" | "production" | "sandbox", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
dict
``{root_id: [supervoxel_id1, ssupervoxel_id2, ...], ...}``
Examples
--------
>>> from fafbseg import flywire
>>> flywire.roots_to_supervoxels(720575940619164912)[720575940619164912]
array([78251074787604983, 78251074787607484, 78251074787605192, ...,
78673699569883003, 78673699569870455, 78673699569887289],
dtype=uint64)
"""
# Make sure we are working with an array of integers
x = make_iterable(x, force_type=np.int64)
# Make sure we're not getting bogged down with duplicates
x = np.unique(x)
if len(x) <= 1:
progress = False
# Get the volume
vol = get_cloudvolume(dataset)
svoxels = {}
# See what we can get from cache
if use_cache:
# Cache for root -> supervoxels
# Grows to max 1Gb by default and persists across sessions
with Cache(directory="~/.fafbseg/svoxel_cache/") as sv_cache:
# See if we have any of these roots cached
with sv_cache.transact():
is_cached = np.isin(x, sv_cache)
# Add supervoxels from cache if we have any
if np.any(is_cached):
# Get values from cache
with sv_cache.transact():
svoxels.update({i: sv_cache[i] for i in x[is_cached]})
# Get the supervoxels for the roots that are still missing
# We need to convert keys to integer array because otherwise there is a
# mismatch in types (int vs np.int?) which causes all root IDs to be in miss
# -> I think that's because of the way disk cache works
miss = x[~np.isin(x, np.array(list(svoxels.keys()), dtype=np.int64))]
get_leaves = retry(vol.get_leaves)
with navis.config.tqdm(
desc="Querying", total=len(x), disable=not progress, leave=False
) as pbar:
# Update for those for which we had cached data
pbar.update(len(svoxels))
for i in miss:
svoxels[i] = get_leaves(i, bbox=vol.meta.bounds(0), mip=0)
pbar.update()
# Update cache
if use_cache:
with Cache(directory="~/.fafbseg/svoxel_cache/") as sv_cache:
with sv_cache.transact():
for i in miss:
sv_cache[i] = svoxels[i]
return svoxels
[docs]
@inject_dataset(disallowed=["flat_630", "flat_571"])
def supervoxels_to_roots(
x,
timestamp=None,
batch_size=10_000,
stop_layer=10,
retry=True,
progress=True,
*,
dataset=None,
):
"""Get root(s) for given supervoxel(s).
Parameters
----------
x : int | list of int
Supervoxel ID(s) to find the root(s) for. Also works for
e.g. L2 IDs.
timestamp : int | str | datetime | "mat", optional
Get roots at given date (and time). Int must be unix
timestamp. String must be ISO 8601 - e.g. '2021-11-15'.
"mat" will use the timestamp of the most recent
materialization. You can also use e.g. "mat_438" to get the
root ID at a specific materialization.
batch_size : int
Max number of supervoxel IDs per query. Reduce batch size if
you experience time outs.
stop_layer : int
Set e.g. to ``2`` to get L2 IDs instead of root IDs.
retry : bool
Whether to retry if a batched query fails.
progress : bool
If True, show progress bar.
dataset : "public" | "production" | "sandbox", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
roots : numpy array
Roots corresponding to supervoxels in `x`.
Examples
--------
>>> from fafbseg import flywire
>>> flywire.supervoxels_to_roots(78321855915861142)
array([720575940594028562])
"""
# Make sure we are working with an array of integers
x = make_iterable(x, force_type=np.int64)
# Check if IDs are valid (zeros are fine because we filter for them later on)
# is_valid_supervoxel(x[(x != 0) & (x != '0')], raise_exc=True)
# Parse the volume
vol = get_cloudvolume(dataset)
# Prepare results array
roots = np.zeros(x.shape, dtype=np.int64)
if isinstance(timestamp, str) and timestamp.startswith("mat"):
client = get_cave_client(dataset=dataset)
if timestamp == "mat" or timestamp == "mat_latest":
timestamp = client.materialize.get_timestamp()
else:
# Split e.g. 'mat_432' to extract version and query timestamp
version = int(timestamp.split("_")[1])
timestamp = client.materialize.get_timestamp(version)
if isinstance(timestamp, np.datetime64):
timestamp = str(timestamp)
with tqdm(
desc="Fetching roots",
leave=False,
total=len(x),
disable=not progress or len(x) < batch_size,
) as pbar:
for i in range(0, len(x), int(batch_size)):
# This batch
batch = x[i : i + batch_size]
# get_roots() doesn't like to be asked for zeros - causes server error
not_zero = batch != 0
try:
roots[i : i + batch_size][not_zero] = vol.get_roots(
batch[not_zero], stop_layer=stop_layer, timestamp=timestamp
)
except KeyboardInterrupt:
raise
except BaseException:
if not retry:
raise
time.sleep(1)
roots[i : i + batch_size][not_zero] = vol.get_roots(
batch[not_zero], stop_layer=stop_layer, timestamp=timestamp
)
pbar.update(len(batch))
return roots
[docs]
def locs_to_supervoxels(locs, mip=2, coordinates="voxel", backend="spine"):
"""Retrieve FlyWire supervoxel IDs at given location(s).
Parameters
----------
locs : list-like | pandas.DataFrame
Array of x/y/z coordinates. If DataFrame must contain
'x', 'y', 'z' or 'fw.x', 'fw.y', 'fw.z' columns. If both
present, 'fw.' columns take precedence!
mip : int [2-8]
Scale to query. Lower mip = more precise but slower;
higher mip = faster but less precise (small supervoxels
might not show at all).
coordinates : "voxel" | "nm"
Units in which your coordinates are in. "voxel" is assumed
to be 4x4x40 (x/y/z) nanometers.
backend : "spine" | "cloudvolume"
Which backend to use. Use "cloudvolume" only when spine
doesn't work.
Returns
-------
numpy.array
List of segmentation IDs in the same order as ``locs``. Invalid
locations will be returned with ID 0.
See Also
--------
:func:`~fafbseg.flywire.locs_to_segments`
Takes locations and returns root IDs. Can also map to a specific
time or materialization.
Examples
--------
>>> from fafbseg import flywire
>>> # Fetch supervoxel at two locations
>>> locs = [[133131, 55615, 3289], [132802, 55661, 3289]]
>>> flywire.locs_to_supervoxels(locs)
array([79801454835332154, 79731086091150780], dtype=uint64)
"""
if backend not in ("spine", "cloudvolume"):
raise ValueError(f"`backend` not recognised: {backend}")
if isinstance(locs, pd.DataFrame):
if np.all(np.isin(["fw.x", "fw.y", "fw.z"], locs.columns)):
locs = locs[["fw.x", "fw.y", "fw.z"]].values
elif np.all(np.isin(["x", "y", "z"], locs.columns)):
locs = locs[["x", "y", "z"]].values
else:
raise ValueError(
"`locs` as pandas.DataFrame must have either [fw.x"
", fw.y, fw.z] or [x, y, z] columns."
)
# Make sure we are working with numbers
if not np.issubdtype(locs.dtype, np.number):
locs = locs.astype(np.float64)
if backend == "spine":
return spine.transform.get_segids(
locs, segmentation="flywire_190410", coordinates=coordinates, mip=mip
)
else:
vol = copy.deepcopy(get_cloudvolume("production"))
# Lower mips appear to cause inconsistencies despite spine also only
# using mip 2 (IIRC?)
# vol.mip = 2
pl = GSPointLoader(vol)
if coordinates in ("voxel", "voxels"):
locs = locs * [4, 4, 40]
pl.add_points(locs)
points, data = pl.load_all(max_workers=4, progress=True, return_sorted=True)
return data
[docs]
@inject_dataset()
def neuron_to_segments(x, short=False, coordinates="voxel", *, dataset=None):
"""Get root IDs overlapping with a given neuron.
Parameters
----------
x : Neuron/List
Neurons for which to return root IDs. Neurons must be
in FlyWire (FAFB14.1) space.
short : bool
If True will only return the top hit for each neuron
(including a confidence score).
coordinates : "voxel" | "nm"
Units the neuron(s) are in. "voxel" is assumed to be
4x4x40 (x/y/z) nanometers.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
overlap_matrix : pandas.DataFrame
DataFrame of root IDs (rows) and IDs
(columns) with overlap in nodes as values::
id id1 id2
root_id
10336680915 5 0
10336682132 0 1
summary : pandas.DataFrame
If ``short=True``: DataFrame of top hits only::
id match confidence
12345 103366809155 0.87665
412314 103366821325 0.65233
See Also
--------
:func:`~fafbseg.flywire.skid_to_id`
Takes a CATMAID (FAFB) skeleton ID or annotations and returns
corresponding FlyWire root IDs.
"""
if isinstance(x, navis.TreeNeuron):
x = navis.NeuronList(x)
assert isinstance(x, navis.NeuronList)
# We must not perform this on x.nodes as this is a temporary property
nodes = x.nodes
# Get segmentation IDs
nodes["root_id"] = locs_to_segments(
nodes[["x", "y", "z"]].values, coordinates=coordinates, dataset=dataset
)
# Count segment IDs
seg_counts = nodes.groupby(["neuron", "root_id"], as_index=False).node_id.count()
seg_counts.columns = ["id", "root_id", "counts"]
# Remove seg IDs 0
seg_counts = seg_counts[seg_counts.root_id != 0]
# Turn into matrix where columns are skeleton IDs, segment IDs are rows
# and values are the overlap counts
matrix = seg_counts.pivot(index="root_id", columns="id", values="counts")
if not short:
return matrix
# Extract top IDs and scores
top_id = matrix.index[np.argmax(matrix.fillna(0).values, axis=0)]
# Confidence is the difference between top and 2nd score
top_score = matrix.max(axis=0).values
sec_score = np.sort(matrix.fillna(0).values, axis=0)[-2, :]
conf = (top_score - sec_score) / matrix.sum(axis=0).values
summary = pd.DataFrame([])
summary["id"] = matrix.columns
summary["match"] = top_id
summary["confidence"] = conf
return summary
[docs]
@inject_dataset(disallowed=["flat_630", "flat_571"])
def locs_to_segments(
locs, timestamp=None, backend="spine", coordinates="voxel", *, dataset=None
):
"""Retrieve FlyWire segment (i.e. root) IDs at given location(s).
Parameters
----------
locs : list-like | pandas.DataFrame
Array of x/y/z coordinates. If DataFrame must contain
'x', 'y', 'z' or 'fw.x', 'fw.y', 'fw.z' columns. If both
present, 'fw.' columns take precedence)!
timestamp : int | str | datetime | "mat", optional
Get roots at given date (and time). Int must be unix
timestamp. String must be ISO 8601 - e.g. '2021-11-15'.
"mat" will use the timestamp of the most recent
materalization. You can also use e.g. "mat_438" to get the
root ID at a specific materialization.
backend : "spine" | "cloudvolume"
Which backend to use. Use "cloudvolume" only when spine
doesn't work because it's terribly slow.
coordinates : "voxel" | "nm"
Units in which your coordinates are in. "voxel" is assumed
to be 4x4x40 (x/y/z) nanometers.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`). Only relevant
if ``root_ids=True``.
Returns
-------
numpy.array
List of segmentation IDs in the same order as ``locs``.
See Also
--------
:func:`~fafbseg.flywire.locs_to_supervoxels`
Takes locations and returns supervoxel IDs.
Examples
--------
>>> from fafbseg import flywire
>>> # Fetch root IDs at two locations
>>> locs = [[133131, 55615, 3289], [132802, 55661, 3289]]
>>> flywire.locs_to_segments(locs)
array([720575940631693610, 720575940631693610])
"""
svoxels = locs_to_supervoxels(locs, coordinates=coordinates, backend=backend)
return supervoxels_to_roots(svoxels, timestamp=timestamp, dataset=dataset)
[docs]
@inject_dataset()
def skid_to_id(x, sample=None, catmaid_instance=None, progress=True, *, dataset=None):
"""Find the FlyWire root ID for a given (FAFB) CATMAID neuron.
This function works by:
1. Fetch the skeleton for given CATMAID neuron.
2. Transform the skeleton to FlyWire space.
3. Map the x/y/z location of the skeleton nodes to root IDs.
4. Return the root ID that was seen the most often.
Parameters
----------
x : int | list-like | str | TreeNeuron/List
Anything that's not a TreeNeuron/List will be passed
directly to ``pymaid.get_neuron``.
sample : int | float, optional
Number (>= 1) or fraction (< 1) of skeleton nodes to sample
to find FlyWire root IDs. If ``None`` (default), will use
all nodes.
catmaid_instance : pymaid.CatmaidInstance, optional
Connection to a CATMAID server. If ``None``, will use the
current global connection. See pymaid docs for details.
progress : bool
If True, shows progress bar.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
pandas.DataFrame
Mapping of skeleton IDs to FlyWire root IDs. Confidence
is the difference between the frequency of the root ID that
was seen most often and the second most seen ID.
See Also
--------
:func:`~fafbseg.flywire.neuron_to_segments`
Takes already downloaded and transformed neuron(s) and
returns corresponding FlyWire root IDs.
Examples
--------
>>> from fafbseg import flywire
>>> import pymaid
>>> # Connect to the VFB's CATMAID
>>> rm = pymaid.CatmaidInstance('https://fafb.catmaid.virtualflybrain.org/',
... project_id=1, api_token=None)
>>> roots = flywire.skid_to_id([6762, 2379517])
>>> roots
skeleton_id flywire_id confidence
0 6762 720575940608544011 0.80
1 2379517 720575940617229632 0.42
"""
if not isinstance(x, (navis.TreeNeuron, navis.NeuronList)):
x = pymaid.get_neuron(x, remote_instance=catmaid_instance)
if isinstance(x, navis.NeuronList) and len(x) == 1:
x = x[0]
if isinstance(x, navis.NeuronList):
res = []
for n in navis.config.tqdm(
x, desc="Searching", disable=not progress, leave=False
):
res.append(skid_to_id(n, dataset=dataset, sample=sample))
return pd.concat(res, axis=0).reset_index(drop=True)
elif isinstance(x, navis.TreeNeuron):
nodes = x.nodes[["x", "y", "z"]]
if sample:
if sample < 1:
nodes = nodes.sample(frac=sample, random_state=1985)
else:
nodes = nodes.sample(n=sample, random_state=1985)
else:
raise TypeError(f'Unable to use data of type "{type(x)}"')
# XForm coordinates from FAFB14 to FAFB14.1
xformed = xform.fafb14_to_flywire(nodes[["x", "y", "z"]].values, coordinates="nm")
# Get the root IDs for each of these locations
roots = locs_to_segments(xformed, coordinates="nm", dataset=dataset)
# Drop zeros
roots = roots[roots != 0]
# Find unique Ids and count them
unique, counts = np.unique(roots, return_counts=True)
# Get sorted indices
sort_ix = np.argsort(counts)
# The "correct" ID is assumed to be the most frequent ID
new_id = unique[sort_ix[-1]]
# Confidence is the difference between the top and the 2nd most frequent ID
if len(unique) > 1:
diff_1st_2nd = counts[sort_ix[-1]] - counts[sort_ix[-2]]
conf = round(diff_1st_2nd / roots.shape[0], 2)
else:
conf = 1
return pd.DataFrame(
[[x.id, new_id, conf]], columns=["skeleton_id", "flywire_id", "confidence"]
)
[docs]
@inject_dataset(disallowed=["flat_630", "flat_571"])
@retry
def is_latest_root(id, timestamp=None, progress=True, *, dataset=None, **kwargs):
"""Check if root is the current one.
Parameters
----------
id : int | list-like
Single ID or list of FlyWire (root) IDs.
timestamp : int | str | datetime | "mat", optional
Checks if roots existed at given date (and time). Int must
be unix timestamp. String must be ISO 8601 - e.g. '2021-11-15'.
"mat" will use the timestamp of the most recent
materialization. You can also use e.g. "mat_438" to get the
root ID at a specific materialization.
progress : bool
Whether to show progress bar.
dataset : "public" | "production" | "sandbox", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
numpy array
Array of booleans
See Also
--------
:func:`~fafbseg.flywire.update_ids`
If you want the new ID. Also allows mapping to a specific
time or materialization.
Examples
--------
>>> from fafbseg import flywire
>>> flywire.is_latest_root(720575940631693610)
array([ True])
"""
id = make_iterable(id, force_type=str)
# The server doesn't like being asked for zeros
not_zero = id != "0"
# Check if all other IDs are valid
is_valid_root(id[not_zero], raise_exc=True, dataset=dataset)
is_latest = np.ones(len(id)).astype(bool)
client = get_cave_client(dataset=dataset)
session = requests.Session()
token = get_chunkedgraph_secret()
session.headers["Authorization"] = f"Bearer {token}"
url = (
client.chunkedgraph._endpoints["is_latest_roots"].format_map(client.chunkedgraph.default_url_mapping)
)
if isinstance(timestamp, str) and timestamp.startswith("mat"):
if timestamp == "mat" or timestamp == "mat_latest":
timestamp = client.materialize.get_timestamp()
else:
# Split e.g. 'mat_432' to extract version and query timestamp
version = int(timestamp.split("_")[1])
timestamp = client.materialize.get_timestamp(version)
if isinstance(timestamp, np.datetime64):
timestamp = str(timestamp)
if isinstance(timestamp, str):
timestamp = dt.datetime.fromisoformat(timestamp)
if timestamp is not None:
params = package_timestamp(timestamp)
else:
params = None
batch_size = 100_000
with navis.config.tqdm(
desc="Checking",
total=not_zero.sum(),
disable=(not_zero.sum() <= batch_size) or not progress,
leave=False,
) as pbar:
for i in range(0, not_zero.sum(), batch_size):
batch = id[not_zero][i : i + batch_size]
post = {"node_ids": batch.tolist()}
# Update progress bar
pbar.update(len(batch))
r = session.post(url, json=post, params=params)
r.raise_for_status()
is_latest[np.where(not_zero)[0][i : i + batch_size]] = np.array(
r.json()["is_latest"]
)
return is_latest
[docs]
@parse_neuroncriteria()
@inject_dataset(disallowed=["flat_630", "flat_571"])
def find_common_time(root_ids, progress=True, *, dataset=None):
"""Find a time at which given root IDs co-existed.
Parameters
----------
root_ids : list | np.ndarray | NeuronCriteria
Root IDs to check.
progress : bool
If True, shows progress bar.
dataset : "public" | "production" | "sandbox", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
datetime.datetime
"""
root_ids = np.asarray(root_ids, dtype=np.int64)
client = get_cave_client(dataset=dataset)
# Get timestamps when roots were created
creations = client.chunkedgraph.get_root_timestamps(root_ids)
# Find out which IDs are still current
is_latest = client.chunkedgraph.is_latest_roots(root_ids)
# Prepare array with death times
deaths = np.array([dt.datetime.now(tz=dt.timezone.utc) for r in root_ids])
# Get lineage graph for outdated root IDs
G = client.chunkedgraph.get_lineage_graph(
root_ids[~is_latest], timestamp_past=min(creations), as_nx_graph=True
)
# Get the immediate successors
succ = np.array([next(G.successors(r)) for r in root_ids[~is_latest]])
# Add time of death
deaths[~is_latest] = client.chunkedgraph.get_root_timestamps(succ)
# Find the latest creation
latest_birth = max(creations)
# Find the earliest death
earliest_death = min(deaths)
if latest_birth > earliest_death:
raise ValueError("Given root IDs never existed at the same time.")
return latest_birth + (earliest_death - latest_birth) / 2
[docs]
@parse_neuroncriteria()
@inject_dataset(disallowed=["flat_630", "flat_571"])
def update_ids(
id,
stop_layer=2,
supervoxels=None,
timestamp=None,
progress=True,
*,
dataset=None,
**kwargs,
):
"""Retrieve the most recent version of given FlyWire (root) ID(s).
This function works by:
1. Check if ID is outdated (see :func:`fafbseg.flywire.is_latest_root`)
2. If supervoxel provided, use it to update ID. Else try 3.
3. See if we can map outdated IDs to a single up-to-date root (works
if neuron has only seen merges). Else try 4.
4. For uncertain IDs, fetch L2 IDs for the old root ID and the new
candidates. Pick the candidate containing most of the original L2
IDs.
Parameters
----------
id : int | list-like | DataFrame | NeuronCriteria
Single ID or list of FlyWire (root) IDs. If DataFrame must
contain either a `root_id` or `root` column and optionally
a `supervoxel_id` or `supervoxel` column.
stop_layer : int
In case of root IDs that have been split, we need to
determine the most likely successor. By default we do that
using L2 IDs but you can speed this up by increasing the
stop layer.
supervoxels : int | list-like, optional
If provided will use these supervoxels to update ``id``
instead of sampling using the L2 IDs.
timestamp : int | str | datetime
Find root ID(s) at given date (and time). Int must be unix
timestamp. String must be ISO 8601 - e.g. '2021-11-15'.
Asking for a specific time will slow things down considerably.
progress : bool
If True, shows progress bar.
dataset : "public" | "production" | "sandbox", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
pandas.DataFrame
Mapping of old -> new root IDs with confidence::
old_id new_id confidence changed
0
1
See Also
--------
:func:`~fafbseg.flywire.is_latest_root`
If all you want is to know whether a (root) ID is up-to-date.
:func:`~fafbseg.flywire.supervoxels_to_roots`
Maps supervoxels to roots. If you have supervoxel IDs for
your neurons this function will be significantly faster for
updating/mapping root IDs.
Examples
--------
>>> from fafbseg import flywire
>>> flywire.update_ids(720575940621039145)
old_id new_id confidence changed
0 720575940621039145 720575940631693610 1.0 True
"""
assert stop_layer > 0, "`stop_layer` must be > 0"
# See if we already check if this was the latest root
is_latest = kwargs.pop("is_latest", None)
vol = get_cloudvolume(dataset, **kwargs)
if isinstance(id, pd.DataFrame):
if isinstance(supervoxels, type(None)):
if "supervoxel_id" in id.columns:
supervoxels = id["supervoxel_id"].values
elif "supervoxel" in id.columns:
supervoxels = id["supervoxel"].values
if "root_id" in id.columns:
id = id["root_id"].values
elif "root" in id.columns:
id = id["root"].values
else:
raise ValueError(
"DataFrame must contain either `root_id` or " "`root` column."
)
elif isinstance(id, pd.Series):
id = id.values
elif isinstance(id, pd.core.arrays.string_.StringArray):
id = np.asarray(id)
if isinstance(timestamp, str) and timestamp.startswith("mat"):
client = get_cave_client(dataset=dataset)
if timestamp == "mat" or timestamp == "mat_latest":
timestamp = client.materialize.get_timestamp()
else:
# Split e.g. 'mat_432' to extract version and query timestamp
version = int(timestamp.split("_")[1])
timestamp = client.materialize.get_timestamp(version)
if isinstance(id, (list, set, np.ndarray)):
# Run is_latest once for all roots
is_latest = is_latest_root(
id, dataset=dataset, timestamp=timestamp, progress=progress
)
if isinstance(supervoxels, type(None)):
res = [
update_ids(
x,
dataset=dataset,
is_latest=il,
supervoxels=None,
timestamp=timestamp,
stop_layer=stop_layer,
)
for x, il, in navis.config.tqdm(
zip(id, is_latest),
desc="Updating",
leave=False,
total=len(id),
disable=not progress or len(id) == 1,
)
]
res = pd.concat(res, axis=0, sort=False, ignore_index=True)
else:
supervoxels = np.asarray(supervoxels)
if len(supervoxels) != len(id):
raise ValueError(
f"Number of supervoxels ({len(supervoxels)}) does "
f"not match number of root IDs ({len(id)})"
)
elif any(pd.isnull(supervoxels)):
raise ValueError("`supervoxels` must not contain `None`")
elif any(pd.isnull(id)):
raise ValueError("`id` must not contain `None`")
id = np.array(id, dtype=np.int64)
res = pd.DataFrame()
res["old_id"] = id
res["new_id"] = id
res.loc[~is_latest, "new_id"] = supervoxels_to_roots(
supervoxels[~is_latest], timestamp=timestamp, dataset=dataset
)
res["conf"] = 1
res["changed"] = res["new_id"] != res["old_id"]
return res
try:
id = np.int64(id)
except ValueError:
raise ValueError(f'"{id} does not look like a valid root ID.')
if id == 0 or pd.isnull(id):
navis.config.logger.warning(
f'Unable to update ID "{id}" - returning ' "unchanged."
)
return id
# Check if outdated
if isinstance(is_latest, type(None)):
is_latest = is_latest_root(
id, dataset=dataset, timestamp=timestamp, progress=progress
)[0]
if isinstance(timestamp, np.datetime64):
timestamp = str(timestamp)
if not is_latest:
if timestamp:
client = get_cave_client(dataset=dataset)
get_leaves = retry(client.chunkedgraph.get_leaves)
l2_ids_orig = get_leaves(id, stop_layer=stop_layer)
get_roots = retry(vol.get_roots)
roots = get_roots(l2_ids_orig, timestamp=timestamp)
# Drop zeros
roots = roots[roots != 0]
if not len(roots):
new_id = 0
conf = 0
else:
uni, cnt = np.unique(roots, return_counts=True)
new_id = uni[np.argmax(cnt)]
conf = cnt[np.argmax(cnt)] / len(roots)
else:
client = get_cave_client(dataset=dataset)
get_latest_roots = retry(client.chunkedgraph.get_latest_roots)
# This endpoint in caveclient seems to require uint64
pot_roots = get_latest_roots(np.uint64(id))
# Note that we're checking whether the suggested new ID is not the same
# as the old ID? That's because I came across a few example where the
# lineage graph appears disconnected (e.g. 720575940613297192), perhaps
# due to an issue in the operations log. The result of that is that
# despite the root ID being outdated, the latest node in the graph is
# still not the most-up-to-date ID.
if len(pot_roots) == 1 and pot_roots[0] != id:
new_id = pot_roots[0]
conf = 1
elif supervoxels:
try:
supervoxels = np.int64(supervoxels)
except ValueError:
raise ValueError(
f'"{supervoxels}" does not look like a valid ' "supervoxel ID."
)
get_root_id = retry(client.chunkedgraph.get_root_id)
new_id = get_root_id(supervoxels_to_roots)
conf = 1
else:
# Get L2 IDs for the original ID
# Note: we could also use higher level IDs
# (stop layer 3 or 4) which would be even fasters
get_leaves = retry(client.chunkedgraph.get_leaves)
l2_ids_orig = get_leaves(id, stop_layer=stop_layer)
# Get new roots for these L2 IDs
get_roots = retry(client.chunkedgraph.get_roots)
new_roots = get_roots(l2_ids_orig)
# Find the most frequent new root
roots, counts = np.unique(new_roots, return_counts=True)
srt = np.argsort(counts)[::-1]
roots = roots[srt]
counts = counts[srt]
# New ID is the most frequent ID
new_id = roots[0]
# Confidence is the fraction of original L2 IDs in the new ID
conf = round(counts[0] / sum(counts), 2)
else:
new_id = id
conf = 1
return pd.DataFrame(
[[id, new_id, conf, id != new_id]],
columns=["old_id", "new_id", "confidence", "changed"],
).astype({"old_id": np.int64, "new_id": np.int64})
@inject_dataset()
def snap_to_id(
locs,
id,
snap_zero=False,
search_radius=160,
coordinates="nm",
max_workers=4,
verbose=True,
*,
dataset=None,
):
"""Snap locations to the correct segmentation ID.
Works by:
1. Fetch segmentation ID for each location and for those with the wrong ID:
2. Fetch cube around each loc and snap to the closest voxel with correct ID
Parameters
----------
locs : (N, 3) array
Array of x/y/z coordinates.
id : int
Expected ID at each location.
snap_zero : bool
If False (default), we will not snap locations that map to
segment ID 0 (i.e. no segmentation).
search_radius : int
Radius [nm] around a location to search for a position with
the correct ID. Lower values will be faster.
coordinates : "voxel" | "nm"
Coordinate system of `locs`. If "voxel" it is assumed to be
4 x 4 x 40 nm.
max_workers : int
verbose : bool
If True will plot summary at then end.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
(N, 3) array
x/y/z locations that are guaranteed to map to the correct ID.
"""
assert coordinates in ["nm", "nanometer", "nanometers", "voxel", "voxels"]
if isinstance(locs, navis.TreeNeuron):
locs = locs.nodes[["x", "y", "z"]].values
# This also makes sure we work on a copy
locs = np.array(locs, copy=True)
assert locs.ndim == 2 and locs.shape[1] == 3
# From hereon out we are working with nanometers
if coordinates in ("voxel", "voxels"):
locs *= [4, 4, 40]
root_ids = locs_to_segments(locs, dataset=dataset, coordinates="nm")
id_wrong = root_ids != id
not_zero = root_ids != 0
to_fix = id_wrong
if not snap_zero:
to_fix = to_fix & not_zero
# Use parallel processes to go over the to-fix nodes
with navis.config.tqdm(desc="Snapping", total=to_fix.sum(), leave=False) as pbar:
with futures.ProcessPoolExecutor(max_workers=max_workers) as ex:
loc_futures = [
ex.submit(
_process_cutout,
id=id,
loc=locs[ix],
dataset=dataset,
radius=search_radius,
)
for ix in np.where(to_fix)[0]
]
for f in futures.as_completed(loc_futures):
pbar.update(1)
# Get results
results = [f.result() for f in loc_futures]
# Stack locations
new_locs = np.vstack(results)
# If no new location found, array will be [0, 0, 0]
not_snapped = new_locs.max(axis=1) == 0
# Update location
to_update = np.where(to_fix)[0][~not_snapped]
locs[to_update, :] = new_locs[~not_snapped]
if verbose:
msg = f"""\
{to_fix.sum()} of {to_fix.shape[0]} locations needed to be snapped.
Of these {not_snapped.sum()} locations could not be snapped - consider
increasing `search_radius`.
"""
print(textwrap.dedent(msg))
return locs
def _process_cutout(loc, id, radius=160, dataset="production"):
"""Process single cutout for snap_to_id."""
# Get this location
loc = loc.round()
# Generating bounding box around this location
mn = loc - radius
mx = loc + radius
# Make sure it's a multiple of 4 and 40
mn = mn - mn % [4, 4, 40]
mx = mx - mx % [4, 4, 40]
# Generate bounding box
bbox = np.vstack((mn, mx))
# Get the cutout, the resolution and offset
cutout, res, offset_nm = get_segmentation_cutout(
bbox, dataset=dataset, root_ids=True, coordinates="nm"
)
# Generate a mask
mask = (cutout == id).astype(int, copy=False)
# Erode so we move our point slightly more inside the segmentation
mask = ndimage.binary_erosion(mask).astype(mask.dtype)
# Find positions the ID we are looking for
our_id = np.vstack(np.where(mask)).T
# Return [0, 0, 0] if unable to snap (i.e. if id not within radius)
if not our_id.size:
return np.array([0, 0, 0])
# Get the closest on to the center of the cutout
center = np.divide(cutout.shape, 2).round()
dist = np.abs(our_id - center).sum(axis=1)
closest = our_id[np.argmin(dist)]
# Convert the cutout offset to absolute 4/4/40 voxel coordinates
snapped = closest * res + offset_nm
return snapped
[docs]
@inject_dataset()
def get_segmentation_cutout(
bbox, root_ids=True, mip=0, coordinates="voxel", *, dataset=None
):
"""Fetch cutout of segmentation.
Parameters
----------
bbox : array-like
Bounding box for the cutout::
[[xmin, xmax], [ymin, ymax], [zmin, zmax]]
root_ids : bool
If True, will return root IDs. If False, will return
supervoxel IDs. Ignored if dataset is "flat_630".
coordinates : "voxel" | "nm"
Units in which your coordinates are in. "voxel" is assumed
to be 4x4x40 (x/y/z) nanometers.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
cutout : np.ndarry
(N, M) array of segmentation (root or supervoxel) IDs.
resolution : (3, ) numpy array
[x, y, z] resolution of voxel in cutout.
nm_offset : (3, ) numpy array
[x, y, z] offset in nanometers of the cutout with respect
to the absolute coordinates.
"""
assert coordinates in ["nm", "nanometer", "nanometers", "voxel", "voxels"]
bbox = np.asarray(bbox)
assert bbox.ndim == 2
if bbox.shape == (2, 3):
pass
elif bbox.shape == (3, 2):
bbox = bbox.T
else:
raise ValueError(f"`bbox` must have shape (2, 3) or (3, 2), got {bbox.shape}")
vol = get_cloudvolume(dataset)
vol.mip = mip
# First convert to nanometers
if coordinates in ("voxel", "voxels"):
bbox = bbox * np.array([4, 4, 40])
# Now convert (back to) to [16, 16, 40] voxel
bbox = (bbox / vol.scale["resolution"]).round().astype(int)
offset_nm = bbox[0] * vol.scale["resolution"]
# Get cutout
cutout = vol[
bbox[0][0] : bbox[1][0], bbox[0][1] : bbox[1][1], bbox[0][2] : bbox[1][2]
]
if root_ids and ("flat" not in dataset):
svoxels = np.unique(cutout.flatten())
roots = supervoxels_to_roots(svoxels, dataset=vol)
sv2r = dict(zip(svoxels[svoxels != 0], roots[svoxels != 0]))
for k, v in sv2r.items():
cutout[cutout == k] = v
return cutout[:, :, :, 0], np.asarray(vol.scale["resolution"]), offset_nm
@inject_dataset(disallowed=["flat_630", "flat_571"])
def is_valid_root(x, raise_exc=False, *, dataset=None):
"""Check if ID is (potentially) valid root ID.
Parameters
----------
x : int | str | iterable
ID(s) to check.
raise_exc : bool
If True and any IDs are invalid will raise an error.
Mostly for internal use.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
bool
If ``x`` is a single ID.
array
If ``x`` is iterable.
See Also
--------
:func:`~fafbseg.flywire.is_valid_supervoxel`
Use this function to check if a supervoxel ID is valid.
"""
client = get_cave_client(dataset=dataset)
vol = get_cloudvolume(client.chunkedgraph.cloudvolume_path)
def _is_valid(x, raise_exc):
try:
is_valid = vol.get_chunk_layer(x) == vol.info["graph"]["n_layers"]
except ValueError:
is_valid = False
if raise_exc and not is_valid:
raise ValueError(f"{x} is not a valid root ID")
return is_valid
if navis.utils.is_iterable(x):
is_valid = np.array([_is_valid(r, raise_exc=False) for r in x])
if raise_exc and not all(is_valid):
invalid = set(np.asarray(x)[~is_valid].tolist())
raise ValueError(f"Invalid root IDs found: {invalid}")
return is_valid
else:
return _is_valid(x, raise_exc=raise_exc)
@inject_dataset(disallowed=["flat_630", "flat_571"])
def is_valid_supervoxel(x, raise_exc=False, *, dataset=None):
"""Check if ID is (potentially) valid supervoxel ID.
Parameters
----------
x : int | str | iterable
ID(s) to check.
raise_exc : bool
If True and any IDs are invalid will raise an error.
Mostly for internal use.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
bool
If ``x`` is a single ID.
array
If ``x`` is iterable.
See Also
--------
:func:`~fafbseg.flywire.is_valid_root`
Use this function to check if a root ID is valid.
"""
client = get_cave_client(dataset=dataset)
vol = get_cloudvolume(client.chunkedgraph.cloudvolume_path)
def _is_valid(x, raise_exc):
try:
is_valid = vol.get_chunk_layer(x) == 1
except ValueError:
is_valid = False
if raise_exc and not is_valid:
raise ValueError(f"{x} is not a valid supervoxel ID")
return is_valid
if navis.utils.is_iterable(x):
is_valid = np.array([_is_valid(r, raise_exc=False) for r in x])
if raise_exc and not all(is_valid):
invalid = set(np.asarray(x)[~is_valid].tolist())
raise ValueError(f"Invalid supervoxel IDs found: {invalid}")
return is_valid
else:
return _is_valid(x, raise_exc=raise_exc)
[docs]
@inject_dataset(disallowed=["flat_630", "flat_571"])
def get_voxels(
x,
mip=0,
sv_map=False,
bounds=None,
thin=False,
progress=True,
use_mirror=True,
threads=4,
*,
dataset=None,
):
"""Fetch voxels making a up given root ID.
Parameters
----------
x : int
A single root ID.
mip : int
Scale at which to fetch voxels. For example, `mip=0` is
at 16 x 16 x 40nm resolution. Every subsequent `mip` halves
the resolution.
sv_map : bool
If True, additionally return a map with the L2 ID for each
voxel.
bounds : (3, 2) or (2, 3) array, optional
Bounding box to return voxels in. Expected to be in 4, 4, 40
voxel space.
thin : bool
If True, will remove voxels at the interface of adjacent
supervoxels that are not supposed to be connected according
to the L2 graph. This is rather expensive but can help in
situations where a neuron self-touches.
use_mirror : bool
If True (default), will use an mirror of the base
segmentation for supervoxel look-up. Possibly slightly
slower than the production dataset but doesn't incur
egress charges for Princeton.
threads : int
Number of parallel threads to use for fetching the data.
progress : bool
Whether to show a progress bar or not.
dataset : "public" | "production" | "sandbox" | "flat_630", optional
Against which FlyWire dataset to query. If ``None`` will fall
back to the default dataset (see
:func:`~fafbseg.flywire.set_default_dataset`).
Returns
-------
voxels : (N, 3) np.ndarray
In voxel space according to `mip`.
sv_map : (N, ) np.ndarray
Supervoxel ID for each voxel. Only if `sv_map=True`.
"""
# IDEA:
# 1. Find surface voxels for each L2 chunk
# 2. Get L2 graph and determine which L2 chunks are supposed to be connected
# 3. Remove surface voxel between adjacent but not connected L2 chunks
from .l2 import chunks_to_nm
# This is a mirror for base segmentation
vol = get_cloudvolume(dataset)
client = get_cave_client()
if use_mirror:
sv_vol = cv.CloudVolume(
"precomputed://https://seungdata.princeton.edu/"
"sseung-archive/fafbv14-ws/"
"ws_190410_FAFB_v02_ws_size_threshold_200",
use_https=True,
progress=False,
fill_missing=True,
)
else:
sv_vol = vol
is_valid_root(x, raise_exc=True, dataset=dataset)
# Get L2 chunks making up this neuron
l2_ids = client.chunkedgraph.get_leaves(x, stop_layer=2)
# Get supervoxels for this neuron
sv = roots_to_supervoxels(x, dataset=dataset)[x]
# Turn l2_ids into chunk indices
l2_ix = [np.array(vol.mesh.meta.meta.decode_chunk_position(l)) for l in l2_ids]
l2_ix = np.unique(l2_ix, axis=0)
# Convert to nm
l2_nm = np.asarray(chunks_to_nm(l2_ix, vol=vol))
# Convert to voxel space
l2_vxl = l2_nm // vol.meta.scales[mip]["resolution"]
# Apply bounds
bounds = parse_bounds(bounds)
if not isinstance(bounds, type(None)):
base_to_mip = np.array(vol.meta.scales[mip]["resolution"]) / [4, 4, 40]
bounds = bounds // base_to_mip.reshape(-1, 1)
l2_vxl = l2_vxl[np.all(l2_vxl >= bounds[:, 0], axis=1)]
l2_vxl = l2_vxl[np.all(l2_vxl <= bounds[:, 1], axis=1)]
voxels = []
svids = []
ch_size = np.array(vol.mesh.meta.meta.graph_chunk_size)
ch_size = ch_size // (vol.mip_resolution(mip) / vol.mip_resolution(0))
old_mip = sv_vol.mip
old_parallel = sv_vol.parallel
try:
sv_vol.mip = mip
sv_vol.parallel = threads
for ch in tqdm(
l2_vxl, disable=not progress, leave=False, desc="Fetching voxels"
):
ct = sv_vol[
ch[0] : ch[0] + ch_size[0],
ch[1] : ch[1] + ch_size[1],
ch[2] : ch[2] + ch_size[2],
][:, :, :, 0]
is_root = np.isin(ct, sv)
this_vxl = np.dstack(np.where(is_root))[0]
this_vxl = this_vxl + ch
voxels.append(this_vxl)
if sv_map or thin:
svids.append(ct[is_root])
except BaseException:
raise
finally:
sv_vol.mip = old_mip
sv_vol.parallel = old_parallel
# uint 16 should be sufficient because even at mip 0 the volume has
# shape (54100, 28160, 7046) -> doesn't exceed 65_535
voxels = np.vstack(voxels).astype("uint16")
if len(svids):
svids = np.concatenate(svids)
if thin:
from .l2 import get_l2_graph
try:
from pykdtree.kdtree import KDTree
except ImportError:
from scipy.spatial import cKDTree as KDTree
# Get the l2 ID for each supervoxel
l2_ids = vol.get_roots(svids, stop_layer=2)
l2_dict = dict(zip(svids, l2_ids))
# Get the l2 graph
G = get_l2_graph(x)
# Create KD tree for all voxels
tree = KDTree(voxels)
# Create a mask for invalidated voxels
invalid = np.zeros(len(voxels), dtype=bool)
# Now go over each supervoxel
for sv in tqdm(
np.unique(svids), disable=not progress, desc="Thinning", leave=False
):
# Get the voxels for this supervoxel
is_this_sv = svids == sv
# If supervoxel has no voxels just continue
if not np.any(is_this_sv):
continue
# Get all supervoxels that could be connected to this supervoxel
is_this_l2 = l2_ids == l2_dict[sv]
is_connected_l2 = np.isin(l2_ids, list(G.neighbors(l2_dict[sv])))
# The mask needs to exclude anything that:
# Isn't this supervoxel OR is supposed to be connected OR
# has already been invalidated in a prior run
mask = is_this_l2 | is_connected_l2 | invalid
# Find "other" voxels that touch voxels for this supervoxel
dist, ix = tree.query(
voxels[is_this_sv], mask=mask, distance_upper_bound=1.75
)
is_touching = dist < np.inf
if not np.any(is_touching):
continue
invalid[np.where(is_this_sv)[0][is_touching]] = True
voxels = voxels[~invalid]
svids = svids[~invalid]
if not sv_map:
return voxels
else:
return voxels, svids