#!/usr/bin/env python3
"""
detect_invisible_unicode_improved.py

Improvements in this version:
 - Always shows the list of scanned files (even if no detections).
 - Adds --show-scanned-limit to cap how many scanned paths to print (default 200).
 - Adds --show-all-scanned to force printing all scanned files (careful with huge repos).
 - Keeps previous features:
    * highlighted snippet, context window, snippet cap
    * per-line / per-file / grand totals
    * --fix (creates .bak), --json output
Usage examples:
  ./detect_invisible_unicode_improved.py .                      # run scan
  ./detect_invisible_unicode_improved.py --context 20 .        # larger context
  ./detect_invisible_unicode_improved.py --show-all-scanned .  # print every scanned file
  ./detect_invisible_unicode_improved.py --ext ".js,.ts" .     # limit extensions
"""

import sys
import os
import argparse
import unicodedata
import json
from collections import defaultdict

# ANSI color helpers (works in macOS Terminal / iTerm)
ANSI_BOLD = '\033[1m'
ANSI_RESET = '\033[0m'
ANSI_RED = '\033[31m'
ANSI_YELLOW_BG = '\033[43m'
ANSI_BLACK = '\033[30m'
ANSI_CYAN = '\033[36m'

def color_marker(text):
    # yellow background with black text for high visibility
    return f"{ANSI_YELLOW_BG}{ANSI_BLACK}{text}{ANSI_RESET}"

# common invisible codepoints to highlight (set of ints)
COMMON_INVIS = {
    0x200B, # ZERO WIDTH SPACE
    0x200C, # ZERO WIDTH NON-JOINER
    0x200D, # ZERO WIDTH JOINER
    0xFEFF, # ZERO WIDTH NO-BREAK SPACE (BOM)
}
COMMON_INVIS.update(range(0xFE00, 0xFE10))  # variation selectors

DEFAULT_SKIP_EXTS = {
    '.png', '.jpg', '.jpeg', '.gif', '.ico', '.pdf', '.zip', '.tar', '.gz',
    '.woff', '.woff2', '.ttf', '.otf', '.exe', '.dll', '.so', '.dylib', '.class',
}

def is_invisible_char(ch):
    # Returns True when ch is considered an invisible/formatting character of interest
    try:
        cp = ord(ch)
    except TypeError:
        return False
    cat = unicodedata.category(ch)
    if cat == 'Cf':
        return True
    if cp in COMMON_INVIS:
        return True
    return False

def make_snippet(line, idx, context, max_snippet_len):
    """
    Build a short snippet around a matched invisible character.
    Returns snippet (with ANSI marker), start_col (1-based), unicode name and codepoint.
    """
    start = max(0, idx - context)
    end = min(len(line), idx + context + 1)
    left = line[start:idx]
    ch = line[idx]
    right = line[idx+1:end]
    cp = ord(ch)
    try:
        name = unicodedata.name(ch)
    except ValueError:
        name = '<unassigned>'
    marker_text = f"U+{cp:04X}"
    marker = color_marker(f"⟪{marker_text}⟫")
    # escape tabs/newlines for display
    def esc(s):
        return s.replace('\t', '\\t').replace('\r','').replace('\n','')
    snippet = f"{esc(left)}{marker}{esc(right)}"
    if len(snippet) > max_snippet_len:
        half = max_snippet_len // 2
        snippet = snippet[:half] + "…" + snippet[-half:]
    col_display = start + 1
    return snippet, col_display, name, f"U+{cp:04X}"

def find_in_file(path, relpath, context=10, max_snippet_len=120, fix=False):
    """
    Scan a single file and return (findings_list, invisible_count_in_file).
    If fix=True, remove invisible chars from the file (write backup .bak).
    """
    findings = []
    try:
        with open(path, 'r', encoding='utf-8', errors='replace') as f:
            lines = f.readlines()
    except Exception:
        return findings, 0  # unreadable/binary -> skip

    total_invisible_in_file = 0
    new_lines = None

    for (i, raw_line) in enumerate(lines, start=1):
        line = raw_line.rstrip('\n')
        invisible_positions = [j for j, ch in enumerate(line) if is_invisible_char(ch)]
        if not invisible_positions:
            continue
        total_invisible_in_file += len(invisible_positions)
        for pos in invisible_positions:
            snippet, col_display, name, codepoint = make_snippet(line, pos, context, max_snippet_len)
            findings.append({
                'file': relpath,
                'line': i,
                'col': pos + 1,
                'codepoint': codepoint,
                'name': name,
                'snippet': snippet,
            })
        if fix:
            if new_lines is None:
                new_lines = list(lines)
            cleaned = ''.join(ch for ch in line if not is_invisible_char(ch))
            new_lines[i-1] = cleaned + '\n'

    if fix and new_lines is not None:
        backup = path + '.bak'
        try:
            os.rename(path, backup)
            with open(path, 'w', encoding='utf-8') as wf:
                wf.writelines(new_lines)
            print(f"{ANSI_CYAN}[FIX]{ANSI_RESET} sanitized file written: {path} (original backed up as {backup})")
        except Exception as e:
            print(f"[ERROR] failed to write sanitized file {path}: {e}", file=sys.stderr)

    return findings, total_invisible_in_file

