How can I copy the parameters of one model to another in LibTorch? - neural-network

How can I copy the parameters of one model to another in LibTorch? I know how to do it in Torch (Python).
net2.load_state_dict(net.state_dict())
I have tried with the code below in C++ with quite a bit of work. It didn't copy one to another.
I don't see an option to copy the parameters of one like model into another like model.
#include <torch/torch.h>
using namespace torch::indexing;
torch::Device device(torch::kCUDA);
void loadstatedict(torch::nn::Module& model, torch::nn::Module& target_model) {
torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
auto new_params = target_model.named_parameters(); // implement this
auto params = model.named_parameters(true /*recurse*/);
auto buffers = model.named_buffers(true /*recurse*/);
for (auto& val : new_params) {
auto name = val.key();
auto* t = params.find(name);
if (t != nullptr) {
t->copy_(val.value());
} else {
t = buffers.find(name);
if (t != nullptr) {
t->copy_(val.value());
}
}
}
}
struct Critic_Net : torch::nn::Module {
torch::Tensor next_state_batch__sampled_action;
public:
Critic_Net() {
lin1 = torch::nn::Linear(3, 3);
lin2 = torch::nn::Linear(3, 1);
lin1->to(device);
lin2->to(device);
}
torch::Tensor forward(torch::Tensor next_state_batch__sampled_action) {
auto h = next_state_batch__sampled_action;
h = torch::relu(lin1->forward(h));
h = lin2->forward(h);
return h;
}
torch::nn::Linear lin1{nullptr}, lin2{nullptr};
};
auto net = Critic_Net();
auto net2 = Critic_Net();
auto the_ones = torch::ones({3, 3}).to(device);
int main() {
std::cout << net.forward(the_ones);
std::cout << net2.forward(the_ones);
loadstatedict(net, net2);
std::cout << net.forward(the_ones);
std::cout << net2.forward(the_ones);
}

Your solution with load_state_dict should work if I understand correctly. The problem here is the same as in your previous question : nothing is registered as either parameters or buffers or submodules. Add the register_module calls and it should work fine.
Link to the question
How this class should look like :
struct Critic_Net : torch::nn::Module {
public:
Critic_Net() {
lin1 = register_module("lin1", torch::nn::Linear(427, 42));
lin2 = register_module("lin1", torch::nn::Linear(42, 286));
lin3 = register_module("lin1", torch::nn::Linear(286, 1));
}
torch::Tensor forward(torch::Tensor next_state_batch__sampled_action) {
// unchanged
}
torch::nn::Linear lin1{nullptr}, lin2{nullptr}, lin3{nullptr};
};

Related

How to write a local branch predictor?

I am trying to use runspec test my local branch predictor, but only find a disappointing result.
By now I have tried use a 64 terms LHT, and when the LHT is full, I use FIFO tactics replace a terms in LHT.I don't know if I use a tiny LHT or my improper replacement tactics makes it a terrible precision, anyway it's only 60.9095.
for (int i = 0; i < 1 << HL; i++)
{
if (tag_lht[i] == (addr&(1-(1<<HL))))
{
addr = addr ^ LHT[i].getVal();
goto here;
break;
}
}
index_lht = index_lht%(1<<HL);
tag_lht[index_lht] = (addr&(1-(1<<HL)));
LHT[index_lht] = ShiftReg<2>();
addr = addr ^ LHT[index_lht].getVal();
index_lht++;
here:
for (int i = 0; i < 1 << L; i++)
{
if (tag[i] == (addr))
{
return bhist[i].isTaken();
}
}
index = index % (1 << L);
tag[index] = (addr);
bhist[index].reset();
return bhist[index++].isTaken();
Here I make some explain about the code. bhist is a table store 2-bit status about each branch instructions when the table is full, use FIFO replacement tactics. tag is where the table store address of each instruction. Besides, likely I use tag_lht to store address of each instruction that stored in LHT. Function isTaken() can easily get the predict result.
Thank you all guys, I find that stupid mistake I make, and the code above is correct, but may not seem work prefect. The mistake bellow:
for (int i = 0; i < (1 << L); i++)
{
if (tag[i] == (addr))
{
if (takenActually)
{
LHT[j].shiftIn(1);
bhist[i].increase();
}
else
{
LHT[j].shiftIn(0);
bhist[i].decrease();
}
}
break;
}
But it should be like this:
for (int i = 0; i < (1 << L); i++)
{
if (tag[i] == (addr))
{
if (takenActually)
{
LHT[j].shiftIn(1);
bhist[i].increase();
}
else
{
LHT[j].shiftIn(0);
bhist[i].decrease();
}
break;
}
}
I am so stupid that I waste you helpful people' s time, I spent so much time to figure out why it don't work, at first I thought that wrong variable or argument are used, now I just think I am a careless man.
Again I thank all you ardent fellows. Then I will answer the question with my full code.
PS. wish that my terrible English have not confuse anyone.:)

