8000 Async Osiris interface by piepie62 · Pull Request #356 · Norbyte/bg3se · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Async Osiris interface #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions BG3Extender/Extender/Client/ClientNetworking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,81 @@ void ExtenderProtocol::ProcessExtenderMessage(net::MessageContext& context, net:
break;
}

case net::MessageWrapper::kS2COsirisQueryResponse:
{
ecl::LuaClientPin pin(ecl::ExtensionState::Get());
if (pin)
{
const auto& response = msg.s2c_osiris_query_response();

if (response.response_case() == net::MsgS2COsirisQueryResponse::ResponseCase::kError)
{
pin->ResolveOsirisFuture(response.responseid(), response.error());
}
else if (!response.succeeded())
{
pin->ResolveOsirisFuture(response.responseid(), "Unknown error occurred in query");
}
else
{
Array<Array<std::variant<std::monostate, StringView, int64_t, float>>> respdata;

respdata.Reallocate(response.results_size());

bool resultError = false;

for (const auto& result : response.results())
{
respdata.Add({});
auto& currResp = *(respdata.begin() + respdata.size() - 1);
currResp.Reallocate(result.num_retvals());
for (const auto& val : result.retvals())
{
// Note: for some reason assigning to/emplacing to a variant crashes the game.
// So... don't
while (val.index() > currResp.size())
{
currResp.Add({});
}
switch (val.val_case())
{
case std::remove_cvref_t<decltype(val)>::ValCase::kIntv:
currResp.Add({val.intv()});
break;
case std::remove_cvref_t<decltype(val)>::ValCase::kNumv:
currResp.Add({val.numv()});
break;
case std::remove_cvref_t<decltype(val)>::ValCase::kStrv:
currResp.Add({val.strv()});
break;
}
}

while (currResp.size() < result.num_retvals())
{
currResp.Add({});
}

if (currResp.size() > result.num_retvals())
{
resultError = true;
break;
}
}

if (resultError)
{
pin->ResolveOsirisFuture(response.responseid(), "Malformed response: values not in order");
}
else
{
pin->ResolveOsirisFuture(response.responseid(), respdata);
}
}
}
break;
}

default:
OsiErrorS("Unknown extension message type received!");
}
Expand Down
170 changes: 170 additions & 0 deletions BG3Extender/Extender/Server/ServerNetworking.cpp
< F438 td class="blob-num blob-num-addition empty-cell">
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@

BEGIN_NS(esv)

namespace
{
void OsirisQueryErrorResponse(net::MessageContext& context, const net::MessageWrapper& msg, const char* error, ...)
{
auto message = gExtender->GetServer().GetNetworkManager().GetFreeMessage();
auto response = message->GetMessage().mutable_s2c_osiris_query_response();
static char errorformatbuf[1024];
va_list args;
va_start(args, error);
vsnprintf(errorformatbuf, 1024, error, args);
va_end(args);
response->set_error(errorformatbuf);
response->set_responseid(msg.c2s_osiris_query().msgid());
gExtender->GetServer().GetNetworkManager().Send(message, context.UserID);
}
}

#define QUERY_ERROR(...) OsirisQueryErrorResponse(context, msg, __VA_ARGS__)

net::ProtocolResult ExtenderProtocol::ProcessMsg(void* unused, net::MessageContext* context, net::Message* msg)
{
auto base = ExtenderProtocolBase::ProcessMsg(unused, context, msg);
Expand Down Expand Up @@ -41,11 +60,162 @@ void ExtenderProtocol::ProcessExtenderMessage(net::MessageContext& context, net:
break;
}

case net::MessageWrapper::kC2SOsirisQuery:
{
// FIXME - not yet supported
auto const& query = msg.c2s_osiris_query();

if (!gExtender->GetCurrentExtensionState()->GetLua() || gExtender->GetCurrentExtensionState()->GetLua()->RestrictionFlags & lua::State::RestrictOsiris) {
QUERY_ERROR("Attempted to read Osiris database in restricted context");
}

switch (query.type())
{
case net::OsirisQueryType::OSIRIS_CALL:
QUERY_ERROR("Remote Osiris calls not (yet) supported!");
OsiErrorS("Remote Osiris calls not (yet) supported!");
break;
case net::OsirisQueryType::OSIRIS_DEFER:
QUERY_ERROR("Remote Osiris deferrals not (yet) supported!");
OsiErrorS("Remote Osiris deferrals not (yet) supported!");
break;
case net::OsirisQueryType::OSIRIS_DELETE:
QUERY_ERROR("Remote Osiris deletions not (yet) supported!");
OsiErrorS("Remote Osiris deletions not (yet) supported!");
break;
case net::OsirisQueryType::OSIRIS_GET:
ProcessOsirisGet(context, msg);
break;
}
break;
}

default:
OsiErrorS("Unknown extension message type received!");
}
}

