Spaces:
Sleeping
Sleeping
lucas-wa
commited on
Commit
·
4db208a
1
Parent(s):
a44c9f8
Parsing code to OOP
Browse files- server/app.py +5 -6
- server/data/load_data.py +0 -75
- server/data/retriever.py +63 -0
- server/inference.py +0 -43
- server/llm/gemini.py +69 -64
- server/services/generate_questions_service.py +49 -0
- web/index.html +1 -1
server/app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
from inference import rag_chain
|
4 |
from pydantic import BaseModel
|
5 |
from fastapi.staticfiles import StaticFiles
|
6 |
-
from
|
|
|
7 |
|
|
|
8 |
|
9 |
class Body(BaseModel):
|
10 |
subject: str
|
@@ -26,10 +27,8 @@ async def generate_questions(body: Body):
|
|
26 |
subject = body.subject
|
27 |
difficultie = body.difficultie
|
28 |
query = f"Quero que você gere questões de biologia, sendo do assunto: {subject} e sendo da dificuldade: {difficultie}."
|
29 |
-
res =
|
30 |
-
return
|
31 |
-
"res": res,
|
32 |
-
}
|
33 |
|
34 |
|
35 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
3 |
from pydantic import BaseModel
|
4 |
from fastapi.staticfiles import StaticFiles
|
5 |
+
from services.generate_questions_service import GenerateQuestionsService
|
6 |
+
# from data.load_data import retriever_pre
|
7 |
|
8 |
+
generate_questions_service = GenerateQuestionsService()
|
9 |
|
10 |
class Body(BaseModel):
|
11 |
subject: str
|
|
|
27 |
subject = body.subject
|
28 |
difficultie = body.difficultie
|
29 |
query = f"Quero que você gere questões de biologia, sendo do assunto: {subject} e sendo da dificuldade: {difficultie}."
|
30 |
+
res = generate_questions_service.handle(f"""{query}""")
|
31 |
+
return res
|
|
|
|
|
32 |
|
33 |
|
34 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
server/data/load_data.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from langchain_community.document_loaders import TextLoader
|
3 |
-
from langchain.vectorstores import Chroma
|
4 |
-
from langchain.chains.query_constructor.base import AttributeInfo
|
5 |
-
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
6 |
-
from llm.gemini import gemini_embeddings, llm
|
7 |
-
from utils.questions_parser import parse_question
|
8 |
-
|
9 |
-
|
10 |
-
try:
|
11 |
-
vectorstore = Chroma(
|
12 |
-
persist_directory="./chroma_db", embedding_function=gemini_embeddings
|
13 |
-
)
|
14 |
-
|
15 |
-
except Exception as e:
|
16 |
-
|
17 |
-
print(e)
|
18 |
-
|
19 |
-
if "DATA_PATH" not in os.environ:
|
20 |
-
raise ValueError("DATA_PATH environment variable is not set")
|
21 |
-
|
22 |
-
DATA_PATH = os.environ["DATA_PATH"]
|
23 |
-
|
24 |
-
data_loader = TextLoader(DATA_PATH, encoding="UTF-8").load()
|
25 |
-
|
26 |
-
questions = list(
|
27 |
-
map(lambda x: "##Questão" + x, data_loader[0].page_content.split("##Questão"))
|
28 |
-
)
|
29 |
-
|
30 |
-
docs = []
|
31 |
-
|
32 |
-
for question in questions:
|
33 |
-
try:
|
34 |
-
docs.append(parse_question(question))
|
35 |
-
except Exception as e:
|
36 |
-
print(e, question)
|
37 |
-
|
38 |
-
db = Chroma.from_documents(docs, gemini_embeddings)
|
39 |
-
vectorstore = Chroma.from_documents(
|
40 |
-
documents=docs, embedding=gemini_embeddings, persist_directory="./chroma_db"
|
41 |
-
)
|
42 |
-
|
43 |
-
vectorstore_disk = Chroma(
|
44 |
-
persist_directory="./chroma_db", embedding_function=gemini_embeddings
|
45 |
-
)
|
46 |
-
|
47 |
-
|
48 |
-
metadata_field_info = [
|
49 |
-
AttributeInfo(
|
50 |
-
name="topico",
|
51 |
-
description="A materia escolar da qual a questão pertence.",
|
52 |
-
type="string",
|
53 |
-
),
|
54 |
-
AttributeInfo(
|
55 |
-
name="assunto",
|
56 |
-
description="O assunto da materia fornecida anteriormente.",
|
57 |
-
type="string",
|
58 |
-
),
|
59 |
-
AttributeInfo(
|
60 |
-
name="dificuldade",
|
61 |
-
description="O nivel de dificuldade para resolver a questao.",
|
62 |
-
type="string",
|
63 |
-
),
|
64 |
-
AttributeInfo(
|
65 |
-
name="tipo",
|
66 |
-
description="O tipo da questao. Pode ser ou Multipla Escolha ou Justificativa",
|
67 |
-
type="string",
|
68 |
-
),
|
69 |
-
]
|
70 |
-
|
71 |
-
document_content_description = "Questões de matérias do ensino médio."
|
72 |
-
|
73 |
-
retriever = SelfQueryRetriever.from_llm(
|
74 |
-
llm, vectorstore, document_content_description, metadata_field_info, verbose=True
|
75 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/data/retriever.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain_community.document_loaders import TextLoader
|
3 |
+
from langchain.vectorstores import Chroma
|
4 |
+
from langchain.chains.query_constructor.base import AttributeInfo
|
5 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
6 |
+
from llm.gemini import Gemini
|
7 |
+
from utils.questions_parser import parse_question
|
8 |
+
|
9 |
+
class Retriever:
|
10 |
+
|
11 |
+
_model = Gemini()
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
|
15 |
+
if "DATA_PATH" not in os.environ:
|
16 |
+
raise ValueError("DATA_PATH environment variable is not set")
|
17 |
+
|
18 |
+
DATA_PATH = os.environ["DATA_PATH"]
|
19 |
+
|
20 |
+
self.data_loader = TextLoader(DATA_PATH, encoding="UTF-8").load()
|
21 |
+
|
22 |
+
self.questions = list(
|
23 |
+
map(lambda x: "##Questão" + x, self.data_loader[0].page_content.split("##Questão"))
|
24 |
+
)
|
25 |
+
|
26 |
+
self.docs = []
|
27 |
+
|
28 |
+
for question in self.questions:
|
29 |
+
try:
|
30 |
+
self.docs.append(parse_question(question))
|
31 |
+
except Exception as e:
|
32 |
+
print(e, question)
|
33 |
+
|
34 |
+
self.vectorstore = Chroma.from_documents(self.docs, self._model.embeddings, persist_directory="./chroma_db")
|
35 |
+
|
36 |
+
self.metadata_field_info = [
|
37 |
+
AttributeInfo(
|
38 |
+
name="topico",
|
39 |
+
description="A materia escolar da qual a questão pertence.",
|
40 |
+
type="string",
|
41 |
+
),
|
42 |
+
AttributeInfo(
|
43 |
+
name="assunto",
|
44 |
+
description="O assunto da materia fornecida anteriormente.",
|
45 |
+
type="string",
|
46 |
+
),
|
47 |
+
AttributeInfo(
|
48 |
+
name="dificuldade",
|
49 |
+
description="O nivel de dificuldade para resolver a questao.",
|
50 |
+
type="string",
|
51 |
+
),
|
52 |
+
AttributeInfo(
|
53 |
+
name="tipo",
|
54 |
+
description="O tipo da questao. Pode ser ou Multipla Escolha ou Justificativa",
|
55 |
+
type="string",
|
56 |
+
),
|
57 |
+
]
|
58 |
+
|
59 |
+
document_content_description = "Questões de matérias do ensino médio."
|
60 |
+
|
61 |
+
self.retriever = SelfQueryRetriever.from_llm(
|
62 |
+
self._model.llm, self.vectorstore, document_content_description, self.metadata_field_info, verbose=True
|
63 |
+
)
|
server/inference.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
from langchain.schema.runnable import RunnablePassthrough
|
2 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
3 |
-
from langchain_core.runnables import RunnableLambda
|
4 |
-
|
5 |
-
from llm.gemini import (
|
6 |
-
questions_template,
|
7 |
-
format_questions_instructions,
|
8 |
-
questions_parser,
|
9 |
-
)
|
10 |
-
from data.load_data import retriever
|
11 |
-
|
12 |
-
|
13 |
-
def get_questions(_dict):
|
14 |
-
question = _dict["question"]
|
15 |
-
context = _dict["context"]
|
16 |
-
messages = questions_template.format_messages(
|
17 |
-
context=context,
|
18 |
-
question=question,
|
19 |
-
format_questions_instructions=format_questions_instructions,
|
20 |
-
)
|
21 |
-
|
22 |
-
tries = 0
|
23 |
-
|
24 |
-
while tries < 3:
|
25 |
-
try:
|
26 |
-
chat = ChatGoogleGenerativeAI(model="gemini-pro")
|
27 |
-
response = chat.invoke(messages)
|
28 |
-
return questions_parser.parse(response.content)
|
29 |
-
except Exception as e:
|
30 |
-
print(e)
|
31 |
-
tries += 1
|
32 |
-
|
33 |
-
return "Não foi possível gerar as questões."
|
34 |
-
|
35 |
-
|
36 |
-
def format_docs(docs):
|
37 |
-
return "\n\n".join(doc.page_content for doc in docs)
|
38 |
-
|
39 |
-
|
40 |
-
rag_chain = {
|
41 |
-
"context": retriever | RunnableLambda(format_docs),
|
42 |
-
"question": RunnablePassthrough(),
|
43 |
-
} | RunnableLambda(get_questions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/llm/gemini.py
CHANGED
@@ -5,78 +5,83 @@ from langchain import PromptTemplate, LLMChain
|
|
5 |
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
6 |
from langchain.prompts import ChatPromptTemplate
|
7 |
|
8 |
-
if "GOOGLE_API_KEY" not in os.environ:
|
9 |
-
raise ValueError("GOOGLE_API_KEY environment variable is not set")
|
10 |
|
11 |
-
|
12 |
-
Instruções para cada questão:
|
13 |
-
- Crie uma questão clara e relevante para o tema.
|
14 |
-
- Forneça cinco opções de resposta, rotuladas de A) a E).
|
15 |
-
- Apenas uma das opções de resposta deve ser correta.
|
16 |
-
- Indique a resposta correta ao final de cada questão.
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
B) Hemoglobina
|
27 |
-
C) Mioglobina
|
28 |
-
D) Citocromo
|
29 |
-
E) Queratina
|
30 |
|
31 |
-
Resposta
|
32 |
-
A) Clorofila
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
Answer:
|
37 |
|
|
|
|
|
|
|
38 |
|
39 |
-
{format_questions_instructions}
|
40 |
-
"""
|
41 |
|
42 |
-
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
questions_template = ChatPromptTemplate.from_template(template=llm_prompt_template)
|
49 |
-
|
50 |
-
questions_chain = LLMChain(llm=llm, prompt=questions_template)
|
51 |
-
|
52 |
-
questions_schema = ResponseSchema(
|
53 |
-
name="questions",
|
54 |
-
description="""Give the questions in json as an array""",
|
55 |
-
)
|
56 |
-
|
57 |
-
questions_schemas = [questions_schema]
|
58 |
-
|
59 |
-
questions_parser = StructuredOutputParser.from_response_schemas(questions_schemas)
|
60 |
-
format_questions_instructions = questions_parser.get_format_instructions()
|
61 |
-
format_questions_instructions = """
|
62 |
-
The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":
|
63 |
-
```json
|
64 |
-
{
|
65 |
-
"questions": [
|
66 |
{
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
6 |
from langchain.prompts import ChatPromptTemplate
|
7 |
|
|
|
|
|
8 |
|
9 |
+
class Gemini:
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
_llm_prompt_template = """
|
12 |
+
Olá, sou uma IA treinada para gerar conteúdo educacional. Por favor, gere cinco questões de múltipla escolha sobre o seguinte tema:
|
13 |
+
Instruções para cada questão:
|
14 |
+
- Crie uma questão clara e relevante para o tema.
|
15 |
+
- Forneça cinco opções de resposta, rotuladas de A) a E).
|
16 |
+
- Apenas uma das opções de resposta deve ser correta.
|
17 |
+
- Indique a resposta correta ao final de cada questão.
|
18 |
|
19 |
+
Exemplo de uma questão:
|
20 |
+
Tema: Fotossíntese
|
21 |
|
22 |
+
Questão:
|
23 |
+
Qual é o pigmento primário responsável pela fotossíntese nas plantas?
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
Opções de Resposta:
|
26 |
+
A) Clorofila
|
27 |
+
B) Hemoglobina
|
28 |
+
C) Mioglobina
|
29 |
+
D) Citocromo
|
30 |
+
E) Queratina
|
31 |
|
32 |
+
Resposta Correta:
|
33 |
+
A) Clorofila
|
|
|
34 |
|
35 |
+
Context: {context}
|
36 |
+
Question: {question}
|
37 |
+
Answer:
|
38 |
|
|
|
|
|
39 |
|
40 |
+
{format_questions_instructions}
|
41 |
+
"""
|
42 |
|
43 |
+
_format_questions_instructions = """
|
44 |
+
The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":
|
45 |
+
```json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
{
|
47 |
+
"questions": [
|
48 |
+
{
|
49 |
+
question: "Qual é o pigmento primário responsável pela fotossíntese nas plantas?",
|
50 |
+
options: ["A) Clorofila",
|
51 |
+
"B) Hemoglobina",
|
52 |
+
"C) Mioglobina",
|
53 |
+
"D) Citocromo",
|
54 |
+
"E) Queratina"],
|
55 |
+
answer: "A"
|
56 |
+
}
|
57 |
+
]
|
58 |
+
```
|
59 |
+
}"""
|
60 |
+
|
61 |
+
def __init__(self):
|
62 |
+
|
63 |
+
if "GOOGLE_API_KEY" not in os.environ:
|
64 |
+
raise ValueError("GOOGLE_API_KEY environment variable is not set")
|
65 |
+
|
66 |
+
self.llm_prompt = PromptTemplate.from_template(self._llm_prompt_template)
|
67 |
+
|
68 |
+
self.embeddings_model = "models/embedding-001"
|
69 |
+
self.model = "gemini-pro"
|
70 |
+
|
71 |
+
self.embeddings = GoogleGenerativeAIEmbeddings(model=self.embeddings_model)
|
72 |
+
self.llm = ChatGoogleGenerativeAI(model=self.model, temperature=0.7, top_p=1)
|
73 |
+
|
74 |
+
self.template = ChatPromptTemplate.from_template(
|
75 |
+
template=self._llm_prompt_template
|
76 |
+
)
|
77 |
+
|
78 |
+
self.chain = LLMChain(llm=self.llm, prompt=self.template)
|
79 |
+
|
80 |
+
self.schemas = [
|
81 |
+
ResponseSchema(
|
82 |
+
name="questions",
|
83 |
+
description="""Give the questions in json as an array""",
|
84 |
+
)
|
85 |
+
]
|
86 |
+
|
87 |
+
self.parser = StructuredOutputParser.from_response_schemas(self.schemas)
|
server/services/generate_questions_service.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.runnables import RunnableLambda
|
2 |
+
from langchain.schema.runnable import RunnablePassthrough
|
3 |
+
from data.retriever import Retriever
|
4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
5 |
+
from llm.gemini import Gemini
|
6 |
+
|
7 |
+
|
8 |
+
class GenerateQuestionsService:
|
9 |
+
|
10 |
+
_retrieve = Retriever()
|
11 |
+
_model = Gemini()
|
12 |
+
|
13 |
+
def handle(self, query: str):
|
14 |
+
|
15 |
+
rag_chain = {
|
16 |
+
"context": self._retrieve.retriever | RunnableLambda(self._format_docs),
|
17 |
+
"question": RunnablePassthrough(),
|
18 |
+
} | RunnableLambda(self._get_questions)
|
19 |
+
|
20 |
+
return rag_chain.invoke(query)
|
21 |
+
|
22 |
+
|
23 |
+
def _get_questions(self, _dict):
|
24 |
+
|
25 |
+
question = _dict["question"]
|
26 |
+
context = _dict["context"]
|
27 |
+
messages = self._model.template.format_messages(
|
28 |
+
context=context,
|
29 |
+
question=question,
|
30 |
+
format_questions_instructions=self._model._format_questions_instructions,
|
31 |
+
)
|
32 |
+
|
33 |
+
tries = 0
|
34 |
+
|
35 |
+
while tries < 3:
|
36 |
+
try:
|
37 |
+
chat = ChatGoogleGenerativeAI(model="gemini-pro")
|
38 |
+
response = chat.invoke(messages)
|
39 |
+
return self._model.parser.parse(response.content)
|
40 |
+
except Exception as e:
|
41 |
+
print(e)
|
42 |
+
tries += 1
|
43 |
+
|
44 |
+
return "Não foi possível gerar as questões."
|
45 |
+
|
46 |
+
|
47 |
+
def _format_docs(self, docs):
|
48 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
49 |
+
|
web/index.html
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
<meta charset="UTF-8" />
|
6 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
7 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
8 |
-
<title>
|
9 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
10 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
11 |
<link href="https://fonts.googleapis.com/css2?family=Inter:[email protected]&display=swap" rel="stylesheet">
|
|
|
5 |
<meta charset="UTF-8" />
|
6 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
7 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
8 |
+
<title>Perguntaí</title>
|
9 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
10 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
11 |
<link href="https://fonts.googleapis.com/css2?family=Inter:[email protected]&display=swap" rel="stylesheet">
|