Schedule an asynchronous event that will complete when stdin has waiting data in boost::asio?

I'm using boost::asio with ncurses for a command-line game. The game needs to draw on the screen at a fixed time interval, and other operations (e.g. networking or file operations) are also executed whenever necessary. All these things can be done with async_read()/async_write() or equivalent on boost::asio.
However, I also need to read keyboard input, which (I think) comes from stdin. The usual way to read input in ncurses is to call getch(), which can be configured to either blocking (wait until there is a character available for consumption) or non-blocking (return a sentinel value of there no characters available) mode.
Using blocking mode would necessitate running getch() on a separate thread, which doesn't play well with ncurses. Using non-blocking mode, however, would cause my application to consume CPU time spinning in a loop until the user presses their keyboard. I've read this answer, which suggests that we can add stdin to the list of file descriptors in a select() call, which would block until one of the file descriptors has new data.
Since I'm using boost::asio, I can't directly use select(). I can't call async_read, because that would consume the character, leaving getch() with nothing to read. Is there something in boost::asio like async_read, but merely checks the existence of input without consuming it?
I think you should be able to use the posix stream descriptor to watch for input on file descriptor 0:
ba::posix::stream_descriptor d(io, 0);
input_loop = [&](error_code ec) {
if (!ec) {
program.on_input();
d.async_wait(ba::posix::descriptor::wait_type::wait_read, input_loop);
}
};
There, program::on_input() would call getch() with no timeout() until it returns ERR:
struct Program {
Program() {
initscr();
ESCDELAY = 0;
timeout(0);
cbreak();
noecho();
keypad(stdscr, TRUE); // receive special keys
clock = newwin(2, 40, 0, 0);
monitor = newwin(10, 40, 2, 0);
syncok(clock, true); // automatic updating
syncok(monitor, true);
scrollok(monitor, true); // scroll the input monitor window
}
~Program() {
delwin(monitor);
delwin(clock);
endwin();
}
void on_clock() {
wclear(clock);
char buf[32];
time_t t = time(NULL);
if (auto tmp = localtime(&t)) {
if (strftime(buf, sizeof(buf), "%T", tmp) == 0) {
strncpy(buf, "[error formatting time]", sizeof(buf));
}
} else {
strncpy(buf, "[error getting time]", sizeof(buf));
}
wprintw(clock, "Async: %s", buf);
wrefresh(clock);
}
void on_input() {
for (auto ch = getch(); ch != ERR; ch = getch()) {
wprintw(monitor, "received key %d ('%c')\n", ch, ch);
}
wrefresh(monitor);
}
WINDOW *monitor = nullptr;
WINDOW *clock = nullptr;
};
With the following main program you'd run it for 10 seconds (because Program doesn't yet know how to exit):
int main() {
Program program;
namespace ba = boost::asio;
using boost::system::error_code;
using namespace std::literals;
ba::io_service io;
std::function<void(error_code)> input_loop, clock_loop;
// Reading input when ready on stdin
ba::posix::stream_descriptor d(io, 0);
input_loop = [&](error_code ec) {
if (!ec) {
program.on_input();
d.async_wait(ba::posix::descriptor::wait_type::wait_read, input_loop);
}
};
// For fun, let's also update the time
ba::high_resolution_timer tim(io);
clock_loop = [&](error_code ec) {
if (!ec) {
program.on_clock();
tim.expires_from_now(100ms);
tim.async_wait(clock_loop);
}
};
input_loop(error_code{});
clock_loop(error_code{});
io.run_for(10s);
}
This works:
Full Listing
#include <boost/asio.hpp>
#include <boost/asio/posix/descriptor.hpp>
#include <iostream>
#include "ncurses.h"
#define CTRL_R 18
#define CTRL_C 3
#define TAB 9
#define NEWLINE 10
#define RETURN 13
#define ESCAPE 27
#define BACKSPACE 127
#define UP 72
#define LEFT 75
#define RIGHT 77
#define DOWN 80
struct Program {
Program() {
initscr();
ESCDELAY = 0;
timeout(0);
cbreak();
noecho();
keypad(stdscr, TRUE); // receive special keys
clock = newwin(2, 40, 0, 0);
monitor = newwin(10, 40, 2, 0);
syncok(clock, true); // automatic updating
syncok(monitor, true);
scrollok(monitor, true); // scroll the input monitor window
}
~Program() {
delwin(monitor);
delwin(clock);
endwin();
}
void on_clock() {
wclear(clock);
char buf[32];
time_t t = time(NULL);
if (auto tmp = localtime(&t)) {
if (strftime(buf, sizeof(buf), "%T", tmp) == 0) {
strncpy(buf, "[error formatting time]", sizeof(buf));
}
} else {
strncpy(buf, "[error getting time]", sizeof(buf));
}
wprintw(clock, "Async: %s", buf);
wrefresh(clock);
}
void on_input() {
for (auto ch = getch(); ch != ERR; ch = getch()) {
wprintw(monitor, "received key %d ('%c')\n", ch, ch);
}
wrefresh(monitor);
}
WINDOW *monitor = nullptr;
WINDOW *clock = nullptr;
};
int main() {
Program program;
namespace ba = boost::asio;
using boost::system::error_code;
using namespace std::literals;
ba::io_service io;
std::function<void(error_code)> input_loop, clock_loop;
// Reading input when ready on stdin
ba::posix::stream_descriptor d(io, 0);
input_loop = [&](error_code ec) {
if (!ec) {
program.on_input();
d.async_wait(ba::posix::descriptor::wait_type::wait_read, input_loop);
}
};
// For fun, let's also update the time
ba::high_resolution_timer tim(io);
clock_loop = [&](error_code ec) {
if (!ec) {
program.on_clock();
tim.expires_from_now(100ms);
tim.async_wait(clock_loop);
}
};
input_loop(error_code{});
clock_loop(error_code{});
io.run_for(10s);
}