void ExtenderProtocol::ProcessOsirisGet(net::MessageContext& context, net::MessageWrapper& msg)
{
auto const& query = msg.c2s_osiris_query();

auto func = gExtender->GetServer().Osiris().LookupFunction(query.name().c_str(), query.num_args());
if (func != nullptr && func->Signature->OutParamList.numOutParams() == 0) {
if (func->Type == FunctionType::Database)
{
auto db = func->Node.Get()->Database.Get();

auto head = db->Facts.Head;
auto current = head->Next;

auto message = gExtender->GetServer().GetNetworkManager().GetFreeMessage();
auto response = message->GetMessage().mutable_s2c_osiris_query_response();
response->set_responseid(query.msgid());
response->set_succeeded(true);

std::vector<std::remove_reference_t<decltype(query.args()[0])>*> args(query.num_args(), nullptr);
for (int i = 0; i < query.args_size(); i++)
{
args[query.args(i).index()] = &query.args(i);
}

const auto check_match = [&args](TupleVec const& v){
for (int i = 0; i < v.Size; i++)
{
switch (gExtender->GetServer().Osiris().GetBaseType((ValueType)v.Values[i].TypeId))
{
case ValueType::Integer:
if (args[i] != nullptr &&
!(args[i]->val_case() == args[i]->kIntv && args[i]->intv() == v.Values[i].Value.Int32))
{
return false;
}
break;

case ValueType::Integer64:
if (args[i] != nullptr &&
!(args[i]->val_case() == args[i]->kIntv && args[i]->intv() == v.Values[i].Value.Int64))
{
return false;
}
break;

case ValueType::Real:
if (args[i] != nullptr &&
!(args[i]->val_case() == args[i]->kNumv && abs(v.Values[i].Value.Float - args[i]->numv()) <= 0.00001f))
{
return false;
}
break;

case ValueType::String:
case ValueType::GuidString:
if (args[i] != nullptr &&
!(args[i]->val_case() == args[i]->kStrv && args[i]->strv() == v.Values[i].Value.String))
{
return false;
}
break;
}
}
return true;
};

while (current != head) {
if (check_match(current->Item)) {
auto result = response->add_results();
result->set_num_retvals((uint32_t)args.size());
for (int i = 0; i < args.size(); i++)
{
if (!args[i])
{
auto v = result->add_retvals();
v->set_index(i);
switch (gExtender->GetServer().Osiris().GetBaseType((ValueType)current->Item.Values[i].TypeId))
{
case ValueType::Integer:
v->set_intv(current->Item.Values[i].Value.Int32);
break;

case ValueType::Integer64:
v->set_intv(current->Item.Values[i].Value.Int64);
break;

case ValueType::Real:
v->set_numv(current->Item.Values[i].Value.Float);
break;

case ValueType::String:
case ValueType::GuidString:
v->set_strv(current->Item.Values[i].Value.String);
break;
default:
OsiErrorS("Unhandled Osi TypedValue");
message->GetMessage().mutable_s2c_osiris_query_response()->set_succeeded(false);
message->GetMessage().mutable_s2c_osiris_query_response()->set_error("Unhandled Osi TypedValue");
goto sendtime; // Used as a multi-stage break;
}
}
}
}

current = current->Next;
}

sendtime:
gExtender->GetServer().GetNetworkManager().Send(message, context.UserID);
}
else
{
QUERY_ERROR("Function '%s(%d)' is not a database", query.name().c_str(), query.args());
}
}
else
{
QUERY_ERROR("No database named '%s(%d)' exists", query.name().c_str(), query.args());
}
}

