#!/usr/bin/env python3
"""
MySQL Database Dump Script
Exports schema and data from MSSQL source, outputting MySQL-compatible SQL
"""

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from aumentum_browser_service import DEFAULT_DB_CONFIG
import pyodbc
from datetime import datetime

OUTPUT_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "database_dump", "mysql")
os.makedirs(OUTPUT_DIR, exist_ok=True)

def get_connection():
    """Get MSSQL database connection (source)"""
    driver = DEFAULT_DB_CONFIG.get("driver", "FreeTDS")
    server = DEFAULT_DB_CONFIG.get("server")
    database = DEFAULT_DB_CONFIG.get("database")
    username = DEFAULT_DB_CONFIG.get("username")
    password = DEFAULT_DB_CONFIG.get("password", "")
    
    if driver == "FreeTDS":
        conn_str = (
            f"DRIVER={{{driver}}};"
            f"SERVERNAME={server};"
            f"DATABASE={database};"
            f"UID={username};"
            f"PWD={password};"
        )
    else:
        conn_str = (
            f"DRIVER={{{driver}}};"
            f"SERVER={server},1433;"
            f"DATABASE={database};"
            f"UID={username};"
            f"PWD={password};"
            "Encrypt=no;"
            "TrustServerCertificate=yes;"
        )
    
    print(f"🔌 Connecting to MSSQL source: {database}...")
    return pyodbc.connect(conn_str, timeout=30)

def convert_mssql_type_to_mysql(mssql_type, max_len, precision, scale):
    """Convert MSSQL data type to MySQL equivalent"""
    type_map = {
        'nvarchar': lambda: f'VARCHAR({max_len // 2})' if max_len > 0 and max_len != -1 else 'TEXT',
        'varchar': lambda: f'VARCHAR({max_len})' if max_len > 0 and max_len != -1 else 'TEXT',
        'nchar': lambda: f'CHAR({max_len // 2})' if max_len > 0 else 'CHAR(1)',
        'char': lambda: f'CHAR({max_len})' if max_len > 0 else 'CHAR(1)',
        'numeric': lambda: f'DECIMAL({precision},{scale})',
        'decimal': lambda: f'DECIMAL({precision},{scale})',
        'tinyint': lambda: 'TINYINT',
        'int': lambda: 'INT',
        'bigint': lambda: 'BIGINT',
        'float': lambda: 'DOUBLE',
        'real': lambda: 'FLOAT',
        'datetime': lambda: 'DATETIME',
        'datetime2': lambda: 'DATETIME',
        'date': lambda: 'DATE',
        'time': lambda: 'TIME',
        'bit': lambda: 'BOOLEAN',
        'image': lambda: 'LONGBLOB',
        'text': lambda: 'TEXT',
        'ntext': lambda: 'TEXT',
        'uniqueidentifier': lambda: 'CHAR(36)',
    }
    
    mssql_type_lower = mssql_type.lower()
    if mssql_type_lower in type_map:
        return type_map[mssql_type_lower]()
    return mssql_type.upper()

