orca/ext/wasm3/platforms/cpp/wasm3_cpp/include/wasm3_cpp.h

413 lines
14 KiB
C++

#pragma once
#include <tuple>
#include <algorithm>
#include <type_traits>
#include <iostream>
#include <vector>
#include <memory>
#include <string>
#include <iterator>
#include <cassert>
#include "wasm3.h"
namespace wasm3 {
/** @cond */
namespace detail {
typedef uint64_t *stack_type;
typedef void *mem_type;
template<typename T, typename...> struct first_type { typedef T type; };
typedef const void *(*m3_api_raw_fn)(IM3Runtime, uint64_t *, void *);
template<typename T>
void arg_from_stack(T &dest, stack_type &_sp, mem_type mem) {
m3ApiGetArg(T, tmp);
dest = tmp;
}
template<typename T>
void arg_from_stack(T* &dest, stack_type &_sp, mem_type _mem) {
m3ApiGetArgMem(T*, tmp);
dest = tmp;
};
template<typename T>
void arg_from_stack(const T* &dest, stack_type &_sp, mem_type _mem) {
m3ApiGetArgMem(const T*, tmp);
dest = tmp;
};
template<char c>
struct m3_sig {
static const char value = c;
};
template<typename T> struct m3_type_to_sig;
template<> struct m3_type_to_sig<int32_t> : m3_sig<'i'> {};
template<> struct m3_type_to_sig<int64_t> : m3_sig<'I'> {};
template<> struct m3_type_to_sig<float> : m3_sig<'f'> {};
template<> struct m3_type_to_sig<double> : m3_sig<'F'> {};
template<> struct m3_type_to_sig<void> : m3_sig<'v'> {};
template<> struct m3_type_to_sig<void *> : m3_sig<'*'> {};
template<> struct m3_type_to_sig<const void *> : m3_sig<'*'> {};
template<typename Ret, typename ... Args>
struct m3_signature {
constexpr static size_t n_args = sizeof...(Args);
constexpr static const char value[n_args + 4] = {
m3_type_to_sig<Ret>::value,
'(',
m3_type_to_sig<Args>::value...,
')',
0
};
};
template <typename ...Args>
static void get_args_from_stack(stack_type &sp, mem_type mem, std::tuple<Args...> &tuple) {
std::apply([&](auto &... item) {
(arg_from_stack(item, sp, mem), ...);
}, tuple);
}
template<typename Func>
struct wrap_helper;
template <typename Ret, typename ...Args>
struct wrap_helper<Ret(Args...)> {
using Func = Ret(Args...);
static const void *wrap_fn(IM3Runtime rt, IM3ImportContext _ctx, stack_type _sp, mem_type mem) {
std::tuple<Args...> args;
// The order here matters: m3ApiReturnType should go before calling get_args_from_stack,
// since both modify `_sp`, and the return value on the stack is reserved before the arguments.
m3ApiReturnType(Ret);
get_args_from_stack(_sp, mem, args);
Func* function = reinterpret_cast<Func*>(_ctx->userdata);
Ret r = std::apply(function, args);
m3ApiReturn(r);
}
};
template <typename ...Args>
struct wrap_helper<void(Args...)> {
using Func = void(Args...);
static const void *wrap_fn(IM3Runtime rt, IM3ImportContext _ctx, stack_type sp, mem_type mem) {
std::tuple<Args...> args;
get_args_from_stack(sp, mem, args);
Func* function = reinterpret_cast<Func*>(_ctx->userdata);
std::apply(function, args);
m3ApiSuccess();
}
};
template<typename Func>
class m3_wrapper;
template<typename Ret, typename ... Args>
class m3_wrapper<Ret(Args...)> {
public:
static M3Result link(IM3Module io_module,
const char *const i_moduleName,
const char *const i_functionName,
Ret (*function)(Args...)) {
return m3_LinkRawFunctionEx(io_module, i_moduleName, i_functionName,
m3_signature<Ret, Args...>::value,
&wrap_helper<Ret(Args...)>::wrap_fn,
reinterpret_cast<void*>(function));
}
};
} // namespace detail
/** @endcond */
class module;
class runtime;
class function;
/**
* Exception thrown for wasm3 errors.
*
* Use error:what() to get the reason as a string.
*/
class error : public std::runtime_error {
public:
explicit error(M3Result err) : std::runtime_error(err) {}
};
/** @cond */
namespace detail {
static inline void check_error(M3Result err) {
if (err != m3Err_none) {
throw error(err);
}
}
} // namespace detail
/** @endcond */
/**
* Wrapper for WASM3 environment.
*
* Runtimes, modules are owned by an environment.
*/
class environment {
public:
environment() {
m_env.reset(m3_NewEnvironment(), m3_FreeEnvironment);
if (m_env == nullptr) {
throw std::bad_alloc();
}
}
/**
* Create new runtime
*
* @param stack_size_bytes size of the WASM stack for this runtime
* @return runtime object
*/
runtime new_runtime(size_t stack_size_bytes);
/**
* Parse a WASM module from file
*
* The parsed module is not loaded into any runtime. Use runtime::load to
* load the module after parsing it.
*
* @param in file (WASM binary)
* @return module object
*/
module parse_module(std::istream &in);
/**
* Parse a WASM module from binary data
*
* @param data pointer to the start of the binary
* @param size size of the binary
* @return module object
*/
module parse_module(const uint8_t *data, size_t size);
protected:
std::shared_ptr<struct M3Environment> m_env;
};
/**
* Wrapper for the runtime, where modules are loaded and executed.
*/
class runtime {
public:
/**
* Load the module into runtime
* @param mod module parsed by environment::parse_module
*/
void load(module &mod);
/**
* Get a function handle by name
*
* If the function is not found, throws an exception.
* @param name name of a function, c-string
* @return function object
*/
function find_function(const char *name);
protected:
friend class environment;
runtime(const std::shared_ptr<M3Environment> &env, size_t stack_size_bytes)
: m_env(env) {
m_runtime.reset(m3_NewRuntime(env.get(), stack_size_bytes, nullptr), &m3_FreeRuntime);
if (m_runtime == nullptr) {
throw std::bad_alloc();
}
}
/* runtime extends the lifetime of the environment */
std::shared_ptr<M3Environment> m_env;
std::shared_ptr<M3Runtime> m_runtime;
};
/**
* Module object holds a webassembly module
*
* It can be constructed by parsing a WASM binary using environment::parse_module.
* Functions can be linked to the loaded module.
* Once constructed, modules can be loaded into the runtime.
*/
class module {
public:
/**
* Link an external function.
*
* Throws an exception if the module doesn't reference a function with the given name.
*
* @tparam Func Function type (signature)
* @param module Name of the module to link the function to, or "*" to link to any module
* @param function_name Name of the function (as referenced by the module)
* @param function Function to link (a function pointer)
*/
template<typename Func>
void link(const char *module, const char *function_name, Func *function);
/**
* Same as module::link, but doesn't throw an exception if the function is not referenced.
*/
template<typename Func>
void link_optional(const char *module, const char *function_name, Func *function);
protected:
friend class environment;
friend class runtime;
module(const std::shared_ptr<M3Environment> &env, std::istream &in_wasm) {
in_wasm.unsetf(std::ios::skipws);
std::copy(std::istream_iterator<uint8_t>(in_wasm),
std::istream_iterator<uint8_t>(),
std::back_inserter(m_moduleRawData));
parse(env.get(), m_moduleRawData.data(), m_moduleRawData.size());
}
module(const std::shared_ptr<M3Environment> &env, const uint8_t *data, size_t size) : m_env(env) {
m_moduleRawData = std::vector<uint8_t>{data, data + size};
parse(env.get(), m_moduleRawData.data(), m_moduleRawData.size());
}
void parse(IM3Environment env, const uint8_t *data, size_t size) {
IM3Module p;
M3Result err = m3_ParseModule(env, &p, data, size);
detail::check_error(err);
m_module.reset(p, [this](IM3Module module) {
if (!m_loaded) {
m3_FreeModule(module);
}
});
}
void load_into(IM3Runtime runtime) {
M3Result err = m3_LoadModule(runtime, m_module.get());
detail::check_error(err);
m_loaded = true;
}
std::shared_ptr<M3Environment> m_env;
std::shared_ptr<M3Module> m_module;
bool m_loaded = false;
std::vector<uint8_t> m_moduleRawData {};
};
/**
* Handle of a function. Can be obtained from runtime::find_function method by name.
*/
class function {
public:
/**
* Call the function with the provided arguments, expressed as strings.
*
* Arguments are passed as strings. WASM3 automatically converts them into the types expected
* by the function being called.
*
* Note that the type of the return value must be explicitly specified as a template argument.
*
* @return the return value of the function.
*/
template<typename Ret, typename ... Args>
typename detail::first_type<Ret,
typename std::enable_if<std::is_convertible<Args, const char*>::value>::type...>::type
call_argv(Args... args) {
/* std::enable_if above checks that all argument types are convertible const char* */
const char* argv[] = {args...};
M3Result res = m3_CallArgv(m_func, sizeof...(args), argv);
detail::check_error(res);
Ret ret;
res = m3_GetResults(m_func, 1, &ret);
detail::check_error(res);
return ret;
}
template<typename ... Args>
typename detail::first_type<void,
typename std::enable_if<std::is_convertible<Args, const char*>::value>::type...>::type
call_argv(Args... args) {
/* std::enable_if above checks that all argument types are convertible const char* */
const char* argv[] = {args...};
M3Result res = m3_CallArgv(m_func, sizeof...(args), argv);
detail::check_error(res);
}
/**
* Call the function with the provided arguments (int/float types).
*
* Note that the type of the return value must be explicitly specified as a template argument.
*
* @return the return value of the function or void.
*/
template<typename Ret = void, typename ... Args>
Ret call(Args... args) {
const void *arg_ptrs[] = { reinterpret_cast<const void*>(&args)... };
M3Result res = m3_Call(m_func, sizeof...(args), arg_ptrs);
detail::check_error(res);
if constexpr (!std::is_void<Ret>::value) {
Ret ret;
const void* ret_ptrs[] = { &ret };
res = m3_GetResults(m_func, 1, ret_ptrs);
detail::check_error(res);
return ret;
}
}
protected:
friend class runtime;
function(const std::shared_ptr<M3Runtime> &runtime, const char *name) : m_runtime(runtime) {
M3Result err = m3_FindFunction(&m_func, runtime.get(), name);
detail::check_error(err);
assert(m_func != nullptr);
}
std::shared_ptr<M3Runtime> m_runtime;
M3Function *m_func = nullptr;
};
inline runtime environment::new_runtime(size_t stack_size_bytes) {
return runtime(m_env, stack_size_bytes);
}
inline module environment::parse_module(std::istream &in) {
return module(m_env, in);
}
inline module environment::parse_module(const uint8_t *data, size_t size) {
return module(m_env, data, size);
}
inline void runtime::load(module &mod) {
mod.load_into(m_runtime.get());
}
inline function runtime::find_function(const char *name) {
return function(m_runtime, name);
}
template<typename Func>
void module::link(const char *module, const char *function_name, Func *function) {
M3Result ret = detail::m3_wrapper<Func>::link(m_module.get(), module, function_name, function);
detail::check_error(ret);
}
template<typename Func>
void module::link_optional(const char *module, const char *function_name, Func *function) {
M3Result ret = detail::m3_wrapper<Func>::link(m_module.get(), module, function_name, function);
if (ret == m3Err_functionLookupFailed) {
return;
}
detail::check_error(ret);
}
} // namespace wasm3