Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions cv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ cv::Mat imreadRGB(const std::string &filename){
return cImg;
}

cv::Mat imreadMask(const std::string &filename){
cv::Mat mask = cv::imread(filename, cv::IMREAD_GRAYSCALE);
if (mask.empty()){
std::cerr << "Cannot read mask " << filename << std::endl;
exit(1);
}
return mask;
}

void imwriteRGB(const std::string &filename, const cv::Mat &image){
cv::Mat rgb;
cv::cvtColor(image, rgb, cv::COLOR_RGB2BGR);
Expand Down Expand Up @@ -48,3 +57,21 @@ torch::Tensor imageToTensor(const cv::Mat &image){
return (img.toType(torch::kFloat32) / 255.0f);
}

torch::Tensor maskToTensor(const cv::Mat &mask){
torch::Tensor m = torch::from_blob(mask.data, { mask.rows, mask.cols, 1 }, torch::kU8);
// Binary mask: threshold at 128, output 0.0 or 1.0
return (m.toType(torch::kFloat32) / 255.0f).ge(0.5f).toType(torch::kFloat32);
}

cv::Mat tensorToMask(const torch::Tensor &t){
int h = t.sizes()[0];
int w = t.sizes()[1];

cv::Mat mask(h, w, CV_8UC1);
torch::Tensor scaledTensor = (t.squeeze() * 255.0).toType(torch::kU8);
uint8_t* dataPtr = static_cast<uint8_t*>(scaledTensor.data_ptr());
std::copy(dataPtr, dataPtr + (w * h), mask.data);

return mask;
}

3 changes: 3 additions & 0 deletions cv_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include <opencv2/imgproc.hpp>

cv::Mat imreadRGB(const std::string &filename);
cv::Mat imreadMask(const std::string &filename);
void imwriteRGB(const std::string &filename, const cv::Mat &image);
cv::Mat floatNxNtensorToMat(const torch::Tensor &t);
torch::Tensor floatNxNMatToTensor(const cv::Mat &m);
cv::Mat tensorToImage(const torch::Tensor &t);
torch::Tensor imageToTensor(const cv::Mat &image);
torch::Tensor maskToTensor(const cv::Mat &mask);
cv::Mat tensorToMask(const torch::Tensor &t);

#endif
33 changes: 33 additions & 0 deletions input_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ void Camera::loadImage(float downscaleFactor){
fy = K[1][1].item<float>();
cx = K[0][2].item<float>();
cy = K[1][2].item<float>();

// Load mask if path is set
if (!maskPath.empty()){
cv::Mat cMask = imreadMask(maskPath);

// Resize mask to match image dimensions
if (cMask.rows != height || cMask.cols != width){
cv::resize(cMask, cMask, cv::Size(width, height), 0.0, 0.0, cv::INTER_NEAREST);
}

mask = maskToTensor(cMask);
}
}

torch::Tensor Camera::getImage(int downscaleFactor){
Expand All @@ -116,6 +128,27 @@ torch::Tensor Camera::getImage(int downscaleFactor){
}
}

torch::Tensor Camera::getMask(int downscaleFactor){
if (!mask.numel()) return torch::Tensor(); // No mask available

if (downscaleFactor <= 1) return mask;

if (maskPyramids.find(downscaleFactor) != maskPyramids.end()){
return maskPyramids[downscaleFactor];
}

// Rescale using nearest neighbor (preserve binary values)
cv::Mat cMask = tensorToMask(mask);
cv::resize(cMask, cMask, cv::Size(cMask.cols / downscaleFactor, cMask.rows / downscaleFactor), 0.0, 0.0, cv::INTER_NEAREST);
torch::Tensor t = maskToTensor(cMask);
maskPyramids[downscaleFactor] = t;
return t;
}

bool Camera::hasMask() const {
return mask.numel() > 0;
}

bool Camera::hasDistortionParameters(){
return k1 != 0.0f || k2 != 0.0f || k3 != 0.0f || p1 != 0.0f || p2 != 0.0f;
}
Expand Down
5 changes: 5 additions & 0 deletions input_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct Camera{
float p2 = 0;
torch::Tensor camToWorld;
std::string filePath = "";
std::string maskPath = ""; // Optional path to mask file
CameraType cameraType = CameraType::Perspective;

Camera(){};
Expand All @@ -37,12 +38,16 @@ struct Camera{
bool hasDistortionParameters();
std::vector<float> undistortionParameters();
torch::Tensor getImage(int downscaleFactor);
torch::Tensor getMask(int downscaleFactor);
bool hasMask() const;

void loadImage(float downscaleFactor);
torch::Tensor K;
torch::Tensor image;
torch::Tensor mask; // Optional mask tensor [H,W,1], 0=exclude, 1=include

std::unordered_map<int, torch::Tensor> imagePyramids;
std::unordered_map<int, torch::Tensor> maskPyramids; // Cached downscaled masks
};

struct Points{
Expand Down
18 changes: 13 additions & 5 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,16 @@ torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt){
return (10.f * torch::log10(1.0 / mse));
}

torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt){
return torch::abs(gt - rendered).mean();
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& mask){
torch::Tensor diff = torch::abs(gt - rendered);
if (mask.numel() > 0){
// Expand mask from [H,W,1] to [H,W,3] for broadcasting
torch::Tensor expandedMask = mask.expand_as(diff);
// Masked mean: sum of masked values / count of masked pixels
torch::Tensor maskedDiff = diff * expandedMask;
return maskedDiff.sum() / (expandedMask.sum() + 1e-8f);
}
return diff.mean();
}

void Model::setupOptimizers(){
Expand Down Expand Up @@ -777,8 +785,8 @@ int Model::loadPly(const std::string &filename){
throw std::runtime_error("Invalid PLY file");
}

torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight){
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt);
torch::Tensor l1Loss = l1(rgb, gt);
torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight, const torch::Tensor &mask){
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt, mask);
torch::Tensor l1Loss = l1(rgb, gt, mask);
return (1.0f - ssimWeight) * l1Loss + ssimWeight * ssimLoss;
}
4 changes: 2 additions & 2 deletions model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace torch::autograd;
torch::Tensor randomQuatTensor(long long n);
torch::Tensor projectionMatrix(float zNear, float zFar, float fovX, float fovY, const torch::Device &device);
torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt);
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt);
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& mask = torch::Tensor());

