#include <cctbx/boost_python/flex_fwd.h>
#include <cctbx/geometry_restraints/shared_wrapper_pickle.hpp>

#include <boost/python.hpp>

#include <boost/python/class.hpp>
#include <boost/python/args.hpp>
#include <boost/python/return_value_policy.hpp>
#include <boost/python/return_by_value.hpp>
#include <scitbx/array_family/boost_python/shared_wrapper.h>
#include <scitbx/array_family/selections.h>
#include <scitbx/stl/map_wrapper.h>
#include <scitbx/boost_python/is_polymorphic_workaround.h>
#include <cctbx/geometry_restraints/bond_misc.h>
#include <cctbx/geometry_restraints/proxy_select.h>

SCITBX_BOOST_IS_POLYMORPHIC_WORKAROUND(
  cctbx::geometry_restraints::bond_asu_proxy)

namespace scitbx { namespace af { namespace boost_python {

  using cctbx::geometry_restraints::bond_simple_proxy;

  template <>
  struct shared_wrapper_default_element<bond_simple_proxy>
  {
    static bond_simple_proxy
    get()
    {
      return bond_simple_proxy(
        af::tiny<unsigned, 2>(0, 0), 0., 0., 0., false, 0);
    }
  };

}}}

namespace cctbx { namespace geometry_restraints {
namespace {

  struct bond_params_wrappers : boost::python::pickle_suite
  {
    typedef bond_params w_t;

    static boost::python::tuple
      getinitargs(w_t const& self)
    {
        return boost::python::make_tuple(self.distance_ideal,
          self.weight,
          self.slack,
          self.limit,
          self.top_out,
          self.origin_id
          );
    }

    static void
    wrap()
    {
      using namespace boost::python;
      class_<w_t>("bond_params", no_init)
        .def(init<
        double,
        double,
        double,
        double,
        bool,
        unsigned char >((
        arg("distance_ideal"),
        arg("weight"),
        arg("slack") = 0,
        arg("limit") = -1.0,
        arg("top_out") = false,
        arg("origin_id") = 0)))
        .def("scale_weight", &w_t::scale_weight, (arg("factor")))
        .def_readwrite("distance_ideal", &w_t::distance_ideal)
        .def_readwrite("weight", &w_t::weight)
        .def_readwrite("slack", &w_t::slack)
        .def_readwrite("limit", &w_t::limit)
        .def_readwrite("top_out", &w_t::top_out)
        .def_readwrite("origin_id", &w_t::origin_id)
        .def_pickle(bond_params_wrappers())
        ;
    }
  };

  struct bond_params_table_wrappers
  {
    static void
    update(
      bond_params_table& self,
      unsigned i_seq,
      unsigned j_seq,
      bond_params const& params)
    {
      if (i_seq <= j_seq) self[i_seq][j_seq] = params;
      else                self[j_seq][i_seq] = params;
    }

    static double
    mean_residual(
      af::const_ref<bond_params_dict> const& self,
      double bond_stretch_factor)
    {
      double sum = 0;
      unsigned n = 0;
      for(unsigned i_seq=0;i_seq<self.size();i_seq++) {
        for(bond_params_dict::const_iterator
              dict_i=self[i_seq].begin();
              dict_i!=self[i_seq].end();
              dict_i++) {
          bond_params const& params = dict_i->second;
          double delta = params.distance_ideal * bond_stretch_factor;
          sum += params.weight * delta * delta;
        }
        n += self[i_seq].size();
      }
      if (n == 0) return 0;
      return sum / static_cast<double>(n);
    }

    static void
    wrap()
    {
      using namespace boost::python;
      typedef return_internal_reference<> rir;
      scitbx::stl::boost_python::map_wrapper<bond_params_dict, rir>::wrap(
        "bond_params_dict");
      typedef scitbx::af::boost_python::shared_wrapper<bond_params_dict, rir> shared_w_t;
      shared_w_t::wrap(
        "bond_params_table")
        .def("update", update, (
          arg("self"),
          arg("i_seq"),
          arg("j_seq"),
          arg("params")))
        .def("mean_residual", mean_residual, (
          arg("self"),
          arg("bond_stretch_factor")))
        .def("proxy_select",
          (bond_params_table(*)(
            af::const_ref<bond_params_dict> const&,
            af::const_ref<std::size_t> const&))
              scitbx::af::array_of_map_proxy_select, (
          arg("iselection")))
        .def("proxy_remove",
          (bond_params_table(*)(
            af::const_ref<bond_params_dict> const&,
            af::const_ref<bool> const&))
              scitbx::af::array_of_map_proxy_remove, (
          arg("selection")))
        .def_pickle(shared_wrapper_pickle_suite< shared_w_t::w_t >())
      ;
    }
  };