Fully-transparent exposure of Eigen::Vector/Matrix types using pybind11

I have a simple class definition:
class State {
private:
Eigen::Vector3f m_data;
public:
State(const Eigen::Vector3f& state) : m_data(state) { }
Eigen::Vector3f get() const { return m_data; }
void set(const Eigen::Vector3f& _state) { m_data = _state; }
std::string repr() const {
return "state data: [x=" + std::to_string(m_data[0]) + ", y=" + std::to_string(m_data[1]) + ", theta=" + std::to_string(m_data[2]) + "]";
}
};
I then expose the above in python with pybind11:
namespace py = pybind11;
PYBIND11_MODULE(bound_state, m) {
m.doc() = "python bindings for State";
py::class_<State>(m, "State")
.def(py::init<Eigen::Vector3f>())
.def("get", &_State::get)
.def("set", &_State::set)
.def("__repr__", &_State::repr);
}
And everything works fine; I am able to import this module into python and construct a State instance with a numpy array. This isn't exactly what I want though. I want to be able to access this object as if it were a numpy array; I want to be able to do something like the following in python:
import bound_state as bs
arr = np.array([1, 2, 3])
a = bs.State(arr)
print(a[0])
(the above throws a TypeError: 'bound_state.State' object does not support indexing)
In the past, I've used boost::python to expose lists by using add_property and this allowed indexing of the underlying data in C++. does pybind11 have something similar that can work with Eigen? Could someone provide an example showing how to expose a State instance that is indexable?
Per the API Docs, this can be done easily with the def_property method.
Turn this bit:
namespace py = pybind11;
PYBIND11_MODULE(bound_state, m) {
m.doc() = "python bindings for State";
py::class_<State>(m, "State")
.def(py::init<Eigen::Vector3f>())
.def("get", &State::get)
.def("set", &State::set)
.def("__repr__", &State::repr);
}
Into this:
namespace py = pybind11;
PYBIND11_MODULE(bound_state, m) {
m.doc() = "python bindings for State";
py::class_<State>(m, "State")
.def(py::init<Eigen::Vector3f>())
.def_property("m_data", &State::get, &State::set)
.def("__repr__", &State::repr);
}
Now, from the python-side, I can do:
import bound_state as bs
arr = np.array([1, 2, 3])
a = bs.State(arr)
print(a.m_data[0])
This is not exactly what I want, but is a step in the right direction.

