Spaces:
Runtime error
Runtime error
from huggingface_hub import InferenceClient | |
import pandas as pd | |
import logging | |
import gradio as gr | |
logging.basicConfig(level=logging.INFO) | |
class FinanceSummary: | |
""" | |
Class for generating a detailed summary of financial data using the Mixtral model. | |
""" | |
def __init__(self): | |
self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
self.logger = logging.getLogger(__name__) | |
def format_prompt(self, data: pd.DataFrame) -> str: | |
""" | |
Format prompt for Mixtral model. | |
Args: | |
data (pd.DataFrame): Financial data in a DataFrame. | |
Returns: | |
str: Formatted prompt for the model. | |
""" | |
prompt = "<s>" | |
prompt += f"[INST] analyze given csv sheet and give me a detailed summary mention the range of money transactions,describe each column{data}[/INST]" | |
return prompt | |
def to_dataframe(self,filepath: str) -> pd.DataFrame: | |
""" | |
Read financial data from a CSV file and return it as a DataFrame. | |
Args: | |
filepath (str): Path to the CSV file containing financial data. | |
Returns: | |
pd.DataFrame: DataFrame containing the financial data. | |
""" | |
financial_data = pd.read_csv(filepath) | |
return financial_data | |
def generate(self, filepath: str, temperature: float = 0.9, max_new_tokens: int = 5000, | |
top_p: float = 0.95, repetition_penalty: float = 1.0) -> str: | |
""" | |
Generate a detailed summary of financial data. | |
Args: | |
data (pd.DataFrame): Financial data in a DataFrame. | |
temperature (float): Controls the randomness of the predictions. Defaults to 0.9. | |
max_new_tokens (int): Maximum number of tokens to generate. Defaults to 5000. | |
top_p (float): The cumulative probability for sampling from the logits. Defaults to 0.95. | |
repetition_penalty (float): Penalty for repetition. Defaults to 1.0. | |
Returns: | |
str: Generated summary. | |
""" | |
try: | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
data = self.to_dataframe(filepath.name) | |
formatted_prompt = self.format_prompt(data) | |
stream = self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, | |
return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return output.replace("</s>", "") | |
except Exception as e: | |
self.logger.error(f"An error occurred: {e}") | |
return "" | |
if __name__ == "__main__": | |
finance_summary = FinanceSummary() | |
with gr.Blocks(css="style.css",theme=gr.themes.Soft()) as demo: | |
with gr.Row(): | |
filepath = gr.File(label="Upload CSV File",elem_classes="upload-file") | |
with gr.Row(): | |
submit_btn = gr.Button(value="Submit") | |
with gr.Row(): | |
summary = gr.Textbox(label="Detailed Summary",lines=20) | |
submit_btn.click(finance_summary.generate,filepath,summary) | |
demo.launch() |