File size: 3,020 Bytes
164a5da
 
 
 
 
 
db90043
164a5da
 
5963e52
db90043
50e9e95
db90043
164a5da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38eba8d
164a5da
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from fastapi import FastAPI, File, UploadFile, Form
import pandas as pd
import matplotlib.pyplot as plt
import io
from fastapi.responses import StreamingResponse
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
app = FastAPI()

table_analyzer = pipeline("table-question-answering", model="google/tapas-base")

user_input_processor = pipeline("text-generation", model="tiiuae/falcon-7b-instruct",torch_dtype=torch.float16)

# ✅ Load T5 Model (ensure correct architecture)
model_name = "google/t5-small"  # Change to the correct T5 model if needed
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

@app.post("/visualize/")
async def visualize(
    file: UploadFile = File(...),
    description: str = Form(None),
    chart_type: str = Form(None),
    x_column: str = Form(None),
    y_column: str = Form(None)
):
    print("🔵 Début du traitement...")

    contents = await file.read()
    excel_data = io.BytesIO(contents)
    print("✅ Fichier reçu et converti en mémoire.")

    try:
        df = pd.read_excel(excel_data)
        print("✅ Lecture du fichier Excel réussie.")
    except Exception as e:
        print(f"❌ Erreur lors de la lecture du fichier Excel : {e}")
        return {"error": "Impossible de lire le fichier Excel."}

    df.columns = df.columns.str.strip().str.lower()
    print("📌 Colonnes après nettoyage :", df.columns.tolist())

    # If no specific chart details are given, infer from description
    if description:
        print("📝 Analyse de la description utilisateur...")
        response = user_input_processor(description, max_length=50)
        inferred_data = response[0]['generated_text']
        print("🔍 Inference AI:", inferred_data)
        # TODO: Extract structured data from response (chart_type, x_column, y_column)

    # Ensure x_column and y_column exist
    if x_column.lower() not in df.columns or y_column.lower() not in df.columns:
        print(f"❌ Erreur: '{x_column}' ou '{y_column}' non trouvées.")
        return {"error": f"Les colonnes '{x_column}' ou '{y_column}' n'existent pas."}

    print("✅ Colonnes valides, préparation du graphique...")

    plt.figure(figsize=(20, 12))
    if chart_type == "bar":
        df.plot(kind="bar", x=x_column.lower(), y=y_column.lower())
    elif chart_type == "line":
        df.plot(kind="line", x=x_column.lower(), y=y_column.lower())
    elif chart_type == "scatter":
        df.plot(kind="scatter", x=x_column.lower(), y=y_column.lower())
    elif chart_type == "pie":
        df.set_index(x_column.lower())[y_column.lower()].plot(kind="pie", autopct="%1.1f%%")
    elif chart_type == "histogram":
        df[y_column.lower()].plot(kind="hist", bins=10)
    else:
        return {"error": "Invalid chart type"}

    img_stream = io.BytesIO()
    plt.savefig(img_stream, format="png")
    img_stream.seek(0)
    plt.close()



    return StreamingResponse(img_stream, media_type="image/png")