|
import yaml |
|
from easydict import EasyDict as edict |
|
from langchain.prompts import PromptTemplate |
|
from langchain_community.chat_models import ChatOpenAI |
|
from pathlib import Path |
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline |
|
from langchain_community.chat_models import AzureChatOpenAI |
|
from langchain.chains import LLMChain |
|
import logging |
|
|
|
LLM_ENV = yaml.safe_load(open('config/llm_env.yml', 'r')) |
|
|
|
|
|
class Color: |
|
RED = '\033[91m' |
|
GREEN = '\033[92m' |
|
YELLOW = '\033[93m' |
|
BLUE = '\033[94m' |
|
END = '\033[0m' |
|
|
|
|
|
def get_llm(config: dict): |
|
""" |
|
Returns the LLM model |
|
:param config: dictionary with the configuration |
|
:return: The llm model |
|
""" |
|
if 'temperature' not in config: |
|
temperature = 0 |
|
else: |
|
temperature = config['temperature'] |
|
if 'model_kwargs' in config: |
|
model_kwargs = config['model_kwargs'] |
|
else: |
|
model_kwargs = {} |
|
if config['type'] == 'OpenAI': |
|
if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '': |
|
return ChatOpenAI(temperature=temperature, model_name=config['name'], |
|
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']), |
|
openai_api_base=config.get('openai_api_base', LLM_ENV['openai']['OPENAI_API_BASE']), |
|
model_kwargs=model_kwargs) |
|
else: |
|
return ChatOpenAI(temperature=temperature, model_name=config['name'], |
|
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']), |
|
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'), |
|
openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']), |
|
model_kwargs=model_kwargs) |
|
elif config['type'] == 'Azure': |
|
return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'], |
|
openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']), |
|
azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']), |
|
openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION'])) |
|
|
|
elif config['type'] == 'Google': |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
return ChatGoogleGenerativeAI(temperature=temperature, model=config['name'], |
|
google_api_key=LLM_ENV['google']['GOOGLE_API_KEY'], |
|
model_kwargs=model_kwargs) |
|
|
|
|
|
elif config['type'] == 'HuggingFacePipeline': |
|
device = config.get('gpu_device', -1) |
|
device_map = config.get('device_map', None) |
|
|
|
return HuggingFacePipeline.from_model_id( |
|
model_id=config['name'], |
|
task="text-generation", |
|
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']}, |
|
device=device, |
|
device_map=device_map |
|
) |
|
else: |
|
raise NotImplementedError("LLM not implemented") |
|
|
|
|
|
def load_yaml(yaml_path: str, as_edict: bool = True) -> edict: |
|
""" |
|
Reads the yaml file and enrich it with more vales. |
|
:param yaml_path: The path to the yaml file |
|
:param as_edict: If True, returns an EasyDict configuration |
|
:return: An EasyDict configuration |
|
""" |
|
with open(yaml_path, 'r') as file: |
|
yaml_data = yaml.safe_load(file) |
|
if 'meta_prompts' in yaml_data.keys() and 'folder' in yaml_data['meta_prompts']: |
|
yaml_data['meta_prompts']['folder'] = Path(yaml_data['meta_prompts']['folder']) |
|
if as_edict: |
|
yaml_data = edict(yaml_data) |
|
return yaml_data |
|
|
|
|
|
def load_prompt(prompt_path: str) -> PromptTemplate: |
|
""" |
|
Reads and returns the contents of a prompt file. |
|
:param prompt_path: The path to the prompt file |
|
""" |
|
with open(prompt_path, 'r') as file: |
|
prompt = file.read().rstrip() |
|
return PromptTemplate.from_template(prompt) |
|
|
|
|
|
def validate_generation_config(base_config, generation_config): |
|
if "annotator" not in generation_config: |
|
raise Exception("Generation config must contain an empty annotator.") |
|
if "label_schema" not in generation_config.dataset or \ |
|
base_config.dataset.label_schema != generation_config.dataset.label_schema: |
|
raise Exception("Generation label schema must match the basic config.") |
|
|
|
|
|
def modify_input_for_ranker(config, task_description, initial_prompt): |
|
modifiers_config = yaml.safe_load(open('prompts/modifiers/modifiers.yml', 'r')) |
|
task_desc_setup = load_prompt(modifiers_config['ranker']['task_desc_mod']) |
|
init_prompt_setup = load_prompt(modifiers_config['ranker']['prompt_mod']) |
|
|
|
llm = get_llm(config.llm) |
|
task_llm_chain = LLMChain(llm=llm, prompt=task_desc_setup) |
|
task_result = task_llm_chain( |
|
{"task_description": task_description}) |
|
mod_task_desc = task_result['text'] |
|
logging.info(f"Task description modified for ranking to: \n{mod_task_desc}") |
|
|
|
prompt_llm_chain = LLMChain(llm=llm, prompt=init_prompt_setup) |
|
prompt_result = prompt_llm_chain({"prompt": initial_prompt, 'label_schema': config.dataset.label_schema}) |
|
mod_prompt = prompt_result['text'] |
|
logging.info(f"Initial prompt modified for ranking to: \n{mod_prompt}") |
|
|
|
return mod_prompt, mod_task_desc |
|
|
|
|
|
def override_config(override_config_file, config_file='config/config_default.yml'): |
|
""" |
|
Override the default configuration file with the override configuration file |
|
:param config_file: The default configuration file |
|
:param override_config_file: The override configuration file |
|
""" |
|
|
|
def override_dict(config_dict, override_config_dict): |
|
for key, value in override_config_dict.items(): |
|
if isinstance(value, dict): |
|
if key not in config_dict: |
|
config_dict[key] = value |
|
else: |
|
override_dict(config_dict[key], value) |
|
else: |
|
config_dict[key] = value |
|
return config_dict |
|
|
|
default_config_dict = load_yaml(config_file, as_edict=False) |
|
override_config_dict = load_yaml(override_config_file, as_edict=False) |
|
config_dict = override_dict(default_config_dict, override_config_dict) |
|
return edict(config_dict) |
|
|