24 #include "pddl_robot_memory_thread.h"
26 #include <utils/misc/string_conversions.h>
28 #include <bsoncxx/exception/exception.hpp>
32 using namespace mongocxx;
33 using namespace bsoncxx;
34 using namespace bsoncxx::builder;
61 PddlRobotMemoryThread::PddlRobotMemoryThread()
62 :
Thread(
"PddlRobotMemoryThread",
Thread::OPMODE_WAITFORWAKEUP),
72 input_path = StringConversions::resolve_path(
73 "@BASEDIR@/src/agents/"
74 +
config->
get_string(
"plugins/pddl-robot-memory/input-problem-description"));
75 output_path = StringConversions::resolve_path(
76 "@BASEDIR@/src/agents/"
77 +
config->
get_string(
"plugins/pddl-robot-memory/output-problem-description"));
84 gen_if->set_msg_id(0);
85 gen_if->set_final(
false);
92 if (
config->
get_bool(
"plugins/pddl-robot-memory/generate-on-init")) {
105 std::ifstream istream(input_path);
106 if (istream.is_open()) {
108 std::string((std::istreambuf_iterator<char>(istream)), std::istreambuf_iterator<char>());
114 input =
"{{=<< >>=}}" + input;
117 ctemplate::TemplateDictionary dict(
"pddl-rm");
119 basic::document facets;
123 std::map<std::string, std::string> templates;
124 while (input.find(
"<<#", cur_pos) != std::string::npos) {
125 cur_pos = input.find(
"<<#", cur_pos) + 3;
126 size_t tpl_end_pos = input.find(
">>", cur_pos);
128 size_t q_del_pos = input.find(
"|", cur_pos);
129 if (q_del_pos == std::string::npos || q_del_pos > tpl_end_pos)
132 std::string template_name = input.substr(cur_pos, q_del_pos - cur_pos);
133 std::string query_str = input.substr(q_del_pos + 1, tpl_end_pos - (q_del_pos + 1));
134 if (templates.find(template_name) != templates.end()) {
135 if (templates[template_name] != query_str) {
137 "Template with same name '%s' but different query '%s' vs '%s'!",
138 template_name.c_str(),
140 templates[template_name].c_str());
142 input.erase(q_del_pos, tpl_end_pos - q_del_pos);
146 templates[template_name] = query_str;
148 input.erase(q_del_pos, tpl_end_pos - q_del_pos);
162 facets.append(basic::kvp(template_name, [query_str](basic::sub_array array) {
163 basic::document query;
164 query.append(basic::kvp(
"$match", from_json(query_str)));
165 array.append(query.view());
167 }
catch (bsoncxx::exception &e) {
169 "Template query failed: %s\n%s",
175 basic::document aggregate_query;
176 aggregate_query.append(basic::kvp(
"$facet", facets.view()));
177 std::vector<document::view> aggregate_pipeline{aggregate_query.view()};
179 auto result = res.view()[
"result"][
"0"].get_document().view();
181 for (
auto e : result) {
182 for (
auto f : e.get_document().view()) {
183 ctemplate::TemplateDictionary *entry_dict = dict.AddSectionDictionary(std::string(e.key()));
184 fill_dict_from_document(entry_dict, f.get_document().view());
189 dict.SetValue(
"GOAL", goal);
192 ctemplate::StringToTemplateCache(
"tpl-cache", input, ctemplate::DO_NOT_STRIP);
193 if (!ctemplate::TemplateNamelist::IsAllSyntaxOkay(ctemplate::DO_NOT_STRIP)) {
195 std::vector<std::string> error_list =
196 ctemplate::TemplateNamelist::GetBadSyntaxList(
false, ctemplate::DO_NOT_STRIP);
197 for (std::string error : error_list) {
203 ctemplate::ExpandTemplate(
"tpl-cache", ctemplate::DO_NOT_STRIP, &dict, &output);
207 std::ofstream ostream(output_path);
208 if (ostream.is_open()) {
209 ostream << output.c_str();
216 gen_if->set_final(
true);
227 PddlRobotMemoryThread::bb_interface_message_received(
Interface * interface,
230 if (message->is_of_type<PddlGenInterface::GenerateMessage>()) {
231 PddlGenInterface::GenerateMessage *msg = (PddlGenInterface::GenerateMessage *)message;
232 gen_if->set_msg_id(msg->id());
233 gen_if->set_final(
false);
235 if (std::string(msg->goal()) !=
"")
239 logger->
log_error(name(),
"Received unknown message of type %s, ignoring", message->type());
251 PddlRobotMemoryThread::fill_dict_from_document(ctemplate::TemplateDictionary *dict,
252 const bsoncxx::document::view &doc,
255 for (
auto elem : doc) {
256 switch (elem.type()) {
258 dict->SetValue(prefix + std::string(elem.key()), std::to_string(elem.get_double()));
261 dict->SetValue(prefix + std::string(elem.key()), elem.get_utf8().value.to_string());
264 dict->SetValue(prefix + std::string(elem.key()), std::to_string(elem.get_bool()));
267 dict->SetIntValue(prefix + std::string(elem.key()), elem.get_int32());
270 dict->SetIntValue(prefix + std::string(elem.key()), elem.get_int64());
272 case type::k_document:
273 fill_dict_from_document(dict,
274 elem.get_document().view(),
275 prefix + std::string(elem.key()) +
"_");
278 dict->SetValue(prefix + std::string(elem.key()), elem.get_oid().value.to_string());
280 case type::k_array: {
284 array::view array = elem.get_array();
286 for (
auto e : array) {
287 b.append(basic::kvp(std::to_string(i++), e.get_document().view()));
289 fill_dict_from_document(dict, b.view(), prefix + std::string(elem.key()) +
"_");
291 std::string array_string;
292 for (
auto e : array) {
294 array_string +=
" " + e.get_utf8().value.to_string();
296 dict->SetValue(prefix + std::string(elem.key()), array_string);
299 default: dict->SetValue(prefix + std::string(elem.key()),
"INVALID_VALUE_TYPE");