Skip to content

Commit 517f2c7

Browse files
authored
Wrap component function exports in functions that call __wasm_init_task (#2417)
* Add task init wrapper * Cleanup * Cleanup * Cleanup * Rename * Add test * Typo * fix comment * More comment fixes * Improve testing * Remove explicit name * Fmt
1 parent 5032b71 commit 517f2c7

File tree

6 files changed

+472
-28
lines changed

6 files changed

+472
-28
lines changed

crates/wit-component/src/encoding.rs

Lines changed: 238 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,32 @@ fn to_val_type(ty: &WasmType) -> ValType {
116116
}
117117
}
118118

119+
fn import_func_name(f: &Function) -> String {
120+
match f.kind {
121+
FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {
122+
format!("import-func-{}", f.item_name())
123+
}
124+
125+
// transform `[method]foo.bar` into `import-method-foo-bar` to
126+
// have it be a valid kebab-name which can't conflict with
127+
// anything else.
128+
//
129+
// There's probably a better and more "formal" way to do this
130+
// but quick-and-dirty string manipulation should work well
131+
// enough for now hopefully.
132+
FunctionKind::Method(_)
133+
| FunctionKind::AsyncMethod(_)
134+
| FunctionKind::Static(_)
135+
| FunctionKind::AsyncStatic(_)
136+
| FunctionKind::Constructor(_) => {
137+
format!(
138+
"import-{}",
139+
f.name.replace('[', "").replace([']', '.', ' '], "-")
140+
)
141+
}
142+
}
143+
}
144+
119145
bitflags::bitflags! {
120146
/// Options in the `canon lower` or `canon lift` required for a particular
121147
/// function.
@@ -394,6 +420,10 @@ pub struct EncodingState<'a> {
394420

395421
/// Metadata about the world inferred from the input to `ComponentEncoder`.
396422
info: &'a ComponentWorld<'a>,
423+
424+
/// Maps from original export name to task initialization wrapper function index.
425+
/// Used to wrap exports with __wasm_init_(async_)task calls.
426+
export_task_initialization_wrappers: HashMap<String, u32>,
397427
}
398428

399429
impl<'a> EncodingState<'a> {
@@ -600,6 +630,10 @@ impl<'a> EncodingState<'a> {
600630
// at the end.
601631
self.instantiate_main_module(&shims)?;
602632

633+
// Create any wrappers needed for initializing tasks if task initialization
634+
// exports are present in the main module.
635+
self.create_export_task_initialization_wrappers()?;
636+
603637
// Separate the adapters according which should be instantiated before
604638
// and after indirect lowerings are encoded.
605639
let (before, after) = self
@@ -688,7 +722,9 @@ impl<'a> EncodingState<'a> {
688722
| Export::GeneralPurposeImportRealloc
689723
| Export::Initialize
690724
| Export::ReallocForAdapter
691-
| Export::IndirectFunctionTable => continue,
725+
| Export::IndirectFunctionTable
726+
| Export::WasmInitTask
727+
| Export::WasmInitAsyncTask => continue,
692728
}
693729
}
694730

@@ -1014,32 +1050,6 @@ impl<'a> EncodingState<'a> {
10141050
name
10151051
}
10161052
}
1017-
1018-
fn import_func_name(f: &Function) -> String {
1019-
match f.kind {
1020-
FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {
1021-
format!("import-func-{}", f.item_name())
1022-
}
1023-
1024-
// transform `[method]foo.bar` into `import-method-foo-bar` to
1025-
// have it be a valid kebab-name which can't conflict with
1026-
// anything else.
1027-
//
1028-
// There's probably a better and more "formal" way to do this
1029-
// but quick-and-dirty string manipulation should work well
1030-
// enough for now hopefully.
1031-
FunctionKind::Method(_)
1032-
| FunctionKind::AsyncMethod(_)
1033-
| FunctionKind::Static(_)
1034-
| FunctionKind::AsyncStatic(_)
1035-
| FunctionKind::Constructor(_) => {
1036-
format!(
1037-
"import-{}",
1038-
f.name.replace('[', "").replace([']', '.', ' '], "-")
1039-
)
1040-
}
1041-
}
1042-
}
10431053
}
10441054

