diff --git a/cmd/start/start.go b/cmd/start/start.go index 8780965..aaa1b5a 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -5,7 +5,6 @@ import ( "os" "sysmonitord/internal/config" "sysmonitord/internal/scanner/file" - "sysmonitord/internal/scanner/hash" "sysmonitord/internal/scanner/process" "sysmonitord/internal/storage" "sysmonitord/pkg/logger" @@ -32,12 +31,6 @@ var StartCmd = &cobra.Command{ zap.String("审计服务器地址", fmt.Sprintf("%s:%d", cfg.Audit.Server, cfg.Audit.Port)), ) - hashCfg := &hash.Config{ - UseFastHash: cfg.Scanner.File.FastHash, - Threshold: cfg.Scanner.File.FastHashSize, - ChunkSize: cfg.Scanner.File.FastHashChunk, - } - storageCfg := &storage.Storage{ DataDir: cfg.Storage.DataDir, ProcessSystemFile: cfg.Storage.ProcessSystemFile, @@ -46,7 +39,8 @@ var StartCmd = &cobra.Command{ // ====== 进程扫描和存储 ====== - procs, err := process.ScanAllProcesses(hashCfg) + startTime := time.Now() + procs, err := process.ScanAllProcesses(cfg) if err != nil { logger.Log.Error("扫描进程失败", zap.Error(err)) os.Exit(1) @@ -75,7 +69,7 @@ var StartCmd = &cobra.Command{ // ====== 文件扫描和存储 ====== logger.Log.Info("正在扫描文件系统...") - startTime := time.Now() + startTime = time.Now() fileScanner := file.NewScanner(cfg) files, err := fileScanner.Scan() diff --git a/internal/scanner/file/scanner.go b/internal/scanner/file/scanner.go index 229447d..722dd6d 100644 --- a/internal/scanner/file/scanner.go +++ b/internal/scanner/file/scanner.go @@ -85,7 +85,7 @@ func (s *Scanner) WalkFunc(result *[]FileInfo) fs.WalkDirFunc { return nil } - hash, err := hash.CalculateHash(path, hashCfg) + hash, err := hash.Calculate(path, info.Size(), hashCfg) if err != nil { logger.Log.Debug("[scan]无法计算文件哈希", zap.String("path", path), zap.Error(err)) return nil diff --git a/internal/scanner/hash/hash.go b/internal/scanner/hash/hash.go index 01cef52..a937569 100644 --- a/internal/scanner/hash/hash.go +++ b/internal/scanner/hash/hash.go @@ -56,34 +56,33 @@ type Config struct { } // ==== 计算文件哈希 ==== - -func CalculateHash(filePath string, cfg *Config) (string, error) { - info, err := os.Stat(filePath) - if err != nil { - logger.Log.Warn("[hash]获取文件信息失败", zap.String("path", filePath), zap.Error(err)) - return "", err +func Calculate(filePath string, fileSize int64, cfg *Config) (string, error) { + if cfg == nil { + cfg = &Config{ + Algorithm: &SHA256Algorithm{}, + } } - fileSize := info.Size() + if fileSize == 0 { + info, err := os.Stat(filePath) + if err != nil { + logger.Log.Warn("[scanner]获取文件信息失败", zap.String("path", filePath), zap.Error(err)) + return "", err + } + fileSize = info.Size() + } if cfg.Algorithm == nil { cfg.Algorithm = &SHA256Algorithm{} } - logger.Log.Debug("[hash]计算文件哈希", - zap.String("path", filePath), - zap.Int64("fileSize", fileSize), - zap.String("Algorithm", cfg.Algorithm.Name())) + logger.Log.Debug("[scanner]计算文件哈希", zap.String("path", filePath), zap.Int64("size", fileSize), zap.String("algorithm", cfg.Algorithm.Name())) if cfg.UseFastHash && fileSize > cfg.Threshold { - logger.Log.Debug("[hash] 分层哈希...", - zap.String("path", filePath), - zap.Int64("fileSize", fileSize), - ) return calculateFast(filePath, fileSize, cfg) + } else { + return calculateFull(filePath, cfg) } - - return calculateFull(filePath, cfg) } func calculateFull(filePath string, cfg *Config) (string, error) { diff --git a/internal/scanner/process/process.go b/internal/scanner/process/process.go index 11287bb..6918ff8 100644 --- a/internal/scanner/process/process.go +++ b/internal/scanner/process/process.go @@ -3,6 +3,7 @@ package process import ( "fmt" "os" + "sysmonitord/internal/config" "sysmonitord/internal/scanner/hash" "sysmonitord/pkg/logger" @@ -18,7 +19,7 @@ type ProcessInfo struct { FileHash string `json:"file_hash"` } -func ScanAllProcesses(hashCfg *hash.Config) ([]ProcessInfo, error) { +func ScanAllProcesses(cfg *config.Config) ([]ProcessInfo, error) { logger.Log.Info("[scan]正在扫描系统中的所有进程...") pids, err := process.Pids() @@ -28,6 +29,13 @@ func ScanAllProcesses(hashCfg *hash.Config) ([]ProcessInfo, error) { } var processList []ProcessInfo + + hashCfg, err := cfg.GetHashConfig() + if err != nil { + logger.Log.Error("[scan]获取哈希配置失败", zap.Error(err)) + return nil, err + } + for _, pid := range pids { p, err := process.NewProcess(pid) if err != nil { @@ -58,7 +66,7 @@ func ScanAllProcesses(hashCfg *hash.Config) ([]ProcessInfo, error) { if exePath != "" { if _, err := os.Stat(exePath); err == nil { - fileHash, err := hash.CalculateHash(exePath, hashCfg) + fileHash, err := hash.Calculate(exePath, 0, hashCfg) if err == nil { info.FileHash = fileHash } else {