Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate | |
from typing import List | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_chroma import Chroma | |
from typing_extensions import TypedDict | |
from typing import Annotated | |
from langgraph.graph.message import AnyMessage, add_messages | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langgraph.graph import END, StateGraph, START | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain_text_splitters import CharacterTextSplitter | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class Request(BaseModel): | |
query : str | |
id : str | |
load_dotenv() | |
os.environ["GROQ_API_KEY"] = os.getenv('GROQ_API_KEY') | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
persist_directory = 'db' | |
embedding = HuggingFaceEmbeddings(model_name="OrdalieTech/Solon-embeddings-large-0.1") | |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5) | |
memory = MemorySaver() | |
if os.path.exists(persist_directory) : | |
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding) | |
else : | |
glob_pattern="./*.md" | |
directory_path = "./documents" | |
loader = DirectoryLoader(directory_path, glob=glob_pattern) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_documents(documents) | |
vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory) | |
retriever = vectordb.as_retriever() | |
system = """ | |
Tu es un assistant spécialisé dans l'enseignement de la spécialité Numérique et sciences informatiques en classe de première et de terminal | |
Tu as un bon niveau en langage Python | |
Ton interlocuteur est un élève qui suit la spécialité nsi en première et en terminale | |
Tu dois uniquement répondre aux questions qui concernent la spécialité numérique et sciences informatiques | |
Tu ne dois pas faire d'erreur, répond à la question uniquement si tu es sûr de ta réponse | |
si tu ne trouves pas la réponse à une question, tu réponds que tu ne connais pas la réponse et que l'élève doit s'adresser à son professeur pour obtenir cette réponse | |
Tu dois uniquement aborder des notions qui sont aux programmes de la spécialité numérique et sciences informatiques (première et terminale), tu ne dois jamais aborder une notion qui n'est pas au programme | |
si l'élève n'arrive pas à trouver la réponse à un exercice, tu ne dois pas lui donner tout de suite la réponse, mais seulement lui donner des indications pour lui permettre de trouver la réponse par lui même | |
Quand tu donnes un exercice Python, dans les indications que tu donnes aux élèves, tu ne dois pas dire aux élèves d'utiliser les fonctions Python : min, max, sum... pour résoudre l'exercice | |
Pour des exercices sur les requêtes SQL, tu ne doir pas utiliser LIKE, GROUP BY, INNER LEFT et INNER RIGHT car ces notions ne sont pas au programme de NSI | |
Tu peux lui donner la réponse à un exercice uniquement si l'élève te demande explicitement cette réponse | |
Tu dois uniquement répondre en langue française | |
Tu trouveras ci-dessous les programmes de la spécialité NSI en première et terminale, tu devras veiller à ce que tes réponses ne sortent pas du cadre de ces programmes | |
Si la question posée ne rentre pas dans le cadre du programme de NSI tu peux tout de même répondre en précisant bien que cette notion est hors programme | |
si tu proposes un exercice, tu dois bien vérifier que toutes les notions nécessaires à la résolution de l'exercice sont explicitement au programme de NSI | |
""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system), | |
("human", "Extraits des programmes de NSI : \n {document} \n\n Historique conversation entre l'assistant et l'élève : \n {historical} \n\n Intervention de l'élève : {question}"), | |
] | |
) | |
chain = prompt | llm | StrOutputParser() | |
def format_docs(docs): | |
return "\n".join(doc.page_content for doc in docs) | |
def format_historical(hist): | |
historical = [] | |
for i in range(0,len(hist)-2,2): | |
historical.append("Elève : "+hist[i].content) | |
historical.append("Assistant : "+hist[i+1].content) | |
return "\n".join(historical[-10:]) | |
class GraphState(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
documents : str | |
def retrieve(state : GraphState): | |
documents = format_docs(retriever.invoke(state['messages'][-1].content)) | |
return {'documents' : documents} | |
def chatbot(state : GraphState): | |
response = chain.invoke({'document': state['documents'], 'historical': format_historical(state['messages']), 'question' : state['messages'][-1].content}) | |
return {"messages": [AIMessage(content=response)]} | |
workflow = StateGraph(GraphState) | |
workflow.add_node('retrieve', retrieve) | |
workflow.add_node('chatbot', chatbot) | |
workflow.add_edge(START, 'retrieve') | |
workflow.add_edge('retrieve','chatbot') | |
workflow.add_edge('chatbot', END) | |
app_chatbot = workflow.compile(checkpointer=memory) | |
def request(req: Request): | |
config = {"configurable": {"thread_id": req.id}} | |
rep = app_chatbot.invoke({"messages": [HumanMessage(content=req.query)]},config, stream_mode="values") | |
return {"response":rep['messages'][-1].content} | |