10451055
fn encode_lift(
@@ -1053,8 +1063,14 @@ impl<'a> EncodingState<'a> {
10531063
let resolve = &self.info.encoder.metadata.resolve;
10541064
let metadata = self.info.module_metadata_for(module);
10551065
let instance_index = self.instance_for(module);
1066+
// If we generated an init task wrapper for this export, use that,
1067+
// otherwise alias the original export.
10561068
let core_func_index =
1057-
self.core_alias_export(Some(core_name), instance_index, core_name, ExportKind::Func);
1069+
if let Some(&wrapper_idx) = self.export_task_initialization_wrappers.get(core_name) {
1070+
wrapper_idx
1071+
} else {
1072+
self.core_alias_export(Some(core_name), instance_index, core_name, ExportKind::Func)
1073+
};
10581074
let exports = self.info.exports_for(module);
10591075

10601076
let options = RequiredOptions::for_export(
@@ -2189,6 +2205,199 @@ impl<'a> EncodingState<'a> {
21892205
.core_alias_export(debug_name, instance, name, kind)
21902206
})
21912207
}
2208+
2209+
/// Modules may define `__wasm_init_(async_)task` functions that must be called
2210+
/// at the start of every exported function to set up the stack pointer and
2211+
/// thread-local storage. To achieve this, we create a wrapper module called
2212+
/// `task-init-wrappers` that imports the original exports and the
2213+
/// task initialization functions, and defines wrapper functions that call
2214+
/// the relevant task initialization function before delegating to the original export.
2215+
/// We then instantiate this wrapper module and use its exports as the final
2216+
/// exports of the component. If we don't find a `__wasm_init_task` export,
2217+
/// we elide the wrapper module entirely.
2218+
fn create_export_task_initialization_wrappers(&mut self) -> Result<()> {
2219+
let instance_index = self.instance_index.unwrap();
2220+
let resolve = &self.info.encoder.metadata.resolve;
2221+
let world = &resolve.worlds[self.info.encoder.metadata.world];
2222+
let exports = self.info.exports_for(CustomModule::Main);
2223+
2224+
let wasm_init_task_export = exports.wasm_init_task();
2225+
let wasm_init_async_task_export = exports.wasm_init_async_task();
2226+
if wasm_init_task_export.is_none() || wasm_init_async_task_export.is_none() {
2227+
// __wasm_init_(async_)task was not exported by the main module,
2228+
// so no wrappers are needed.
2229+
return Ok(());
2230+
}
2231+
let wasm_init_task = wasm_init_task_export.unwrap();
2232+
let wasm_init_async_task = wasm_init_async_task_export.unwrap();
2233+
2234+
// Collect the exports that we will need to wrap, alongside information
2235+
// that we'll need to build the wrappers.
2236+
let funcs_to_wrap: Vec<_> = exports
2237+
.iter()
2238+
.flat_map(|(core_name, export)| match export {
2239+
Export::WorldFunc(key, _, abi) => match &world.exports[key] {
2240+
WorldItem::Function(f) => Some((core_name, f, abi)),
2241+
_ => None,
2242+
},
2243+
Export::InterfaceFunc(_, id, func_name, abi) => {
2244+
let func = &resolve.interfaces[*id].functions[func_name.as_str()];
2245+
Some((core_name, func, abi))
2246+
}
2247+
_ => None,
2248+
})
2249+
.collect();
2250+
2251+
if funcs_to_wrap.is_empty() {
2252+
// No exports, so no wrappers are needed.
2253+
return Ok(());
2254+
}
2255+
2256+
// Now we build the wrapper module
2257+
let mut types = TypeSection::new();
2258+
let mut imports = ImportSection::new();
2259+
let mut functions = FunctionSection::new();
2260+
let mut exports_section = ExportSection::new();
2261+
let mut code = CodeSection::new();
2262+
2263+
// Type for __wasm_init_(async_)task: () -> ()
2264+
types.ty().function([], []);
2265+
let wasm_init_task_type_idx = 0;
2266+
2267+
// Import __wasm_init_task and __wasm_init_async_task into the wrapper module
2268+
imports.import(
2269+
"",
2270+
wasm_init_task,
2271+
EntityType::Function(wasm_init_task_type_idx),
2272+
);
2273+
imports.import(
2274+
"",
2275+
wasm_init_async_task,
2276+
EntityType::Function(wasm_init_task_type_idx),
2277+
);
2278+
let wasm_init_task_func_idx = 0u32;
2279+
let wasm_init_async_task_func_idx = 1u32;
2280+
2281+
let mut type_indices = HashMap::new();
2282+
let mut next_type_idx = 1u32;
2283+
let mut next_func_idx = 2u32;
2284+
2285+
// First pass: create all types and import all original functions
2286+
struct FuncInfo<'a> {
2287+
name: &'a str,
2288+
type_idx: u32,
2289+
orig_func_idx: u32,
2290+
is_async: bool,
2291+
n_params: usize,
2292+
}
2293+
let mut func_info = Vec::new();
2294+
for &(name, func, abi) in funcs_to_wrap.iter() {
2295+
let sig = resolve.wasm_signature(*abi, func);
2296+
let type_idx = *type_indices.entry(sig.clone()).or_insert_with(|| {
2297+
let idx = next_type_idx;
2298+
types.ty().function(
2299+
sig.params.iter().map(to_val_type),
2300+
sig.results.iter().map(to_val_type),
2301+
);
2302+
next_type_idx += 1;
2303+
idx
2304+
});
2305+
2306+
imports.import("", &import_func_name(func), EntityType::Function(type_idx));
2307+
let orig_func_idx = next_func_idx;
2308+
next_func_idx += 1;
2309+
2310+
func_info.push(FuncInfo {
2311+
name,
2312+
type_idx,
2313+
orig_func_idx,
2314+
is_async: abi.is_async(),
2315+
n_params: sig.params.len(),
2316+
});
2317+
}
2318+
2319+
// Second pass: define wrapper functions
2320+
for info in func_info.iter() {
2321+
let wrapper_func_idx = next_func_idx;
2322+
functions.function(info.type_idx);
2323+
2324+
let mut func = wasm_encoder::Function::new([]);
2325+
if info.is_async {
2326+
func.instruction(&Instruction::Call(wasm_init_async_task_func_idx));
2327+
} else {
2328+
func.instruction(&Instruction::Call(wasm_init_task_func_idx));
2329+
}
2330+
for i in 0..info.n_params as u32 {
2331+
func.instruction(&Instruction::LocalGet(i));
2332+
}
2333+
func.instruction(&Instruction::Call(info.orig_func_idx));
2334+
func.instruction(&Instruction::End);
2335+
code.function(&func);
2336+
2337+
exports_section.export(info.name, ExportKind::Func, wrapper_func_idx);
2338+
next_func_idx += 1;
2339+
}
2340+
2341+
let mut wrapper_module = Module::new();
2342+
wrapper_module.section(&types);
2343+
wrapper_module.section(&imports);
2344+
wrapper_module.section(&functions);
2345+
wrapper_module.section(&exports_section);
2346+
wrapper_module.section(&code);
2347+
2348+
let wrapper_module_idx = self
2349+
.component
2350+
.core_module(Some("init-task-wrappers"), &wrapper_module);
2351+
2352+
// Prepare imports for instantiating the wrapper module
2353+
let mut wrapper_imports = Vec::new();
2354+
let init_idx = self.core_alias_export(
2355+
Some(wasm_init_task),
2356+
instance_index,
2357+
wasm_init_task,
2358+
ExportKind::Func,
2359+
);
2360+
let init_async_idx = self.core_alias_export(
2361+
Some(wasm_init_async_task),
2362+
instance_index,
2363+
wasm_init_async_task,
2364+
ExportKind::Func,
2365+
);
2366+
wrapper_imports.push((wasm_init_task.into(), ExportKind::Func, init_idx));
2367+
wrapper_imports.push((
2368+
wasm_init_async_task.into(),
2369+
ExportKind::Func,
2370+
init_async_idx,
2371+
));
2372+
2373+
// Import all original exports to be wrapped
2374+
for (name, func, _) in &funcs_to_wrap {
2375+
let orig_idx =
2376+
self.core_alias_export(Some(name), instance_index, name, ExportKind::Func);
2377+
wrapper_imports.push((import_func_name(func), ExportKind::Func, orig_idx));
2378+
}
2379+
2380+
let wrapper_args_idx = self.component.core_instantiate_exports(
2381+
Some("init-task-wrappers-args"),
2382+
wrapper_imports.iter().map(|(n, k, i)| (n.as_str(), *k, *i)),
2383+
);
2384+
2385+
let wrapper_instance = self.component.core_instantiate(
2386+
Some("init-task-wrappers-instance"),
2387+
wrapper_module_idx,
2388+
[("", ModuleArg::Instance(wrapper_args_idx))],
2389+
);
2390+
2391+
// Map original names to wrapper indices
2392+
for (name, _, _) in funcs_to_wrap {
2393+
let wrapper_idx =
2394+
self.core_alias_export(Some(&name), wrapper_instance, &name, ExportKind::Func);
2395+
self.export_task_initialization_wrappers
2396+
.insert(name.into(), wrapper_idx);
2397+
}
2398+
2399+
Ok(())
2400+
}
21922401
}
21932402

21942403
/// A list of "shims" which start out during the component instantiation process
@@ -3123,6 +3332,7 @@ impl ComponentEncoder {
31233332
exported_instances: Default::default(),
31243333
aliased_core_items: Default::default(),
31253334
info: &world,
3335+
export_task_initialization_wrappers: HashMap::new(),
31263336
};
31273337
state.encode_imports(&self.import_name_map)?;
31283338
state.encode_core_modules();

0 commit comments

Comments
 (0)