def dump_schema_mysql(conn, output_file):
    """Dump database schema in MySQL format"""
    cursor = conn.cursor()
    
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(f"-- MySQL Database Schema Dump\n")
        f.write(f"-- Source: {DEFAULT_DB_CONFIG.get('database')} (MSSQL)\n")
        f.write(f"-- Generated: {datetime.now().isoformat()}\n")
        f.write(f"-- ==========================================\n\n")
        
        f.write(f"CREATE DATABASE IF NOT EXISTS `{DEFAULT_DB_CONFIG.get('database')}` ")
        f.write(f"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;\n\n")
        f.write(f"USE `{DEFAULT_DB_CONFIG.get('database')}`;\n\n")
        f.write(f"SET FOREIGN_KEY_CHECKS=0;\n\n")
        
        # Get all tables
        cursor.execute("""
            SELECT 
                s.name AS schema_name,
                t.name AS table_name,
                t.create_date,
                t.modify_date
            FROM sys.tables t
            JOIN sys.schemas s ON t.schema_id = s.schema_id
            WHERE s.name = 'LRSAdmin'
            ORDER BY t.name
        """)
        
        tables = cursor.fetchall()
        print(f"📊 Found {len(tables)} tables")
        
        f.write(f"-- Total Tables: {len(tables)}\n\n")
        
        for schema_name, table_name, create_date, modify_date in tables:
            full_table_name = f"{schema_name}.{table_name}"
            f.write(f"-- ==========================================\n")
            f.write(f"-- Table: {table_name}\n")
            f.write(f"-- Created: {create_date}\n")
            f.write(f"-- Modified: {modify_date}\n")
            f.write(f"-- ==========================================\n\n")
            
            # Get table columns
            cursor.execute("""
                SELECT 
                    c.name AS column_name,
                    t.name AS data_type,
                    c.max_length,
                    c.precision,
                    c.scale,
                    c.is_nullable,
                    c.is_identity,
                    ISNULL(dc.definition, '') AS default_value
                FROM sys.columns c
                JOIN sys.types t ON c.user_type_id = t.user_type_id
                LEFT JOIN sys.default_constraints dc ON c.default_object_id = dc.object_id
                WHERE c.object_id = OBJECT_ID(?)
                ORDER BY c.column_id
            """, (full_table_name,))
            
            columns = cursor.fetchall()
            
            f.write(f"CREATE TABLE `{table_name}` (\n")
            
            col_defs = []
            for col in columns:
                col_name, data_type, max_len, precision, scale, nullable, is_identity, default = col
                
                # Convert to MySQL type
                mysql_type = convert_mssql_type_to_mysql(data_type, max_len, precision, scale)
                
                # If column has IDENTITY and is DECIMAL/NUMERIC, convert to BIGINT
                # MySQL AUTO_INCREMENT only works with integer types
                if is_identity and data_type.lower() in ('numeric', 'decimal') and scale == 0:
                    mysql_type = 'BIGINT'
                
                # Build column definition
                col_def = f"    `{col_name}` {mysql_type}"
                
                # Add AUTO_INCREMENT (MySQL equivalent of IDENTITY)
                if is_identity:
                    col_def += " AUTO_INCREMENT"
                
                # Add nullable
                if not nullable:
                    col_def += " NOT NULL"
                else:
                    col_def += " NULL"
                
                # Add default (convert MSSQL defaults to MySQL)
                if default:
                    # Remove parentheses and convert MSSQL defaults
                    default = default.strip('()')
                    if default.upper() in ('GETDATE()', 'GETUTCDATE()'):
                        default = 'CURRENT_TIMESTAMP'
                    col_def += f" DEFAULT {default}"
                
                col_defs.append(col_def)
            
            f.write(",\n".join(col_defs))
            f.write(f"\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;\n\n")
            
            # Get primary key
            cursor.execute("""
                SELECT 
                    i.name AS index_name,
                    c.name AS column_name,
                    ic.key_ordinal
                FROM sys.indexes i
                JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
                JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id
                WHERE i.object_id = OBJECT_ID(?)
                AND i.is_primary_key = 1
                ORDER BY ic.key_ordinal
            """, (full_table_name,))
            
            pk_rows = cursor.fetchall()
            if pk_rows:
                pk_name = pk_rows[0][0]
                pk_cols = ", ".join([f"`{row[1]}`" for row in pk_rows])
                f.write(f"ALTER TABLE `{table_name}`\n")
                f.write(f"    ADD CONSTRAINT `{pk_name}` PRIMARY KEY ({pk_cols});\n\n")
            
            # Get indexes
            cursor.execute("""
                SELECT 
                    i.name AS index_name,
                    i.is_unique,
                    c.name AS column_name,
                    ic.key_ordinal
                FROM sys.indexes i
                JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
                JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id
                WHERE i.object_id = OBJECT_ID(?)
                AND i.is_primary_key = 0
                AND i.type > 0
                ORDER BY i.name, ic.key_ordinal
            """, (full_table_name,))
            
            idx_rows = cursor.fetchall()
            if idx_rows:
                # Group by index name
                current_idx = None
                idx_cols = []
                is_unique = False
                
                for idx_name, idx_is_unique, col_name, ordinal in idx_rows:
                    if current_idx != idx_name:
                        # Write previous index
                        if current_idx:
                            unique_str = "UNIQUE " if is_unique else ""
                            f.write(f"CREATE {unique_str}INDEX `{current_idx}` ON `{table_name}` ({', '.join([f'`{c}`' for c in idx_cols])});\n")
                        # Start new index
                        current_idx = idx_name
                        is_unique = idx_is_unique
                        idx_cols = [col_name]
                    else:
                        idx_cols.append(col_name)
                
                # Write last index
                if current_idx:
                    unique_str = "UNIQUE " if is_unique else ""
                    f.write(f"CREATE {unique_str}INDEX `{current_idx}` ON `{table_name}` ({', '.join([f'`{c}`' for c in idx_cols])});\n")
                
                f.write("\n")
            
            f.write("\n")
        
        f.write("SET FOREIGN_KEY_CHECKS=1;\n")
    
    cursor.close()
    print(f"✅ Schema dumped to: {output_file}")

