#!/usr/bin/env python3
"""
IEEE 802.1 Document Processor
Extracts text from documents and generates summaries using Claude
"""

import os
import sqlite3
import re
import json
from pathlib import Path
import pdfplumber
from anthropic import Anthropic
from tqdm import tqdm
import time

# Load configuration
def load_config():
    config = {}
    if os.path.exists('.env'):
        with open('.env', 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#') and '=' in line:
                    key, value = line.split('=', 1)
                    config[key.strip()] = value.strip()
    return config

config = load_config()
ANTHROPIC_API_KEY = config.get('ANTHROPIC_API_KEY')

if not ANTHROPIC_API_KEY or ANTHROPIC_API_KEY == 'your_api_key_here':
    print("Error: Please set ANTHROPIC_API_KEY in .env file")
    exit(1)

client = Anthropic(api_key=ANTHROPIC_API_KEY)

def extract_text_from_pdf(filepath, max_pages=10):
    """Extract text from PDF (first few pages)"""
    try:
        text = ""
        with pdfplumber.open(filepath) as pdf:
            # Only read first max_pages to keep costs down
            for page in pdf.pages[:max_pages]:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n"
        return text.strip()
    except Exception as e:
        print(f"Error extracting PDF {filepath}: {e}")
        return None

def extract_text_from_txt(filepath, max_chars=10000):
    """Extract text from TXT file"""
    try:
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            text = f.read(max_chars)
        return text.strip()
    except Exception as e:
        print(f"Error reading TXT {filepath}: {e}")
        return None

def extract_metadata(text, filename):
    """Extract metadata using regex patterns"""
    metadata = {
        'subgroup': None,
        'doc_number': None
    }
    
    if not text:
        return metadata
    
    # Look for 802.1x subgroup identifiers
    subgroup_patterns = [
        r'802\.1[a-z]',
        r'802\.10',
        r'802\.[0-9]+[a-z]?'
    ]
    
    for pattern in subgroup_patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            metadata['subgroup'] = match.group(0).upper()
            break
    
    # Look for document numbers (common patterns in IEEE docs)
    doc_patterns = [
        r'P802\.1[a-z]/D[0-9]+',
        r'IEEE\s+[0-9]+\.[0-9]+[a-z]?',
        r'Draft\s+[0-9]+\.[0-9]+'
    ]
    
    for pattern in doc_patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            metadata['doc_number'] = match.group(0)
            break
    
    return metadata

def generate_summary(text, filename, year):
    """Generate 1-2 sentence summary using Claude"""
    if not text or len(text) < 50:
        return "Document too short or unreadable to summarize."
    
    # Truncate text to first 4000 chars to save tokens
    text_sample = text[:4000]
    
    prompt = f"""This is an IEEE 802.1 networking document from {year} named "{filename}".

Please provide a 1-2 sentence summary that captures:
1. The main topic or purpose of the document
2. Key technical focus (e.g., VLANs, bridging, security, etc.)

Document excerpt:
{text_sample}

Summary:"""
    
    try:
        message = client.messages.create(
            model="claude-sonnet-4-5-20250929",
            max_tokens=150,
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        
        summary = message.content[0].text.strip()
        return summary
    
    except Exception as e:
        print(f"Error generating summary: {e}")
        return "Error generating summary."

def process_documents(batch_size=10, start_from=0):
    """Process unprocessed documents"""
    conn = sqlite3.connect('ieee_docs.db')
    c = conn.cursor()
    
    # Get unprocessed documents
    c.execute('''SELECT id, filepath, filename, year, extension 
                 FROM documents 
                 WHERE processed = 0 
                 ORDER BY id
                 LIMIT ? OFFSET ?''', (batch_size, start_from))
    
    documents = c.fetchall()
    
    if not documents:
        print("No unprocessed documents found!")
        conn.close()
        return
    
    print(f"Processing {len(documents)} documents...")
    
    for doc_id, filepath, filename, year, extension in tqdm(documents):
        # Extract text
        if extension == '.pdf':
            text = extract_text_from_pdf(filepath)
        elif extension == '.txt':
            text = extract_text_from_txt(filepath)
        else:
            continue
        
        if not text:
            # Mark as processed even if extraction failed
            c.execute('''UPDATE documents 
                        SET processed = 1, text_extracted = 0 
                        WHERE id = ?''', (doc_id,))
            conn.commit()
            continue
        
        # Extract metadata
        metadata = extract_metadata(text, filename)
        
        # Generate summary
        summary = generate_summary(text, filename, year)
        
        # Update database
        c.execute('''UPDATE documents 
                    SET processed = 1, 
                        text_extracted = 1,
                        summary = ?,
                        subgroup = ?,
                        doc_number = ?,
                        metadata = ?
                    WHERE id = ?''',
                 (summary, metadata['subgroup'], metadata['doc_number'],
                  json.dumps(metadata), doc_id))
        
        conn.commit()
        
        # Small delay to avoid rate limits
        time.sleep(0.5)
    
    conn.close()
    print("\nProcessing complete!")

def show_stats():
    """Show processing statistics"""
    conn = sqlite3.connect('ieee_docs.db')
    c = conn.cursor()
    
    c.execute('SELECT COUNT(*) FROM documents')
    total = c.fetchone()[0]
    
    c.execute('SELECT COUNT(*) FROM documents WHERE processed = 1')
    processed = c.fetchone()[0]
    
    c.execute('SELECT COUNT(*) FROM documents WHERE text_extracted = 1')
    extracted = c.fetchone()[0]
    
    print(f"\nStatistics:")
    print(f"Total documents: {total}")
    print(f"Processed: {processed}")
    print(f"Successfully extracted: {extracted}")
    print(f"Remaining: {total - processed}")
    
    conn.close()

def main():
    print("IEEE 802.1 Document Processor")
    print("=" * 50)
    
    show_stats()
    
    print("\nOptions:")
    print("1. Process next 10 documents")
    print("2. Process next 50 documents")
    print("3. Process ALL remaining documents")
    print("4. Show statistics only")
    
    choice = input("\nEnter choice (1-4): ").strip()
    
    if choice == '1':
        process_documents(batch_size=10)
    elif choice == '2':
        process_documents(batch_size=50)
    elif choice == '3':
        conn = sqlite3.connect('ieee_docs.db')
        c = conn.cursor()
        c.execute('SELECT COUNT(*) FROM documents WHERE processed = 0')
        remaining = c.fetchone()[0]
        conn.close()
        
        confirm = input(f"\nProcess {remaining} documents? This will cost API tokens. (yes/no): ")
        if confirm.lower() == 'yes':
            process_documents(batch_size=remaining)
    elif choice == '4':
        pass
    
    show_stats()

if __name__ == '__main__':
    main()
