using System.Text.Json; using DatabaseSnapshotsService.Models; using MySqlConnector; using System.Text; namespace DatabaseSnapshotsService.Services { public class RecoveryService { private readonly SnapshotConfiguration _config; private readonly string _recoveryPointsPath; private readonly string _eventsPath; private readonly OptimizedFileService _fileService; private readonly EncryptionService _encryptionService; private int _nextPointId = 1; public RecoveryService(SnapshotConfiguration config) { _config = config; _recoveryPointsPath = Path.Combine(config.EventStore.Path, "recovery_points"); _eventsPath = config.EventStore.Path; _fileService = new OptimizedFileService(); // Initialize encryption service - match SnapshotService pattern _encryptionService = new EncryptionService( config.Security.EncryptionKey, config.Security.Encryption ); // Ensure directories exist Directory.CreateDirectory(_recoveryPointsPath); Directory.CreateDirectory(_eventsPath); // Load next ID from existing recovery points LoadNextPointId(); } public async Task CreateRecoveryPointAsync(string name, string? description = null) { // Check if name already exists if (await GetRecoveryPointAsync(name) != null) { throw new ArgumentException($"Recovery point '{name}' already exists"); } var point = new RecoveryPoint { Id = _nextPointId++, Name = name, Timestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), Description = description, CreatedAt = DateTime.UtcNow, EventCount = await GetTotalEventCountAsync(), LastEventId = await GetLastEventIdAsync() }; // Save recovery point await SaveRecoveryPointAsync(point); return point; } public async Task> ListRecoveryPointsAsync() { var points = new List(); var pointFiles = Directory.GetFiles(_recoveryPointsPath, "*.json"); foreach (var file in pointFiles) { try { var jsonBytes = await _fileService.ReadFileOptimizedAsync(file); var json = Encoding.UTF8.GetString(jsonBytes); var point = JsonSerializer.Deserialize(json); if (point != null) { points.Add(point); } } catch (Exception ex) { Console.WriteLine($"Warning: Could not load recovery point from {file}: {ex.Message}"); } } return points.OrderByDescending(p => p.CreatedAt).ToList(); } public async Task GetRecoveryPointAsync(string name) { var pointFiles = Directory.GetFiles(_recoveryPointsPath, "*.json"); foreach (var file in pointFiles) { try { var jsonBytes = await _fileService.ReadFileOptimizedAsync(file); var json = Encoding.UTF8.GetString(jsonBytes); var point = JsonSerializer.Deserialize(json); if (point?.Name == name) { return point; } } catch (Exception ex) { Console.WriteLine($"Warning: Could not load recovery point from {file}: {ex.Message}"); } } return null; } public async Task PreviewRestoreAsync(long timestamp) { var preview = new RestorePreview { TargetTimestamp = timestamp, EventCount = 0, AffectedTables = new List(), EstimatedDuration = TimeSpan.Zero, Warnings = new List() }; // Find the closest snapshot before the target timestamp var snapshotService = new SnapshotService(_config); var snapshots = await snapshotService.ListSnapshotsAsync(); var closestSnapshot = snapshots .Where(s => s.Timestamp <= timestamp) .OrderByDescending(s => s.Timestamp) .FirstOrDefault(); if (closestSnapshot != null) { preview.SnapshotId = closestSnapshot.Id; preview.Warnings.Add($"Will use snapshot {closestSnapshot.Id} as base"); } else { preview.Warnings.Add("No suitable snapshot found - will restore from scratch"); } // Count events that would be applied var events = await GetEventsInRangeAsync(closestSnapshot?.Timestamp ?? 0, timestamp); preview.EventCount = events.Count; // Get affected tables preview.AffectedTables = events .Select(e => e.Table) .Distinct() .ToList(); // Estimate duration (rough calculation) preview.EstimatedDuration = TimeSpan.FromSeconds(events.Count * 0.001); // 1ms per event return preview; } public async Task RestoreAsync(long timestamp) { try { Console.WriteLine("=== PERFORMING ACTUAL RECOVERY ==="); Console.WriteLine("This will modify the target database!"); Console.WriteLine($"Starting restore to timestamp {timestamp}..."); // Find the target snapshot and build restore chain var (targetSnapshot, restoreChain) = await BuildRestoreChainAsync(timestamp); if (targetSnapshot == null) { throw new Exception($"No snapshot found for timestamp {timestamp}"); } Console.WriteLine($"Target snapshot: {targetSnapshot.Id} ({targetSnapshot.Type})"); Console.WriteLine($"Restore chain: {restoreChain.Count} snapshots"); // Restore the full snapshot (first in chain) var fullSnapshot = restoreChain.First(); Console.WriteLine($"Restoring full snapshot {fullSnapshot.Id}..."); await RestoreFromSnapshotAsync(fullSnapshot); // Apply incremental snapshots in order var incrementals = restoreChain.Skip(1).ToList(); if (incrementals.Any()) { Console.WriteLine($"Applying {incrementals.Count} incremental snapshots..."); foreach (var incremental in incrementals) { Console.WriteLine($"Applying incremental snapshot {incremental.Id}..."); await ApplyIncrementalSnapshotAsync(incremental); } } Console.WriteLine("Validating restore..."); await ValidateRestoreAsync(); Console.WriteLine("Database validation passed"); Console.WriteLine("Restore completed successfully"); } catch (Exception ex) { Console.WriteLine($"Restore failed: {ex.Message}"); throw; } } private async Task<(SnapshotInfo? TargetSnapshot, List RestoreChain)> BuildRestoreChainAsync(long timestamp) { var snapshotService = new SnapshotService(_config); var snapshots = await snapshotService.ListSnapshotsAsync(); // Find the target snapshot (closest to timestamp) var targetSnapshot = snapshots .Where(s => s.Timestamp <= timestamp) .OrderByDescending(s => s.Timestamp) .FirstOrDefault(); if (targetSnapshot == null) return (null, new List()); // Build restore chain: full snapshot + all incrementals up to target var restoreChain = new List(); if (targetSnapshot.Type.Equals("Full", StringComparison.OrdinalIgnoreCase)) { // Target is a full snapshot, just restore it restoreChain.Add(targetSnapshot); } else { // Target is incremental, need to find the full snapshot and all incrementals var current = targetSnapshot; var chain = new List(); // Walk backwards to find the full snapshot while (current != null) { chain.Insert(0, current); // Add to front to maintain order if (current.Type.Equals("Full", StringComparison.OrdinalIgnoreCase)) break; // Find parent snapshot current = snapshots.FirstOrDefault(s => s.Id == current.ParentSnapshotId); } restoreChain = chain; } return (targetSnapshot, restoreChain); } private async Task ApplyIncrementalSnapshotAsync(SnapshotInfo incremental) { Console.WriteLine($"Applying incremental snapshot {incremental.Id}..."); if (!File.Exists(incremental.FilePath)) { throw new FileNotFoundException($"Incremental snapshot file not found: {incremental.FilePath}"); } // Read and decompress/decrypt the snapshot file var sqlContent = await ReadSnapshotFileAsync(incremental.FilePath); // Extract connection details from configuration var connectionString = _config.ConnectionString; var server = ExtractValue(connectionString, "Server") ?? "localhost"; var port = ExtractValue(connectionString, "Port") ?? "3306"; var database = ExtractValue(connectionString, "Database") ?? "trading_platform"; var userId = ExtractValue(connectionString, "Uid") ?? "root"; var password = ExtractValue(connectionString, "Pwd") ?? ""; // Build mysql command arguments var mysqlArgs = $"-h{server} -P{port} -u{userId}"; if (!string.IsNullOrEmpty(password)) { mysqlArgs += $" -p{password}"; } mysqlArgs += $" {database}"; // Apply the SQL content using mysql via stdin var startInfo = new System.Diagnostics.ProcessStartInfo { FileName = "mysql", Arguments = mysqlArgs, RedirectStandardInput = true, RedirectStandardOutput = true, RedirectStandardError = true, UseShellExecute = false, CreateNoWindow = true }; using var process = System.Diagnostics.Process.Start(startInfo); if (process != null) { // Write the SQL content to mysql stdin await process.StandardInput.WriteAsync(sqlContent); await process.StandardInput.FlushAsync(); process.StandardInput.Close(); string stdOut = await process.StandardOutput.ReadToEndAsync(); string stdErr = await process.StandardError.ReadToEndAsync(); await process.WaitForExitAsync(); if (process.ExitCode != 0) { Console.WriteLine($"[mysql stdout] {stdOut}"); Console.WriteLine($"[mysql stderr] {stdErr}"); throw new Exception($"mysql failed with exit code {process.ExitCode}"); } } } private async Task GetTotalEventCountAsync() { var eventFiles = Directory.GetFiles(_eventsPath, "events_*.json"); long totalCount = 0; foreach (var file in eventFiles) { try { var lines = await File.ReadAllLinesAsync(file); totalCount += lines.Length; } catch (Exception ex) { Console.WriteLine($"Warning: Could not read event file {file}: {ex.Message}"); } } return totalCount; } private async Task GetLastEventIdAsync() { var eventFiles = Directory.GetFiles(_eventsPath, "events_*.json"); long lastId = 0; foreach (var file in eventFiles.OrderByDescending(f => f)) { try { var lines = await File.ReadAllLinesAsync(file); if (lines.Length > 0) { var lastLine = lines.Last(); var lastEvent = JsonSerializer.Deserialize(lastLine); if (lastEvent != null && lastEvent.Id > lastId) { lastId = lastEvent.Id; } } } catch (Exception ex) { Console.WriteLine($"Warning: Could not read event file {file}: {ex.Message}"); } } return lastId; } private async Task> GetEventsInRangeAsync(long fromTimestamp, long toTimestamp) { var events = new List(); var eventFiles = Directory.GetFiles(_eventsPath, "events_*.json"); foreach (var file in eventFiles) { try { var lines = await File.ReadAllLinesAsync(file); foreach (var line in lines) { var evt = JsonSerializer.Deserialize(line); if (evt != null && evt.Timestamp >= fromTimestamp && evt.Timestamp <= toTimestamp) { events.Add(evt); } } } catch (Exception ex) { Console.WriteLine($"Warning: Could not read event file {file}: {ex.Message}"); } } return events.OrderBy(e => e.Timestamp).ToList(); } private async Task RestoreFromSnapshotAsync(SnapshotInfo snapshot) { Console.WriteLine($"Restoring database from snapshot {snapshot.Id}..."); if (!File.Exists(snapshot.FilePath)) { throw new FileNotFoundException($"Snapshot file not found: {snapshot.FilePath}"); } // Use programmatic restoration (handles encryption/compression better) await RestoreProgrammaticallyAsync(snapshot); } private async Task RestoreProgrammaticallyAsync(SnapshotInfo snapshot) { // Read and decompress the snapshot file var sqlContent = await ReadSnapshotFileAsync(snapshot.FilePath); // Create a temporary file with the SQL content var tempFile = Path.GetTempFileName(); await File.WriteAllTextAsync(tempFile, sqlContent); try { // Extract connection details from configuration var connectionString = _config.ConnectionString; var server = ExtractValue(connectionString, "Server") ?? "localhost"; var port = ExtractValue(connectionString, "Port") ?? "3306"; var database = ExtractValue(connectionString, "Database") ?? "trading_platform"; var userId = ExtractValue(connectionString, "Uid") ?? "root"; var password = ExtractValue(connectionString, "Pwd") ?? ""; // Build mysql command arguments var mysqlArgs = $"-h{server} -P{port} -u{userId}"; if (!string.IsNullOrEmpty(password)) { mysqlArgs += $" -p{password}"; } mysqlArgs += $" {database}"; // Use mysql command to restore the database var startInfo = new System.Diagnostics.ProcessStartInfo { FileName = "mysql", Arguments = mysqlArgs, RedirectStandardInput = true, RedirectStandardOutput = true, RedirectStandardError = true, UseShellExecute = false, CreateNoWindow = true }; using var process = new System.Diagnostics.Process { StartInfo = startInfo }; process.ErrorDataReceived += (sender, e) => { if (!string.IsNullOrEmpty(e.Data)) { Console.WriteLine($"[mysql restore] {e.Data}"); } }; process.Start(); process.BeginErrorReadLine(); // Send the SQL content to mysql via stdin using var writer = process.StandardInput; await writer.WriteAsync(sqlContent); await writer.FlushAsync(); writer.Close(); await process.WaitForExitAsync(); if (process.ExitCode != 0) { throw new Exception($"mysql restore failed with exit code {process.ExitCode}"); } Console.WriteLine("Database restore completed successfully using mysql command"); } finally { // Clean up temporary file if (File.Exists(tempFile)) { File.Delete(tempFile); } } } private async Task ReadSnapshotFileAsync(string filePath) { try { // Check if file is encrypted and compressed if (filePath.EndsWith(".lz4.enc")) { // First decrypt, then decompress var decryptedPath = filePath.Replace(".lz4.enc", ".lz4.tmp"); var decompressedPath = filePath.Replace(".lz4.enc", ".sql.tmp"); try { // Decrypt the file using the instance field await _encryptionService.DecryptFileAsync(filePath, decryptedPath); // Decompress the decrypted file await _fileService.DecompressFileStreamingAsync(decryptedPath, decompressedPath); // Read the final SQL content var content = await _fileService.ReadFileOptimizedAsync(decompressedPath); return Encoding.UTF8.GetString(content); } finally { // Clean up temporary files if (File.Exists(decryptedPath)) File.Delete(decryptedPath); if (File.Exists(decompressedPath)) File.Delete(decompressedPath); } } else if (filePath.EndsWith(".lz4")) { // Only compressed, not encrypted var tempPath = filePath.Replace(".lz4", ".tmp"); await _fileService.DecompressFileStreamingAsync(filePath, tempPath); var content = await _fileService.ReadFileOptimizedAsync(tempPath); File.Delete(tempPath); // Clean up temp file return Encoding.UTF8.GetString(content); } else if (filePath.EndsWith(".enc")) { // Only encrypted, not compressed var tempPath = filePath.Replace(".enc", ".tmp"); await _encryptionService.DecryptFileAsync(filePath, tempPath); var content = await _fileService.ReadFileOptimizedAsync(tempPath); File.Delete(tempPath); // Clean up temp file return Encoding.UTF8.GetString(content); } else { // Plain text file var content = await _fileService.ReadFileOptimizedAsync(filePath); return Encoding.UTF8.GetString(content); } } catch (Exception ex) { throw new InvalidOperationException($"Failed to read snapshot file {filePath}: {ex.Message}", ex); } } private async Task ValidateRestoreAsync() { // Basic validation - check if database is accessible and has expected data using var connection = new MySqlConnection(_config.ConnectionString); await connection.OpenAsync(); // Check if we can query the database using var command = new MySqlCommand("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE()", connection); var tableCount = await command.ExecuteScalarAsync(); if (Convert.ToInt32(tableCount) == 0) { throw new Exception("Database validation failed: No tables found after restore"); } } private async Task SaveRecoveryPointAsync(RecoveryPoint point) { var pointFile = Path.Combine(_recoveryPointsPath, $"{point.Name}.json"); var json = JsonSerializer.Serialize(point, new JsonSerializerOptions { WriteIndented = true }); var jsonBytes = Encoding.UTF8.GetBytes(json); await _fileService.WriteFileOptimizedAsync(pointFile, jsonBytes); } private void LoadNextPointId() { var pointFiles = Directory.GetFiles(_recoveryPointsPath, "*.json"); if (pointFiles.Length > 0) { var maxId = pointFiles .Select(f => Path.GetFileNameWithoutExtension(f)) .Where(name => int.TryParse(name, out _)) .Select(int.Parse) .DefaultIfEmpty(0) .Max(); _nextPointId = maxId + 1; } } private string? ExtractValue(string connectionString, string key) { var pairs = connectionString.Split(';'); foreach (var pair in pairs) { var keyValue = pair.Split('='); if (keyValue.Length == 2 && keyValue[0].Trim().Equals(key, StringComparison.OrdinalIgnoreCase)) { return keyValue[1].Trim(); } } return null; } } }