Aging_MouthReplace / dlibs /dlib /conditioning_class /conditioning_class_kernel_4.h
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
16 kB
// Copyright (C) 2004 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_
#define DLIB_CONDITIONING_CLASS_KERNEl_4_
#include "conditioning_class_kernel_abstract.h"
#include "../assert.h"
#include "../algs.h"
namespace dlib
{
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
class conditioning_class_kernel_4
{
/*!
REQUIREMENTS ON pool_size
pool_size > 0
this will be the number of nodes contained in our memory pool
REQUIREMENTS ON mem_manager
mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h
INITIAL VALUE
total == 1
escapes == 1
next == 0
CONVENTION
get_total() == total
get_count(alphabet_size-1) == escapes
if (next != 0) then
next == pointer to the start of a linked list and the linked list
is terminated by a node with a next pointer of 0.
get_count(symbol) == node::count for the node where node::symbol==symbol
or 0 if no such node currently exists.
if (there is a node for the symbol) then
LOW_COUNT(symbol) == the sum of all node's counts in the linked list
up to but not including the node for the symbol.
get_memory_usage() == global_state.memory_usage
!*/
struct node
{
unsigned short symbol;
unsigned short count;
node* next;
};
public:
class global_state_type
{
public:
global_state_type (
) :
memory_usage(pool_size*sizeof(node)+sizeof(global_state_type))
{}
private:
unsigned long memory_usage;
typename mem_manager::template rebind<node>::other pool;
friend class conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>;
};
conditioning_class_kernel_4 (
global_state_type& global_state_
);
~conditioning_class_kernel_4 (
);
void clear(
);
bool increment_count (
unsigned long symbol,
unsigned short amount = 1
);
unsigned long get_count (
unsigned long symbol
) const;
inline unsigned long get_total (
) const;
unsigned long get_range (
unsigned long symbol,
unsigned long& low_count,
unsigned long& high_count,
unsigned long& total_count
) const;
void get_symbol (
unsigned long target,
unsigned long& symbol,
unsigned long& low_count,
unsigned long& high_count
) const;
unsigned long get_memory_usage (
) const;
global_state_type& get_global_state (
);
static unsigned long get_alphabet_size (
);
private:
void half_counts (
);
/*!
ensures
- divides all counts by 2 but ensures that escapes is always at least 1
!*/
// restricted functions
conditioning_class_kernel_4(conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>&); // copy constructor
conditioning_class_kernel_4& operator=(conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>&); // assignment operator
// data members
unsigned short total;
unsigned short escapes;
node* next;
global_state_type& global_state;
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
conditioning_class_kernel_4 (
global_state_type& global_state_
) :
total(1),
escapes(1),
next(0),
global_state(global_state_)
{
COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
// update memory usage
global_state.memory_usage += sizeof(conditioning_class_kernel_4);
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
~conditioning_class_kernel_4 (
)
{
clear();
// update memory usage
global_state.memory_usage -= sizeof(conditioning_class_kernel_4);
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
clear(
)
{
total = 1;
escapes = 1;
while (next)
{
node* temp = next;
next = next->next;
global_state.pool.deallocate(temp);
}
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_memory_usage(
) const
{
return global_state.memory_usage;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
typename conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::global_state_type& conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_global_state(
)
{
return global_state;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
bool conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
increment_count (
unsigned long symbol,
unsigned short amount
)
{
if (symbol == alphabet_size-1)
{
// make sure we won't cause any overflow
if (total >= 65536 - amount )
half_counts();
escapes += amount;
total += amount;
return true;
}
// find the symbol and increment it or add a new node to the list
if (next)
{
node* temp = next;
node* previous = 0;
while (true)
{
if (temp->symbol == static_cast<unsigned short>(symbol))
{
// make sure we won't cause any overflow
if (total >= 65536 - amount )
half_counts();
// we have found the symbol
total += amount;
temp->count += amount;
// if this node now has a count greater than its parent node
if (previous && temp->count > previous->count)
{
// swap the nodes so that the nodes will be in semi-sorted order
swap(temp->count,previous->count);
swap(temp->symbol,previous->symbol);
}
return true;
}
else if (temp->next == 0)
{
// we did not find the symbol so try to add it to the list
if (global_state.pool.get_number_of_allocations() < pool_size)
{
// make sure we won't cause any overflow
if (total >= 65536 - amount )
half_counts();
node* t = global_state.pool.allocate();
t->next = 0;
t->symbol = static_cast<unsigned short>(symbol);
t->count = amount;
temp->next = t;
total += amount;
return true;
}
else
{
// no memory left
return false;
}
}
else if (temp->count == 0)
{
// remove nodes that have a zero count
if (previous)
{
previous->next = temp->next;
node* t = temp;
temp = temp->next;
global_state.pool.deallocate(t);
}
else
{
next = temp->next;
node* t = temp;
temp = temp->next;
global_state.pool.deallocate(t);
}
}
else
{
previous = temp;
temp = temp->next;
}
} // while (true)
}
// if there aren't any nodes in the list yet then do this instead
else
{
if (global_state.pool.get_number_of_allocations() < pool_size)
{
// make sure we won't cause any overflow
if (total >= 65536 - amount )
half_counts();
next = global_state.pool.allocate();
next->next = 0;
next->symbol = static_cast<unsigned short>(symbol);
next->count = amount;
total += amount;
return true;
}
else
{
// no memory left
return false;
}
}
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_count (
unsigned long symbol
) const
{
if (symbol == alphabet_size-1)
{
return escapes;
}
else
{
node* temp = next;
while (temp)
{
if (temp->symbol == symbol)
return temp->count;
temp = temp->next;
}
return 0;
}
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_alphabet_size (
)
{
return alphabet_size;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_total (
) const
{
return total;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
unsigned long conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_range (
unsigned long symbol,
unsigned long& low_count,
unsigned long& high_count,
unsigned long& total_count
) const
{
if (symbol != alphabet_size-1)
{
node* temp = next;
unsigned long low = 0;
while (temp)
{
if (temp->symbol == static_cast<unsigned short>(symbol))
{
high_count = temp->count + low;
low_count = low;
total_count = total;
return temp->count;
}
low += temp->count;
temp = temp->next;
}
return 0;
}
else
{
total_count = total;
high_count = total;
low_count = total-escapes;
return escapes;
}
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
get_symbol (
unsigned long target,
unsigned long& symbol,
unsigned long& low_count,
unsigned long& high_count
) const
{
node* temp = next;
unsigned long high = 0;
while (true)
{
if (temp != 0)
{
high += temp->count;
if (target < high)
{
symbol = temp->symbol;
high_count = high;
low_count = high - temp->count;
return;
}
temp = temp->next;
}
else
{
// this must be the escape symbol
symbol = alphabet_size-1;
low_count = total-escapes;
high_count = total;
return;
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// private member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size,
unsigned long pool_size,
typename mem_manager
>
void conditioning_class_kernel_4<alphabet_size,pool_size,mem_manager>::
half_counts (
)
{
total = 0;
if (escapes > 1)
escapes >>= 1;
//divide all counts by 2
node* temp = next;
while (temp)
{
temp->count >>= 1;
total += temp->count;
temp = temp->next;
}
total += escapes;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CONDITIONING_CLASS_KERNEl_4_