pybind11 equivalent of boost::python::extract?

I am considering port of a complex code from boost::python to pybind11, but I am puzzled by the absence of something like boost::python::extract<...>().check(). I read pybind11::cast<T> can be used to extract c++ object from a py::object, but the only way to check if the cast is possible is by calling it and catching the exception when the cast fails. Is there something I am overlooking?
isinstance will do the job (doc) :
namespace py = pybind11;
py::object obj = ...
if (py::isinstance<py::array_t<double>>(obj))
{
....
}
else if (py::isinstance<py::str>(obj))
{
std::string val = obj.cast<std::string>();
std::cout << val << std::endl;
}
else if (py::isinstance<py::list>(obj))
{
...
}

Access non static function from static function

Here is some insight: I am working with UnityScript in Unity 4.6.3. I have one script called Pause.js and it contains this function:
function fadeMusicOut () {
while (audio.volume >= 0.005) {
yield WaitForSeconds(0.1);
Debug.Log("Loop Entered: " + audio.volume);
audio.volume = (audio.volume - 0.015);
}
Another script GameManager.js has this function:
static function Score (wallName : String) {
if (wallName == "rightWall") {
playerScore01 += 1;
}
else {
playerScore02 += 1;
}
if (playerScore01 == SettingsBack.scoreLimit || playerScore02 == SettingsBack.scoreLimit)
{
startParticles = 1;
SettingsBack.gameOver = 1;
BallControl.fadeSound = 1;
yield WaitForSeconds(4);
Camera.main.SendMessage("fadeOut");
Pause.fadeMusic = 1;
SettingsBack.soundVolume = 0;
yield WaitForSeconds(2);
playerScore01 = 0;
playerScore02 = 0;
SettingsBack.soundVolume = oldSoundVol;
Application.LoadLevel("_Menu");
}
}
So pretty much I want to call the fadeMusicOut() function from static function Score, but it will not let me because it says it needs an instance of that object.
The Pause.js script is not attached to any game objects, but it is attached to 2 buttons that call their specific functions. The GameManager.js script is attached to an object called GM. So how can I go about calling fadeMusicOut() from the Score function?
I have tried setting new vars that import the game object but still no luck. I tried making fadeMusicOut() a static function, but it creates many errors.
Any help at all is appreciated.