def iter_files(root, allowed_exts=None, skip_exts=None):
    scanned = []
    for dirpath, dirnames, filenames in os.walk(root):
        # skip .git directory by default
        if '.git' in dirpath.split(os.sep):
            continue
        # skip node_modules directory by default
        if 'node_modules' in dirpath.split(os.sep):
            continue
        for fn in filenames:
            _, ext = os.path.splitext(fn.lower())
            if skip_exts and ext in skip_exts:
                continue
            if allowed_exts:
                if ext not in allowed_exts:
                    continue
            path = os.path.join(dirpath, fn)
            scanned.append(path)
    return scanned

def print_scanned_list(scanned_paths, root, limit=200, show_all=False):
    rels = [os.path.relpath(p, root) for p in scanned_paths]
    total = len(rels)
    print(ANSI_BOLD + f"\nScanned files: {total}" + ANSI_RESET)
    if total == 0:
        print(" (no files found under the given path / extension filters)")
        return
    if show_all or total <= limit:
        for r in rels:
            print("  " + r)
    else:
        for r in rels[:limit]:
            print("  " + r)
        print(f"  ... (only showing first {limit} files). Use --show-all-scanned to list all.)")

def main():
    p = argparse.ArgumentParser(description='Detect invisible Unicode characters (improved, show scanned files).')
    p.add_argument('path', nargs='?', default='.', help='path to repository (default: .)')
    p.add_argument('--context', type=int, default=10, help='chars before/after to show (default 10)')
    p.add_argument('--max-snippet', type=int, default=120, help='cap snippet length (default 120)')
    p.add_argument('--ext', help='comma separated list of extensions to INCLUDE (e.g. .js,.ts,.json). If omitted, all text files are scanned except common binary exts.')
    p.add_argument('--skip', help='comma separated list of extensions to SKIP (overrides defaults)', default=','.join(sorted(DEFAULT_SKIP_EXTS)))
    p.add_argument('--fix', action='store_true', help='attempt to remove detected invisible characters (creates .bak of each changed file). Use cautiously.')
    p.add_argument('--json', action='store_true', help='output JSON lines instead of human readable')
    p.add_argument('--show-scanned-limit', type=int, default=200, help='how many scanned file paths to print (default 200)')
    p.add_argument('--show-all-scanned', action='store_true', help='print every scanned file path (can be very large)')
    args = p.parse_args()

    root = args.path
    allowed_exts = None
    if args.ext:
        allowed_exts = set([e.strip().lower() if e.strip().startswith('.') else '.'+e.strip().lower() for e in args.ext.split(',') if e.strip()])

    skip_exts = set([e.strip().lower() if e.strip().startswith('.') else '.'+e.strip().lower() for e in args.skip.split(',') if e.strip()])

    # get scanned file list first (so we can always display it)
    scanned_paths = iter_files(root, allowed_exts=allowed_exts, skip_exts=skip_exts)
    print_scanned_list(scanned_paths, root, limit=args.show_scanned_limit, show_all=args.show_all_scanned)

    total_counts_by_file = defaultdict(int)
    total_findings = []
    grand_total = 0

    # Now actually scan each file
    for path in scanned_paths:
        rel = os.path.relpath(path, root)
        findings, count_in_file = find_in_file(path, rel, context=args.context, max_snippet_len=args.max_snippet, fix=args.fix)
        if findings:
            total_findings.extend(findings)
            total_counts_by_file[rel] += count_in_file
            grand_total += count_in_file

    # Print findings (grouped per file)
    if not args.json:
        if grand_total == 0:
            print("\n" + ANSI_BOLD + "No invisible Unicode characters found." + ANSI_RESET)
        else:
            print("\n" + ANSI_BOLD + "Detections" + ANSI_RESET)
            for rel in sorted(total_counts_by_file.keys()):
                cnt = total_counts_by_file[rel]
                print(f"\n{ANSI_BOLD}-- {rel} --{ANSI_RESET}  (invisible chars: {cnt})")
                # print per-line grouping
                by_line = defaultdict(list)
                for f in total_findings:
                    if f['file'] == rel:
                        by_line[f['line']].append(f)
                for line_no in sorted(by_line.keys()):
                    group = by_line[line_no]
                    print(f" Line {line_no}  (count: {len(group)})")
                    for f in group:
                        print(f"  Col {f['col']:>3} {f['codepoint']} {f['name']}")
                        print(f"    {f['snippet']}")
    else:
        # JSON-output: emit each finding then a summary object
        for f in total_findings:
            # snippet contains ANSI coloring — remove for JSON
            snippet_plain = f['snippet']
            # strip ANSI sequences (simple removal)
            import re
            ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
            snippet_plain = ansi_escape.sub('', snippet_plain)
            out = dict(f)
            out['snippet'] = snippet_plain
            print(json.dumps(out, ensure_ascii=False))
        summary = {
            'total_files_scanned': len(scanned_paths),
            'total_files_with_findings': len(total_counts_by_file),
            'per_file_counts': dict(total_counts_by_file),
            'grand_total': grand_total,
        }
        print(json.dumps({'summary': summary}, ensure_ascii=False))

    # final compact summary
    if not args.json:
        print("\n" + ANSI_BOLD + "Summary" + ANSI_RESET)
        print(f" Scanned files: {len(scanned_paths)}")
        print(f" Files with invisible chars: {len(total_counts_by_file)}")
        print(f" Total invisible characters found: {grand_total}")
    return 0

if __name__ == '__main__':
    sys.exit(main())

