File size: 3,665 Bytes
861919a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9fb5b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from datetime import datetime, timedelta
import functools
import json
import os
import pandas as pd
from prophet import Prophet
from pathlib import Path
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

# MODEL
MODEL = "mistral-large-latest"
API_KEY=os.environ["MISTRAL_API_KEY"]
CLIENT = MistralClient(api_key=API_KEY)

# PATH
FILE = Path(__file__).resolve()
BASE_PATH = FILE.parents[1]

HISTORY = pd.read_csv(os.path.join(BASE_PATH, "data/cereal_price.csv"), encoding="latin-1")
HISTORY = HISTORY[HISTORY["memberStateName"]=="France"]
HISTORY['price'] = HISTORY['price'].str.replace(",", ".").astype('float64')


def model_predict(week=26):
    """
    Predict future prices using the Prophet model.

    Parameters:
    - weeks (int): Number of periods to predict into the future (default is 26).

    Returns:
    - dict: Dictionary containing predicted values and confidence intervals.
    """

    # Prepare the historical data for the model
    data = HISTORY[['endDate', 'price']]
    data.columns = ['ds', 'y']

    # Prophet Model
    # Instantiate a Prophet object
    model = Prophet()

    # Fit the model with historical data
    model.fit(data)
    
    # Calculate the current date
    today_date = datetime.now().date()

    # Calculate the end date for the future DataFrame (specified number of periods from today)
    end_date = today_date + timedelta(weeks=week)

    # Create a DataFrame with dates starting from today and ending in the specified number of periods
    future_df = pd.date_range(start=today_date, end=end_date, freq='W').to_frame(name='ds').reset_index(drop=True)

    # Make predictions on the future DataFrame
    forecast = model.predict(future_df)

    # Return relevant columns from the forecast DataFrame as a dictionary
    result_dict = {
        'ds': forecast['ds'].tolist(),
        'yhat_lower': forecast['yhat_lower'].tolist(),
        'yhat_upper': forecast['yhat_upper'].tolist(),
        'yhat': forecast['yhat'].tolist()
    }
    
    return result_dict

model_predict_tool = [{
    "type": "function",
    "function": {
        "name": "model_predict",
        "description": "Predict future prices using the Prophet model.",
        "parameters": {
            "type": "object",
            "properties": {
                "week": {
                    "type": "integer",
                    "description": "Number of periods to predict into the future (default is 26).",
                },
            },
            "required": ["week"]
        },
    },
}]

names_to_functions = {
    'model_predict': functools.partial(model_predict), 
}

# messages = [
#     ChatMessage(role="user", content="Predict future prices using the Prophet model for 4 weeks in the future")
# ]

def forecast(messages
):
    response = CLIENT.chat(
        model=MODEL,
        messages=messages,
        tools=model_predict_tool,
        tool_choice="auto"
    )
    is_ok = True
    try: 
        tool_call = response.choices[0].message.tool_calls[0]
        function_name = tool_call.function.name
        function_params = json.loads(tool_call.function.arguments)
        function_result = names_to_functions[function_name](**function_params)
        date = function_result["ds"][-1]
        lower = function_result["yhat_lower"][-1]
        upper = function_result["yhat_upper"][-1]
        prediction = function_result["yhat"][-1]
    except: 
        is_ok = False
        pass 
    if is_ok:     
        return {"date" : str(date), "prix_minimum": lower, "prix_maximum": upper, "prix_estimé": prediction}
    else: 
        return response.choices[0].message.content