import cv2
import numpy as np
import easyocr
import re

class OmrEngine:
    def __init__(self, reader=None):
        if reader:
            self.reader = reader
        else:
            # Fallback if not injected (though injection is preferred for performance)
            self.reader = easyocr.Reader(['fa', 'en'], gpu=True)

    def process_exam(self, image_path, options_count=5):
        """
        Main entry point to process the exam sheet.
        """
        try:
            image = cv2.imread(image_path)
            if image is None:
                return {'status': 'error', 'message': 'Image not found'}

            # 1. Standardize Image Size (Width 1000px)
            target_width = 1000
            h, w = image.shape[:2]
            scale = target_width / w
            target_height = int(h * scale)
            image = cv2.resize(image, (target_width, target_height))

            # 2. Extract National ID from FULL HEADER
            # EasyOCR can find the National ID when reading the full header
            header_y2 = int(target_height * 0.25)
            header_roi = image[0:header_y2, :]
            
            # Task 1: Handle shifting layout (Long names push ID down)
            # Strategy: Try standard header first, then extended header.
            national_id_result = self.extract_national_id(header_roi)
            
            body_start_ratio = 0.26 # Default body start
            
            # Attempt 2: Extended Header (Top 35%) if ID not found
            if not national_id_result['id']:
                print("⚠️ ID not found in standard header. Trying extended ROI (35%)...")
                header_y2_extended = int(target_height * 0.35)
                header_roi_extended = image[0:header_y2_extended, :]
                national_id_result = self.extract_national_id(header_roi_extended)
                
                if national_id_result['id']:
                    print("✅ ID found in extended ROI. Shifting body start down.")
                    # Task 1: If ID is lower, questions are also lower.
                    body_start_ratio = 0.36 
            
            national_id = national_id_result['id']

            # 3. Detect Answers
            body_start = int(target_height * body_start_ratio)
            body_roi = image[body_start:, :]
            
            # Task 3: Pass detection of noise/skew
            answers_result = self.detect_answers(body_roi, options_count)
            answers = answers_result['answers']
            
            # Task 5: Compile Suspicious Flags
            warnings = []
            is_suspicious = False
            
            if national_id_result['is_fixed']:
                warnings.append("National ID was auto-corrected (checksum error).")
                is_suspicious = True
            
            if not national_id:
                warnings.append("National ID could not be read.")
                is_suspicious = True
                
            if answers_result.get('is_noisy'):
                warnings.append("High noise detected in answer sheet.")
                is_suspicious = True
                
            if answers_result.get('skew_detected'):
                warnings.append("Potential skew detected.")
                is_suspicious = True

            return {
                'status': 'success',
                'national_id': national_id,
                'answers': answers,
                'warnings': warnings,
                'is_suspicious': is_suspicious,
                'debug_info': 'Processed successfully'
            }

        except Exception as e:
            return {'status': 'error', 'message': str(e)}

    def extract_national_id(self, image_roi):
        """
        Extracts 10-digit Persian National ID.
        Strategy:
        1. Preprocess (Sharpen/Gray).
        2. Try OCR on full header.
        3. If failed, try OCR on sliding window strips (Line-by-Line).
        4. If found candidates but invalid checksum, try smart fixes.
        """
        try:
            print("Reading header for National ID...")
            
            # Common Preprocessing
            gray = cv2.cvtColor(image_roi, cv2.COLOR_BGR2GRAY)
            kernel_sharpen = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]])
            sharpened = cv2.filter2D(gray, -1, kernel_sharpen)
            preprocessed_full = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2BGR)
            
            all_candidates = []

            # --- Method 1: Full Header Scan ---
            print("  [Method 1] Full Header Scan")
            results = self.reader.readtext(preprocessed_full)
            self._process_ocr_results(results, all_candidates)
            
            # Check if we already found a valid ID
            valid_id = self._find_valid_id(all_candidates)
            if valid_id:
                return {'id': valid_id, 'is_fixed': False}

            # --- Method 2: Line-by-Line Strip Scan ---
            # Often standard OCR fails on the full image if there's noise or complex layout.
            # Scanning strips helps focus the attention.
            print("  [Method 2] Line-by-Line Strip Scan")
            h, w = image_roi.shape[:2]
            strip_height = 100 # Approx height of a text line with margin
            step = 50          # 50% overlap
            
            for y in range(0, h - strip_height + 1, step):
                strip = preprocessed_full[y:y+strip_height, :]
                # Use a slightly different config or just standard
                strip_results = self.reader.readtext(strip)
                self._process_ocr_results(strip_results, all_candidates)
                
                # Check immediately to save time
                valid_id = self._find_valid_id(all_candidates)
                if valid_id:
                    print(f"    ✓ Found ID in strip y={y}")
                    return {'id': valid_id, 'is_fixed': False}

            # --- Method 3: Smart Fixes on all collected candidates ---
            print("\n--- No valid ID found, trying smart fixes (5→0 priority) ---")
            
            # Sort candidates by confidence
            all_candidates.sort(key=lambda x: x[1], reverse=True)
            
            # Deduplicate candidates
            seen_candidates = set()
            unique_candidates = []
            for c, p in all_candidates:
                if c not in seen_candidates:
                    unique_candidates.append((c, p))
                    seen_candidates.add(c)
            
            for candidate, prob in unique_candidates:
                print(f"Trying to fix: {candidate} (confidence: {prob:.2f})")
                fixed = self.try_fix_candidate(candidate)
                if fixed:
                    print(f"    ✓✓✓ FIXED ID: {fixed}")
                    return {'id': fixed, 'is_fixed': True}
            
            # --- Method 4: Return Best Guess (Task 2 Requirement) ---
            # If no valid ID, return the most likely 10-digit number anyway
            # so the Admin can see it and manually correct it if needed.
            # We filter for strings that are exactly 10 digits first.
            best_guess = None
            for candidate, prob in unique_candidates:
                 if len(candidate) == 10:
                     best_guess = candidate
                     break
            
            if best_guess:
                print(f"Returning best guess (Invalid Checksum): {best_guess}")
                return {'id': best_guess, 'is_fixed': True} # Flag as fixed/suspicious
            
            print("No 10-digit ID found")
            return {'id': None, 'is_fixed': False}

        except Exception as e:
            print(f"ID Extraction Error: {e}")
            import traceback
            traceback.print_exc()
            return {'id': None, 'is_fixed': False}

    def _process_ocr_results(self, results, candidates_list):
        """Helper to process OCR results and append potential IDs to list."""
        for (bbox, text, prob) in results:
            if prob < 0.1: continue
            
            # Normalize
            normalized = self.normalize_digits(text)
            # Clean noise
            cleaned = normalized.replace('|', '').replace('l', '').replace('I', '').replace('O', '').replace('o', '0')
            # Extract digits
            digits_only = re.sub(r'\D', '', cleaned)
            
            if len(digits_only) >= 10:
                # Extract all 10-digit substrings
                for i in range(len(digits_only) - 9):
                    candidate = digits_only[i:i+10]
                    candidates_list.append((candidate, prob))

    def _find_valid_id(self, candidates_list):
        """Helper to check if any candidate in the list is valid."""
        for candidate, prob in candidates_list:
            if self.validate_iranian_national_id(candidate):
                print(f"    ✓✓✓ VALID ID: {candidate}")
                return candidate
        return None



    def try_fix_candidate(self, candidate):
        """
        Fix common OCR errors with optimized speed:
        1. Aggressive 5→0 fixes (most common issue)
        2. Common National ID patterns
        3. Limited single-digit fixes
        """
        if not candidate or len(candidate) != 10:
            return None
        
        # PRIORITY 0: Fast 5→0 replacement (MOST COMMON PERSIAN OCR ERROR)
        positions_with_5 = [i for i, char in enumerate(candidate) if char == '5']
        if positions_with_5:
            # Try replacing each '5' with '0' individually
            for pos in positions_with_5:
                test_id = candidate[:pos] + '0' + candidate[pos+1:]
                if self.validate_iranian_national_id(test_id):
                    print(f"    ✓ Fixed position {pos}: '5' → '0'")
                    return test_id
            
            # Try replacing TWO '5's with '0' (common case)
            if len(positions_with_5) >= 2:
                for i in range(len(positions_with_5)):
                    for j in range(i+1, len(positions_with_5)):
                        test_id = list(candidate)
                        test_id[positions_with_5[i]] = '0'
                        test_id[positions_with_5[j]] = '0'
                        test_id = ''.join(test_id)
                        if self.validate_iranian_national_id(test_id):
                            print(f"    ✓ Fixed positions {positions_with_5[i]},{positions_with_5[j]}: '5' → '0'")
                            return test_id
        
        # PRIORITY 1: Common ID patterns (fast check)
        # Pattern: 8→1 at position 0
        if candidate[0] == '8':
            test_id = '1' + candidate[1:]
            if self.validate_iranian_national_id(test_id):
                print(f"    ✓ Fixed position 0: '8' → '1'")
                return test_id
        
        # Pattern: 0→5 at common positions
        for pos in [1, 7, 8]:
            if pos < len(candidate) and candidate[pos] == '5':
                test_id = candidate[:pos] + '0' + candidate[pos+1:]
                if self.validate_iranian_national_id(test_id):
                    print(f"    ✓ Fixed position {pos}: '5' → '0'")
                    return test_id
        
        # PRIORITY 2: Limited single-digit fixes (faster)
        # Only try most common confusions
        common_fixes = {
            '0': ['5', '8'],
            '1': ['8', '7'],
            '5': ['0', '6'],
            '6': ['9', '5', '0'],
            '8': ['0', '1', '6'],
            '9': ['6'],
        }
        
        for i in range(10):
            original_char = candidate[i]
            if original_char in common_fixes:
                for new_char in common_fixes[original_char]:
                    test_id = candidate[:i] + new_char + candidate[i+1:]
                    if self.validate_iranian_national_id(test_id):
                        print(f"    ✓ Fixed position {i}: '{original_char}' → '{new_char}'")
                        return test_id
        
        print("  ✗ No valid fix found")
        return None

    def normalize_digits(self, text):
        """Converts Persian/Arabic digits to English digits."""
        replacements = {
            '۰': '0', '۱': '1', '۲': '2', '۳': '3', '۴': '4',
            '۵': '5', '۶': '6', '۷': '7', '۸': '8', '۹': '9',
            '٠': '0', '١': '1', '٢': '2', '٣': '3', '٤': '4',
            '٥': '5', '٦': '6', '٧': '7', '٨': '8', '٩': '9'
        }
        for k, v in replacements.items():
            text = text.replace(k, v)
        return text

    def validate_iranian_national_id(self, national_id):
        """
        Validates Iranian National ID using checksum algorithm.
        """
        if not re.match(r'^\d{10}$', national_id):
            return False

        check = int(national_id[9])
        s = sum(int(national_id[x]) * (10 - x) for x in range(9))
        remainder = s % 11

        return (remainder < 2 and check == remainder) or (remainder >= 2 and check + remainder == 11)

    def detect_answers(self, image_roi, options_count=5):
        """
        Detects answer bubbles using HoughCircles (finds both empty and filled).
        Then determines which one is darkest (most filled).
        Returns dict with answers and metadata.
        """
        print("\n=== Starting Answer Detection (HoughCircles Optimized) ===")
        
        result_metadata = {
            'is_noisy': False,
            'skew_detected': False,
            'answers': {}
        }

        h, w = image_roi.shape[:2]
        
        # Preprocessing
        gray = cv2.cvtColor(image_roi, cv2.COLOR_BGR2GRAY)
        
        # Apply median blur to reduce noise
        blurred = cv2.medianBlur(gray, 5)
        
        cv2.imwrite("static/debug_gray.jpg", blurred)
        
        # HoughCircles with parameters tuned for OMR
        circles = cv2.HoughCircles(
            blurred,
            cv2.HOUGH_GRADIENT,
            dp=1,
            minDist=15,      # Minimum distance between circles
            param1=50,       # Canny edge detection threshold
            param2=12,       # Accumulator threshold (lower = more sensitive)
            minRadius=10,    # Minimum radius
            maxRadius=20     # Maximum radius
        )
        
        if circles is None:
            print("No circles detected! Trying with looser parameters...")
            # Try again with even looser parameters
            circles = cv2.HoughCircles(
                blurred,
                cv2.HOUGH_GRADIENT,
                dp=1,
                minDist=12,
                param1=40,
                param2=10,
                minRadius=8,
                maxRadius=22
            )
        
        if circles is None:
            print("Still no circles found!")
            result_metadata['is_noisy'] = True # Or just empty
            return result_metadata
        
        circles = np.round(circles[0, :]).astype("int")
        print(f"Detected {len(circles)} circles")
        
        # Task 3: Noise Detection
        # If too many circles (e.g., > 300 for a 40-question exam with 4 options = 160 bubbles)
        # or too few (< 20), flag as noisy/suspicious.
        # Standard sheet ~100-200 circles.
        if len(circles) > 400 or len(circles) < 20:
             print(f"⚠️ Suspicious number of circles: {len(circles)}")
             result_metadata['is_noisy'] = True
        
        # Calculate intensity for each circle
        bubbles = []
        for (x, y, r) in circles:
            # Create mask
            mask = np.zeros(gray.shape, dtype="uint8")
            cv2.circle(mask, (x, y), r, 255, -1)
            
            # Average intensity (lower = darker = filled)
            mean_intensity = cv2.mean(gray, mask=mask)[0]
            
            bubbles.append((x, y, r, mean_intensity))
        
        # Debug visualization
        debug_img = image_roi.copy()
        for (x, y, r, intensity) in bubbles:
            color = (0, 255, 0) if intensity > 150 else (0, 0, 255)  # Green if empty, Red if filled
            cv2.circle(debug_img, (x, y), r, color, 2)
            cv2.putText(debug_img, f"{int(intensity)}", (x-10, y+5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1)
        cv2.imwrite("static/debug_circles.jpg", debug_img)
        print("Saved: static/debug_circles.jpg")
        
        # Split into columns
        mid_x = w // 2
        right_bubbles = [b for b in bubbles if b[0] > mid_x]
        left_bubbles = [b for b in bubbles if b[0] <= mid_x]
        
        print(f"Right column: {len(right_bubbles)} circles")
        print(f"Left column: {len(left_bubbles)} circles")
        
        detected_answers = {}
        
        def process_column(bubbles_list, start_q_num):
            if not bubbles_list:
                print("  No bubbles in this column!")
                return
            
            # Sort by Y
            bubbles_list.sort(key=lambda b: b[1])
            
            # Group into rows
            rows = []
            current_row = [bubbles_list[0]]
            
            for i in range(1, len(bubbles_list)):
                prev_y = current_row[-1][1]
                curr_y = bubbles_list[i][1]
                
                if abs(curr_y - prev_y) < 25:
                    current_row.append(bubbles_list[i])
                else:
                    rows.append(current_row)
                    current_row = [bubbles_list[i]]
            rows.append(current_row)
            
            print(f"\n  Grouped into {len(rows)} rows")
            
            q_num = start_q_num
            for row_idx, row in enumerate(rows):
                # Sort by X
                row.sort(key=lambda b: b[0])
                
                # Task 3: Skew Detection (Simple)
                # Check if Y coordinates in a row vary significantly
                y_coords = [b[1] for b in row]
                if max(y_coords) - min(y_coords) > 10:
                     # If deviation > 10px in a single row, it might be skewed
                     result_metadata['skew_detected'] = True
                
                print(f"  Row {row_idx+1}: {len(row)} circles")
                
                # Remove rightmost if more than options_count (question number box)
                while len(row) > options_count:
                    removed = row.pop()
                    print(f"    Removed extra circle at X={removed[0]}")
                
                # Skip if not enough
                if len(row) < options_count:
                    print(f"    Skipping (not enough circles)")
                    continue
                
                # Reverse for RTL
                row.reverse()
                
                # Find darkest
                intensities = [b[3] for b in row]
                min_intensity = min(intensities)
                avg_intensity = np.mean(intensities)
                
                filled_idx = intensities.index(min_intensity)
                filled_option = filled_idx + 1
                
                # Threshold: must be significantly darker
                darkness_threshold = 0.80  # Changed from 0.75
                if min_intensity < avg_intensity * darkness_threshold:
                    detected_answers[q_num] = filled_option
                    print(f"    Q{q_num}: Option {filled_option} (dark: {min_intensity:.0f}, avg: {avg_intensity:.0f})")
                else:
                    detected_answers[q_num] = None
                    print(f"    Q{q_num}: No clear answer (similar intensities)")
                
                q_num += 1
        
        print("\nProcessing RIGHT column:")
        process_column(right_bubbles, 1)
        
        last_q = max(detected_answers.keys()) if detected_answers else 0
        print(f"\nProcessing LEFT column (starting from Q{last_q + 1}):")
        process_column(left_bubbles, last_q + 1)
        
        print(f"\n=== Detected {len(detected_answers)} answers ===\n")
        
        result_metadata['answers'] = detected_answers
        return result_metadata
        # LEFT COLUMN: 11 to 20.
        # Note: In Persian sheets, the "Right" side is usually first.
        
        # IMPORTANT: The user image also shows text Q numbers. 
        # If we successfully filtered them out, we should have groups of 5 bubbles per row.
        
        process_column(right_bubbles, 1)
        
        # The start number for the left column depends on how many were in the right?
        # Or fixed? Usually fixed. Let's assume 20 questions total for now based on the image (rows 1-20 visible).
        # Actually the image shows two columns. 
        # Right visual side: 1..10
        # Left visual side: 11..20
        # In the array `left_bubbles` (x < mid), that's the Left visual side (11-20).
        # In the array `right_bubbles` (x > mid), that's the Right visual side (1-10).
        
        # Wait! In the image:
        # The numbers 1, 2, ... are on the RIGHTmost column.
        # Then options.
        # Then a gap.
        # Then numbers 11... on the middle-left?
        
        # Let's look at the uploaded image behavior.
        # Usually:
        # [Left Column 11-20]   [Right Column 1-10]
        # or
        # [Col 2] [Col 1]?
        # Persian is RTL. Column 1 (Questions 1-10) is on the RIGHT.
        # Column 2 (Questions 11-20) is on the LEFT.
        
        process_column(right_bubbles, 1)
        
        # Count how many rows we processed in right_bubbles to offset left? 
        # Or just hardcode logic or better: Count valid rows.
        # Let's count approximate rows.
        # If right_bubbles produced X answers, left starts at X+1? No, usually fixed blocks (e.g. 20 per col).
        # But from the user image count, it looks like 10 per column.
        
        # Let's just assume simple flow: Right col is Q1..Q{len}, Left is Q{len+1}..
        last_q_right = max(detected_answers.keys()) if detected_answers else 0
        process_column(left_bubbles, last_q_right + 1)
        
        return detected_answers
