Skip to content
Merged

Next #93

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4da98e5
add sync batchnorm
harisreedhar Jul 5, 2025
f6b69df
replace random.choice with hash
harisreedhar Jul 5, 2025
252a2db
fifty percent reduction
harisreedhar Jul 5, 2025
8ddc095
fix discriminator input
harisreedhar Jul 5, 2025
26427c7
restore dataset.py
harisreedhar Jul 5, 2025
bdefe92
Remove duplicates
henryruhs Jul 5, 2025
d3334ee
add discriminator_ratio to config
harisreedhar Jul 7, 2025
ef3560d
fix onnx export bug: replace round() with int()
harisreedhar Jul 8, 2025
fbb50c3
Fix embedding naming
henryruhs Jul 8, 2025
33c5e0b
Introduce ModelWithConfigCheckpoint callback (#86)
henryruhs Jul 25, 2025
f26961e
Fix dist ini
henryruhs Jul 27, 2025
04a509c
Style: Refactor typing and improve code clarity in training.py (#88)
NeuroDonu Jul 29, 2025
f228f5a
Add type casting for trainer params
henryruhs Jul 29, 2025
a1fb0e3
Add type casting for trainer params
henryruhs Jul 29, 2025
7e36ea6
Add type casting for trainer params
henryruhs Jul 29, 2025
6eb0d16
Remove inplace activations for torch.compile compatibility (#89)
NeuroDonu Jul 31, 2025
1659ebf
Fix README
henryruhs Aug 7, 2025
d9212ea
improvise with norm layers & weighted average
harisreedhar Aug 19, 2025
f51e1ca
add skip layer
harisreedhar Aug 21, 2025
9ca46bc
use gelu instead of leaky_relu
harisreedhar Aug 21, 2025
6d08153
cleanup
harisreedhar Aug 21, 2025
833683b
cleanup
harisreedhar Aug 21, 2025
db59423
Merge pull request #92 from facefusion/improvements/crossface_norm
harisreedhar Aug 21, 2025
8edbb02
Update dependencies
henryruhs Sep 5, 2025
e74ab0e
Different defaults and enable validation
henryruhs Sep 6, 2025
13cdcd1
Different defaults and enable validation
henryruhs Sep 6, 2025
d78ad69
Revert to higher batch size
henryruhs Sep 6, 2025
3f0f00e
Just use copy over copy2
henryruhs Sep 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crossface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ file_pattern = .datasets/megaface/**/*.jpg

```
[training.loader]
batch_size = 256
batch_size = 128
num_workers = 8
split_ratio = 0.95
```
Expand Down Expand Up @@ -90,7 +90,7 @@ python train.py
Launch the TensorBoard to monitor the training.

```
tensorboard --logdir=.logs
tensorboard --logdir .logs
```


Expand Down
2 changes: 1 addition & 1 deletion crossface/src/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def export() -> None:
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')

os.makedirs(config_directory_path, exist_ok = True)
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, map_location ='cpu').eval()
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
model.ir_version = torch.tensor(config_ir_version)
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
33 changes: 21 additions & 12 deletions crossface/src/models/crossface.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
import torch
from torch import Tensor, nn


class CrossFace(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = self.create_layers()
self.leaky_relu = nn.LeakyReLU()
self.sequence = self.create_sequence()
self.linear = nn.Linear(512, 512)
self.apply(init_weight)

@staticmethod
def create_layers() -> nn.ModuleList:
return nn.ModuleList(
[
def create_sequence() -> nn.Sequential:
return nn.Sequential(
nn.Linear(512, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 512)
])
)

def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = input_tensor / torch.norm(input_tensor)
temp_tensor = nn.functional.normalize(input_tensor, p = 2, dim = -1)
return self.sequence(temp_tensor) + 0.2 * self.linear(temp_tensor)

for layer in self.layers[:-1]:
output_tensor = self.leaky_relu(layer(output_tensor))

output_tensor = self.layers[-1](output_tensor)
return output_tensor
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
nn.init.constant_(module.bias, 0.01)
29 changes: 20 additions & 9 deletions crossface/src/training.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import os
import shutil
from configparser import ConfigParser
from typing import Tuple
from pathlib import Path
from typing import Tuple, cast

import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader

from .dataset import StaticDataset
from .models.crossface import CrossFace
from .types import Batch, Embedding, OptimizerSet
from .types import Batch, Embedding, OptimizerSet, TrainerPrecision, TrainerStrategy

CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
Expand Down Expand Up @@ -67,6 +69,13 @@ def configure_optimizers(self) -> OptimizerSet:
return optimizer_set


class ModelWithConfigCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
super()._save_checkpoint(trainer, checkpoint_path)
config_path = Path(checkpoint_path).with_suffix('.ini')
shutil.copy('config.ini', config_path)


def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
Expand All @@ -89,8 +98,8 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T

def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
Expand All @@ -105,15 +114,17 @@ def create_trainer() -> Trainer:
precision = config_precision,
callbacks =
[
ModelCheckpoint(
ModelWithConfigCheckpoint(
monitor = 'training_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_epochs = 1,
every_n_epochs = 1000,
save_top_k = 5,
save_last = True
)
]
),
StochasticWeightAveraging(swa_lrs = 1e-2)
],
val_check_interval = 1000
)


Expand Down
5 changes: 4 additions & 1 deletion crossface/src/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, TypeAlias
from typing import Any, Literal, TypeAlias

from torch import Tensor

Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor

OptimizerSet : TypeAlias = Any

TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
5 changes: 3 additions & 2 deletions hyperswap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ face_masker_path = .models/face_masker.pt
```
[training.model.generator]
source_channels = 512
output_channels = 4096
output_size = 256
num_blocks = 2
```
Expand Down Expand Up @@ -89,10 +88,12 @@ mask_weight = 5.0
```
[training.trainer]
accumulate_size = 4
discriminator_ratio = 0.4
gradient_clip = 20.0
max_epochs = 50
strategy = auto
precision = 16-mixed
sync_batchnorm = false
preview_frequency = 100
```

Expand Down Expand Up @@ -164,7 +165,7 @@ python train.py
Launch the TensorBoard to monitor the training.

```
tensorboard --logdir=.logs
tensorboard --logdir .logs
```


Expand Down
5 changes: 3 additions & 2 deletions hyperswap/config.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[training.dataset]
file_pattern =
convert_template =
multiplier
multiplier =
transform_size =
usage_mode =
batch_mode =
Expand All @@ -20,7 +20,6 @@ face_masker_path =

[training.model.generator]
source_channels =
output_channels =
output_size =
num_blocks =

Expand All @@ -47,10 +46,12 @@ mask_weight =

[training.trainer]
accumulate_size =
discriminator_ratio =
gradient_clip =
max_epochs =
strategy =
precision =
sync_batchnorm =
preview_frequency =

[training.modifier]
Expand Down
68 changes: 24 additions & 44 deletions hyperswap/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,73 +43,44 @@ def __getitem__(self, index : int) -> Batch:
def __len__(self) -> int:
return len(resolve_static_file_pattern(self.config_file_pattern))

def compose_transforms(self) -> transforms:
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def prepare_equal_batch(self, source_path : str) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
return source_tensor, source_tensor
return self.create_batch(source_path, source_path, self.config_convert_template, self.config_convert_template)

def prepare_same_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)

