#!/usr/bin/env python3 import argparse from string import Template import pydot def get_states(edges): states = dict() for edge in graph.get_edges(): source = edge.get_source() dest = edge.get_destination() if source not in states.keys(): states[source] = [] if dest not in states.keys(): states[dest] = [] states[source].append(dest) return states def generate_states_stub(state_names): # Specialization with open('state_impl.cpp', 'w') as fd: fd.write('''#include "state_machine.h" ''') state_spe_impl = Template(''' // // State $name // void State$name::enter() { } void State$name::leave() { } ''') for state_name in state_names: d = dict(name=state_name) fd.write(state_spe_impl.substitute(d)) def generate_header(states): state_names = states.keys() nb_states = len(state_names) with open("state_machine.h", "w") as fd: fd.write('''#include #include ''') # Enum fd.write('enum StateEnum {\n ') fd.write(',\n '.join(state_names)) fd.write('\n};\n') # State state_template = Template(''' class State { public: State() {}; ~State() {}; bool canTransitionTo(const State& state) const; std::string get_name() const; void print() const; StateEnum value; virtual void enter() {}; virtual void leave() {}; protected: std::array transitions; private: }; ''') d = dict(nb_states=nb_states) fd.write(state_template.substitute(d)) # Specialization state_spe_template = Template(''' class State$name: public State { public: State$name() { value = $name; transitions = $transitions; }; void enter(); void leave(); }; ''') for state_name in state_names: transitions = ['true' if x in states[state_name] else 'false' for x in state_names] str_transitions = '{ ' + ', '.join(transitions) + ' }' d = dict(name=state_name, transitions=str_transitions) fd.write(state_spe_template.substitute(d)) # State machine state_machine_template = Template(''' class StateMachine { public: StateMachine(StateEnum initial_state):currentState(initial_state) {}; ~StateMachine() {}; bool transitionTo(StateEnum s); void print() const; private: StateEnum currentState; std::array states = $states_array; }; ''') states_class = ['State' + x + '()' for x in state_names] states_array = '{ ' + ', '.join(states_class) + ' }' d = dict(nb_states=nb_states, states_array=states_array) fd.write(state_machine_template.substitute(d)) def generate_impl(states): state_names = states.keys() nb_states = len(state_names) with open('state_machine.cpp', 'w') as fd: state_machine_impl = Template('''#include "state_machine.h" #include #include #include using namespace std; bool State::canTransitionTo(const State& state) const { return this->transitions[state.value]; } std::string State::get_name() const { const array names = $states_list; return names[this->value]; } void State::print() const { const array names = $states_list; const string& name = names[this->value]; cout << "State " << this->get_name() << " (" << this->value << ")\\n"; } bool StateMachine::transitionTo(StateEnum s) { State& currentState = this->states[this->currentState]; State& nextState = this->states[s]; const bool allowed = currentState.canTransitionTo(nextState); cout << "Transition from " << currentState.get_name() << " to " << nextState.get_name() << ": "; if (allowed) { cout << "allowed." << "\\n"; } else { cout << "denied." << "\\n"; } if (allowed) { currentState.leave(); nextState.enter(); this->currentState = s; } return allowed; } void StateMachine::print() const { cout << "States:\\n"; for (const auto& s : this->states) { s.print(); } } ''') states_list = '{ "' + '", "'.join(state_names) + '" }' d = dict(nb_states=nb_states, states_list=states_list) fd.write(state_machine_impl.substitute(d)) def generate_cpp(states): generate_header(states) generate_impl(states) generate_states_stub(states) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Generate code for state machine.') parser.add_argument('--input', '-i', type=open, help='Graphviz file representing the state machine.') args = parser.parse_args() text = args.input.read() graphs = pydot.graph_from_dot_data(text) for graph in graphs: #print(graph) states = get_states(graph.get_edges) print(states) generate_cpp(states)