struct Model{
Model(const InputData &inputData, int numCameras,
Expand Down Expand Up @@ -74,7 +74,7 @@ struct Model{
void saveSplat(const std::string &filename);
void saveDebugPly(const std::string &filename, int step);
int loadPly(const std::string &filename);
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight);
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight, const torch::Tensor &mask = torch::Tensor());

void addToOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples);
void removeFromOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &deletedMask);
Expand Down
43 changes: 41 additions & 2 deletions opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ int main(int argc, char *argv[]){
("stop-screen-size-at", "Stop splitting gaussians that are larger than [split-screen-size] after these many steps", cxxopts::value<int>()->default_value("4000"))
("split-screen-size", "Split gaussians that are larger than this percentage of screen space", cxxopts::value<float>()->default_value("0.05"))
("colmap-image-path", "Override the default image path for COLMAP-based input", cxxopts::value<std::string>()->default_value(""))
("mask-dir", "Path to directory containing mask images (binary: 0=exclude, 1=include)", cxxopts::value<std::string>()->default_value(""))
#ifdef USE_VISUALIZATION
("has-visualization", "Show the visualization steps of training", cxxopts::value<bool>()->default_value("0"))
#endif
Expand Down Expand Up @@ -95,6 +96,7 @@ int main(int argc, char *argv[]){
const int stopScreenSizeAt = result["stop-screen-size-at"].as<int>();
const float splitScreenSize = result["split-screen-size"].as<float>();
const std::string colmapImageSourcePath = result["colmap-image-path"].as<std::string>();
const std::string maskDir = result["mask-dir"].as<std::string>();
#ifdef USE_VISUALIZATION
const bool hasVisualization = result["has-visualization"].as<bool>();
#endif
Expand All @@ -121,6 +123,33 @@ int main(int argc, char *argv[]){
try{
InputData inputData = inputDataFromX(projectRoot, colmapImageSourcePath);

// Set mask paths if mask directory is provided
if (!maskDir.empty()){
fs::path maskDirPath(maskDir);
if (!fs::exists(maskDirPath)){
std::cerr << "Mask directory does not exist: " << maskDir << std::endl;
exit(1);
}

for (Camera &cam : inputData.cameras){
fs::path imagePath(cam.filePath);
std::string imageName = imagePath.stem().string();

// Try common mask extensions
for (const std::string &ext : {".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"}){
fs::path maskPath = maskDirPath / (imageName + ext);
if (fs::exists(maskPath)){
cam.maskPath = maskPath.string();
break;
}
}

if (cam.maskPath.empty()){
std::cerr << "Warning: No mask found for " << cam.filePath << std::endl;
}
}
}

parallel_for(inputData.cameras.begin(), inputData.cameras.end(), [&downScaleFactor](Camera &cam){
cam.loadImage(downScaleFactor);
});
Expand Down Expand Up @@ -157,7 +186,13 @@ int main(int argc, char *argv[]){
torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step));
gt = gt.to(device);

torch::Tensor mainLoss = model.mainLoss(rgb, gt, ssimWeight);
torch::Tensor mask;
if (cam.hasMask()){
mask = cam.getMask(model.getDownscaleFactor(step));
mask = mask.to(device);
}

torch::Tensor mainLoss = model.mainLoss(rgb, gt, ssimWeight, mask);
mainLoss.backward();

if (step % displayStep == 0) {
Expand Down Expand Up @@ -203,7 +238,11 @@ int main(int argc, char *argv[]){
if (valCam != nullptr){
torch::Tensor rgb = model.forward(*valCam, numIters);
torch::Tensor gt = valCam->getImage(model.getDownscaleFactor(numIters)).to(device);
std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, gt, ssimWeight).item<float>() << std::endl;
torch::Tensor valMask;
if (valCam->hasMask()){
valMask = valCam->getMask(model.getDownscaleFactor(numIters)).to(device);
}
std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, gt, ssimWeight, valMask).item<float>() << std::endl;
}
}catch(const std::exception &e){
std::cerr << e.what() << std::endl;
Expand Down
15 changes: 12 additions & 3 deletions ssim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

using namespace torch::indexing;

torch::Tensor SSIM::eval(const torch::Tensor& rendered, const torch::Tensor& gt) {
torch::Tensor SSIM::eval(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& mask) {
torch::Tensor img1 = gt.permute({2, 0, 1}).index({None, "..."});
torch::Tensor img2 = rendered.permute({2, 0, 1}).index({None, "..."});

if (img1.device() != window.device()){
window = window.to(img1.device());
}
Expand All @@ -22,12 +22,21 @@ torch::Tensor SSIM::eval(const torch::Tensor& rendered, const torch::Tensor& gt)
torch::Tensor sigma1Sq = torch::nn::functional::conv2d(img1 * img1, window, torch::nn::functional::Conv2dFuncOptions().padding(windowSize / 2).groups(channel)) - mu1Sq;
torch::Tensor sigma2Sq = torch::nn::functional::conv2d(img2 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(windowSize / 2).groups(channel)) - mu2Sq;
torch::Tensor sigma12 = torch::nn::functional::conv2d(img1 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(windowSize / 2).groups(channel)) - mu1mu2;

const float C1 = 0.01 * 0.01;
const float C2 = 0.03 * 0.03;

torch::Tensor ssimMap = ((2.0f * mu1mu2 + C1) * (2.0f * sigma12 + C2)) / ((mu1Sq + mu2Sq + C1) * (sigma1Sq + sigma2Sq + C2));

if (mask.numel() > 0){
// ssimMap is [1, C, H, W], mask is [H, W, 1]
// Permute mask to [1, 1, H, W] and expand to match ssimMap channels
torch::Tensor ssimMask = mask.permute({2, 0, 1}).index({None, "..."}); // [1, 1, H, W]
ssimMask = ssimMask.expand_as(ssimMap); // [1, C, H, W]
torch::Tensor maskedSsim = ssimMap * ssimMask;
return maskedSsim.sum() / (ssimMask.sum() + 1e-8f);
}

return ssimMap.mean();
}

Expand Down
2 changes: 1 addition & 1 deletion ssim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SSIM{
window = createWindow();
};

torch::Tensor eval(const torch::Tensor& rendered, const torch::Tensor& gt);
torch::Tensor eval(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& mask = torch::Tensor());
private:
torch::Tensor createWindow();
torch::Tensor gaussian(float sigma);
Expand Down
Loading