File size: 4,073 Bytes
b599481
 
253820a
b599481
 
 
 
 
dbb98e1
b599481
 
 
 
 
 
4931dff
b599481
 
 
 
 
 
dbb98e1
4f5b924
 
 
 
b599481
 
4f5b924
1e5f915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb98e1
1e5f915
 
 
 
dbb98e1
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb98e1
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5d70e5
 
253820a
b599481
 
f5d70e5
 
 
 
 
 
 
 
 
 
 
fa465f9
dbb98e1
fa465f9
f5d70e5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Utility functions for CRS Arena."""

import ast
import asyncio
import logging
import os
import sys
from datetime import timedelta
from typing import Dict

import openai
import pandas as pd
import streamlit as st
import yaml
from huggingface_hub import HfApi
from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from src.model.crs_model import CRSModel

# Initialize Hugging Face API
HF_API = HfApi(token=os.environ.get("hf_token"))
LOADING_MESSAGE = (
    "The fighters are warming up, please be patient it may take a moment... "
    ":robot_face: :punch: :gun: :boom:"
)


@st.cache_resource(show_spinner=LOADING_MESSAGE, ttl=timedelta(days=3))
def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
    """Returns a CRS model.

    Args:
        model_name: Model name.
        model_config_file: Model configuration file.

    Raises:
        FileNotFoundError: If model configuration file is not found.

    Returns:
        CRS model.
    """
    logging.debug(f"Loading CRS model {model_name}.")
    if not os.path.exists(model_config_file):
        raise FileNotFoundError(
            f"Model configuration file {model_config_file} not found."
        )

    model_args = yaml.safe_load(open(model_config_file, "r"))

    if "chatgpt" in model_name:
        openai.api_key = os.environ.get("openai_api_key")

    # Extract crs model from name
    name = model_name.split("_")[0]

    return CRSModel(name, **model_args)


async def upload_conversation_logs_to_hf(
    conversation_log_file_path: str, repo_filename: str
) -> None:
    """Uploads conversation logs to Hugging Face asynchronously.

    Args:
        conversation_log_file_path: Path to the conversation log file locally.
        repo_filename: Name of the file in the Hugging Face repository.

    Raises:
        Exception: If an error occurs during the upload.
    """
    logging.debug(
        "Uploading conversation logs to Hugging Face: "
        f"{conversation_log_file_path}."
    )
    try:
        await asyncio.get_event_loop().run_in_executor(
            None,
            lambda: HF_API.upload_file(
                path_or_fileobj=conversation_log_file_path,
                path_in_repo=repo_filename,
                repo_id=os.environ.get("dataset_repo"),
                repo_type="dataset",
            ),
        )
        logging.debug("Conversation logs uploaded to Hugging Face.")
    except Exception as e:
        logging.error(
            f"Error uploading conversation logs to Hugging Face: {e}"
        )


async def upload_feedback_to_gsheet(
    row: Dict[str, str], worksheet: str = "votes"
) -> None:
    """Uploads feedback to Google Sheets asynchronously.

    Args:
        row: Row to upload to the worksheet.
        worksheet: Name of the worksheet to upload the feedback to.

    Raises:
        Exception: If an error occurs during the upload.
    """
    logging.debug("Uploading feedback to Google Sheets.")
    try:
        await asyncio.get_event_loop().run_in_executor(
            None, lambda: _upload_feedback_to_gsheet_sync(row, worksheet)
        )
    except Exception as e:
        logging.error(f"Error uploading feedback to Google Sheets: {e}")


def _upload_feedback_to_gsheet_sync(
    row: Dict[str, str], worksheet: str
) -> None:
    """Uploads feedback to Google Sheets synchronously.

    Args:
        row: Row to upload to the worksheet.
        worksheet: Name of the worksheet to upload the feedback to.
    """
    gs_connection = GSheetsServiceAccountClient(
        ast.literal_eval(os.environ.get("gsheet"))
    )
    df = gs_connection.read(worksheet=worksheet)
    if df[df["id"] == row["id"]].empty:
        df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
    else:
        # Add feedback to existing row
        df.loc[df["id"] == row["id"], "feedback"] = row["feedback"]
    gs_connection.update(data=df, worksheet=worksheet)
    logging.debug("Feedback uploaded to Google Sheets.")