AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2012 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MIN_CuT_Hh_
#define DLIB_MIN_CuT_Hh_
#include "min_cut_abstract.h"
#include "../matrix.h"
#include "general_flow_graph.h"
#include "../is_kind.h"
#include <iostream>
#include <fstream>
#include <deque>
// ----------------------------------------------------------------------------------------
namespace dlib
{
typedef unsigned char node_label;
// ----------------------------------------------------------------------------------------
const node_label SOURCE_CUT = 0;
const node_label SINK_CUT = 254;
const node_label FREE_NODE = 255;
// ----------------------------------------------------------------------------------------
template <typename flow_graph>
typename disable_if<is_directed_graph<flow_graph>,typename flow_graph::edge_type>::type
graph_cut_score (
const flow_graph& g
)
{
typedef typename flow_graph::edge_type edge_weight_type;
edge_weight_type score = 0;
typedef typename flow_graph::out_edge_iterator out_edge_iterator;
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
{
if (g.get_label(i) != SOURCE_CUT)
continue;
for (out_edge_iterator n = g.out_begin(i); n != g.out_end(i); ++n)
{
if (g.get_label(g.node_id(n)) != SOURCE_CUT)
{
score += g.get_flow(n);
}
}
}
return score;
}
template <typename directed_graph>
typename enable_if<is_directed_graph<directed_graph>,typename directed_graph::edge_type>::type
graph_cut_score (
const directed_graph& g
)
{
return graph_cut_score(dlib::impl::general_flow_graph<const directed_graph>(g));
}
// ----------------------------------------------------------------------------------------
class min_cut
{
public:
min_cut()
{
}
min_cut( const min_cut& )
{
// Intentionally left empty since all the member variables
// don't logically contribute to the state of this object.
// This copy constructor is here to explicitly avoid the overhead
// of copying these transient variables.
}
template <
typename directed_graph
>
typename enable_if<is_directed_graph<directed_graph> >::type operator() (
directed_graph& g,
const unsigned long source_node,
const unsigned long sink_node
) const
{
DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
"\t void min_cut::operator()"
<< "\n\t Invalid arguments were given to this function."
);
DLIB_ASSERT(graph_has_symmetric_edges(g) == true,
"\t void min_cut::operator()"
<< "\n\t Invalid arguments were given to this function."
);
dlib::impl::general_flow_graph<directed_graph> temp(g);
(*this)(temp, source_node, sink_node);
}
template <
typename flow_graph
>
typename disable_if<is_directed_graph<flow_graph> >::type operator() (
flow_graph& g,
const unsigned long source_node,
const unsigned long sink_node
) const
{
#ifdef ENABLE_ASSERTS
DLIB_ASSERT(source_node != sink_node &&
source_node < g.number_of_nodes() &&
sink_node < g.number_of_nodes(),
"\t void min_cut::operator()"
<< "\n\t Invalid arguments were given to this function."
<< "\n\t g.number_of_nodes(): " << g.number_of_nodes()
<< "\n\t source_node: " << source_node
<< "\n\t sink_node: " << sink_node
<< "\n\t this: " << this
);
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
{
typename flow_graph::out_edge_iterator j, end = g.out_end(i);
for (j = g.out_begin(i); j != end; ++j)
{
const unsigned long jj = g.node_id(j);
DLIB_ASSERT(g.get_flow(i,jj) >= 0,
"\t void min_cut::operator()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t i: "<< i
<< "\n\t jj: "<< jj
<< "\n\t g.get_flow(i,jj): "<< g.get_flow(i,jj)
<< "\n\t this: "<< this
);
}
}
#endif
parent.clear();
active.clear();
orphans.clear();
typedef typename flow_graph::edge_type edge_type;
COMPILE_TIME_ASSERT(is_signed_type<edge_type>::value);
typedef typename flow_graph::out_edge_iterator out_edge_iterator;
typedef typename flow_graph::in_edge_iterator in_edge_iterator;
// initialize labels
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
g.set_label(i, FREE_NODE);
g.set_label(source_node, SOURCE_CUT);
g.set_label(sink_node, SINK_CUT);
// used to indicate "no parent"
const unsigned long no_parent = g.number_of_nodes();
parent.assign(g.number_of_nodes(), no_parent);
time = 1;
dist.assign(g.number_of_nodes(), 0);
ts.assign(g.number_of_nodes(), time);
active.push_back(source_node);
active.push_back(sink_node);
in_edge_iterator in_begin = g.in_begin(active[0]);
out_edge_iterator out_begin = g.out_begin(active[0]);
unsigned long source_side, sink_side;
while (grow(g,source_side,sink_side, in_begin, out_begin))
{
++time;
ts[source_node] = time;
ts[sink_node] = time;
augment(g, source_node, sink_node, source_side, sink_side);
adopt(g, source_node, sink_node);
}
}
private:
unsigned long distance_to_origin (
const unsigned long no_parent,
unsigned long p,
unsigned long
) const
{
unsigned long start = p;
unsigned long count = 0;
while (p != no_parent)
{
if (ts[p] == time)
{
count += dist[p];
unsigned long count_down = count;
// adjust the dist and ts for the nodes on this path.
while (start != p)
{
ts[start] = time;
dist[start] = count_down;
--count_down;
start = parent[start];
}
return count;
}
p = parent[p];
++count;
}
return std::numeric_limits<unsigned long>::max();
}
template <typename flow_graph>
void adopt (
flow_graph& g,
const unsigned long source,
const unsigned long sink
) const
{
typedef typename flow_graph::out_edge_iterator out_edge_iterator;
typedef typename flow_graph::in_edge_iterator in_edge_iterator;
// used to indicate "no parent"
const unsigned long no_parent = g.number_of_nodes();
while (orphans.size() > 0)
{
const unsigned long p = orphans.back();
orphans.pop_back();
const unsigned char label_p = g.get_label(p);
// Try to find a valid parent for p.
if (label_p == SOURCE_CUT)
{
const in_edge_iterator begin(g.in_begin(p));
const in_edge_iterator end(g.in_end(p));
unsigned long best_dist = std::numeric_limits<unsigned long>::max();
unsigned long best_node = 0;
for(in_edge_iterator q = begin; q != end; ++q)
{
const unsigned long id = g.node_id(q);
if (g.get_label(id) != label_p || g.get_flow(q) <= 0 )
continue;
unsigned long temp = distance_to_origin(no_parent, id,source);
if (temp < best_dist)
{
best_dist = temp;
best_node = id;
}
}
if (best_dist != std::numeric_limits<unsigned long>::max())
{
parent[p] = best_node;
dist[p] = dist[best_node] + 1;
ts[p] = time;
}
// if we didn't find a parent for p
if (parent[p] == no_parent)
{
for(in_edge_iterator q = begin; q != end; ++q)
{
const unsigned long id = g.node_id(q);
if (g.get_label(id) != SOURCE_CUT)
continue;
if (g.get_flow(q) > 0)
active.push_back(id);
if (parent[id] == p)
{
parent[id] = no_parent;
orphans.push_back(id);
}
}
g.set_label(p, FREE_NODE);
}
}
else
{
unsigned long best_node = 0;
unsigned long best_dist = std::numeric_limits<unsigned long>::max();
const out_edge_iterator begin(g.out_begin(p));
const out_edge_iterator end(g.out_end(p));
for(out_edge_iterator q = begin; q != end; ++q)
{
const unsigned long id = g.node_id(q);
if (g.get_label(id) != label_p || g.get_flow(q) <= 0)
continue;
unsigned long temp = distance_to_origin(no_parent, id,sink);
if (temp < best_dist)
{
best_dist = temp;
best_node = id;
}
}
if (best_dist != std::numeric_limits<unsigned long>::max())
{
parent[p] = best_node;
dist[p] = dist[best_node] + 1;
ts[p] = time;
}
// if we didn't find a parent for p
if (parent[p] == no_parent)
{
for(out_edge_iterator q = begin; q != end; ++q)
{
const unsigned long id = g.node_id(q);
if (g.get_label(id) != SINK_CUT)
continue;
if (g.get_flow(q) > 0)
active.push_back(id);
if (parent[id] == p)
{
parent[id] = no_parent;
orphans.push_back(id);
}
}
g.set_label(p, FREE_NODE);
}
}
}
}
template <typename flow_graph>
void augment (
flow_graph& g,
const unsigned long& source,
const unsigned long& sink,
const unsigned long& source_side,
const unsigned long& sink_side
) const
{
typedef typename flow_graph::edge_type edge_type;
// used to indicate "no parent"
const unsigned long no_parent = g.number_of_nodes();
unsigned long s = source_side;
unsigned long t = sink_side;
edge_type min_cap = g.get_flow(s,t);
// find the bottleneck capacity on the current path.
// check from source_side back to the source for the min capacity link.
t = s;
while (t != source)
{
s = parent[t];
const edge_type temp = g.get_flow(s, t);
if (temp < min_cap)
{
min_cap = temp;
}
t = s;
}
// check from sink_side back to the sink for the min capacity link
s = sink_side;
while (s != sink)
{
t = parent[s];
const edge_type temp = g.get_flow(s, t);
if (temp < min_cap)
{
min_cap = temp;
}
s = t;
}
// now push the max possible amount of flow though the path
s = source_side;
t = sink_side;
g.adjust_flow(t,s, min_cap);
// trace back towards the source
t = s;
while (t != source)
{
s = parent[t];
g.adjust_flow(t,s, min_cap);
if (g.get_flow(s,t) <= 0)
{
parent[t] = no_parent;
orphans.push_back(t);
}
t = s;
}
// trace back towards the sink
s = sink_side;
while (s != sink)
{
t = parent[s];
g.adjust_flow(t,s, min_cap);
if (g.get_flow(s,t) <= 0)
{
parent[s] = no_parent;
orphans.push_back(s);
}
s = t;
}
}
template <typename flow_graph>
bool grow (
flow_graph& g,
unsigned long& source_side,
unsigned long& sink_side,
typename flow_graph::in_edge_iterator& in_begin,
typename flow_graph::out_edge_iterator& out_begin
) const
/*!
ensures
- if (an augmenting path was found) then
- returns true
- (#source_side, #sink_side) == the point where the two trees meet.
#source_side is part of the source tree and #sink_side is part of
the sink tree.
- else
- returns false
!*/
{
typedef typename flow_graph::out_edge_iterator out_edge_iterator;
typedef typename flow_graph::in_edge_iterator in_edge_iterator;
while (active.size() != 0)
{
// pick an active node
const unsigned long A = active[0];
const unsigned char label_A = g.get_label(A);
// process its neighbors
if (label_A == SOURCE_CUT)
{
const out_edge_iterator out_end = g.out_end(A);
for(out_edge_iterator& i = out_begin; i != out_end; ++i)
{
if (g.get_flow(i) > 0)
{
const unsigned long id = g.node_id(i);
const unsigned char label_i = g.get_label(id);
if (label_i == FREE_NODE)
{
active.push_back(id);
g.set_label(id,SOURCE_CUT);
parent[id] = A;
ts[id] = ts[A];
dist[id] = dist[A] + 1;
}
else if (label_A != label_i)
{
source_side = A;
sink_side = id;
return true;
}
else if (is_closer(A, id))
{
parent[id] = A;
ts[id] = ts[A];
dist[id] = dist[A] + 1;
}
}
}
}
else if (label_A == SINK_CUT)
{
const in_edge_iterator in_end = g.in_end(A);
for(in_edge_iterator& i = in_begin; i != in_end; ++i)
{
if (g.get_flow(i) > 0)
{
const unsigned long id = g.node_id(i);
const unsigned char label_i = g.get_label(id);
if (label_i == FREE_NODE)
{
active.push_back(id);
g.set_label(id,SINK_CUT);
parent[id] = A;
ts[id] = ts[A];
dist[id] = dist[A] + 1;
}
else if (label_A != label_i)
{
sink_side = A;
source_side = id;
return true;
}
else if (is_closer(A, id))
{
parent[id] = A;
ts[id] = ts[A];
dist[id] = dist[A] + 1;
}
}
}
}
active.pop_front();
if (active.size() != 0)
{
in_begin = g.in_begin(active[0]);
out_begin = g.out_begin(active[0]);
}
}
return false;
}
inline bool is_closer (
unsigned long p,
unsigned long q
) const
{
// return true if p is closer to a terminal than q
return ts[q] <= ts[p] && dist[q] > dist[p];
}
mutable std::vector<uint32> dist;
mutable std::vector<uint32> ts;
mutable uint32 time;
mutable std::vector<unsigned long> parent;
mutable std::deque<unsigned long> active;
mutable std::vector<unsigned long> orphans;
};
// ----------------------------------------------------------------------------------------
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_MIN_CuT_Hh_