File size: 3,543 Bytes
26a187b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb9b81
26a187b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee6d3aa
26a187b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from huggingface_hub import InferenceClient
import pandas as pd
import logging
import gradio as gr

logging.basicConfig(level=logging.INFO)

class FinanceSummary:
    """
    Class for generating a detailed summary of financial data using the Mixtral model.
    """

    def __init__(self):
        self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
        self.logger = logging.getLogger(__name__)

    def format_prompt(self, data: pd.DataFrame) -> str:
        """
        Format prompt for Mixtral model.
        
        Args:
            data (pd.DataFrame): Financial data in a DataFrame.

        Returns:
            str: Formatted prompt for the model.
        """
        prompt = "<s>"
        prompt += f"[INST] analyze given csv sheet and give me a detailed summary mention the range of money transactions,describe each column{data}[/INST]"
        return prompt

    def to_dataframe(self,filepath: str) -> pd.DataFrame:
        """
        Read financial data from a CSV file and return it as a DataFrame.

        Args:
            filepath (str): Path to the CSV file containing financial data.

        Returns:
            pd.DataFrame: DataFrame containing the financial data.
        """
        financial_data = pd.read_csv(filepath)
        return financial_data

    def generate(self, filepath: str, temperature: float = 0.9, max_new_tokens: int = 5000,
                 top_p: float = 0.95, repetition_penalty: float = 1.0) -> str:
        """
        Generate a detailed summary of financial data.

        Args:
            data (pd.DataFrame): Financial data in a DataFrame.
            temperature (float): Controls the randomness of the predictions. Defaults to 0.9.
            max_new_tokens (int): Maximum number of tokens to generate. Defaults to 5000.
            top_p (float): The cumulative probability for sampling from the logits. Defaults to 0.95.
            repetition_penalty (float): Penalty for repetition. Defaults to 1.0.

        Returns:
            str: Generated summary.
        """
        try:
            temperature = float(temperature)
            if temperature < 1e-2:
                temperature = 1e-2
            top_p = float(top_p)

            generate_kwargs = dict(
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                do_sample=True,
                seed=42,
            )
            data = self.to_dataframe(filepath.name)
            formatted_prompt = self.format_prompt(data)
            stream = self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
                                                  return_full_text=False)
            output = ""

            for response in stream:
                output += response.token.text
            return output.replace("</s>", "")
        except Exception as e:
            self.logger.error(f"An error occurred: {e}")
            return ""

if __name__ == "__main__":

    finance_summary = FinanceSummary()

    with gr.Blocks(css="style.css",theme=gr.themes.Soft()) as demo:
      with gr.Row():
        filepath = gr.File(label="Upload CSV File",elem_classes="upload-file")
      with gr.Row():
        submit_btn = gr.Button(value="Submit")       
      with gr.Row():
        summary = gr.Textbox(label="Detailed Summary",lines=20)
    
      submit_btn.click(finance_summary.generate,filepath,summary)

    demo.launch()