diff --git a/.gitignore b/.gitignore index e58b693..168d626 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .DS_Store .DS_Store? *.asv +python/bml/__pycache__/ diff --git a/python/bml/__init__.py b/python/bml/__init__.py new file mode 100644 index 0000000..d367b64 --- /dev/null +++ b/python/bml/__init__.py @@ -0,0 +1,7 @@ +"""BML - Brain Modulation Lab toolbox. + +Python translation of the MATLAB BML toolbox for ECoG/MER/LFP/Audio +data manipulation and analysis. +""" + +from bml import utils, annot, sync, stat diff --git a/python/bml/annot.py b/python/bml/annot.py new file mode 100644 index 0000000..5c481e4 --- /dev/null +++ b/python/bml/annot.py @@ -0,0 +1,772 @@ +"""Annotation table functions for BML. + +Python translations of MATLAB functions from the BML toolbox annot/ +directory. Annotation tables are represented as :class:`pandas.DataFrame` +objects with at least the columns ``id``, ``starts``, ``ends`` and +``duration``. The table description is stored in ``df.attrs['description']``. +""" + +import os +import re +import warnings + +import numpy as np +import pandas as pd + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _reorder_vars(df, first_cols): + """Reorder columns so *first_cols* come first, preserving others.""" + first = [c for c in first_cols if c in df.columns] + rest = [c for c in df.columns if c not in first] + return df[first + rest] + + +def _get_description(df): + """Return the description stored on a DataFrame, or ''.""" + if isinstance(df, pd.DataFrame): + return df.attrs.get("description", "") + return "" + + +def _set_description(df, description): + """Set description on a DataFrame.""" + df.attrs["description"] = description or "" + return df + + +def _conform_to(template, other): + """Conform *other* to have the same columns as *template*.""" + for col in template.columns: + if col not in other.columns: + other = other.copy() + other[col] = np.nan + return other[template.columns] + + +def _collapse_rows(rows, additive=None): + """Collapse multiple annotation rows into a single summary row.""" + additive = additive or [] + result = { + "starts": rows["starts"].min(), + "ends": rows["ends"].max(), + "cons_duration": rows["duration"].sum(), + "id_starts": rows["id"].min(), + "id_ends": rows["id"].max(), + "cons_n": len(rows), + } + skip = set(result.keys()) | set(additive) + for col in rows.columns: + if col in skip or col in ("id", "duration"): + continue + vals = rows[col].dropna().unique() + if len(vals) == 1: + result[col] = vals[0] + else: + result[col] = np.nan + for col in additive: + if col in rows.columns: + result[col] = rows[col].sum() + return result + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def annot_table(x=None, description=None): + """Create or validate an annotation DataFrame. + + Translated from ``bml_annot_table.m``. + + Parameters + ---------- + x : DataFrame, dict, list, numpy.ndarray, or None + Input data. Must contain ``starts`` and ``ends`` columns (or be + coercible to a two-column table that will be renamed). + description : str or None + Optional description stored in ``df.attrs['description']``. + + Returns + ------- + pandas.DataFrame + Annotation table with columns ``id``, ``starts``, ``ends``, + ``duration`` followed by any extra columns. + """ + # Handle empty / None + if x is None or (isinstance(x, pd.DataFrame) and x.empty): + df = pd.DataFrame() + return _set_description(df, description or "") + + # Convert various types to DataFrame + if isinstance(x, np.ndarray): + if x.ndim == 1: + x = x.reshape(-1, 1) + df = pd.DataFrame(x) + elif isinstance(x, dict): + df = pd.DataFrame(x) + elif isinstance(x, list): + df = pd.DataFrame(x) + elif isinstance(x, pd.DataFrame): + df = x.copy() + else: + df = pd.DataFrame(x) + + if description is None: + description = _get_description(df) or "" + + if df.empty: + return _set_description(df, description) + + # Ensure 'starts' column + if "starts" not in df.columns: + if len(df.columns) <= 2: + cols = list(df.columns) + df = df.rename(columns={cols[0]: "starts"}) + else: + raise ValueError("x should have variable 'starts'") + + # Ensure 'ends' column + if "ends" not in df.columns: + if len(df.columns) == 1: + df["ends"] = df["starts"] + elif len(df.columns) == 2: + cols = list(df.columns) + other = [c for c in cols if c != "starts"][0] + df = df.rename(columns={other: "ends"}) + else: + raise ValueError("x should have variable 'ends'") + + # Ensure 'id' column + if "id" not in df.columns: + df = df.sort_values("starts").reset_index(drop=True) + df.insert(0, "id", range(1, len(df) + 1)) + else: + if df["id"].nunique() < len(df): + raise ValueError("inconsistent id variable") + df = df.sort_values("id").reset_index(drop=True) + + # Recalculate duration + if "duration" in df.columns: + df = df.drop(columns=["duration"]) + df["duration"] = df["ends"] - df["starts"] + + df = _reorder_vars(df, ["id", "starts", "ends", "duration"]) + return _set_description(df, description) + + +def annot_overlap(annot, timetol=1e-5): + """Find overlapping annotations. + + Translated from ``bml_annot_overlap.m``. + + Parameters + ---------- + annot : pandas.DataFrame + Annotation table with ``starts``, ``ends``, ``id`` columns. + timetol : float, optional + Time tolerance in seconds (default ``1e-5``). + + Returns + ------- + pandas.DataFrame + Table with columns ``starts``, ``ends``, ``id1``, ``id2`` for + each overlapping pair, or an empty DataFrame if none found. + """ + annot = annot_table(annot) + if len(annot) <= 1: + return pd.DataFrame(columns=["starts", "ends", "id1", "id2"]) + + rows = [] + i, j = 0, 1 + n = len(annot) + while i < n and j < n: + si, ei = annot["starts"].iloc[i], annot["ends"].iloc[i] + sj, ej = annot["starts"].iloc[j], annot["ends"].iloc[j] + if ej - si > timetol and ei - sj > timetol: + rows.append({ + "starts": max(si, sj), + "ends": min(ei, ej), + "id1": annot["id"].iloc[i], + "id2": annot["id"].iloc[j], + }) + j += 1 + elif ei - sj <= timetol: + i += 1 + j = i + 1 + elif ej - si <= timetol: + j += 1 + else: + raise RuntimeError("Unsupported input annotations tables") + + if not rows: + return pd.DataFrame(columns=["starts", "ends", "id1", "id2"]) + return pd.DataFrame(rows) + + +def annot_intersect(x, y, keep="both", groupby=None, groupby_x=None, + groupby_y=None, description=None, warn=True): + """Intersection of two annotation tables. + + Translated from ``bml_annot_intersect.m``. + + Parameters + ---------- + x, y : pandas.DataFrame + Annotation tables. *y* should have no overlapping annotations. + keep : str + Which extra variables to keep: ``'both'``, ``'none'``, ``'x'``, + or ``'y'``. + groupby, groupby_x, groupby_y : str or None + Column name(s) to group by before intersecting. + description : str or None + Description for the result. + warn : bool + Warn on variable name conflicts. + + Returns + ------- + pandas.DataFrame + Intersection annotation table. + """ + x = annot_table(x) + y = annot_table(y) + + if x.empty: + return x.copy() + if y.empty: + return y.copy() + + x_desc = _get_description(x) or "x" + y_desc = _get_description(y) or "y" + if x_desc == y_desc: + x_desc, y_desc = x_desc + "_x", y_desc + "_y" + + xidn = f"{x_desc}_id" + yidn = f"{y_desc}_id" + + if description is None: + description = f"intersect_{x_desc}_{y_desc}" + + if groupby_x is None: + groupby_x = groupby + if groupby_y is None: + groupby_y = groupby + + # Determine groups + if groupby_x is None and groupby_y is None: + x = x.copy() + y = y.copy() + x["_groupby_"] = 1 + y["_groupby_"] = 1 + groupby_x = groupby_y = "_groupby_" + groups = [1] + elif groupby_x is not None and groupby_y is not None: + gx = set(x[groupby_x].unique()) + gy = set(y[groupby_y].unique()) + groups = sorted(gx & gy) + if not groups: + return _set_description(pd.DataFrame(), description) + else: + raise ValueError("groupby_x and groupby_y must both be given or both be None") + + result_rows = [] + for g in groups: + xg = x[x[groupby_x] == g] + yg = y[y[groupby_y] == g] + if yg.empty or xg.empty: + continue + + # Two-pointer intersection + xg = xg.sort_values("starts").reset_index(drop=True) + yg = yg.sort_values("starts").reset_index(drop=True) + + has_x_overlap = not annot_overlap(xg).empty if len(xg) > 1 else False + + i, j = 0, 0 + while i < len(xg) and j < len(yg): + xs, xe = xg["starts"].iloc[i], xg["ends"].iloc[i] + ys, ye = yg["starts"].iloc[j], yg["ends"].iloc[j] + if xs < ye and xe > ys: + result_rows.append({ + "starts": max(xs, ys), + "ends": min(xe, ye), + xidn: xg["id"].iloc[i], + yidn: yg["id"].iloc[j], + }) + if has_x_overlap: + if xe < ye or j >= len(yg) - 1: + i += 1 + j = 0 + else: + j += 1 + else: + if xe < ye: + i += 1 + else: + j += 1 + elif xe <= ys: + i += 1 + if has_x_overlap: + j = 0 + elif xs >= ye: + j += 1 + else: + raise RuntimeError("Unsupported input annotations tables") + + if not result_rows: + return _set_description(pd.DataFrame(), description) + + result = pd.DataFrame(result_rows) + result = annot_table(result, description) + + # Remove groupby helper column + if "_groupby_" in result.columns: + result = result.drop(columns=["_groupby_"]) + if "_groupby_" in x.columns: + x = x.drop(columns=["_groupby_"]) + if "_groupby_" in y.columns: + y = y.drop(columns=["_groupby_"]) + + # Join extra variables based on keep + keep = keep.lower().replace(" ", "").replace("_", "").replace("keep", "") + if keep in ("both", "x"): + x_join = x.drop(columns=["starts", "ends"], errors="ignore") + if groupby_x and groupby_x in x_join.columns and groupby_x != "_groupby_": + x_join = x_join.drop(columns=[groupby_x]) + x_join = x_join.rename(columns={"id": xidn}) + # Prefix common columns + for col in x_join.columns: + if col != xidn and col in result.columns: + x_join = x_join.rename(columns={col: f"{x_desc}_{col}"}) + result = result.merge(x_join, on=xidn, how="left") + + if keep in ("both", "y"): + y_join = y.drop(columns=["starts", "ends"], errors="ignore") + if groupby_y and groupby_y in y_join.columns and groupby_y != "_groupby_": + y_join = y_join.drop(columns=[groupby_y]) + y_join = y_join.rename(columns={"id": yidn}) + for col in y_join.columns: + if col != yidn and col in result.columns: + y_join = y_join.rename(columns={col: f"{y_desc}_{col}"}) + result = result.merge(y_join, on=yidn, how="left") + + return _set_description(result, description) + + +def annot_filter(annot, filter_annot, overlap=0, description=None): + """Filter annotations by intersection with *filter_annot*. + + Translated from ``bml_annot_filter.m``. + + Parameters + ---------- + annot : pandas.DataFrame + Annotations to filter. + filter_annot : pandas.DataFrame + Filter annotations. + overlap : float, optional + Minimum fraction of overlap required (default ``0`` = touch). + description : str or None + Description for the result. + + Returns + ------- + pandas.DataFrame + Filtered annotations. + """ + annot = annot_table(annot) + filter_annot = annot_table(filter_annot) + + if annot.empty: + return annot.copy() + + # Fast path for single-row touch filter + if overlap == 0 and len(filter_annot) == 1: + fs = filter_annot["starts"].iloc[0] + fe = filter_annot["ends"].iloc[0] + return annot[(annot["starts"] < fe) & (annot["ends"] > fs)].copy() + + inter = annot_intersect( + annot, + filter_annot[["id", "starts", "ends", "duration"]], + keep="none", + ) + if inter.empty: + return _set_description(pd.DataFrame(columns=annot.columns), description) + + a_desc = _get_description(annot) or "x" + annot_id_col = f"{a_desc}_id" + if annot_id_col not in inter.columns: + # Try default names + for col in inter.columns: + if col.endswith("_id") and col != "id": + annot_id_col = col + break + + if overlap > 0: + # Sum intersection durations per annot_id + inter_dur = inter.groupby(annot_id_col)["duration"].sum().reset_index() + inter_dur.columns = [annot_id_col, "intersect_dur"] + merged = annot[annot["id"].isin(inter_dur[annot_id_col])].merge( + inter_dur, left_on="id", right_on=annot_id_col, how="left" + ) + ratio = merged["intersect_dur"] / merged["duration"] + keep_ids = merged.loc[(ratio >= overlap) | ratio.isna(), "id"] + return annot[annot["id"].isin(keep_ids)].copy() + else: + return annot[annot["id"].isin(inter[annot_id_col])].copy() + + +def annot_consolidate(annot, criterion=None, additive=None, groupby=None, + description=None): + """Consolidate (merge) overlapping or contiguous annotations. + + Translated from ``bml_annot_consolidate.m``. + + Parameters + ---------- + annot : pandas.DataFrame + Annotation table. + criterion : callable or None + Function accepting a DataFrame of candidate rows and returning + ``True`` if they should be merged. Default: merge if the last + row's ``starts`` is <= the max ``ends`` of previous rows. + additive : list of str or None + Column names whose values should be summed during collapse. + groupby : str or None + Column name to group by before consolidating. + description : str or None + Description for the result. + + Returns + ------- + pandas.DataFrame + Consolidated annotation table. + """ + annot = annot_table(annot) + if annot.empty: + return annot.copy() + + if description is None: + description = "cons_" + (_get_description(annot) or "annot") + additive = additive or [] + + if criterion is None: + def criterion(rows): + return rows["starts"].iloc[-1] <= rows["ends"].iloc[:-1].max() + + if groupby is None: + groups = [None] + else: + groups = sorted(annot[groupby].unique()) + + all_cons = [] + for g in groups: + if g is None: + ag = annot + else: + ag = annot[annot[groupby] == g] + + ag = ag.sort_values("starts").reset_index(drop=True) + + if len(ag) <= 1: + row = _collapse_rows(ag, additive) + all_cons.append(row) + continue + + i = 0 + j = 1 + while i < len(ag): + if j == 1: + curr_rows = ag.iloc[i:i + 1] + + if i + j >= len(ag): + all_cons.append(_collapse_rows(curr_rows, additive)) + break + + merge_rows = ag.iloc[i:i + j + 1] + if criterion(merge_rows): + curr_rows = merge_rows + j += 1 + if i + j > len(ag): + all_cons.append(_collapse_rows(curr_rows, additive)) + break + else: + all_cons.append(_collapse_rows(curr_rows, additive)) + i = i + j + j = 1 + if i == len(ag) - 1: + all_cons.append(_collapse_rows(ag.iloc[i:i + 1], additive)) + break + + if not all_cons: + return _set_description(pd.DataFrame(), description) + + result = pd.DataFrame(all_cons) + result = annot_table(result, description) + return result + + +def annot_rename(annot, *args, **kwargs): + """Rename columns of an annotation table. + + Translated from ``bml_annot_rename.m``. + + Can be called as:: + + annot_rename(df, 'old1', 'new1', 'old2', 'new2') + annot_rename(df, old1='new1', old2='new2') + + Parameters + ---------- + annot : pandas.DataFrame + Annotation table. + *args : str + Alternating old/new column name pairs. + **kwargs : str + Old=new column name mappings. + + Returns + ------- + pandas.DataFrame + Renamed annotation table. + """ + rename_map = {} + if args: + if len(args) % 2 != 0: + raise ValueError("Column rename arguments must come in pairs") + for i in range(0, len(args), 2): + rename_map[args[i]] = args[i + 1] + rename_map.update(kwargs) + + for old_name in rename_map: + if old_name not in annot.columns: + raise ValueError(f"variable {old_name} not present in annotation table") + + desc = _get_description(annot) + result = annot.rename(columns=rename_map) + return _set_description(result, desc) + + +def annot_read(filename, **kwargs): + """Read an annotation table from a tab-delimited file. + + Translated from ``bml_annot_read.m``. + + Parameters + ---------- + filename : str + Path to the file. + **kwargs + Additional keyword arguments passed to :func:`pandas.read_csv`. + + Returns + ------- + pandas.DataFrame + Annotation table. + """ + kwargs.setdefault("sep", "\t") + kwargs.setdefault("na_values", ["NA"]) + + df = pd.read_csv(filename, **kwargs) + name = os.path.splitext(os.path.basename(filename))[0] + + if "onset" in df.columns and "duration" in df.columns: + df = df.rename(columns={"onset": "starts"}) + df["ends"] = df["starts"] + df["duration"] + df["id"] = range(1, len(df) + 1) + + return annot_table(df, name) + + +def annot_read_tsv(filename, append_cols_from_filename=False, **kwargs): + """Read a BIDS-style TSV annotation table. + + Translated from ``bml_annot_read_tsv.m``. + + Parameters + ---------- + filename : str + Path to the ``.tsv`` file. + append_cols_from_filename : bool, optional + If ``True``, extract ``subject_id``, ``session_id``, ``task_id`` + from BIDS-style filename patterns. + **kwargs + Additional keyword arguments passed to :func:`pandas.read_csv`. + + Returns + ------- + pandas.DataFrame + Annotation table. + """ + kwargs.setdefault("sep", "\t") + kwargs.setdefault("na_values", ["n/a"]) + + df = pd.read_csv(filename, **kwargs) + name = os.path.splitext(os.path.basename(filename))[0] + + if "onset" in df.columns and "duration" in df.columns: + df = df.rename(columns={"onset": "starts"}) + df["ends"] = df["starts"] + df["duration"] + df["id"] = range(1, len(df) + 1) + df = annot_table(df, name) + + if append_cols_from_filename: + bids_keys = {"sub": "subject_id", "ses": "session_id", "task": "task_id"} + for key, col in bids_keys.items(): + pattern = rf"(?<={key}-)[a-zA-Z0-9-]+" + matches = re.findall(pattern, filename) + if matches and col not in df.columns: + df[col] = matches[0] + + return annot_table(df, name) + + +def annot_write(annot, filename): + """Write an annotation table to a tab-delimited file. + + Translated from ``bml_annot_write.m``. + + Parameters + ---------- + annot : pandas.DataFrame + Annotation table. + filename : str + Output file path. + """ + annot = annot_table(annot) + df = annot.copy() + + _, ext = os.path.splitext(filename) + if ext == ".tsv": + df = df.drop(columns=["id", "ends"], errors="ignore") + + df.to_csv(filename, sep="\t", index=False) + + +def annot_write_tsv(annot, filename): + """Write an annotation table in BIDS TSV format. + + Translated from ``bml_annot_write_tsv.m``. + + Renames ``starts`` to ``onset`` and drops ``id`` and ``ends`` columns. + + Parameters + ---------- + annot : pandas.DataFrame + Annotation table. + filename : str + Output file path. + """ + annot = annot_table(annot) + df = annot.rename(columns={"starts": "onset"}) + df = df.drop(columns=["id", "ends"], errors="ignore") + df.to_csv(filename, sep="\t", index=False) + + +def annot_rowbind(*args): + """Row-bind multiple annotation DataFrames. + + Translated from ``bml_annot_rowbind.m``. + + Parameters + ---------- + *args : pandas.DataFrame + Annotation tables to concatenate. + + Returns + ------- + pandas.DataFrame + Combined annotation table. + """ + frames = [df for df in args if df is not None and not df.empty] + if not frames: + return pd.DataFrame() + + # Conform all to the first frame's columns + template = frames[0] + conformed = [template] + for df in frames[1:]: + conformed.append(_conform_to(template, df)) + + result = pd.concat(conformed, ignore_index=True) + if "id" in result.columns: + result = result.drop(columns=["id"]) + if "starts" in result.columns and "ends" in result.columns: + result = annot_table(result) + return result + + +def annot_coverage(x, y, groupby_x=None, groupby_y=None, colname="coverage"): + """Calculate fraction of *y* covered by *x*. + + Translated from ``bml_annot_coverage.m``. + + Parameters + ---------- + x : pandas.DataFrame + Numerator annotations. + y : pandas.DataFrame + Denominator annotations. + groupby_x : str or None + Column name to group *x* by. + groupby_y : str or None + Column name to group *y* by. + colname : str + Name for the coverage column (default ``'coverage'``). + + Returns + ------- + pandas.DataFrame + Copy of *y* with an added coverage column. + """ + x = annot_table(x) + y = annot_table(y) + + if groupby_x is None: + groups = [None] + else: + groups = sorted(x[groupby_x].unique()) + + result_rows = [] + for g in groups: + if g is None: + xg = x + else: + xg = x[x[groupby_x] == g].sort_values("starts") + + yg = y if groupby_y is None else y[y[groupby_y] == g] + + for _, yrow in yg.iterrows(): + if xg.empty: + cvg = 0.0 + else: + t = yrow["starts"] + cvg = 0.0 + for _, xrow in xg.iterrows(): + if t >= yrow["ends"]: + break + if xrow["starts"] < yrow["ends"] and xrow["ends"] > t: + s = max(xrow["starts"], t) + e = min(xrow["ends"], yrow["ends"]) + cvg += e - s + t = e + cvg = cvg / yrow["duration"] if yrow["duration"] > 0 else 0.0 + + row_data = yrow.to_dict() + if groupby_x and g is not None: + row_data[groupby_x] = g + row_data[colname] = cvg + result_rows.append(row_data) + + if not result_rows: + return pd.DataFrame() + + result = pd.DataFrame(result_rows) + if "id" in result.columns: + result = result.drop(columns=["id"]) + return annot_table(result) diff --git a/python/bml/stat.py b/python/bml/stat.py new file mode 100644 index 0000000..21eb042 --- /dev/null +++ b/python/bml/stat.py @@ -0,0 +1,342 @@ +"""Statistical functions for BML. + +Python translations of MATLAB functions from the BML toolbox stat/ +directory. +""" + +import warnings +from itertools import combinations + +import numpy as np +from scipy.stats import norm + + +# -- robust_std --------------------------------------------------------------- + +def robust_std(data, center=None): + """Row-wise robust estimation of standard deviation. + + Translated from ``bml_robust_std.m``. + + The estimator works by finding the quantile of absolute deviations + from *center* and scaling by the corresponding normal quantile. + It iterates over increasing quantile levels starting at 0.5 until + the estimate is numerically distinguishable from zero, or returns 0. + + Parameters + ---------- + data : numpy.ndarray + 1-D or 2-D array. If 2-D, the robust standard deviation is + computed independently for each row. + center : numpy.ndarray or None, optional + Center of the distribution for each row. Must be a 1-D array + with one element per row when *data* is 2-D. Defaults to the + row-wise ``nanmedian``. + + Returns + ------- + numpy.ndarray + 1-D array of length ``data.shape[0]`` (or length 1 for a 1-D + input) containing the robust standard deviation estimates. + """ + data = np.atleast_2d(np.asarray(data, dtype=float)) + + if center is None: + center = np.nanmedian(data, axis=1) + else: + center = np.asarray(center, dtype=float).ravel() + + n_rows = data.shape[0] + result = np.zeros(n_rows) + + for i in range(n_rows): + row = data[i, :] + s = np.nanquantile(np.abs(row), 0.95) + # eps(s) in MATLAB equals np.spacing(s) in NumPy + eps_s = np.spacing(s) + p = 0.5 + while result[i] < 1e4 * eps_s and p < 1: + abs_dev = np.abs(row - center[i]) + result[i] = ( + np.nanquantile(abs_dev, p) / norm.ppf((1 + p) / 2) + ) + p += 0.05 + if result[i] < 1e4 * eps_s: + result[i] = 0.0 + + return result + + +# -- FDR ---------------------------------------------------------------------- + +def FDR(p_list, alpha=0.05, corrected=False): + """False Discovery Rate (Benjamini & Hochberg, 1995). + + Translated from ``bml_FDR.m`` by Edden Gerber. + + Parameters + ---------- + p_list : array_like + 1-D sequence of p-values. + alpha : float, optional + Desired significance threshold (default ``0.05``). + corrected : bool, optional + If ``True``, apply the Benjamini & Yekutieli (2001) + dependency correction (default ``False``). + + Returns + ------- + ind : numpy.ndarray + 0-based indices into *p_list* of the significant p-values. + thres : float + The p-value threshold used. + """ + p_list = np.asarray(p_list, dtype=float).ravel() + n_vals = len(p_list) + num_tests = n_vals + + # Sort descending + p_sorted = np.sort(p_list)[::-1] + + # Build comparison vector (descending rank / num_tests * alpha) + ranks_desc = np.arange(num_tests, 0, -1) # num_tests, ..., 1 + if corrected: + correction = np.sum(np.arange(1, num_tests + 1) / num_tests) + comp = ranks_desc / num_tests * alpha / correction + else: + comp = ranks_desc / num_tests * alpha + + # comp((end-n_vals+1):end) – since n_vals == num_tests this is a no-op, + # but kept for fidelity. + comp = comp[-n_vals:] + + # Find first (in descending-sorted order) p-value that passes + indices = np.where(p_sorted <= comp)[0] + if len(indices) == 0: + thres = 0.0 + else: + thres = p_sorted[indices[0]] + + ind = np.where(p_list <= thres)[0] + return ind, thres + + +# -- fdr_bh ------------------------------------------------------------------ + +def fdr_bh(pvals, q=0.05, method='pdep', report=False): + """Benjamini-Hochberg / Benjamini-Yekutieli FDR procedure. + + Translated from ``bml_fdr_bh.m`` by David M. Groppe. + + Executes the Benjamini & Hochberg (1995) or Benjamini & Yekutieli + (2001) procedure for controlling the false discovery rate of a + family of hypothesis tests, and returns FCR-adjusted confidence + interval coverage. + + Parameters + ---------- + pvals : array_like + Vector or matrix of p-values. + q : float, optional + Desired false discovery rate (default ``0.05``). + method : str, optional + ``'pdep'`` for the original BH procedure (valid under + independence or positive dependence) or ``'dep'`` for the BY + procedure (valid under arbitrary dependence). Default + ``'pdep'``. + report : bool, optional + If ``True``, print a summary to stdout (default ``False``). + + Returns + ------- + h : numpy.ndarray + Boolean array of the same shape as *pvals*; ``True`` where the + null hypothesis is rejected. + crit_p : float + Critical p-value threshold. 0 if nothing is significant. + adj_ci_cvrg : float + FCR-adjusted confidence interval coverage, or ``NaN`` if no + p-values are significant. + adj_p : numpy.ndarray + Adjusted p-values (same shape as *pvals*). Values can exceed 1. + + Raises + ------ + ValueError + If *pvals* contains values outside [0, 1] or *method* is + unrecognised. + """ + pvals = np.asarray(pvals, dtype=float) + original_shape = pvals.shape + + if np.any(pvals < 0): + raise ValueError("Some p-values are less than 0.") + if np.any(pvals > 1): + raise ValueError("Some p-values are greater than 1.") + + method = method.lower() + if method not in ('pdep', 'dep'): + raise ValueError("method must be 'pdep' or 'dep'.") + + # Flatten to a sorted row vector (matching MATLAB behaviour for + # matrices with more than one row or > 2 dimensions). + p_flat = pvals.ravel() + sort_ids = np.argsort(p_flat, kind='mergesort') + p_sorted = p_flat[sort_ids] + unsort_ids = np.argsort(sort_ids, kind='mergesort') + m = len(p_sorted) + + ranks = np.arange(1, m + 1, dtype=float) + + if method == 'pdep': + thresh = ranks * q / m + wtd_p = m * p_sorted / ranks + else: # 'dep' + denom = m * np.sum(1.0 / ranks) + thresh = ranks * q / denom + wtd_p = denom * p_sorted / ranks + + # Compute adjusted p-values (D.H.J. Poot's efficient algorithm) + adj_p = np.full(m, np.nan) + wtd_p_sindex = np.argsort(wtd_p, kind='mergesort') + wtd_p_sorted = wtd_p[wtd_p_sindex] + nextfill = 0 # 0-based + for k in range(m): + if wtd_p_sindex[k] >= nextfill: + adj_p[nextfill:wtd_p_sindex[k] + 1] = wtd_p_sorted[k] + nextfill = wtd_p_sindex[k] + 1 + if nextfill >= m: + break + adj_p = adj_p[unsort_ids].reshape(original_shape) + + # Determine significance + rej = p_sorted <= thresh + rej_indices = np.where(rej)[0] + if len(rej_indices) == 0: + crit_p = 0.0 + h = np.zeros(original_shape, dtype=bool) + adj_ci_cvrg = np.nan + else: + max_id = rej_indices[-1] + crit_p = p_sorted[max_id] + h = (pvals <= crit_p) + adj_ci_cvrg = 1.0 - thresh[max_id] + + if report: + n_sig = int(np.sum(p_sorted <= crit_p)) + word = "is" if n_sig == 1 else "are" + print( + f"Out of {m} tests, {n_sig} {word} significant using a " + f"false discovery rate of {q}." + ) + if method == 'pdep': + print( + "FDR/FCR procedure used is guaranteed valid for " + "independent or positively dependent tests." + ) + else: + print( + "FDR/FCR procedure used is guaranteed valid for " + "independent or dependent tests." + ) + + return h, crit_p, adj_ci_cvrg, adj_p + + +# -- permutation_test -------------------------------------------------------- + +def permutation_test(sample1, sample2, permutations, sidedness='both', + exact=False): + """Permutation test for a difference in means. + + Translated from ``permutationTest.m`` by Laurens R Krol. + + Parameters + ---------- + sample1 : array_like + Measurements from the first (experimental) sample. + sample2 : array_like + Measurements from the second (control) sample. + permutations : int + Number of random permutations. Ignored when *exact* is + ``True``. + sidedness : str, optional + ``'both'`` (default) for a two-sided test, ``'smaller'`` to + test that ``mean(sample1) < mean(sample2)``, or ``'larger'`` + to test that ``mean(sample1) > mean(sample2)``. + exact : bool, optional + If ``True``, enumerate all possible combinations instead of + using random permutations (default ``False``). Only feasible + for small sample sizes. + + Returns + ------- + p : float + The p-value. + observed_difference : float + ``nanmean(sample1) - nanmean(sample2)``. + effect_size : float + Hedges' *g* effect size. + """ + sample1 = np.asarray(sample1, dtype=float).ravel() + sample2 = np.asarray(sample2, dtype=float).ravel() + + all_observations = np.concatenate([sample1, sample2]) + observed_difference = np.nanmean(sample1) - np.nanmean(sample2) + + n1 = len(sample1) + n2 = len(sample2) + n_total = n1 + n2 + + # Hedges' g (pooled std with Bessel's correction) + pooled_std = np.sqrt( + ((n1 - 1) * np.nanstd(sample1, ddof=1) ** 2 + + (n2 - 1) * np.nanstd(sample2, ddof=1) ** 2) + / (n_total - 2) + ) + effect_size = ( + observed_difference / pooled_std if pooled_std != 0 else np.nan + ) + + if exact: + all_combinations = list(combinations(range(n_total), n1)) + permutations = len(all_combinations) + + random_differences = np.empty(permutations) + + for n in range(permutations): + if exact: + idx1 = np.array(all_combinations[n]) + idx2 = np.setdiff1d(np.arange(n_total), idx1) + else: + perm = np.random.permutation(n_total) + idx1 = perm[:n1] + idx2 = perm[n1:] + + random_differences[n] = ( + np.nanmean(all_observations[idx1]) + - np.nanmean(all_observations[idx2]) + ) + + if sidedness == 'both': + p = ( + (np.sum(np.abs(random_differences) > np.abs(observed_difference)) + 1) + / (permutations + 1) + ) + elif sidedness == 'smaller': + p = ( + (np.sum(random_differences < observed_difference) + 1) + / (permutations + 1) + ) + elif sidedness == 'larger': + p = ( + (np.sum(random_differences > observed_difference) + 1) + / (permutations + 1) + ) + else: + raise ValueError( + f"sidedness must be 'both', 'smaller', or 'larger', " + f"got {sidedness!r}" + ) + + return p, observed_difference, effect_size diff --git a/python/bml/sync.py b/python/bml/sync.py new file mode 100644 index 0000000..b6dea89 --- /dev/null +++ b/python/bml/sync.py @@ -0,0 +1,188 @@ +"""Synchronization utilities. + +Translated from the MATLAB ``sync/`` directory of the BML toolbox. +""" + +import math + +import numpy as np +import pandas as pd + +from bml.utils import getopt + +_PTT = 9 # precision for time tolerance = -log10(timetol) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _matlab_round(x): + """Round to nearest integer, with ties away from zero (MATLAB convention). + + ``numpy.round`` uses banker's rounding (ties to even) which differs from + MATLAB's ``round`` which rounds ties away from zero. + """ + x = np.asarray(x, dtype=float) + return (np.sign(x) * np.floor(np.abs(x) + 0.5)).astype(int) + + +def _round_significant(x, n): + """Round *x* to *n* significant digits. + + Equivalent to MATLAB ``round(x, n, 'significant')``. + + Parameters + ---------- + x : float + Value to round. + n : int + Number of significant digits. + + Returns + ------- + float + """ + if x == 0: + return 0.0 + magnitude = int(np.floor(np.log10(abs(x)))) + return round(x, -magnitude + (n - 1)) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def time2idx(cfg, time, skip_factor=1): + """Calculate sample indices from times and file coordinates. + + Translated from ``bml_time2idx.m``. + + Parameters + ---------- + cfg : dict + Configuration with keys ``t1``, ``s1``, ``t2``, ``s2`` and + optionally ``nSamples``. + time : array_like + Numeric vector of times. + skip_factor : int, optional + Integer downsample factor (default ``1``). + + Returns + ------- + numpy.ndarray + Integer sample indices corresponding to each time. + + Raises + ------ + ValueError + If any computed index exceeds *nSamples*. + """ + skip_factor = int(round(skip_factor)) + time = np.asarray(time, dtype=float) + + t1 = np.round(getopt(cfg, 't1'), _PTT) + s1 = math.ceil(getopt(cfg, 's1') / skip_factor) + t2 = np.round(getopt(cfg, 't2'), _PTT) + s2 = math.floor(getopt(cfg, 's2') / skip_factor) + n_samples = getopt(cfg, 'nSamples') + + idx = _matlab_round( + (t2 * s1 - s2 * t1 + (s2 - s1) * np.round(time, _PTT)) + / (t2 - t1) + ) + + if n_samples is not None and np.any(idx > n_samples): + raise ValueError("index exceeds number of samples in file") + + return idx + + +def idx2time(cfg, idx, skip_factor=1): + """Calculate sample midpoint times from indices and file coordinates. + + Translated from ``bml_idx2time.m``. + + Parameters + ---------- + cfg : dict or pandas.DataFrame + Configuration with keys/columns ``t1``, ``s1``, ``t2``, ``s2``. + When a :class:`~pandas.DataFrame` with more than one row the + sample ranges ``(s1, s2)`` must not overlap. + idx : array_like + Integer sample indices. + skip_factor : int, optional + Integer downsample factor (default ``1``). + + Returns + ------- + numpy.ndarray + Times corresponding to each index. + + Raises + ------ + ValueError + If sample ranges in a multi-row DataFrame overlap. + """ + skip_factor = int(round(skip_factor)) + idx = np.asarray(idx, dtype=float) + + # --- Multi-row DataFrame (split sync) -------------------------------- + if isinstance(cfg, pd.DataFrame) and len(cfg) > 1: + # Inline overlap check: sort by s1 and verify no overlap + sorted_df = cfg.sort_values('s1').reset_index(drop=True) + for i in range(len(sorted_df) - 1): + if sorted_df['s2'].iloc[i] >= sorted_df['s1'].iloc[i + 1]: + raise ValueError( + "sample ranges (s1, s2) must not overlap" + ) + + time = np.zeros(len(idx)) + for _, row in cfg.iterrows(): + t1 = np.round(float(row['t1']), _PTT) + s1 = float(row['s1']) + t2 = np.round(float(row['t2']), _PTT) + s2 = float(row['s2']) + Fs = _round_significant( + (s2 - s1) / np.round(t2 - t1, _PTT), _PTT + ) + + if skip_factor > 1: + s1 = math.ceil(s1 / skip_factor) + s2 = math.floor(s2 / skip_factor) + t1 = t1 + (skip_factor - 1) * 0.5 / Fs + t2 = t2 - (skip_factor - 1) * 0.5 / Fs + Fs = _round_significant( + (s2 - s1) / np.round(t2 - t1, _PTT), _PTT + ) + + mask = (idx >= s1) & (idx <= s2) + time[mask] = ( + idx[mask] / Fs - 0.5 / Fs + + (s2 * t1 - t2 * s1) / (s2 - s1) + ) + return time + + # --- Single-row dict or single-row DataFrame ------------------------- + if isinstance(cfg, pd.DataFrame): + cfg = cfg.iloc[0].to_dict() + + t1 = np.round(getopt(cfg, 't1'), _PTT) + s1 = getopt(cfg, 's1') + t2 = np.round(getopt(cfg, 't2'), _PTT) + s2 = getopt(cfg, 's2') + Fs = _round_significant( + (s2 - s1) / np.round(t2 - t1, _PTT), _PTT + ) + + if skip_factor > 1: + s1 = math.ceil(s1 / skip_factor) + s2 = math.floor(s2 / skip_factor) + t1 = t1 + (skip_factor - 1) * 0.5 / Fs + t2 = t2 - (skip_factor - 1) * 0.5 / Fs + Fs = _round_significant( + (s2 - s1) / np.round(t2 - t1, _PTT), _PTT + ) + + time = idx / Fs - 0.5 / Fs + (s2 * t1 - t2 * s1) / (s2 - s1) + return time diff --git a/python/bml/utils.py b/python/bml/utils.py new file mode 100644 index 0000000..9973a19 --- /dev/null +++ b/python/bml/utils.py @@ -0,0 +1,251 @@ +"""Utility functions for BML. + +Python translations of MATLAB functions from the BML toolbox utils/ and +signal/ directories. +""" + +import json +import warnings + +import numpy as np + + +def getopt(cfg, key, default=None, emptymeaningful=False): + """Get a value from a configuration dict. + + Translated from ``bml_getopt.m``. + + Parameters + ---------- + cfg : dict or None + Configuration mapping. ``None`` or an empty dict is treated as + an absent configuration and *default* is returned. + key : str + The key to look up. + default : object, optional + Value returned when *key* is not present or (unless + *emptymeaningful* is ``True``) when the stored value is ``None``. + emptymeaningful : bool, optional + When ``False`` (the default) a ``None`` value is replaced by + *default*. Set to ``True`` to allow ``None`` through. + + Returns + ------- + object + The looked-up value, or *default*. + """ + if cfg is None or (isinstance(cfg, dict) and len(cfg) == 0): + val = default + elif isinstance(cfg, dict): + val = cfg.get(key, default) + else: + raise TypeError( + f"cfg must be a dict or None, got {type(cfg).__name__}" + ) + + if val is None and default is not None and not emptymeaningful: + val = default + + return val + + +def map_values(element, domain, codomain, non_domain=None): + """Map elements from *domain* to *codomain*. + + Translated from ``bml_map.m``. + + For each item in *element*, its position in *domain* is found and the + corresponding *codomain* value is returned. + + Parameters + ---------- + element : list or numpy.ndarray + Values to map. + domain : list or numpy.ndarray + Known input values. + codomain : list or numpy.ndarray + Corresponding output values (same length as *domain*). + non_domain : object, optional + Value used for elements that are not found in *domain*. If + ``None`` (the default) a ``ValueError`` is raised for missing + elements. + + Returns + ------- + list or numpy.ndarray + Mapped values. A ``numpy.ndarray`` is returned when *element* + is a ``numpy.ndarray`` and *codomain* is also a + ``numpy.ndarray``; otherwise a ``list``. + + Raises + ------ + ValueError + If *domain* and *codomain* have different lengths, or if an + element is not found in *domain* and *non_domain* is ``None``. + """ + if len(domain) != len(codomain): + raise ValueError( + "domain and codomain must have the same length" + ) + + # Build a lookup: domain value -> first codomain value + # (mirrors MATLAB find(..., 1) which returns the first match) + lookup = {} + for d, c in zip(domain, codomain): + if d not in lookup: + lookup[d] = c + + use_array = isinstance(element, np.ndarray) and isinstance( + codomain, np.ndarray + ) + + mapped = [] + for e in element: + if e in lookup: + mapped.append(lookup[e]) + elif non_domain is not None: + mapped.append(non_domain) + else: + raise ValueError( + f"element {e!r} not found in domain and no non_domain default given" + ) + + if use_array: + return np.array(mapped) + return mapped + + +def getidx(element, collection): + """Get first indices of *element* values in *collection*. + + Translated from ``bml_getidx.m``. + + Parameters + ---------- + element : list or numpy.ndarray + Values to locate. + collection : list or numpy.ndarray + The collection to search in. + + Returns + ------- + list of int + For each item in *element*, the 0-based index of its first + occurrence in *collection*, or ``-1`` if not found. + + Notes + ----- + The MATLAB version returns 1-based indices with 0 for "not found". + This Python version uses the conventional 0-based indexing with + ``-1`` for "not found". + """ + # Convert to list for uniform handling + col_list = list(collection) + + indices = [] + for e in element: + try: + indices.append(col_list.index(e)) + except ValueError: + indices.append(-1) + + return indices + + +def readjson(filename): + """Read a JSON file and return its parsed contents. + + Translated from ``readjson.m``. + + Parameters + ---------- + filename : str or path-like + Path to the JSON file. + + Returns + ------- + object + The decoded JSON data (typically a ``dict`` or ``list``). + """ + with open(filename, "r") as fid: + return json.load(fid) + + +def _round_sigfigs(x, sigfigs): + """Round *x* to *sigfigs* significant figures. + + Equivalent to MATLAB ``round(x, sigfigs, 'signif')``. + """ + if x == 0: + return 0.0 + magnitude = int(np.floor(np.log10(abs(x)))) + return round(x, -magnitude + (sigfigs - 1)) + + +def getFs(raw, cfg=None): + """Return the sampling frequency of a raw data structure. + + Translated from ``bml_getFs.m``. + + Parameters + ---------- + raw : dict + Raw data structure with a ``'time'`` key whose value is a list + (or other iterable) of 1-D array-like time vectors, one per + trial. + cfg : dict or None, optional + Configuration dict. Recognised keys: + + * ``timetol`` – absolute time tolerance in seconds + (default ``1e-9``). + * ``reltimetol`` – relative time tolerance + (default ``1e-4``). + * ``freqsignif`` – number of significant figures for rounding + the sampling frequency (default ``4``). + + Returns + ------- + float + Estimated sampling frequency in Hz rounded to *freqsignif* + significant figures. + """ + timetol = getopt(cfg, "timetol", 1e-9) + reltimetol = getopt(cfg, "reltimetol", 1e-4) + freqsignif = getopt(cfg, "freqsignif", 4) + + trials = raw["time"] + n_trials = len(trials) + + median_dt = np.full(n_trials, np.nan) + timetol_offenders = [] + reltimetol_offenders = [] + + for t in range(n_trials): + dts = np.diff(np.asarray(trials[t], dtype=float)) + median_dt[t] = np.median(dts) + dt_range = np.ptp(dts) # equivalent to MATLAB range() + if dt_range > timetol: + timetol_offenders.append(t) + if median_dt[t] != 0 and dt_range / median_dt[t] > reltimetol: + reltimetol_offenders.append(t) + + if timetol_offenders: + warnings.warn( + f"trials {timetol_offenders} don't comply with timetol of {timetol}" + ) + if reltimetol_offenders: + warnings.warn( + f"trials {reltimetol_offenders} don't comply with reltimetol of {reltimetol}" + ) + + mean_median_dt = np.mean(median_dt) + + # Check across-trial consistency + if n_trials > 1: + cross_range = np.ptp(median_dt) + if cross_range > timetol: + warnings.warn("timetol violated across trials") + if mean_median_dt != 0 and cross_range / mean_median_dt > reltimetol: + warnings.warn("reltimetol violated across trials") + + return _round_sigfigs(1.0 / mean_median_dt, freqsignif) diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 0000000..8d94708 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=78.1.1", "wheel"] +build-backend = "setuptools.backends._legacy:_Backend" + +[project] +name = "bml" +version = "0.1.0" +description = "Brain Modulation Lab toolbox - Python translation of MATLAB BML toolbox" +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "numpy>=1.21", + "pandas>=1.3", + "scipy>=1.8", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", +] diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..e69de29