Skip to content

Commit

Permalink
add CN plotting functionality + further breakpoint modularization
Browse files Browse the repository at this point in the history
  • Loading branch information
suhas-r committed Jan 5, 2025
1 parent 0ea08fa commit a70bfc8
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 196 deletions.
225 changes: 94 additions & 131 deletions coral/breakpoint/breakpoint_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
AmpliconInterval,
BPReads,
Breakpoint,
BreakpointStats,
ChimericAlignment,
ChrPairOrientation,
CNSInterval,
Interval,
ReadInterval,
Strand,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -288,17 +290,17 @@ def alignment2bp_l(
return bp_list


def filter_low_support_breakpoints(
chr_to_cns_to_reads: dict[str, dict[int, set[str]]], min_support: int
def filter_small_breakpoint_clusters(
chr_to_cns_to_reads: dict[str, dict[int, set[str]]], min_cluster_cutoff: int
) -> dict[str, dict[int, set[str]]]:
"""
Initial filtering of potential breakpoints of insufficient support.
Fltering potential breakpoint clusters based on min # of supporting reads.
"""
chr_to_cns_to_reads_filtered: dict[str, dict[int, set[str]]] = dict()
for chr_, cns_to_reads in chr_to_cns_to_reads.items():
chr_to_cns_to_reads_filtered[chr_] = dict()
for cn, reads in cns_to_reads.items():
if len(reads) >= min_support:
if len(reads) >= min_cluster_cutoff:
chr_to_cns_to_reads_filtered[chr_][cn] = reads
return chr_to_cns_to_reads_filtered

Expand All @@ -315,7 +317,7 @@ def cluster_bp_list(
ChrPairOrientation(bp.chr1, bp.chr2, bp.strand1, bp.strand2)
].append(bpi)

bp_clusters = []
bp_clusters: list[list[Breakpoint]] = []
for chr_pair, bp_idxs in bp_dict.items():
if len(bp_idxs) >= min_cluster_size:
bp_clusters_: list[list[Breakpoint]] = []
Expand Down Expand Up @@ -387,152 +389,113 @@ def interval2bp(
)


def bpc2bp(bp_cluster, bp_distance_cutoff):
def bpc2bp(bp_cluster: list[Breakpoint], bp_distance_cutoff: float):
"""
Call exact breakpoint from a breakpoint cluster
"""
bp = bp_cluster[0][:-2]
bp[1] = 0 if bp[2] == "+" else 1_000_000_000
bp[4] = 0 if bp[5] == "+" else 1_000_000_000
bpr = []
bp_stats = [0, 0, 0, 0]
bp_stats_ = [0, 0, 0, 0, 0, 0]
bp: Breakpoint = bp_cluster[0] # [:-2]
bp.start = 0 if bp.strand1 == Strand.FORWARD else 1_000_000_000
bp.end = 0 if bp.strand2 == Strand.FORWARD else 1_000_000_000
bpr: list[BPReads] = []

# Calculate basic dist. stats (mean/std) for breakpoints
bp_stats = BreakpointStats(bp_distance_cutoff)
for bp_ in bp_cluster:
bp_stats[0] += bp_[1]
bp_stats[2] += bp_[1] * bp_[1]
bp_stats[1] += bp_[4]
bp_stats[3] += bp_[4] * bp_[4]
for i in range(4):
bp_stats[i] /= len(bp_cluster) * 1.0
try:
bp_stats[2] = max(
bp_distance_cutoff / 2.99,
np.sqrt(bp_stats[2] - bp_stats[0] * bp_stats[0]),
)
except:
bp_stats[2] = bp_distance_cutoff / 2.99
try:
bp_stats[3] = max(
bp_distance_cutoff / 2.99,
np.sqrt(bp_stats[3] - bp_stats[1] * bp_stats[1]),
)
except:
bp_stats[3] = bp_distance_cutoff / 2.99
bp1_list = []
bp4_list = []
bp_stats.observe(bp_)

bp_starts = []
bp_ends = []
for bp_ in bp_cluster:
if bp_.start > bp_stats.start.mean + 3 * bp_stats.start_window:
continue
if bp_.start < bp_stats.start.mean - 3 * bp_stats.start_window:
continue
if bp_.end > bp_stats.end.mean + 3 * bp_stats.end_window:
continue
if bp_.end < bp_stats.end.mean - 3 * bp_stats.end_window:
continue
bp_starts.append(bp_.start)
bp_ends.append(bp_.end)

if len(bp_starts) > 0:
bp_start_ctr = Counter(bp_starts)
if (
bp_[1] <= bp_stats[0] + 3 * bp_stats[2]
and bp_[1] >= bp_stats[0] - 3 * bp_stats[2]
and bp_[4] <= bp_stats[1] + 3 * bp_stats[3]
and bp_[4] >= bp_stats[1] - 3 * bp_stats[3]
):
bp1_list.append(bp_[1])
bp4_list.append(bp_[4])
# if (bp_[2] == '+' and bp_[1] > bp[1]) or (bp_[2] == '-' and bp_[1] < bp[1]):
# bp[1] = bp_[1]
# if (bp_[5] == '+' and bp_[4] > bp[4]) or (bp_[5] == '-' and bp_[4] < bp[4]):
# bp[4] = bp_[4]
if len(bp1_list) > 0:
bp1_counter = Counter(bp1_list)
if (
len(bp1_counter.most_common(2)) == 1
or bp1_counter.most_common(2)[0][1]
> bp1_counter.most_common(2)[1][1]
len(bp_start_ctr.most_common(2)) == 1
or bp_start_ctr.most_common(2)[0][1]
> bp_start_ctr.most_common(2)[1][1]
):
bp[1] = bp1_counter.most_common(2)[0][0]
elif len(bp1_list) % 2 == 1:
bp[1] = int(np.median(bp1_list))
elif bp_[2] == "+":
bp[1] = int(np.ceil(np.median(bp1_list)))
bp.start = bp_start_ctr.most_common(2)[0][0]
elif len(bp_starts) % 2 == 1:
bp.start = int(np.median(bp_starts))
elif bp_.strand1.is_forward:
bp.start = int(np.ceil(np.median(bp_starts)))
else:
bp[1] = int(np.floor(np.median(bp1_list)))
if len(bp4_list) > 0:
bp4_counter = Counter(bp4_list)
bp.start = int(np.floor(np.median(bp_starts)))
if len(bp_ends) > 0:
bp_end_ctr = Counter(bp_ends)
if (
len(bp4_counter.most_common(2)) == 1
or bp4_counter.most_common(2)[0][1]
> bp4_counter.most_common(2)[1][1]
len(bp_end_ctr.most_common(2)) == 1
or bp_end_ctr.most_common(2)[0][1] > bp_end_ctr.most_common(2)[1][1]
):
bp[4] = bp4_counter.most_common(2)[0][0]
elif len(bp4_list) % 2 == 1:
bp[4] = int(np.median(bp4_list))
elif bp_[5] == "+":
bp[4] = int(np.ceil(np.median(bp4_list)))
bp.end = bp_end_ctr.most_common(2)[0][0]
elif len(bp_ends) % 2 == 1:
bp.end = int(np.median(bp_ends))
elif bp_.strand2.is_forward:
bp.end = int(np.ceil(np.median(bp_ends)))
else:
bp[4] = int(np.floor(np.median(bp4_list)))
bp.end = int(np.floor(np.median(bp_ends)))
bp_cluster_r = []

final_bp_stats = BreakpointStats(bp_distance_cutoff)
for bp_ in bp_cluster:
if bp_match(
bp_, bp, bp_[7] * 1.2, [bp_distance_cutoff, bp_distance_cutoff]
):
bpr.append(bp_[6])
bp_stats_[0] += bp_[1]
bp_stats_[2] += bp_[1] * bp_[1]
bp_stats_[1] += bp_[4]
bp_stats_[3] += bp_[4] * bp_[4]
if bp_[-3] == 0:
bp_stats_[4] += bp_[-2]
bp_stats_[5] += bp_[-1]
else:
bp_stats_[4] += bp_[-1]
bp_stats_[5] += bp_[-2]
if bp_match(bp_, bp, bp_.gap * 1.2, bp_distance_cutoff):
bpr.append(bp_.read_info)
final_bp_stats.observe(bp)
else:
bp_cluster_r.append(bp_)
if len(bpr) == 0:
return bp, bpr, [0, 0, 0, 0, 0, 0], []
for i in range(6):
bp_stats_[i] /= len(bpr) * 1.0
# print (bp_stats_)
try:
bp_stats_[2] = np.sqrt(bp_stats_[2] - bp_stats_[0] * bp_stats_[0])
except:
bp_stats_[2] = 0
try:
bp_stats_[3] = np.sqrt(bp_stats_[3] - bp_stats_[1] * bp_stats_[1])
except:
bp_stats_[3] = 0
return bp, bpr, bp_stats_, bp_cluster_r


def bp_match(bp1, bp2, rgap, bp_distance_cutoff):
return bp, bpr, final_bp_stats, []
return bp, bpr, final_bp_stats, bp_cluster_r


def bp_match(
bp1: Breakpoint, bp2: Breakpoint, rgap: float, bp_dist_cutoff: float
):
"""
Check if two breakpoints match
A breakpoint (chr1, e1, chr2, s2) must either satisfy chr1 > chr2 or chr1 == chr2 and e1 >= s2
"""
if (
bp1[0] == bp2[0]
and bp1[3] == bp2[3]
and bp1[2] == bp2[2]
and bp1[5] == bp2[5]
):
if rgap <= 0:
return (
abs(int(bp1[1]) - int(bp2[1])) < bp_distance_cutoff[0]
and abs(int(bp1[4]) - int(bp2[4])) < bp_distance_cutoff[1]
)
rgap_ = rgap
consume_rgap = [0, 0]
if bp1[2] == "+" and int(bp1[1]) <= int(bp2[1]) - bp_distance_cutoff[0]:
rgap_ -= int(bp2[1]) - bp_distance_cutoff[0] - int(bp1[1]) + 1
consume_rgap[0] = 1
if bp1[2] == "-" and int(bp1[1]) >= int(bp2[1]) + bp_distance_cutoff[0]:
rgap_ -= int(bp1[1]) - int(bp2[1]) - bp_distance_cutoff[0] + 1
consume_rgap[0] = 1
if bp1[5] == "+" and int(bp1[4]) <= int(bp2[4]) - bp_distance_cutoff[1]:
rgap_ -= int(bp2[4]) - bp_distance_cutoff[1] - int(bp1[4]) + 1
consume_rgap[1] = 1
if bp1[5] == "-" and int(bp1[4]) >= int(bp2[4]) + bp_distance_cutoff[1]:
rgap_ -= int(bp1[4]) - int(bp2[4]) - bp_distance_cutoff[1] + 1
consume_rgap[1] = 1
if bp1.chr1 != bp2.chr1 or bp1.chr2 != bp2.chr2:
return False
if bp1.strand1 != bp2.strand2 or bp1.strand2 != bp2.strand1:
return False
if rgap <= 0:
return (
(consume_rgap[0] == 1 and rgap_ >= 0)
or (abs(int(bp1[1]) - int(bp2[1])) < bp_distance_cutoff[0])
) and (
(consume_rgap[1] == 1 and rgap_ >= 0)
or (abs(int(bp1[4]) - int(bp2[4])) < bp_distance_cutoff[1])
abs(bp1.start - bp2.start) < bp_dist_cutoff
and abs(bp1.end - bp2.end) < bp_dist_cutoff
)
return False

rgap_ = rgap
consume_rgap = [0, 0]
if bp1.strand1.is_forward and bp1.start <= bp2.start - bp_dist_cutoff:
rgap_ -= bp2.start - bp_dist_cutoff - bp1.start + 1
consume_rgap[0] = 1
if bp1.strand1.is_reverse and bp1.start >= bp2.start + bp_dist_cutoff:
rgap_ -= bp1.start - bp2.start - bp_dist_cutoff + 1
consume_rgap[0] = 1
if bp1.strand2.is_forward and bp1.end <= bp2.end - bp_dist_cutoff:
rgap_ -= bp2.end - bp_dist_cutoff - bp1.end + 1
consume_rgap[1] = 1
if bp1.strand2.is_reverse and bp1.end >= bp2.end + bp_dist_cutoff:
rgap_ -= bp1.end - bp2.end - bp_dist_cutoff + 1
consume_rgap[1] = 1
return (
(consume_rgap[0] == 1 and rgap_ >= 0)
or (abs(bp1.start - bp2.start)) < bp_dist_cutoff
) and (
(consume_rgap[1] == 1 and rgap_ >= 0)
or (abs(bp1.end - bp2.end)) < bp_dist_cutoff
)


def sort_chrom_names(chromlist: List[str]) -> List[str]:
Expand Down Expand Up @@ -777,10 +740,10 @@ def output_breakpoint_info_lr(g: BreakpointGraph, filename: str, bp_stats):
# TODO: further modularize the utilities in this file
def fetch_breakpoint_reads(
bam_file: pysam.AlignmentFile,
) -> tuple[dict[str, list[ChimericAlignment]], datatypes.EditDistanceStats]:
) -> tuple[dict[str, list[ChimericAlignment]], datatypes.BasicStatTracker]:
read_name_to_length: dict[str, int] = {}
chimeric_strings: dict[str, list[str]] = defaultdict(list)
edit_dist_stats = datatypes.EditDistanceStats()
edit_dist_stats = datatypes.BasicStatTracker()

for read in bam_file.fetch():
rn: str = read.query_name # type: ignore[assignment]
Expand Down
Loading

0 comments on commit a70bfc8

Please sign in to comment.