  struct bond_simple_proxy_wrappers : boost::python::pickle_suite
  {
    typedef bond_simple_proxy w_t;

    static boost::python::tuple
      getinitargs(w_t const& self)
    {
        return boost::python::make_tuple(self.i_seqs,
          self.distance_ideal,
          self.weight,
          self.slack,
          self.limit,
          self.top_out,
          self.origin_id
          );
    }

    static void
    wrap()
    {
      using namespace boost::python;
      typedef return_value_policy<return_by_value> rbv;
      class_<w_t, bases<bond_params> >("bond_simple_proxy", no_init)
        .def(init<
          af::tiny<unsigned, 2> const&,
          double,
          double,
          double,
          double,
          bool,
          unsigned char >((
            arg("i_seqs"),
            arg("distance_ideal"),
            arg("weight"),
            arg("slack")=0,
            arg("limit")=-1.0,
            arg("top_out")=false,
            arg("origin_id")=0)))
        .def(init<
          af::tiny<unsigned, 2> const&, double, double, double >((
            arg("i_seqs"),
            arg("initial_eq_distance"),
            arg("eq_distance"),
            arg("weight"))))
        .def(init<
          af::tiny<unsigned, 2> const&,
          sgtbx::rt_mx const&,
          double,
          double,
          double,
          double,
          bool,
          unsigned char >((
            arg("i_seqs"),
            arg("rt_mx_ji"),
            arg("distance_ideal"),
            arg("weight"),
            arg("slack")=0,
            arg("limit")=-1.0,
            arg("top_out")=false,
            arg("origin_id")=0)))
        .def(init<
          af::tiny<unsigned, 2> const&,
          sgtbx::rt_mx const&,
          bond_params const&>((
            arg("i_seqs"),
            arg("rt_mx_ji"),
            arg("params"))))
        .def("sort_i_seqs", &w_t::sort_i_seqs)
        .add_property("i_seqs", make_getter(&w_t::i_seqs, rbv()))
        .add_property("rt_mx_ji", make_getter(&w_t::rt_mx_ji, rbv()))
        .def_pickle(bond_simple_proxy_wrappers())
        ;
      {
        typedef return_internal_reference<> rir;
        typedef scitbx::af::boost_python::shared_wrapper<bond_simple_proxy, rir> shared_w_t;
        shared_w_t::wrap(
          "shared_bond_simple_proxy")
          .def("proxy_select",
            (af::shared<w_t>(*)(
              af::const_ref<w_t> const&,
              unsigned char))
                shared_proxy_select_origin, (
            arg("origin_id")))
          .def_pickle(shared_wrapper_pickle_suite< shared_w_t::w_t >())
        ;
      }
    }
  };

  struct bond_sym_proxy_wrappers : boost::python::pickle_suite
  {
    typedef bond_sym_proxy w_t;

    static boost::python::tuple
      getinitargs(w_t const& self)
    {
        return boost::python::make_tuple(self.i_seqs,
          self.rt_mx_ji,
          self.distance_ideal,
          self.weight,
          self.slack,
          self.limit,
          self.top_out,
          self.origin_id
          );
    }

    static void
    wrap()
    {
      using namespace boost::python;
      typedef return_value_policy<return_by_value> rbv;
      class_<w_t, bases<bond_params> >("bond_sym_proxy", no_init)
        .def(init<
          af::tiny<unsigned, 2> const&,
          sgtbx::rt_mx const&,
          double,
          double,
          double,
          double,
          bool,
          unsigned char >((
            arg("i_seqs"),
            arg("rt_mx_ji"),
            arg("distance_ideal"),
            arg("weight"),
            arg("slack")=0,
            arg("limit")=-1.0,
            arg("top_out")=false,
            arg("origin_id")=0)))
        .add_property("i_seqs", make_getter(&w_t::i_seqs, rbv()))
        .def_readonly("rt_mx_ji", &w_t::rt_mx_ji)
        .def_pickle(bond_sym_proxy_wrappers())
        ;
    }
  };

