|
"""Utility functions for CRS Arena.""" |
|
|
|
import ast |
|
import asyncio |
|
import logging |
|
import os |
|
import sys |
|
from datetime import timedelta |
|
from typing import Dict |
|
|
|
import openai |
|
import pandas as pd |
|
import streamlit as st |
|
import yaml |
|
from huggingface_hub import HfApi |
|
from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient |
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
|
from src.model.crs_model import CRSModel |
|
|
|
|
|
HF_API = HfApi(token=os.environ.get("hf_token")) |
|
LOADING_MESSAGE = ( |
|
"The fighters are warming up, please be patient it may take a moment... " |
|
":robot_face: :punch: :gun: :boom:" |
|
) |
|
|
|
|
|
@st.cache_resource(show_spinner=LOADING_MESSAGE, ttl=timedelta(days=3)) |
|
def get_crs_model(model_name: str, model_config_file: str) -> CRSModel: |
|
"""Returns a CRS model. |
|
|
|
Args: |
|
model_name: Model name. |
|
model_config_file: Model configuration file. |
|
|
|
Raises: |
|
FileNotFoundError: If model configuration file is not found. |
|
|
|
Returns: |
|
CRS model. |
|
""" |
|
logging.debug(f"Loading CRS model {model_name}.") |
|
if not os.path.exists(model_config_file): |
|
raise FileNotFoundError( |
|
f"Model configuration file {model_config_file} not found." |
|
) |
|
|
|
model_args = yaml.safe_load(open(model_config_file, "r")) |
|
|
|
if "chatgpt" in model_name: |
|
openai.api_key = os.environ.get("openai_api_key") |
|
|
|
|
|
name = model_name.split("_")[0] |
|
|
|
return CRSModel(name, **model_args) |
|
|
|
|
|
async def upload_conversation_logs_to_hf( |
|
conversation_log_file_path: str, repo_filename: str |
|
) -> None: |
|
"""Uploads conversation logs to Hugging Face asynchronously. |
|
|
|
Args: |
|
conversation_log_file_path: Path to the conversation log file locally. |
|
repo_filename: Name of the file in the Hugging Face repository. |
|
|
|
Raises: |
|
Exception: If an error occurs during the upload. |
|
""" |
|
logging.debug( |
|
"Uploading conversation logs to Hugging Face: " |
|
f"{conversation_log_file_path}." |
|
) |
|
try: |
|
await asyncio.get_event_loop().run_in_executor( |
|
None, |
|
lambda: HF_API.upload_file( |
|
path_or_fileobj=conversation_log_file_path, |
|
path_in_repo=repo_filename, |
|
repo_id=os.environ.get("dataset_repo"), |
|
repo_type="dataset", |
|
), |
|
) |
|
logging.debug("Conversation logs uploaded to Hugging Face.") |
|
except Exception as e: |
|
logging.error( |
|
f"Error uploading conversation logs to Hugging Face: {e}" |
|
) |
|
|
|
|
|
async def upload_feedback_to_gsheet( |
|
row: Dict[str, str], worksheet: str = "votes" |
|
) -> None: |
|
"""Uploads feedback to Google Sheets asynchronously. |
|
|
|
Args: |
|
row: Row to upload to the worksheet. |
|
worksheet: Name of the worksheet to upload the feedback to. |
|
|
|
Raises: |
|
Exception: If an error occurs during the upload. |
|
""" |
|
logging.debug("Uploading feedback to Google Sheets.") |
|
try: |
|
await asyncio.get_event_loop().run_in_executor( |
|
None, lambda: _upload_feedback_to_gsheet_sync(row, worksheet) |
|
) |
|
except Exception as e: |
|
logging.error(f"Error uploading feedback to Google Sheets: {e}") |
|
|
|
|
|
def _upload_feedback_to_gsheet_sync( |
|
row: Dict[str, str], worksheet: str |
|
) -> None: |
|
"""Uploads feedback to Google Sheets synchronously. |
|
|
|
Args: |
|
row: Row to upload to the worksheet. |
|
worksheet: Name of the worksheet to upload the feedback to. |
|
""" |
|
gs_connection = GSheetsServiceAccountClient( |
|
ast.literal_eval(os.environ.get("gsheet")) |
|
) |
|
df = gs_connection.read(worksheet=worksheet) |
|
if df[df["id"] == row["id"]].empty: |
|
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) |
|
else: |
|
|
|
df.loc[df["id"] == row["id"], "feedback"] = row["feedback"] |
|
gs_connection.update(data=df, worksheet=worksheet) |
|
logging.debug("Feedback uploaded to Google Sheets.") |
|
|