SHOGUN  3.2.1
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
GraphCut.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2014 Jiaolong Xu
8  * Copyright (C) 2014 Jiaolong Xu
9  */
10 
12 #include <shogun/io/SGIO.h>
13 
14 using namespace shogun;
15 
17  : CMAPInferImpl()
18 {
19  SG_UNSTABLE("CGraphCut::CGraphCut()", "\n");
20 
21  m_nodes = NULL;
22  m_edges = NULL;
23 }
24 
26  : CMAPInferImpl(fg)
27 {
28  ASSERT(m_fg != NULL);
29 
30  m_nodes = NULL;
31  m_edges = NULL;
32 
33  init();
34 }
35 
36 CGraphCut::CGraphCut(int32_t num_nodes, int32_t num_edges)
37  : CMAPInferImpl()
38 {
39  m_nodes = NULL;
40  m_edges = NULL;
41 
42  m_num_nodes = num_nodes;
43  // build s-t graph
44  build_st_graph(m_num_nodes, num_edges);
45 }
46 
48 {
49  if (m_nodes!=NULL)
50  SG_FREE(m_nodes);
51 
52  if (m_edges!=NULL)
53  SG_FREE(m_edges);
54 }
55 
56 void CGraphCut::init()
57 {
59 
61 
62  for (int32_t i = 0; i < cards.size(); i++)
63  {
64  if (cards[i] != 2)
65  {
66  SG_ERROR("This implementation of the graph cut optimizer supports only binary variables.");
67  }
68  }
69 
70  m_num_factors_at_order = SGVector<int32_t> (4);
71  m_num_factors_at_order.zero();
72 
73  for (int32_t i = 0; i < facs->get_num_elements(); i++)
74  {
75  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(i));
76 
77  int32_t num_vars = fac->get_num_vars();
78 
79  SG_UNREF(fac);
80 
81  if (num_vars > 3)
82  {
83  SG_ERROR("This implementation of the graph cut optimizer supports only factors of order <= 3.");
84  }
85 
86  ++m_num_factors_at_order[num_vars];
87 
88  }
89 
90  m_num_variables = m_fg->get_num_vars();
91  int32_t max_num_edges = m_num_factors_at_order[2] + 3 * m_num_factors_at_order[3];
92  m_num_nodes = m_num_variables + m_num_factors_at_order[3];
93 
94  // build s-t graph
95  build_st_graph(m_num_nodes, max_num_edges);
96 
97  for (int32_t j = 0; j < m_fg->get_num_factors(); j++)
98  {
99  CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(j));
100  add_factor(fac);
101  SG_UNREF(fac);
102  }
103 
104  SG_UNREF(facs);
105 }
106 
107 void CGraphCut::build_st_graph(int32_t num_nodes, int32_t num_edges)
108 {
109  m_num_nodes = num_nodes;
110 
111  // allocate s-t graph
112  m_nodes = SG_MALLOC(Node, m_num_nodes);
113  m_edges = SG_MALLOC(Edge, 2 * num_edges);
114  m_edges_last = m_edges;
115 
116  for (int32_t i = 0; i < m_num_nodes; i++)
117  {
118  m_nodes[i].id = i;
119  m_nodes[i].tree_cap = 0;
120  m_nodes[i].first = NULL;
121  }
122 
123  m_num_edges = 0; // m_num_edges will be counted in add_edge()
124  m_flow = 0;
125 }
126 
128 {
129  Node* node_i;
130 
131  m_active_first[0] = NULL;
132  m_active_last[0] = NULL;
133  m_active_first[1] = NULL;
134  m_active_last[1] = NULL;
135  m_orphan_first = NULL;
136  m_orphan_last = NULL;
137 
138  m_timestamp = 0;
139 
140  for (int32_t i = 0; i < m_num_nodes; i++)
141  {
142  node_i = m_nodes + i;
143  node_i->next = NULL;
144  node_i->timestamp = m_timestamp;
145 
146  if (node_i->tree_cap > 0)
147  {
148  // i is connected to the source
149  node_i->type_tree = SOURCE;
150  node_i->parent = TERMINAL_EDGE;
151  set_active(node_i);
152  node_i->dist_terminal = 1;
153  }
154  else if (node_i->tree_cap < 0)
155  {
156  // i is connected to the sink
157  node_i->type_tree = SINK;
158  node_i->parent = TERMINAL_EDGE;
159  set_active(node_i);
160  node_i->dist_terminal = 1;
161  }
162  else
163  {
164  node_i->parent = NULL;
165  }
166  }
167 }
168 
170 {
171  REQUIRE(assignment.size() == m_fg->get_cardinalities().size(),
172  "%s::inference(): the output assignment should be prepared as"
173  "the same size as variables!\n", get_name());
174 
175  // compute max flow
176  init_maxflow();
177  compute_maxflow();
178 
179  for (int32_t vi = 0; vi < assignment.size(); vi++)
180  {
181  assignment[vi] = get_assignment(vi) == SOURCE ? 0 : 1;
182  }
183 
184  m_map_energy = m_fg->evaluate_energy(assignment);
185  SG_DEBUG("fg.evaluate_energy(assignment) = %f\n", m_fg->evaluate_energy(assignment));
186  SG_DEBUG("minimized energy = %f\n", m_map_energy);
187 
188  return m_map_energy;
189 }
190 
191 void CGraphCut::add_factor(CFactor* factor)
192 {
193  SGVector<int32_t> fcards = factor->get_cardinalities();
194 
195  for (int32_t i = 0; i < fcards.size(); i++)
196  {
197  ASSERT(fcards[i] == 2);
198  }
199 
200  int32_t f_order = factor->get_num_vars();
201 
202  switch (f_order)
203  {
204  case 0:
205  break;
206  case 1:
207  {
208  SGVector<int32_t> fvars = factor->get_variables();
209  SGVector<float64_t> fenrgs = factor->get_energies();
210  ASSERT(fenrgs.size() == 2);
211  int32_t var = fvars[0];
212  float64_t v0 = fenrgs[0];
213  float64_t v1 = fenrgs[1];
214 
215  if (v0 < v1)
216  {
217  add_tweights(var, v1 - v0, 0);
218  }
219  else
220  {
221  add_tweights(var, 0, v0 - v1);
222  }
223  }
224  break;
225  case 2:
226  {
227  SGVector<int32_t> fvars = factor->get_variables();
228  SGVector<float64_t> fenrgs = factor->get_energies();
229  int32_t var0 = fvars[0];
230  int32_t var1 = fvars[1];
231  float64_t A = fenrgs[0]; //E{0,0} = {y_var0, y_var1}
232  float64_t B = fenrgs[2]; //E{0,1}
233  float64_t C = fenrgs[1]; //E{1,0}
234  float64_t D = fenrgs[3]; //E{1,1}
235 
236  // Added "truncation" code below to ensure regularity / submodularity
237  if (A + D > C + B)
238  {
239  SG_DEBUG("Truncation is applied to ensure regularity / submodularity.");
240 
241  float64_t delta = A + D - C - B;
242  float64_t subtrA = delta / 3;
243  A = A - subtrA;
244  C = C + subtrA;
245  B = B + (delta - subtrA * 2) + 0.0001; // for numeric issue
246  }
247 
248  // first variabe
249  if (C > A)
250  {
251  add_tweights(var0, C - A, 0);
252  }
253  else
254  {
255  add_tweights(var0, 0, A - C);
256  }
257  // second varibale
258  if (D > C)
259  {
260  add_tweights(var1, D - C, 0);
261  }
262  else
263  {
264  add_tweights(var1, 0, C - D);
265  }
266 
267  // submodular term
268  float64_t term = B + C - A - D;
269 
270  // term >= 0 is the regularity condition.
271  // It is the sufficient and necessary condition for any function to be graph-representable
272  if (term < 0)
273  {
274  SG_ERROR("\nRegularity condition is not satisfied\n");
275  }
276 
277  add_edge(var0, var1, term, 0);
278  }
279  break;
280  case 3:
281  {
282  SGVector<int32_t> fvars = factor->get_variables();
283  SGVector<float64_t> fenrgs = factor->get_energies();
284  int32_t var0 = fvars[0];
285  int32_t var1 = fvars[1];
286  int32_t var2 = fvars[2];
287  float64_t A = fenrgs[0]; //{0,0,0}
288  float64_t E = fenrgs[1]; //{1,0,0}
289  float64_t C = fenrgs[2]; //{0,1,0}
290  float64_t G = fenrgs[3]; //{1,1,0}
291  float64_t B = fenrgs[4]; //{0,0,1}
292  float64_t F = fenrgs[5]; //{1,0,1}
293  float64_t D = fenrgs[6]; //{0,1,1}
294  float64_t H = fenrgs[7]; //{1,1,1}
295 
296  int32_t id = get_tripleId(fvars);
297  float64_t P = (A + D + F + G) - (B + C + E + H);
298 
299  if (P >= 0.0)
300  {
301  if (F - B >= 0)
302  {
303  add_tweights(var0, F - B, 0);
304  }
305  else
306  {
307  add_tweights(var0, 0, B - F);
308  }
309 
310  if (G - E >= 0)
311  {
312  add_tweights(var1, G - E, 0);
313  }
314  else
315  {
316  add_tweights(var1, 0, E - G);
317  }
318 
319  if (D - C >= 0)
320  {
321  add_tweights(var2, D - C, 0);
322  }
323  else
324  {
325  add_tweights(var2, 0, C - D);
326  }
327 
328  add_edge(var1, var2, B + C - A - D, 0);
329  add_edge(var2, var0, B + E - A - F, 0);
330  add_edge(var0, var1, C + E - A - G, 0);
331 
332  add_edge(var0, id, P, 0);
333  add_edge(var1, id, P, 0);
334  add_edge(var2, id, P, 0);
335  add_edge(id, 1, P, 0);
336  }
337  else
338  {
339  if (C - G >= 0)
340  {
341  add_tweights(var0, 0, C - G);
342  }
343  else
344  {
345  add_tweights(var0, G - C, 0);
346  }
347 
348  if (B - D >= 0)
349  {
350  add_tweights(var1, 0, B - D);
351  }
352  else
353  {
354  add_tweights(var1, D - B, 0);
355  }
356 
357  if (E - F >= 0)
358  {
359  add_tweights(var2, 0, E - F);
360  }
361  else
362  {
363  add_tweights(var2, F - E, 0);
364  }
365 
366  add_edge(var2, var1, F + G - E - H, 0);
367  add_edge(var0, var2, D + G - C - H, 0);
368  add_edge(var1, var0, D + F - B - H, 0);
369 
370  add_edge(id, var0, -P, 0);
371  add_edge(id, var1, -P, 0);
372  add_edge(id, var2, -P, 0);
373  add_tweights(id, -P, 0);
374  }
375  }
376  break;
377  default:
378  SG_ERROR("This implementation of the graph cut optimizer does not support factors of order > 3.");
379  break;
380  }
381 }
382 
383 int32_t CGraphCut::get_tripleId(SGVector<int32_t> triple)
384 {
385  // search for triple in list
386  int32_t counter = m_num_variables;
387 
388  for (int32_t i = 0; i < m_triple_list.get_num_elements(); i++)
389  {
390  SGVector<int32_t> vec = m_triple_list[i];
391 
392  if (triple[0] == vec[0] && triple[1] == vec[1] && triple[2] == vec[2])
393  {
394  return counter;
395  }
396 
397  m_num_variables++;
398  }
399  // add triple to list
400  m_triple_list.push_back(triple);
401 
402  ASSERT(counter - m_num_variables < m_num_factors_at_order[3]);
403 
404  return counter;
405 }
406 
407 void CGraphCut::add_tweights(int32_t i, float64_t cap_source, float64_t cap_sink)
408 {
409  ASSERT(i >= 0 && i < m_num_nodes);
410 
411  float64_t delta = m_nodes[i].tree_cap;
412 
413  if (delta > 0)
414  {
415  cap_source += delta;
416  }
417  else
418  {
419  cap_sink -= delta;
420  }
421 
422  m_flow += (cap_source < cap_sink) ? cap_source : cap_sink;
423 
424  m_nodes[i].tree_cap = cap_source - cap_sink;
425 }
426 
427 void CGraphCut::add_edge(int32_t i, int32_t j, float64_t capacity, float64_t reverse_capacity)
428 {
429  ASSERT(i >= 0 && i < m_num_nodes);
430  ASSERT(j >= 0 && j < m_num_nodes);
431  ASSERT(i != j);
432  ASSERT(capacity >= 0);
433  ASSERT(reverse_capacity >= 0);
434 
435  Edge* e = m_edges_last++;
436  e->id = m_num_edges++;
437  Edge* e_rev = m_edges_last++;
438  e_rev->id = m_num_edges++;
439 
440  Node* node_i = m_nodes + i;
441  Node* node_j = m_nodes + j;
442 
443  e->reverse = e_rev;
444  e_rev->reverse = e;
445  e->next = node_i->first;
446  node_i->first = e;
447  e_rev->next = node_j->first;
448  node_j->first = e_rev;
449  e->head = node_j;
450  e_rev->head = node_i;
451  e->residual_capacity = capacity;
452  e_rev->residual_capacity = reverse_capacity;
453 }
454 
455 void CGraphCut::set_active(Node* node_i)
456 {
457  if (node_i->next == NULL)
458  {
459  // it's not in the list yet
460  if (m_active_last[1])
461  {
462  m_active_last[1]->next = node_i;
463  }
464  else
465  {
466  m_active_first[1] = node_i;
467  }
468 
469  m_active_last[1] = node_i;
470  node_i->next = node_i;
471  }
472 }
473 
474 Node* CGraphCut::next_active()
475 {
476  // Returns the next active node. If it is connected to the sink,
477  // it stays in the list, otherwise it is removed from the list.
478  Node* node_i;
479 
480  while (true)
481  {
482  if ((node_i = m_active_first[0]) == NULL)
483  {
484  m_active_first[0] = node_i = m_active_first[1];
485  m_active_last[0] = m_active_last[1];
486  m_active_first[1] = NULL;
487  m_active_last[1] = NULL;
488 
489  if (node_i == NULL)
490  {
491  return NULL;
492  }
493  }
494 
495  // remove it from the active list
496  if (node_i->next == node_i)
497  {
498  m_active_first[0] = NULL;
499  m_active_last[0] = NULL;
500  }
501  else
502  {
503  m_active_first[0] = node_i->next;
504  }
505 
506  node_i->next = NULL;
507 
508  // a node in the list is active iff it has a parent
509  if (node_i->parent != NULL)
510  {
511  return node_i;
512  }
513  }
514 }
515 
517 {
518  Node* current_node = NULL;
519  bool active_set_found = true;
520 
521  // start the main loop
522  while (true)
523  {
524  if (sg_io->get_loglevel() == MSG_DEBUG)
525  test_consistency(current_node);
526 
527  Edge* connecting_edge;
528 
529  // find a path from source to sink
530  active_set_found = grow(connecting_edge, current_node);
531 
532  if (!active_set_found)
533  {
534  break;
535  }
536 
537  if (connecting_edge == NULL)
538  {
539  continue;
540  }
541 
542  m_timestamp++;
543 
544  // augment that path
545  augment_path(connecting_edge);
546 
547  // adopt orphans, rebuild the search tree structure
548  adopt();
549  }
550 
551  if (sg_io->get_loglevel() == MSG_DEBUG)
552  test_consistency();
553 
554  return m_flow;
555 }
556 
557 bool CGraphCut::grow(Edge* &edge, Node* &current_node)
558 {
559  Node* node_i, *node_j;
560 
561  if ((node_i = current_node) != NULL)
562  {
563  node_i->next = NULL; // remove active flag
564 
565  if (node_i->parent == NULL)
566  {
567  node_i = NULL;
568  }
569  }
570 
571  if (node_i == NULL && (node_i = next_active()) == NULL)
572  {
573  return false;
574  }
575 
576  if (node_i->type_tree == SOURCE)
577  {
578  // grow source tree
579  for (edge = node_i->first; edge != NULL; edge = edge->next)
580  {
581  if (edge->residual_capacity)
582  {
583  node_j = edge->head;
584 
585  if (node_j->parent == NULL)
586  {
587  node_j->type_tree = SOURCE;
588  node_j->parent = edge->reverse;
589  node_j->timestamp = node_i->timestamp;
590  node_j->dist_terminal = node_i->dist_terminal + 1;
591  set_active(node_j);
592  }
593  else if (node_j->type_tree == SINK)
594  {
595  break;
596  }
597  else if (node_j->timestamp <= node_i->timestamp && node_j->dist_terminal > node_i->dist_terminal)
598  {
599  // heuristic - trying to make the distance from j to the source shorter
600  node_j->parent = edge->reverse;
601  node_j->timestamp = node_i->timestamp;
602  node_j->dist_terminal = node_i->dist_terminal + 1;
603  }
604  }
605  }
606  }
607  else
608  {
609  // grow sink tree
610  for (edge = node_i->first; edge != NULL; edge = edge->next)
611  {
612  if (edge->reverse->residual_capacity)
613  {
614  node_j = edge->head;
615 
616  if (node_j->parent == NULL)
617  {
618  node_j->type_tree = SINK;
619  node_j->parent = edge->reverse;
620  node_j->timestamp = node_i->timestamp;
621  node_j->dist_terminal = node_i->dist_terminal + 1;
622  set_active(node_j);
623  }
624  else if (node_j->type_tree == SOURCE)
625  {
626  edge = edge->reverse;
627  break;
628  }
629  else if (node_j->timestamp <= node_i->timestamp && node_j->dist_terminal > node_i->dist_terminal)
630  {
631  // heuristic - trying to make the distance from j to the sink shorter
632  node_j->parent = edge->reverse;
633  node_j->timestamp = node_i->timestamp;
634  node_j->dist_terminal = node_i->dist_terminal + 1;
635  }
636  }
637  }
638  } // grow sink tree
639 
640  if (edge != NULL)
641  {
642  node_i->next = node_i; // set active flag
643  current_node = node_i;
644  }
645  else
646  {
647  current_node = NULL;
648  }
649 
650  return true;
651 }
652 
653 void CGraphCut::augment_path(Edge* connecting_edge)
654 {
655  Node* node_i;
656  Edge* edge;
657  float64_t bottleneck;
658 
659  // 1. Finding bottleneck capacity
660  // 1a the source tree
661  bottleneck = connecting_edge->residual_capacity;
662 
663  for (node_i = connecting_edge->reverse->head; ; node_i = edge->head)
664  {
665  edge = node_i->parent;
666 
667  if (edge == TERMINAL_EDGE)
668  {
669  break;
670  }
671 
672  if (bottleneck > edge->reverse->residual_capacity)
673  {
674  bottleneck = edge->reverse->residual_capacity;
675  }
676  }
677 
678  if (bottleneck > node_i->tree_cap)
679  {
680  bottleneck = node_i->tree_cap;
681  }
682 
683  // 1b the sink tree
684  for (node_i = connecting_edge->head; ; node_i = edge->head)
685  {
686  edge = node_i->parent;
687 
688  if (edge == TERMINAL_EDGE)
689  {
690  break;
691  }
692 
693  if (bottleneck > edge->residual_capacity)
694  {
695  bottleneck = edge->residual_capacity;
696  }
697  }
698 
699  if (bottleneck > - node_i->tree_cap)
700  {
701  bottleneck = - node_i->tree_cap;
702  }
703 
704 
705  // 2. Augmenting
706  // 2a the source tree
707  connecting_edge->reverse->residual_capacity += bottleneck;
708  connecting_edge->residual_capacity -= bottleneck;
709 
710  for (node_i = connecting_edge->reverse->head; ; node_i = edge->head)
711  {
712  edge = node_i->parent;
713 
714  if (edge == TERMINAL_EDGE)
715  {
716  break;
717  }
718 
719  edge->residual_capacity += bottleneck;
720  edge->reverse->residual_capacity -= bottleneck;
721 
722  if (edge->reverse->residual_capacity == 0)
723  {
724  set_orphan_front(node_i); // add node_i to the beginning of the adoptation list
725  }
726  }
727 
728  node_i->tree_cap -= bottleneck;
729 
730  if (node_i->tree_cap == 0)
731  {
732  set_orphan_front(node_i); // add node_i to the beginning of the adoptation list
733  }
734 
735  // 2b the sink tree
736  for (node_i = connecting_edge->head; ; node_i = edge->head)
737  {
738  edge = node_i->parent;
739 
740  if (edge == TERMINAL_EDGE)
741  {
742  break;
743  }
744 
745  edge->reverse->residual_capacity += bottleneck;
746  edge->residual_capacity -= bottleneck;
747 
748  if (edge->residual_capacity == 0)
749  {
750  set_orphan_front(node_i);
751  }
752  }
753 
754  node_i->tree_cap += bottleneck;
755 
756  if (node_i->tree_cap == 0)
757  {
758  set_orphan_front(node_i);
759  }
760 
761  m_flow += bottleneck;
762 }
763 
764 void CGraphCut::adopt()
765 {
766  NodePtr* np, *np_next;
767  Node* node_i;
768 
769  while ((np = m_orphan_first) != NULL)
770  {
771  np_next = np->next;
772  np->next = NULL;
773 
774  while ((np = m_orphan_first) != NULL)
775  {
776  m_orphan_first = np->next;
777  node_i = np->ptr;
778  SG_FREE(np);
779 
780  if (m_orphan_first == NULL)
781  {
782  m_orphan_last = NULL;
783  }
784 
785  process_orphan(node_i, node_i->type_tree);
786  }
787 
788  m_orphan_first = np_next;
789  }
790 }
791 
792 void CGraphCut::set_orphan_front(Node* node_i)
793 {
794  NodePtr* np;
795  node_i->parent = ORPHAN_EDGE;
796  np = SG_MALLOC(NodePtr, 1);
797  np->ptr = node_i;
798  np->next = m_orphan_first;
799  m_orphan_first = np;
800 }
801 
802 void CGraphCut::set_orphan_rear(Node* node_i)
803 {
804  NodePtr* np;
805  node_i->parent = ORPHAN_EDGE;
806  np = SG_MALLOC(NodePtr, 1);
807  np->ptr = node_i;
808 
809  if (m_orphan_last != NULL)
810  {
811  m_orphan_last->next = np;
812  }
813  else
814  {
815  m_orphan_first = np;
816  }
817 
818  m_orphan_last = np;
819  np->next = NULL;
820 }
821 
822 void CGraphCut::process_orphan(Node* node_i, ETerminalType terminalType_tree)
823 {
824  Node* node_j;
825  Edge* edge0;
826  Edge* edge0_min = NULL;
827  Edge* edge;
828  int32_t d;
829  int32_t d_min = INFINITE_D;
830 
831  // trying to find a new parent
832  for (edge0 = node_i->first; edge0 != NULL; edge0 = edge0->next)
833  {
834  if ((terminalType_tree == SOURCE && edge0->reverse->residual_capacity) ||
835  (terminalType_tree == SINK && edge0->residual_capacity))
836  {
837  node_j = edge0->head;
838 
839  if (node_j->type_tree == terminalType_tree && (edge = node_j->parent) != NULL)
840  {
841  // check the origin of node_j
842  d = 0;
843  while (1)
844  {
845  if (node_j->timestamp == m_timestamp)
846  {
847  d += node_j->dist_terminal;
848  break;
849  }
850 
851  edge = node_j->parent;
852  d++;
853 
854  if (edge == TERMINAL_EDGE)
855  {
856  node_j->timestamp = m_timestamp;
857  node_j->dist_terminal = 1;
858  break;
859  }
860 
861  if (edge == ORPHAN_EDGE)
862  {
863  d = INFINITE_D;
864  break;
865  }
866 
867  node_j = edge->head;
868  } // while
869 
870  if (d < INFINITE_D) // node_j originates from the source, done
871  {
872  if (d < d_min)
873  {
874  edge0_min = edge0;
875  d_min = d;
876  }
877  // set marks along the path
878  for (node_j = edge0->head; node_j->timestamp != m_timestamp; node_j = node_j->parent->head)
879  {
880  node_j->timestamp = m_timestamp;
881  node_j->dist_terminal = d--;
882  }
883  }
884 
885  } // if node_j->type_tree
886  } // if(edge0->reverse->residual_capacity)
887  } // for edge0 = node_i->first
888 
889  if ((node_i->parent = edge0_min) != NULL)
890  {
891  node_i->timestamp = m_timestamp;
892  node_i->dist_terminal = d_min + 1;
893  }
894  else
895  {
896  // no parent is found, process neighbors
897  for (edge0 = node_i->first; edge0 != NULL; edge0 = edge0->next)
898  {
899  node_j = edge0->head;
900 
901  if (node_j->type_tree == terminalType_tree && (edge = node_j->parent) != NULL)
902  {
903  bool is_active_source = (terminalType_tree == SOURCE && edge0->reverse->residual_capacity);
904  bool is_active_sink = (terminalType_tree == SINK && edge0->residual_capacity);
905 
906  if (is_active_source || is_active_sink)
907  {
908  set_active(node_j);
909  }
910 
911  if (edge != TERMINAL_EDGE && edge != ORPHAN_EDGE && edge->head == node_i)
912  {
913  set_orphan_rear(node_j); // add node_j to the end of the adoptation list
914  }
915  }
916  } // for edge0 = node_i->first
917  }
918 }
919 
921 {
922  if (m_nodes[i].parent != NULL)
923  {
924  return m_nodes[i].type_tree;
925  }
926  else
927  {
928  return default_terminal;
929  }
930 }
931 
933 {
934  // print SOURCE-node_i and node_i->SINK edges
935  for (int32_t i = 0; i < m_num_nodes; i++)
936  {
937  Node* node_i = m_nodes + i;
938  if (node_i->parent == TERMINAL_EDGE)
939  {
940  if (node_i->type_tree == SOURCE)
941  {
942  SG_SPRINT("\n s -> %d, cost = %f", node_i->id, node_i->tree_cap);
943  }
944  else
945  {
946  SG_SPRINT("\n %d -> t, cost = %f", node_i->id, node_i->tree_cap);
947  }
948  }
949  }
950 
951  // print node_i->node_j edges
952  for (int32_t i = 0; i < m_num_edges; i++)
953  {
954  Edge* edge = m_edges + i;
955  SG_SPRINT("\n %d -> %d, cost = %f", edge->reverse->head->id, edge->head->id, edge->residual_capacity);
956  }
957 
958 }
959 
961 {
962  for (int32_t i = 0; i < m_num_nodes; i++)
963  {
964  Node* node_i = m_nodes + i;
965 
966  if (get_assignment(i) == SOURCE)
967  {
968  SG_SPRINT("\nNode %2d: S", node_i->id);
969  }
970  else
971  {
972  SG_SPRINT("\nNode %2d: T", node_i->id);
973  }
974  }
975 }
976 
977 void CGraphCut::test_consistency(Node* current_node)
978 {
979  Node* node_i;
980  Edge* edge;
981  int32_t num1 = 0;
982  int32_t num2 = 0;
983 
984  // test whether all nodes i with i->next!=NULL are indeed in the queue
985  for (int32_t i = 0; i < m_num_nodes; i++)
986  {
987  node_i = m_nodes + i;
988  if (node_i->next || node_i == current_node)
989  {
990  num1++;
991  }
992  }
993 
994  for (int32_t r = 0; r < 3; r++)
995  {
996  node_i = (r == 2) ? current_node : m_active_first[r];
997 
998  if (node_i)
999  {
1000  for (; ; node_i = node_i->next)
1001  {
1002  num2++;
1003 
1004  if (node_i->next == node_i)
1005  {
1006  if (r < 2)
1007  ASSERT(node_i == m_active_last[r])
1008  else
1009  ASSERT(node_i == current_node)
1010 
1011  break;
1012  }
1013  }
1014  }
1015  }
1016 
1017  ASSERT(num1 == num2);
1018 
1019  for (int32_t i = 0; i < m_num_nodes; i++)
1020  {
1021  node_i = m_nodes + i;
1022 
1023  // test whether all edges in seach trees are non-saturated
1024  if (node_i->parent == NULL) {}
1025  else if (node_i->parent == ORPHAN_EDGE) {}
1026  else if (node_i->parent == TERMINAL_EDGE)
1027  {
1028  if (node_i->type_tree == SOURCE)
1029  ASSERT(node_i->tree_cap > 0)
1030  else
1031  ASSERT(node_i->tree_cap < 0)
1032  }
1033  else
1034  {
1035  if (node_i->type_tree == SOURCE)
1036  ASSERT(node_i->parent->reverse->residual_capacity > 0)
1037  else
1038  ASSERT(node_i->parent->residual_capacity > 0)
1039  }
1040 
1041  // test whether passive nodes in search trees have neighbors in
1042  // a different tree through non-saturated edges
1043  if (node_i->parent && !node_i->next)
1044  {
1045  if (node_i->type_tree == SOURCE)
1046  {
1047  ASSERT(node_i->tree_cap >= 0);
1048 
1049  for (edge = node_i->first; edge; edge = edge->next)
1050  {
1051  if (edge->residual_capacity > 0)
1052  {
1053  ASSERT(edge->head->parent && edge->head->type_tree == SOURCE);
1054  }
1055  }
1056  }
1057  else
1058  {
1059  ASSERT(node_i->tree_cap <= 0);
1060 
1061  for (edge = node_i->first; edge; edge = edge->next)
1062  {
1063  if (edge->reverse->residual_capacity > 0)
1064  {
1065  ASSERT(edge->head->parent && (edge->head->type_tree == SINK));
1066  }
1067  }
1068  }
1069  }
1070  // test marking invariants
1071  if (node_i->parent && node_i->parent != ORPHAN_EDGE && node_i->parent != TERMINAL_EDGE)
1072  {
1073  ASSERT(node_i->timestamp <= node_i->parent->head->timestamp);
1074 
1075  if (node_i->timestamp == node_i->parent->head->timestamp)
1076  {
1077  ASSERT(node_i->dist_terminal > node_i->parent->head->dist_terminal);
1078  }
1079  }
1080  }
1081 }
virtual const char * get_name() const
Definition: GraphCut.h:103
float64_t residual_capacity
Definition: GraphCut.h:44
float64_t evaluate_energy(const SGVector< int32_t > state) const
static float64_t * H
Definition: libbmrm.cpp:26
#define INFINITE_D
Definition: GraphCut.h:27
#define ORPHAN_EDGE
Definition: GraphCut.h:25
const SGVector< int32_t > get_variables() const
Definition: Factor.cpp:107
int32_t id
Definition: GraphCut.h:49
int32_t get_num_factors() const
void print_assignment()
Definition: GraphCut.cpp:960
#define SG_UNREF(x)
Definition: SGRefObject.h:35
ETerminalType
Definition: GraphCut.h:31
Node * next
Definition: GraphCut.h:52
float64_t m_map_energy
Definition: GraphCut.h:276
CDynamicObjectArray * get_factors() const
#define SG_ERROR(...)
Definition: SGIO.h:130
#define REQUIRE(x,...)
Definition: SGIO.h:207
void add_edge(int32_t i, int32_t j, float64_t capacity, float64_t reverse_capacity)
Definition: GraphCut.cpp:427
Class CMAPInferImpl abstract class of MAP inference implementation.
Definition: MAPInference.h:97
void build_st_graph(int32_t num_nodes, int32_t num_edges)
Definition: GraphCut.cpp:107
SGVector< float64_t > get_energies() const
Definition: Factor.cpp:169
float64_t compute_maxflow()
Definition: GraphCut.cpp:516
CFactorGraph * m_fg
Definition: MAPInference.h:127
#define SG_SPRINT(...)
Definition: SGIO.h:181
#define ASSERT(x)
Definition: SGIO.h:202
void add_tweights(int32_t i, float64_t cap_source, float64_t cap_sink)
Definition: GraphCut.cpp:407
SGIO * sg_io
Definition: init.cpp:36
int32_t timestamp
Definition: GraphCut.h:53
int32_t size() const
Definition: SGVector.h:55
Edge * parent
Definition: GraphCut.h:51
int32_t get_num_vars() const
double float64_t
Definition: common.h:50
ETerminalType type_tree
Definition: GraphCut.h:55
ETerminalType get_assignment(int32_t i, ETerminalType default_termainl=SOURCE)
Definition: GraphCut.cpp:920
Edge * first
Definition: GraphCut.h:50
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
EMessageType get_loglevel() const
Definition: SGIO.cpp:286
#define TERMINAL_EDGE
Definition: GraphCut.h:24
const int32_t get_num_vars() const
Definition: Factor.cpp:112
#define SG_DEBUG(...)
Definition: SGIO.h:108
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:27
float64_t tree_cap
Definition: GraphCut.h:56
Node * head
Definition: GraphCut.h:41
virtual ~CGraphCut()
Definition: GraphCut.cpp:47
Node * ptr
Definition: GraphCut.h:62
Edge * reverse
Definition: GraphCut.h:43
Edge * next
Definition: GraphCut.h:42
CSGObject * get_element(int32_t index) const
SGVector< int32_t > get_cardinalities() const
NodePtr * next
Definition: GraphCut.h:63
int32_t dist_terminal
Definition: GraphCut.h:54
int32_t id
Definition: GraphCut.h:40
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:133
#define delta
Definition: sfa.cpp:23
Class CFactor A factor is defined on a clique in the factor graph. Each factor can have its own data...
Definition: Factor.h:89
const SGVector< int32_t > get_cardinalities() const
Definition: Factor.cpp:122
virtual float64_t inference(SGVector< int32_t > assignment)
Definition: GraphCut.cpp:169

SHOGUN 机器学习工具包 - 项目文档