  struct bond_asu_proxy_wrappers : boost::python::pickle_suite
  {
    typedef bond_asu_proxy w_t;

    static boost::python::tuple
      getinitargs(w_t const& self)
    {
        return boost::python::make_tuple(self.init_pair,
          self.distance_ideal,
          self.weight,
          self.slack,
          self.limit,
          self.top_out,
          self.origin_id
          );
    }

    static void
    wrap()
    {
      using namespace boost::python;
      class_<w_t, bases<bond_params, asu_mapping_index_pair> >(
            "bond_asu_proxy", no_init)
        .def(init<
          asu_mapping_index_pair const&,
          double,
          double,
          double,
          double,
          bool,
          unsigned char >((
            arg("pair"),
            arg("distance_ideal"),
            arg("weight"),
            arg("slack")=0,
            arg("limit")=-1.0,
            arg("top_out")=false,
            arg("origin_id")=0)))
        .def(init<asu_mapping_index_pair const&, bond_params const&>(
          (arg("pair"), arg("params"))))
        .def("as_simple_proxy", &w_t::as_simple_proxy)
        .def_pickle(bond_asu_proxy_wrappers())
        ;
      {
        typedef return_internal_reference<> rir;
        typedef scitbx::af::boost_python::shared_wrapper<bond_asu_proxy, rir> shared_w_t;
        shared_w_t::wrap(
          "shared_bond_asu_proxy")
          .def("proxy_select",
            (af::shared<w_t>(*)(
              af::const_ref<w_t> const&,
              unsigned char))
                shared_proxy_select_origin, (
            arg("origin_id")))
          .def_pickle(shared_wrapper_pickle_suite< shared_w_t::w_t >())
        ;
      }
    }
  };

  struct bond_wrappers : boost::python::pickle_suite
  {
    typedef bond w_t;

    static boost::python::tuple
      getinitargs(w_t const& self)
    {
        return boost::python::make_tuple(self.sites,
          self.distance_ideal,
          self.weight,
          self.slack,
          self.limit,
          self.top_out,
          self.origin_id
          );
    }

    static void
    wrap()
    {
      using namespace boost::python;
      typedef return_value_policy<return_by_value> rbv;
      class_<w_t, bases<bond_params, cctbx::geometry::distance<double> > >(
        "bond", no_init)
        .def(init<
          af::tiny<scitbx::vec3<double>, 2> const&,
          double,
          double,
          double,
          double,
          bool,
          unsigned char >((
            arg("sites"),
            arg("distance_ideal"),
            arg("weight"),
            arg("slack")=0,
            arg("limit")=-1.0,
            arg("top_out")=false,
            arg("origin_id")=0)))
        .def(init<af::const_ref<scitbx::vec3<double> > const&,
                  bond_simple_proxy const&>(
          (arg("sites_cart"), arg("proxy"))))
        .def(init<uctbx::unit_cell const&,
                  af::const_ref<scitbx::vec3<double> > const&,
                  bond_simple_proxy const&>(
          (arg("unit_cell"), arg("sites_cart"), arg("proxy"))))
        .def(init<uctbx::unit_cell const&,
                  af::const_ref<scitbx::vec3<double> > const&,
                  bond_sym_proxy const&>(
          (arg("unit_cell"), arg("sites_cart"), arg("proxy"))))
        .def(init<af::const_ref<scitbx::vec3<double> > const&,
                  asu_mappings const&,
                  bond_asu_proxy const&>(
          (arg("sites_cart"), arg("asu_mappings"), arg("proxy"))))
        .add_property("sites", make_getter(&w_t::sites, rbv()))
        .def_readonly("distance_model", &w_t::distance_model)
        .def_readonly("delta", &w_t::delta)
        .def_readonly("delta_slack", &w_t::delta_slack)
        .def_readonly("origin_id", &w_t::origin_id)
        .def("residual", &w_t::residual)
        .def("gradients", &w_t::gradients)
        .def_pickle(bond_wrappers())
        ;
    }
  };