def prepare_source_batch(self, source_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))

target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template)
return source_tensor, target_tensor
return self.create_batch(source_path, target_path, self.config_convert_template, config_convert_template)

def prepare_target_batch(self, target_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))

source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
return self.create_batch(source_path, target_path, config_convert_template, self.config_convert_template)

def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)

def compose_transforms(self) -> transforms:
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
config_parser = ConfigParser()
Expand All @@ -126,6 +97,15 @@ def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:

return config_parser

def create_batch(self, source_path : str, target_path : str, source_convert_template : ConvertTemplate, target_convert_template : ConvertTemplate) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, source_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, target_convert_template)
return source_tensor, target_tensor

@staticmethod
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
if convert_template:
Expand Down
2 changes: 1 addition & 1 deletion hyperswap/src/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def export() -> None:
config_precision = CONFIG_PARSER.get('exporting', 'precision')

os.makedirs(config_directory_path, exist_ok = True)
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location ='cpu').eval()
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()

if config_precision == 'half':
model = HalfPrecision(model).eval()
Expand Down
12 changes: 6 additions & 6 deletions hyperswap/src/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) ->
return output_tensor


def calculate_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
def calculate_face_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0

embedding = embedder(crop_tensor)
embedding = nn.functional.normalize(embedding, p = 2)
return embedding
face_embedding = embedder(crop_tensor)
face_embedding = nn.functional.normalize(face_embedding, p = 2)
return face_embedding


def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
Expand All @@ -56,15 +56,15 @@ def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:


def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
padding = int(input_tensor.shape[2] * factor + 0.5)
kernel_size = 1 + 2 * padding
temp_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
output_tensor = nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
return output_tensor


def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
padding = int(input_tensor.shape[2] * factor + 0.5)
kernel_size = 1 + 2 * padding
temp_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
output_tensor = 1 - nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
Expand Down
4 changes: 2 additions & 2 deletions hyperswap/src/inferencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torchvision import io

from .helper import calculate_embedding
from .helper import calculate_face_embedding
from .training import HyperSwapTrainer

CONFIG_PARSER = configparser.ConfigParser()
Expand All @@ -22,6 +22,6 @@ def infer() -> None:

source_tensor = io.read_image(config_source_path)
target_tensor = io.read_image(config_target_path)
source_embedding = calculate_embedding(embedder, source_tensor, (0, 0, 0, 0))
source_embedding = calculate_face_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = generator(source_embedding, target_tensor)
io.write_jpeg(output_tensor, config_output_path)
Loading