void NetworkManager::Reset()
{
peerVersions_.clear();
Expand Down
4 changes: 3 additions & 1 deletion BG3Extender/Extender/Server/ServerNetworking.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class ExtenderProtocol : public net::ExtenderProtocolBase

protected:
void ProcessExtenderMessage(net::MessageContext& context, net::MessageWrapper& msg) override;

void ProcessOsirisGet(net::MessageContext& context, net::MessageWrapper& msg);
};

class NetworkManager
Expand Down Expand Up @@ -41,4 +43,4 @@ class NetworkManager
std::unordered_map<PeerId, uint32_t> peerVersions_;
};

END_NS()
END_NS()
3 changes: 2 additions & 1 deletion BG3Extender/Extender/Shared/ExtenderNet.h
10000
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class ExtenderMessage : public Message
static constexpr uint32_t MaxPayloadLength = 0xfffff;

static constexpr uint32_t VerInitial = 1;
static constexpr uint32_t VerClientOsirisQuery = VerInitial + 1;
// Version of protocol, increment each time the protobuf changes
static constexpr uint32_t ProtoVersion = VerInitial;
static constexpr uint32_t ProtoVersion = VerClientOsirisQuery;

ExtenderMessage();
~ExtenderMessage() override;
Expand Down
41 changes: 41 additions & 0 deletions BG3Extender/Extender/Shared/ExtenderProtocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,45 @@ message MsgUserVars {
repeated UserVar vars = 1;
}

message OsirisVal {
uint32 index = 1;
oneof val {
string strv = 2;
int64 intv = 3;
float numv = 4;
}
}

enum OsirisQueryType {
OSIRIS_GET = 0;
OSIRIS_CALL = 1;
OSIRIS_DELETE = 2;
OSIRIS_DEFER = 3;
}

// Notifies the server that the client requested an Osiris query
message MsgC2SOsirisQuery {
uint32 msgid = 1;
string name = 2;
OsirisQueryType type = 3;
uint32 num_args = 4; // Full number of arguments, including null ones
repeated OsirisVal args = 5;
}

message OsirisResult {
uint32 num_retvals = 1;
repeated OsirisVal retvals = 2;
}

message MsgS2COsirisQueryResponse {
uint32 responseid = 1;
repeated OsirisResult results = 2;
oneof response {
bool succeeded = 3;
string error = 4;
}
}

message MessageWrapper {
oneof msg {
MsgPostLuaMessage post_lua = 1;
Expand All @@ -100,5 +139,7 @@ message MessageWrapper {
MsgS2CSyncStat s2c_sync_stat = 6;
MsgS2CKick s2c_kick = 7;
MsgUserVars user_vars = 8;
MsgC2SOsirisQuery c2s_osiris_query = 9;
MsgS2COsirisQueryResponse s2c_osiris_query_response = 10;
}
}
2 changes: 1 addition & 1 deletion BG3Extender/Extender/Shared/ThreadedExtenderState.inl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void ThreadedExtenderState::RemoveThread(DWORD threadId)

void ThreadedExtenderState::EnqueueTask(std::function<void()> fun)
{
threadTasks_.push(fun);
threadTasks_.push(std::move(fun));
}

void ThreadedExtenderState::SubmitTaskAndWait(std::function<void()> fun)
Expand Down
6 changes: 3 additions & 3 deletions BG3Extender/Extender/Shared/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ BEGIN_SE()
#if !defined(OSI_NO_DEBUG_LOG)
#define LuaError(msg) { \
std::stringstream ss; \
ss << __FUNCTION__ "(): " msg; \
ss << __FUNCTION__ "(): " << msg; \
LogLuaError(ss.str()); \
}

#define OsiError(msg) { \
std::stringstream ss; \
ss << __FUNCTION__ "(): " msg; \
ss << __FUNCTION__ "(): " << msg; \
LogOsirisError(ss.str()); \
}

#define OsiWarn(msg) { \
std::stringstream ss; \
ss << __FUNCTION__ "(): " msg; \
ss << __FUNCTION__ "(): " << msg; \
LogOsirisWarning(ss.str()); \
}

Expand Down
Loading
0