From 557c48d6566a26330b3751e61f47fdcc64322971 Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Fri, 17 Jan 2025 19:16:16 +0100 Subject: [PATCH] add: OOM protection for high concurrency usage of spooledtempfile --- client.go | 9 +- dialer.go | 3 +- pkg/spooledtempfile/spooled.go | 308 ++++++++++++++++++ .../spooledtempfile/spooled_test.go | 68 +++- read.go | 4 +- spooled.go | 208 ------------ spooledmanager.go | 131 -------- utils.go | 5 +- write.go | 3 +- 9 files changed, 383 insertions(+), 356 deletions(-) create mode 100644 pkg/spooledtempfile/spooled.go rename spooled_test.go => pkg/spooledtempfile/spooled_test.go (80%) delete mode 100644 spooled.go delete mode 100644 spooledmanager.go diff --git a/client.go b/client.go index 72134e0..a5e31ea 100644 --- a/client.go +++ b/client.go @@ -30,6 +30,7 @@ type HTTPClientSettings struct { DecompressBody bool FollowRedirects bool FullOnDisk bool + MaxRAMUsageFraction float64 VerifyCerts bool RandomLocalIP bool DisableIPv4 bool @@ -53,7 +54,10 @@ type CustomHTTPClient struct { MaxReadBeforeTruncate int verifyCerts bool FullOnDisk bool - randomLocalIP bool + // MaxRAMUsageFraction is the fraction of system RAM above which we'll force spooling to disk. For example, 0.5 = 50%. + // If set to <= 0, the default value is DefaultMaxRAMUsageFraction. + MaxRAMUsageFraction float64 + randomLocalIP bool } func (c *CustomHTTPClient) Close() error { @@ -125,6 +129,9 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient // Configure if we are only storing responses only on disk or in memory and on disk. httpClient.FullOnDisk = HTTPClientSettings.FullOnDisk + // Configure the maximum RAM usage fraction + httpClient.MaxRAMUsageFraction = HTTPClientSettings.MaxRAMUsageFraction + // Configure our max read before we start truncating records if HTTPClientSettings.MaxReadBeforeTruncate == 0 { httpClient.MaxReadBeforeTruncate = 1000000000 diff --git a/dialer.go b/dialer.go index 3405655..cdfe66f 100644 --- a/dialer.go +++ b/dialer.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/CorentinB/warc/pkg/spooledtempfile" "github.com/google/uuid" "github.com/miekg/dns" tls "github.com/refraction-networking/utls" @@ -487,7 +488,7 @@ func (d *customDialer) readResponse(respPipe *io.PipeReader, warcTargetURIChanne } // Write the data up until the end of the headers to a temporary buffer - tempBuffer := NewSpooledTempFile("warc", d.client.TempDir, -1, d.client.FullOnDisk) + tempBuffer := spooledtempfile.NewSpooledTempFile("warc", d.client.TempDir, -1, d.client.FullOnDisk, d.client.MaxRAMUsageFraction) block = make([]byte, 1) wrote := 0 responseRecord.Content.Seek(0, 0) diff --git a/pkg/spooledtempfile/spooled.go b/pkg/spooledtempfile/spooled.go new file mode 100644 index 0000000..6870df9 --- /dev/null +++ b/pkg/spooledtempfile/spooled.go @@ -0,0 +1,308 @@ +package spooledtempfile + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" +) + +// MaxInMemorySize is the max number of bytes (currently 1MB) +// to hold in memory before starting to write to disk +const MaxInMemorySize = 1024 * 1024 + +// DefaultMaxRAMUsageFraction is the default fraction of system RAM above which +// we'll force spooling to disk. For example, 0.5 = 50%. +const DefaultMaxRAMUsageFraction = 0.50 + +var spooledPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, +} + +// ReaderAt is the interface for ReadAt - read at position, without moving pointer. +type ReaderAt interface { + ReadAt(p []byte, off int64) (n int, err error) +} + +// ReadSeekCloser is an io.Reader + ReaderAt + io.Seeker + io.Closer + Stat +type ReadSeekCloser interface { + io.Reader + io.Seeker + ReaderAt + io.Closer + FileName() string + Len() int +} + +// spooledTempFile writes to memory (or to disk if +// over MaxInMemorySize) and deletes the file on Close +type spooledTempFile struct { + buf *bytes.Buffer + mem *bytes.Reader + file *os.File + filePrefix string + tempDir string + maxInMemorySize int + fullOnDisk bool + reading bool // transitions at most once from false -> true + closed bool + maxRAMUsageFraction float64 // fraction above which we skip in-memory buffering +} + +// ReadWriteSeekCloser is an io.Writer + io.Reader + io.Seeker + io.Closer. +type ReadWriteSeekCloser interface { + ReadSeekCloser + io.Writer +} + +// NewSpooledTempFile returns an ReadWriteSeekCloser, +// with some important constraints: +// - You can Write into it, but whenever you call Read or Seek on it, +// subsequent Write calls will panic. +// - If threshold is -1, then the default MaxInMemorySize is used. +// - If maxRAMUsageFraction <= 0, we default to DefaultMaxRAMUsageFraction. E.g. 0.5 = 50%. +// +// If the system memory usage is above maxRAMUsageFraction, we skip writing +// to memory and spool directly on disk to avoid OOM scenarios in high concurrency. +func NewSpooledTempFile( + filePrefix string, + tempDir string, + threshold int, + fullOnDisk bool, + maxRAMUsageFraction float64, +) ReadWriteSeekCloser { + if threshold < 0 { + threshold = MaxInMemorySize + } + if maxRAMUsageFraction <= 0 { + maxRAMUsageFraction = DefaultMaxRAMUsageFraction + } + + return &spooledTempFile{ + filePrefix: filePrefix, + tempDir: tempDir, + buf: spooledPool.Get().(*bytes.Buffer), + maxInMemorySize: threshold, + fullOnDisk: fullOnDisk, + maxRAMUsageFraction: maxRAMUsageFraction, + } +} + +func (s *spooledTempFile) prepareRead() error { + if s.closed { + return io.EOF + } + + if s.reading && (s.file != nil || s.buf == nil || s.mem != nil) { + return nil + } + + s.reading = true + if s.file != nil { + if _, err := s.file.Seek(0, 0); err != nil { + return fmt.Errorf("file=%v: %w", s.file, err) + } + return nil + } + + s.mem = bytes.NewReader(s.buf.Bytes()) + return nil +} + +func (s *spooledTempFile) Len() int { + if s.file != nil { + fi, err := s.file.Stat() + if err != nil { + return -1 + } + return int(fi.Size()) + } + return s.buf.Len() +} + +func (s *spooledTempFile) Read(p []byte) (n int, err error) { + if err := s.prepareRead(); err != nil { + return 0, err + } + + if s.file != nil { + return s.file.Read(p) + } + return s.mem.Read(p) +} + +func (s *spooledTempFile) ReadAt(p []byte, off int64) (n int, err error) { + if err := s.prepareRead(); err != nil { + return 0, err + } + + if s.file != nil { + return s.file.ReadAt(p, off) + } + return s.mem.ReadAt(p, off) +} + +func (s *spooledTempFile) Seek(offset int64, whence int) (int64, error) { + if err := s.prepareRead(); err != nil { + return 0, err + } + + if s.file != nil { + return s.file.Seek(offset, whence) + } + return s.mem.Seek(offset, whence) +} + +func (s *spooledTempFile) Write(p []byte) (n int, err error) { + if s.closed { + return 0, io.EOF + } + if s.reading { + panic("write after read") + } + + // If we already have a file open, we always write to disk. + if s.file != nil { + return s.file.Write(p) + } + + // Otherwise, check if system memory usage is above threshold + // or if we've exceeded our own in-memory limit, or if user forced on-disk. + aboveRAMThreshold := s.isSystemMemoryUsageHigh() + if aboveRAMThreshold || s.fullOnDisk || (s.buf.Len()+len(p) > s.maxInMemorySize) { + // Switch to file if we haven't already + s.file, err = os.CreateTemp(s.tempDir, s.filePrefix+"-") + if err != nil { + return 0, err + } + + // Copy what we already had in the buffer + _, err = io.Copy(s.file, s.buf) + if err != nil { + s.file.Close() + s.file = nil + return 0, err + } + + // Release the buffer + s.buf.Reset() + spooledPool.Put(s.buf) + s.buf = nil + + // Write incoming bytes directly to file + n, err = s.file.Write(p) + if err != nil { + s.file.Close() + s.file = nil + return n, err + } + return n, nil + } + + // Otherwise, stay in memory. + return s.buf.Write(p) +} + +func (s *spooledTempFile) Close() error { + s.closed = true + s.mem = nil + + if s.buf != nil { + s.buf.Reset() + spooledPool.Put(s.buf) + s.buf = nil + } + + if s.file == nil { + return nil + } + + s.file.Close() + + if err := os.Remove(s.file.Name()); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + s.file = nil + return nil +} + +func (s *spooledTempFile) FileName() string { + if s.file != nil { + return s.file.Name() + } + return "" +} + +// isSystemMemoryUsageHigh returns true if current memory usage +// exceeds s.maxRAMUsageFraction of total system memory. +// This implementation is Linux-specific via /proc/meminfo. +func (s *spooledTempFile) isSystemMemoryUsageHigh() bool { + usedFraction, err := getSystemMemoryUsedFraction() + if err != nil { + // If we fail to get memory usage info, we conservatively return false, + // or you may choose to return true to avoid in-memory usage. + return false + } + return usedFraction >= s.maxRAMUsageFraction +} + +// getSystemMemoryUsedFraction parses /proc/meminfo on Linux to figure out +// how much memory is used vs total. Returns fraction = used / total +// This is a Linux-specific implementation. +// This function is defined as a variable so it can be overridden in tests. +var getSystemMemoryUsedFraction = func() (float64, error) { + f, err := os.Open("/proc/meminfo") + if err != nil { + return 0, err + } + defer f.Close() + + // We look for MemTotal, MemAvailable (or MemFree if MemAvailable is missing) + var memTotal, memAvailable, memFree, buffers, cached uint64 + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + key := strings.TrimRight(fields[0], ":") + value, _ := strconv.ParseUint(fields[1], 10, 64) + // value is typically in kB + switch key { + case "MemTotal": + memTotal = value + case "MemAvailable": + memAvailable = value + case "MemFree": + memFree = value + case "Buffers": + buffers = value + case "Cached": + cached = value + } + } + + if memTotal == 0 { + return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo") + } + + // If MemAvailable is present (Linux 3.14+), we can directly use it: + if memAvailable > 0 { + used := memTotal - memAvailable + return float64(used) / float64(memTotal), nil + } + + // Otherwise, approximate "available" as free+buffers+cached + approxAvailable := memFree + buffers + cached + used := memTotal - approxAvailable + return float64(used) / float64(memTotal), nil +} diff --git a/spooled_test.go b/pkg/spooledtempfile/spooled_test.go similarity index 80% rename from spooled_test.go rename to pkg/spooledtempfile/spooled_test.go index 0e9ee37..888ed4f 100644 --- a/spooled_test.go +++ b/pkg/spooledtempfile/spooled_test.go @@ -1,4 +1,4 @@ -package warc +package spooledtempfile import ( "bytes" @@ -13,7 +13,7 @@ import ( // TestInMemoryBasic writes data below threshold and verifies it remains in memory. func TestInMemoryBasic(t *testing.T) { - spool := NewSpooledTempFile("test", os.TempDir(), 100, false) + spool := NewSpooledTempFile("test", os.TempDir(), 100, false, -1) defer spool.Close() // Write data smaller than threshold @@ -62,7 +62,7 @@ func TestInMemoryBasic(t *testing.T) { // TestThresholdCrossing writes enough data to switch from in-memory to disk. func TestThresholdCrossing(t *testing.T) { - spool := NewSpooledTempFile("test", os.TempDir(), 10, false) + spool := NewSpooledTempFile("test", os.TempDir(), 10, false, -1) defer spool.Close() data1 := []byte("12345") @@ -105,7 +105,7 @@ func TestThresholdCrossing(t *testing.T) { // TestForceOnDisk checks the fullOnDisk parameter. func TestForceOnDisk(t *testing.T) { - spool := NewSpooledTempFile("test", os.TempDir(), 1000000, true) + spool := NewSpooledTempFile("test", os.TempDir(), 1000000, true, -1) defer spool.Close() input := []byte("force to disk") @@ -129,7 +129,7 @@ func TestForceOnDisk(t *testing.T) { // TestReadAtAndSeekInMemory tests seeking and ReadAt on an in-memory spool. func TestReadAtAndSeekInMemory(t *testing.T) { - spool := NewSpooledTempFile("test", "", 100, false) + spool := NewSpooledTempFile("test", "", 100, false, -1) defer spool.Close() data := []byte("HelloWorld123") @@ -173,7 +173,7 @@ func TestReadAtAndSeekInMemory(t *testing.T) { // TestReadAtAndSeekOnDisk tests seeking and ReadAt on a spool that has switched to disk. func TestReadAtAndSeekOnDisk(t *testing.T) { - spool := NewSpooledTempFile("test", "", 10, false) + spool := NewSpooledTempFile("test", "", 10, false, -1) defer spool.Close() data := []byte("HelloWorld123") @@ -204,7 +204,7 @@ func TestReadAtAndSeekOnDisk(t *testing.T) { // TestWriteAfterReadPanic ensures writing after reading panics per your design. func TestWriteAfterReadPanic(t *testing.T) { - spool := NewSpooledTempFile("test", "", 100, false) + spool := NewSpooledTempFile("test", "", 100, false, -1) defer spool.Close() _, err := spool.Write([]byte("ABCDEFG")) @@ -237,7 +237,7 @@ func TestWriteAfterReadPanic(t *testing.T) { // TestCloseInMemory checks closing while still in-memory. func TestCloseInMemory(t *testing.T) { - spool := NewSpooledTempFile("test", "", 100, false) + spool := NewSpooledTempFile("test", "", 100, false, -1) _, err := spool.Write([]byte("Small data")) if err != nil { @@ -267,7 +267,7 @@ func TestCloseInMemory(t *testing.T) { // TestCloseOnDisk checks closing after spool has switched to disk. func TestCloseOnDisk(t *testing.T) { - spool := NewSpooledTempFile("test", "", 10, false) + spool := NewSpooledTempFile("test", "", 10, false, -1) _, err := spool.Write([]byte("1234567890ABC")) if err != nil { @@ -305,7 +305,7 @@ func TestCloseOnDisk(t *testing.T) { // TestLen verifies Len() for both in-memory and on-disk states. func TestLen(t *testing.T) { - spool := NewSpooledTempFile("test", "", 5, false) + spool := NewSpooledTempFile("test", "", 5, false, -1) defer spool.Close() data := []byte("1234") @@ -329,7 +329,7 @@ func TestLen(t *testing.T) { // TestFileName checks correctness of FileName in both modes. func TestFileName(t *testing.T) { - spool := NewSpooledTempFile("testprefix", os.TempDir(), 5, false) + spool := NewSpooledTempFile("testprefix", os.TempDir(), 5, false, -1) defer spool.Close() if spool.FileName() != "" { @@ -353,3 +353,49 @@ func TestFileName(t *testing.T) { t.Errorf("Expected file name prefix 'testprefix', got %s", base) } } + +// TestSkipInMemoryAboveRAMUsage verifies that if `isSystemMemoryUsageHigh()` +// returns true, the spool goes directly to disk even for small writes. +func TestSkipInMemoryAboveRAMUsage(t *testing.T) { + // Save the old function so we can restore it later + oldGetSystemMemoryUsedFraction := getSystemMemoryUsedFraction + // Force system memory usage to appear above 50% + getSystemMemoryUsedFraction = func() (float64, error) { + return 0.60, nil // 60% used => above the 50% threshold + } + // Restore after test + defer func() { + getSystemMemoryUsedFraction = oldGetSystemMemoryUsedFraction + }() + + // Even though threshold is large (e.g. 1MB), because our mock usage is 60%, + // spool should skip memory and go straight to disk. + spool := NewSpooledTempFile("testram", os.TempDir(), 1024*1024, false, 0.50) + defer spool.Close() + + // Write a small amount of data + data := []byte("This is a small test") + n, err := spool.Write(data) + if err != nil { + t.Fatalf("Write error: %v", err) + } + if n != len(data) { + t.Errorf("Write count mismatch: got %d, want %d", n, len(data)) + } + + // Because memory usage was deemed “too high” from the start, + // we should already be on disk + fn := spool.FileName() + if fn == "" { + t.Fatalf("Expected spool to be on disk, but FileName() was empty") + } + + // Verify data can be read back + out, err := io.ReadAll(spool) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(out) != string(data) { + t.Errorf("Data mismatch. Got %q, want %q", out, data) + } +} diff --git a/read.go b/read.go index 6b8193d..2c85537 100644 --- a/read.go +++ b/read.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "strconv" + + "github.com/CorentinB/warc/pkg/spooledtempfile" ) // Reader store the bufio.Reader and gzip.Reader for a WARC file @@ -92,7 +94,7 @@ func (r *Reader) ReadRecord() (*Record, bool, error) { } // reading doesn't really need to be in TempDir, nor can we access it as it's on the client. - buf := NewSpooledTempFile("warc", "", -1, false) + buf := spooledtempfile.NewSpooledTempFile("warc", "", -1, false, -1) _, err = io.CopyN(buf, tempReader, length) if err != nil { return nil, false, fmt.Errorf("copying record content: %w", err) diff --git a/spooled.go b/spooled.go deleted file mode 100644 index 8b8c2fc..0000000 --- a/spooled.go +++ /dev/null @@ -1,208 +0,0 @@ -package warc - -import ( - "bytes" - "fmt" - "io" - "os" - "sync" -) - -// MaxInMemorySize is the max number of bytes (currently 1MB) -// to hold in memory before starting to write to disk -const MaxInMemorySize = 1000000 - -var spooledPool = sync.Pool{ - New: func() interface{} { - return bytes.NewBuffer(nil) - }, -} - -// ReaderAt is the interface for ReadAt - read at position, without moving pointer. -type ReaderAt interface { - ReadAt(p []byte, off int64) (n int, err error) -} - -// ReadSeekCloser is an io.Reader + ReaderAt + io.Seeker + io.Closer + Stat -type ReadSeekCloser interface { - io.Reader - io.Seeker - ReaderAt - io.Closer - FileName() string - Len() int -} - -// spooledTempFile writes to memory (or to disk if -// over MaxInMemorySize) and deletes the file on Close -type spooledTempFile struct { - buf *bytes.Buffer - mem *bytes.Reader - file *os.File - filePrefix string - tempDir string - maxInMemorySize int - fullOnDisk bool - reading bool - closed bool - manager *SpoolManager -} - -// ReadWriteSeekCloser is an io.Writer + io.Reader + io.Seeker + io.Closer. -type ReadWriteSeekCloser interface { - ReadSeekCloser - io.Writer -} - -// NewSpooledTempFile returns an ReadWriteSeekCloser. -// If threshold is -1, then the default MaxInMemorySize is used. -func NewSpooledTempFile(filePrefix string, tempDir string, threshold int, fullOnDisk bool) ReadWriteSeekCloser { - if threshold < 0 { - threshold = MaxInMemorySize - } - - s := &spooledTempFile{ - filePrefix: filePrefix, - tempDir: tempDir, - buf: spooledPool.Get().(*bytes.Buffer), - maxInMemorySize: threshold, - fullOnDisk: fullOnDisk, - manager: DefaultSpoolManager, - } - - s.manager.RegisterSpool(s) - - return s -} - -func (s *spooledTempFile) prepareRead() error { - if s.closed { - return io.EOF - } - if s.reading && (s.file != nil || s.buf == nil || s.mem != nil) { - return nil - } - s.reading = true - if s.file != nil { - if _, err := s.file.Seek(0, 0); err != nil { - return fmt.Errorf("file=%v: %w", s.file, err) - } - return nil - } - s.mem = bytes.NewReader(s.buf.Bytes()) - return nil -} - -func (s *spooledTempFile) Len() int { - if s.file != nil { - fi, err := s.file.Stat() - if err != nil { - return -1 - } - return int(fi.Size()) - } - return s.buf.Len() -} - -func (s *spooledTempFile) Read(p []byte) (n int, err error) { - if err := s.prepareRead(); err != nil { - return 0, err - } - if s.file != nil { - return s.file.Read(p) - } - return s.mem.Read(p) -} - -func (s *spooledTempFile) ReadAt(p []byte, off int64) (n int, err error) { - if err := s.prepareRead(); err != nil { - return 0, err - } - if s.file != nil { - return s.file.ReadAt(p, off) - } - return s.mem.ReadAt(p, off) -} - -func (s *spooledTempFile) Seek(offset int64, whence int) (int64, error) { - if err := s.prepareRead(); err != nil { - return 0, err - } - if s.file != nil { - return s.file.Seek(offset, whence) - } - return s.mem.Seek(offset, whence) -} - -func (s *spooledTempFile) Write(p []byte) (n int, err error) { - if s.closed { - return 0, io.EOF - } - if s.reading { - panic("write after read") - } - if s.file != nil { - return s.file.Write(p) - } - proposedSize := s.buf.Len() + len(p) - if s.fullOnDisk || - proposedSize > s.maxInMemorySize { - if err := s.switchToFile(); err != nil { - return 0, err - } - return s.file.Write(p) - } - s.manager.AddBytes(len(p)) - return s.buf.Write(p) -} - -func (s *spooledTempFile) switchToFile() error { - f, err := os.CreateTemp(s.tempDir, s.filePrefix+"-") - if err != nil { - return err - } - if _, err = io.Copy(f, s.buf); err != nil { - f.Close() - return err - } - s.manager.SubBytes(s.buf.Len()) - s.buf.Reset() - spooledPool.Put(s.buf) - s.buf = nil - s.file = f - return nil -} - -func (s *spooledTempFile) forceToDiskIfInMemory() { - if s.file == nil && !s.closed { - _ = s.switchToFile() - } -} - -func (s *spooledTempFile) Close() error { - if s.closed { - return nil - } - s.closed = true - s.mem = nil - if s.buf != nil { - s.manager.SubBytes(s.buf.Len()) - s.buf.Reset() - spooledPool.Put(s.buf) - s.buf = nil - } - if s.file != nil { - s.file.Close() - os.Remove(s.file.Name()) - s.file = nil - } - s.manager.UnregisterSpool(s) - return nil -} - -func (s *spooledTempFile) FileName() string { - if s.file != nil { - return s.file.Name() - } - return "" -} diff --git a/spooledmanager.go b/spooledmanager.go deleted file mode 100644 index 470545a..0000000 --- a/spooledmanager.go +++ /dev/null @@ -1,131 +0,0 @@ -package warc - -import ( - "container/heap" - "sync" - "time" - - "golang.org/x/sys/unix" -) - -// SpoolManager enforces a global memory limit, tracks spoolers, and handles eviction. -type SpoolManager struct { - mu sync.Mutex - spoolers spoolHeap - spoolerIndex map[*spooledTempFile]*spoolItem - currentMemUsage int64 - GlobalMemoryLimit int64 -} - -type spoolItem struct { - s *spooledTempFile - priority time.Time // used to determine which spooler is oldest (min-heap) - index int // heap interface requirement -} - -type spoolHeap []*spoolItem - -func (h spoolHeap) Len() int { return len(h) } - -func (h spoolHeap) Less(i, j int) bool { - return h[i].priority.Before(h[j].priority) -} - -func (h spoolHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] - h[i].index = i - h[j].index = j -} - -func (h *spoolHeap) Push(x interface{}) { - item := x.(*spoolItem) - item.index = len(*h) - *h = append(*h, item) -} - -func (h *spoolHeap) Pop() interface{} { - old := *h - n := len(old) - item := old[n-1] - *h = old[0 : n-1] - return item -} - -// DefaultSpoolManager is the global manager. Adjust limit as desired. -var DefaultSpoolManager = NewSpoolManager(getHalfOfAvailableRAM()) - -func NewSpoolManager(limit int64) *SpoolManager { - m := &SpoolManager{ - GlobalMemoryLimit: limit, - spoolerIndex: make(map[*spooledTempFile]*spoolItem), - } - heap.Init(&m.spoolers) - return m -} - -func (m *SpoolManager) RegisterSpool(s *spooledTempFile) { - m.mu.Lock() - defer m.mu.Unlock() - item := &spoolItem{ - s: s, - priority: time.Now(), - } - m.spoolerIndex[s] = item - heap.Push(&m.spoolers, item) -} - -func (m *SpoolManager) UnregisterSpool(s *spooledTempFile) { - m.mu.Lock() - defer m.mu.Unlock() - item, ok := m.spoolerIndex[s] - if !ok { - return - } - delete(m.spoolerIndex, s) - heap.Remove(&m.spoolers, item.index) -} - -func (m *SpoolManager) CanAddBytes(n int) bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.currentMemUsage+int64(n) <= m.GlobalMemoryLimit -} - -func (m *SpoolManager) AddBytes(n int) { - m.mu.Lock() - m.currentMemUsage += int64(n) - m.mu.Unlock() -} - -func (m *SpoolManager) SubBytes(n int) { - m.mu.Lock() - m.currentMemUsage -= int64(n) - if m.currentMemUsage < 0 { - m.currentMemUsage = 0 - } - m.mu.Unlock() -} - -func (m *SpoolManager) EvictIfNeeded() { - m.mu.Lock() - defer m.mu.Unlock() - - for m.currentMemUsage > m.GlobalMemoryLimit && len(m.spoolers) > 0 { - item := m.spoolers[0] - if item.s.file == nil && !item.s.closed { - item.s.forceToDiskIfInMemory() - } else { - // If it's already on disk or closed, pop it to avoid looping - heap.Remove(&m.spoolers, item.index) - delete(m.spoolerIndex, item.s) - } - } -} - -func getHalfOfAvailableRAM() int64 { - var info unix.Sysinfo_t - if err := unix.Sysinfo(&info); err != nil { - panic(err) - } - return int64(info.Totalram) / 2 -} diff --git a/utils.go b/utils.go index 10924e2..433eeca 100644 --- a/utils.go +++ b/utils.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "github.com/CorentinB/warc/pkg/spooledtempfile" gzip "github.com/klauspost/compress/gzip" "github.com/klauspost/compress/zstd" @@ -193,7 +194,7 @@ func NewWriter(writer io.Writer, fileName string, compression string, contentLen func NewRecord(tempDir string, fullOnDisk bool) *Record { return &Record{ Header: NewHeader(), - Content: NewSpooledTempFile("warc", tempDir, -1, fullOnDisk), + Content: spooledtempfile.NewSpooledTempFile("warc", tempDir, -1, fullOnDisk, -1), } } @@ -337,7 +338,7 @@ func GenerateWarcFileName(prefix string, compression string, atomicSerial *int64 return prefix + "-" + date + "-" + formattedSerial + "-" + hostName + ".warc.open" } -func getContentLength(rwsc ReadWriteSeekCloser) int { +func getContentLength(rwsc spooledtempfile.ReadWriteSeekCloser) int { // If the FileName leads to no existing file, it means that the SpooledTempFile // never had the chance to buffer to disk instead of memory, in which case we can // just read the buffer (which should be <= 2MB) and return the length diff --git a/write.go b/write.go index 95862d5..86c86d6 100644 --- a/write.go +++ b/write.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/CorentinB/warc/pkg/spooledtempfile" "github.com/klauspost/compress/gzip" "github.com/google/uuid" @@ -36,7 +37,7 @@ type RecordBatch struct { // Record represents a WARC record. type Record struct { Header Header - Content ReadWriteSeekCloser + Content spooledtempfile.ReadWriteSeekCloser Version string // WARC/1.0, WARC/1.1 ... }