def dump_table_data_mysql(conn, table_name, output_file, limit=1000):
    """Dump data from a specific table in MySQL format"""
    cursor = conn.cursor()
    
    full_table_name = f"LRSAdmin.{table_name}"
    
    # Get row count
    cursor.execute(f"SELECT COUNT(*) FROM {full_table_name}")
    total_rows = cursor.fetchone()[0]
    
    if total_rows == 0:
        print(f"   ⚠️  {table_name}: No data")
        return
    
    # Get data
    query = f"SELECT TOP {limit} * FROM {full_table_name}"
    cursor.execute(query)
    
    columns = [desc[0] for desc in cursor.description]
    rows = cursor.fetchall()
    
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(f"-- Table: {table_name}\n")
        f.write(f"-- Total Rows: {total_rows:,}\n")
        f.write(f"-- Dumped Rows: {len(rows):,}\n")
        f.write(f"-- Generated: {datetime.now().isoformat()}\n")
        f.write(f"-- ==========================================\n\n")
        
        f.write("SET FOREIGN_KEY_CHECKS=0;\n\n")
        
        for row in rows:
            values = []
            for val in row:
                if val is None:
                    values.append("NULL")
                elif isinstance(val, str):
                    # Escape single quotes for MySQL
                    escaped = val.replace("\\", "\\\\").replace("'", "\\'")
                    values.append(f"'{escaped}'")
                elif isinstance(val, (int, float)):
                    values.append(str(val))
                elif isinstance(val, datetime):
                    # MySQL datetime format
                    values.append(f"'{val.strftime('%Y-%m-%d %H:%M:%S')}'")
                else:
                    escaped = str(val).replace("\\", "\\\\").replace("'", "\\'")
                    values.append(f"'{escaped}'")
            
            col_list = ", ".join(f"`{col}`" for col in columns)
            val_list = ", ".join(values)
            f.write(f"INSERT INTO `{table_name}` ({col_list}) VALUES ({val_list});\n")
        
        f.write("\nSET FOREIGN_KEY_CHECKS=1;\n")
    
    cursor.close()
    print(f"   ✅ {table_name}: {len(rows):,} rows dumped (out of {total_rows:,} total)")

def main():
    """Main export function"""
    print("="*80)
    print("📦 MSSQL to MySQL Database Export")
    print("="*80)
    print()
    
    try:
        conn = get_connection()
        print(f"✅ Connected to database\n")
        
        # 1. Export schema
        print("1️⃣ Exporting database schema (MySQL format)...")
        schema_file = os.path.join(OUTPUT_DIR, f"schema_mysql_{datetime.now().strftime('%Y%m%d_%H%M%S')}.sql")
        dump_schema_mysql(conn, schema_file)
        print()
        
        # 2. Export key table data (limited rows)
        print("2️⃣ Exporting table data (MySQL format, limited to 1000 rows per table)...")
        
        key_tables = [
            'lr_source_document',
            'lr_transaction',
            'lr_transaction_document',
            'lr_party',
            'lr_dictionary',
            'alf_node',
            'alf_node_properties',
            'alf_content_url',
            'alf_content_data',
            'alf_qname'
        ]
        
        for table in key_tables:
            table_data_file = os.path.join(OUTPUT_DIR, f"data_mysql_{table}.sql")
            try:
                dump_table_data_mysql(conn, table, table_data_file, limit=1000)
            except Exception as e:
                print(f"   ❌ {table}: Error - {e}")
        
        print()
        
        conn.close()
        
        print("="*80)
        print("✅ Export Complete!")
        print("="*80)
        print()
        print(f"📁 Output directory: {OUTPUT_DIR}")
        print()
        print("Next steps:")
        print("1. Create MySQL database:")
        print("   mysql -u root -p -e \"CREATE DATABASE LRS43 CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;\"")
        print("2. Import schema:")
        print(f"   mysql -u root -p LRS43 < {schema_file}")
        print("3. Import data:")
        print(f"   for f in {OUTPUT_DIR}/data_mysql_*.sql; do mysql -u root -p LRS43 < \"$f\"; done")
        print()
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0

if __name__ == "__main__":
    sys.exit(main())

