state-machine/generator.py

210 lines
4.9 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
from string import Template
import pydot
def get_states(edges):
states = dict()
for edge in graph.get_edges():
source = edge.get_source()
dest = edge.get_destination()
if source not in states.keys():
states[source] = []
if dest not in states.keys():
states[dest] = []
states[source].append(dest)
return states
def generate_states_stub(state_names):
# Specialization
with open('state_impl.cpp', 'w') as fd:
fd.write('''#include "state_machine.h"
''')
state_spe_impl = Template('''
//
// State $name
//
void State$name::enter()
{
}
void State$name::leave()
{
}
''')
for state_name in state_names:
d = dict(name=state_name)
fd.write(state_spe_impl.substitute(d))
def generate_header(states):
state_names = states.keys()
nb_states = len(state_names)
with open("state_machine.h", "w") as fd:
fd.write('''#include <array>
#include <string>
''')
# Enum
fd.write('enum StateEnum {\n ')
fd.write(',\n '.join(state_names))
fd.write('\n};\n')
# State
state_template = Template('''
class State {
public:
State() {};
~State() {};
bool canTransitionTo(const State& state) const;
std::string get_name() const;
void print() const;
StateEnum value;
virtual void enter() {};
virtual void leave() {};
protected:
std::array<bool, $nb_states> transitions;
private:
};
''')
d = dict(nb_states=nb_states)
fd.write(state_template.substitute(d))
# Specialization
state_spe_template = Template('''
class State$name: public State {
public:
State$name() {
value = $name;
transitions = $transitions;
};
void enter();
void leave();
};
''')
for state_name in state_names:
transitions = ['true' if x in states[state_name] else 'false' for x in state_names]
str_transitions = '{ ' + ', '.join(transitions) + ' }'
d = dict(name=state_name, transitions=str_transitions)
fd.write(state_spe_template.substitute(d))
# State machine
state_machine_template = Template('''
class StateMachine {
public:
StateMachine(StateEnum initial_state):currentState(initial_state) {};
~StateMachine() {};
bool transitionTo(StateEnum s);
void print() const;
private:
StateEnum currentState;
std::array<State, $nb_states> states = $states_array;
};
''')
states_class = ['State' + x + '()' for x in state_names]
states_array = '{ ' + ', '.join(states_class) + ' }'
d = dict(nb_states=nb_states, states_array=states_array)
fd.write(state_machine_template.substitute(d))
def generate_impl(states):
state_names = states.keys()
nb_states = len(state_names)
with open('state_machine.cpp', 'w') as fd:
state_machine_impl = Template('''#include "state_machine.h"
#include <iostream>
#include <array>
#include <string>
using namespace std;
bool State::canTransitionTo(const State& state) const
{
return this->transitions[state.value];
}
std::string State::get_name() const
{
const array<string, $nb_states> names = $states_list;
return names[this->value];
}
void State::print() const
{
const array<string, $nb_states> names = $states_list;
const string& name = names[this->value];
cout << "State " << this->get_name() << " (" << this->value << ")\\n";
}
bool StateMachine::transitionTo(StateEnum s)
{
State& currentState = this->states[this->currentState];
State& nextState = this->states[s];
const bool allowed = currentState.canTransitionTo(nextState);
cout << "Transition from " << currentState.get_name()
<< " to " << nextState.get_name() << ": ";
if (allowed) {
cout << "allowed." << "\\n";
} else {
cout << "denied." << "\\n";
}
if (allowed) {
currentState.leave();
nextState.enter();
this->currentState = s;
}
return allowed;
}
void StateMachine::print() const
{
cout << "States:\\n";
for (const auto& s : this->states) {
s.print();
}
}
''')
states_list = '{ "' + '", "'.join(state_names) + '" }'
d = dict(nb_states=nb_states, states_list=states_list)
fd.write(state_machine_impl.substitute(d))
def generate_cpp(states):
generate_header(states)
generate_impl(states)
generate_states_stub(states)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate code for state machine.')
parser.add_argument('--input', '-i', type=open, help='Graphviz file representing the state machine.')
args = parser.parse_args()
text = args.input.read()
graphs = pydot.graph_from_dot_data(text)
for graph in graphs:
#print(graph)
states = get_states(graph.get_edges)
print(states)
generate_cpp(states)