  void
  wrap_all()
  {
    using namespace boost::python;
    bond_params_wrappers::wrap();
    bond_params_table_wrappers::wrap();
    bond_simple_proxy_wrappers::wrap();
    bond_sym_proxy_wrappers::wrap();
    bond_asu_proxy_wrappers::wrap();
    bond_wrappers::wrap();
    def("extract_bond_params", extract_bond_params, (
      (arg("n_seq"), arg("bond_simple_proxies"))));
    def("bond_distances_model",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&))
      bond_distances_model,
      (arg("sites_cart"), arg("proxies")));
    def("bond_deltas",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&))
      bond_deltas,
      (arg("sites_cart"), arg("proxies")));
    def("bond_deltas",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&,
        unsigned char))
      bond_deltas,
      (arg("sites_cart"), arg("proxies"), arg("origin_id")));
    def("bond_residuals",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&))
      bond_residuals,
      (arg("sites_cart"), arg("proxies")));
    def("bond_residuals",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&,
        unsigned char origin_id))
      bond_residuals,
      (arg("sites_cart"), arg("proxies"), arg("origin_id")));
    def("bond_residual_sum",
      (double(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&,
        af::ref<scitbx::vec3<double> > const&))
      bond_residual_sum,
      (arg("sites_cart"), arg("proxies"), arg("gradient_array")));
    def("bond_residual_sum",
      (double(*)(
        uctbx::unit_cell const&,
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&,
        af::ref<scitbx::vec3<double> > const&))
      bond_residual_sum, (
        arg("unit_cell"),
        arg("sites_cart"),
        arg("proxies"),
        arg("gradient_array")));
    def("bond_deltas",
      (af::shared<double>(*)(
        uctbx::unit_cell const&,
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&))
      bond_deltas,
      (arg("unit_cell"), arg("sites_cart"), arg("proxies")));
    def("bond_deltas",
      (af::shared<double>(*)(
        uctbx::unit_cell const&,
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&,
        unsigned char))
      bond_deltas,
      (arg("unit_cell"), arg("sites_cart"), arg("proxies"), arg("origin_id")));
    def("bond_residuals",
      (af::shared<double>(*)(
        uctbx::unit_cell const&,
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&))
      bond_residuals,
      (arg("unit_cell"), arg("sites_cart"), arg("proxies")));
    def("bond_residuals",
      (af::shared<double>(*)(
        uctbx::unit_cell const&,
        af::const_ref<scitbx::vec3<double> > const&,
        af::const_ref<bond_simple_proxy> const&,
        unsigned char))
      bond_residuals,
      (arg("unit_cell"), arg("sites_cart"), arg("proxies"), arg("origin_id")));
    def("bond_distances_model",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        bond_sorted_asu_proxies_base const&))
      bond_distances_model,
      (arg("sites_cart"), arg("sorted_asu_proxies")));
    def("bond_deltas",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        bond_sorted_asu_proxies_base const&))
      bond_deltas,
      (arg("sites_cart"), arg("sorted_asu_proxies")));
    def("bond_deltas",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        bond_sorted_asu_proxies_base const&,
        unsigned char))
      bond_deltas,
      (arg("sites_cart"), arg("sorted_asu_proxies"), arg("origin_id")));
    def("bond_residuals",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        bond_sorted_asu_proxies_base const&))
      bond_residuals,
      (arg("sites_cart"), arg("sorted_asu_proxies")));
    def("bond_residuals",
      (af::shared<double>(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        bond_sorted_asu_proxies_base const&,
        unsigned char))
      bond_residuals,
      (arg("sites_cart"), arg("sorted_asu_proxies"), arg("origin_id")));
    def("bond_residual_sum",
      (double(*)(
        af::const_ref<scitbx::vec3<double> > const&,
        bond_sorted_asu_proxies_base const&,
        af::ref<scitbx::vec3<double> > const&,
        bool)) bond_residual_sum, (
          arg("sites_cart"),
          arg("sorted_asu_proxies"),
          arg("gradient_array"),
          arg("disable_cache")=false));

    def("home_restraints_summation_skip_special_positions",
      home_restraints_summation_skip_special_positions, (
        arg("sites_cart"),
        arg("gradients"),
        arg("site_symmetry_table_indices"),
        arg("home_sites_cart"),
        arg("iselection"),
        arg("weight"),
        arg("slack")));
  }

} // namespace <anonymous>

namespace boost_python {

  void
  wrap_bond() { wrap_all(); }

}}} // namespace cctbx::geometry_restraints::boost_python
