Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions internal/chunk/backend/dbase/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"fmt"
"iter"
"log/slog"
"os"
"path/filepath"
"runtime/trace"
"time"

Expand All @@ -38,15 +40,51 @@ import (

const preallocSz = 100 // preallocate slice size

// DefaultDBFile is the default database filename used when a directory
// is passed instead of a file path.
const DefaultDBFile = "slackdump.sqlite"

type Source struct {
conn *sqlx.DB
// canClose set to false when the connection is passed to the source
// and should not be closed by the source.
canClose bool
}

// ErrIsDirectory is returned when a directory path is passed instead of
// a database file.
var ErrIsDirectory = fmt.Errorf("path is a directory")

// validateDBPath checks if path is a directory and returns a helpful error
// with a suggestion if the expected database file exists inside.
func validateDBPath(path string) error {
// Check if path is a symlink and warn about it.
li, err := os.Lstat(path)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to stat path, continuing", "path", path, "error", err)
} else if err == nil && li.Mode()&os.ModeSymlink != 0 {
slog.Warn("database path is a symlink, following it to the target", "path", path)
}
fi, err := os.Stat(path)
if err != nil {
// Non-existent paths are allowed (for creating new databases).
return nil
}
if fi.IsDir() {
dbFile := filepath.Join(path, DefaultDBFile)
if _, err := os.Stat(dbFile); err == nil {
return fmt.Errorf("%w: %s (did you mean %q?)", ErrIsDirectory, path, dbFile)
}
return fmt.Errorf("%w: %s (no %s found inside)", ErrIsDirectory, path, DefaultDBFile)
}
return nil
}

// Open attempts to open the database at given path for reading.
func Open(ctx context.Context, path string) (*Source, error) {
if err := validateDBPath(path); err != nil {
return nil, err
}
// migrate to the latest
if err := migrate(ctx, path); err != nil {
return nil, err
Expand All @@ -64,6 +102,9 @@ func Open(ctx context.Context, path string) (*Source, error) {
// OpenRW attempts to open the database at given path for reading and writing.
// Use [Open] when only read access is needed.
func OpenRW(ctx context.Context, path string) (*RWSource, error) {
if err := validateDBPath(path); err != nil {
return nil, err
}
if err := migrate(ctx, path); err != nil {
return nil, err
}
Expand Down
123 changes: 121 additions & 2 deletions internal/chunk/backend/dbase/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"errors"
"os"
"path/filepath"
"reflect"
"sort"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/jmoiron/sqlx"
"github.com/rusq/slack"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"

"github.com/rusq/slackdump/v4/internal/fixtures"
Expand All @@ -49,6 +51,7 @@ func TestOpen(t *testing.T) {
name string
args args
checkFn utilityFunc
fn any
wantErr bool
}{
{
Expand All @@ -60,22 +63,138 @@ func TestOpen(t *testing.T) {
checkFn: checkGooseTable,
wantErr: false,
},
{
name: "rejects directory",
args: args{
ctx: context.Background(),
path: t.TempDir(),
},
wantErr: true,
},
{
name: "rejects directory with OpenRW",
args: args{
ctx: context.Background(),
path: t.TempDir(),
},
wantErr: true,
fn: OpenRW,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Open(tt.args.ctx, tt.args.path)
var got *Source
var err error
if tt.fn != nil {
switch fn := tt.fn.(type) {
case func(context.Context, string) (*Source, error):
got, err = fn(tt.args.ctx, tt.args.path)
case func(context.Context, string) (*RWSource, error):
rw, err2 := fn(tt.args.ctx, tt.args.path)
if err2 == nil {
rw.Close()
}
err = err2
default:
t.Fatalf("unsupported fn type %T", tt.fn)
}
} else {
got, err = Open(tt.args.ctx, tt.args.path)
}
if (err != nil) != tt.wantErr {
t.Errorf("Open() error = %v, wantErr %v", err, tt.wantErr)
return
}
defer got.Close()
if got != nil {
defer got.Close()
}
if tt.checkFn != nil {
tt.checkFn(t, testutil.TestDBDSN(t, tt.args.path))
}
})
}
}

func Test_validateDBPath(t *testing.T) {
dir := t.TempDir()

regularFile := filepath.Join(dir, "regular.db")
require.NoError(t, os.WriteFile(regularFile, nil, 0644))

dbDir := t.TempDir()
dbFile := filepath.Join(dbDir, "slackdump.sqlite")
require.NoError(t, os.WriteFile(dbFile, nil, 0644))

emptyDir := t.TempDir()

// Create a symlink to a regular file
symlinkFile := filepath.Join(dir, "symlink.db")
require.NoError(t, os.Symlink(regularFile, symlinkFile))

// Create a symlink to a directory
symlinkDir := filepath.Join(dir, "symlink_dir")
require.NoError(t, os.Symlink(dbDir, symlinkDir))

tests := []struct {
name string
path string
wantErr bool
wantErrIs error
errContain string
}{
{
name: "non-existent path is allowed",
path: filepath.Join(dir, "nonexistent.db"),
wantErr: false,
},
{
name: "directory without slackdump.sqlite",
path: emptyDir,
wantErr: true,
wantErrIs: ErrIsDirectory,
errContain: "no slackdump.sqlite found inside",
},
{
name: "directory with slackdump.sqlite suggests correct path",
path: dbDir,
wantErr: true,
wantErrIs: ErrIsDirectory,
errContain: "did you mean",
},
{
name: "regular file is allowed",
path: regularFile,
wantErr: false,
},
{
name: "symlink to file is allowed (logs warning)",
path: symlinkFile,
wantErr: false,
},
{
name: "symlink to directory follows target and rejects",
path: symlinkDir,
wantErr: true,
wantErrIs: ErrIsDirectory,
errContain: "did you mean",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateDBPath(tt.path)
if tt.wantErr {
require.Error(t, err)
require.ErrorIs(t, err, tt.wantErrIs)
if tt.errContain != "" {
require.Contains(t, err.Error(), tt.errContain)
}
} else {
require.NoError(t, err)
}
})
}
}

func TestSource_Close(t *testing.T) {
type fields struct {
conn *sqlx.DB
Expand Down