Spaces:
Build error
Build error
import asyncio | |
import json | |
from annotated_types import Len | |
import structlog | |
from typing import Optional | |
from openai.types.chat import ChatCompletionMessageParam | |
from pydantic import BaseModel, Field, validator | |
from typing_extensions import Annotated, AsyncGenerator, Iterable, Unpack | |
from ariadne import ObjectType, SubscriptionType | |
from graphql import GraphQLResolveInfo | |
from vespa.application import Vespa | |
from vespa.io import VespaQueryResponse | |
from openai import AsyncOpenAI | |
from .data import questions as data_questions, shot_user, shot_assistant | |
from .cache import cache_questions, cache_generate_summary | |
from .settings import VESPA_APP_URL, OPENAI_API_KEY | |
from .generated.schema_types import ( | |
Answer, | |
AnswersParams, | |
AnswersQueryResult, | |
GenerateSummaryParams, | |
GenerateSummarySubscriptionResult, | |
Question, | |
QuestionsParams, | |
QuestionsQueryResult, | |
) | |
clientVespa = Vespa(url=VESPA_APP_URL) | |
clientOpenAI = AsyncOpenAI( | |
api_key=str(OPENAI_API_KEY), | |
) | |
logger = structlog.get_logger("qa") | |
query = ObjectType("Query") | |
class QaFieldModel(BaseModel): | |
sddocname: str | |
documentid: str | |
doc_id: str | |
category_major: Optional[str] = None | |
category_medium: Optional[str] = None | |
category_minor: Optional[str] = None | |
question: str | |
answer: str | |
class QaModel(BaseModel): | |
id: str | |
relevance: float | |
source: str | |
fields: QaFieldModel | |
class AnswersParamsModel(BaseModel): | |
query: Optional[str] = Field(strict=True, max_length=1024) | |
async def resolve_answer( | |
_, info: GraphQLResolveInfo, **params: Unpack[AnswersParams] | |
) -> AnswersQueryResult: | |
assert info is not None, "Prevent type check error" | |
params_parsed = AnswersParamsModel.model_validate(params, strict=True) | |
answers: list[Answer] = [] | |
query = params_parsed.query | |
if not query: | |
logger.warning("Query is empty", params=params_parsed) | |
return {"answers": answers} | |
query_parsed = ( | |
query.replace("\\", "\\\\") | |
.replace('"', '\\"') | |
.replace(":", "\\:") | |
.replace(")", "\\)") | |
) | |
base = "select * from qa where" | |
anno = "{targetHits:100,approximate:false}" | |
cond01 = f"({anno}nearestNeighbor(answer_embedding_me5s, q))" | |
cond02 = f"({anno}nearestNeighbor(question_embedding_me5s, q))" | |
async with clientVespa.asyncio() as sess: | |
res: VespaQueryResponse = await sess.query( | |
yql=f"{base} {cond01} or {cond02}", | |
lang="ja", | |
hits=20, | |
ranking="semantic", | |
body={ | |
"input.query(q)": f'embed(multilingual-e5-small, "query: {query_parsed}")', | |
}, | |
) | |
if not res.is_successful(): | |
logger.warning("Vespa query failed", json=res.json, status=res.status_code) | |
return {"answers": answers} | |
hits = [QaModel.model_validate(hit, strict=True) for hit in res.hits] | |
answers = [ | |
Answer( | |
id=hit.fields.doc_id, | |
docId=hit.fields.doc_id, | |
categoryMajor=hit.fields.category_major, | |
categoryMedium=hit.fields.category_medium, | |
categoryMinor=hit.fields.category_minor, | |
question=hit.fields.question, | |
answer=hit.fields.answer, | |
score=hit.relevance, | |
) | |
for hit in hits | |
] | |
return {"answers": answers} | |
class QuestionsParamsModel(BaseModel): | |
query: Optional[str] = Field(strict=True, max_length=1024) | |
async def resolve_question( | |
_, info: GraphQLResolveInfo, **params: Unpack[QuestionsParams] | |
) -> QuestionsQueryResult: | |
assert info is not None, "Prevent type check error" | |
params_parsed = QuestionsParamsModel.model_validate(params, strict=True) | |
questions: list[Question] = data_questions | |
query = params_parsed.query | |
if not query: | |
logger.warning("Query is empty", params=params_parsed) | |
return {"questions": questions} | |
query_parsed = ( | |
query.replace("\\", "\\\\") | |
.replace('"', '\\"') | |
.replace(":", "\\:") | |
.replace(")", "\\)") | |
) | |
cached_questions = await cache_questions.get(query) | |
if isinstance(cached_questions, list): | |
return {"questions": cached_questions} | |
base = "select * from qa where" | |
anno = "{targetHits:100,approximate:false}" | |
cond01 = "({targetHits:100}userInput(@condQuery))" | |
cond02 = f"({anno}nearestNeighbor(question_embedding_me5s, q))" | |
async with clientVespa.asyncio() as sess: | |
res: VespaQueryResponse = await sess.query( | |
yql=f"{base} {cond01} or {cond02}", | |
lang="ja", | |
hits=20, | |
ranking="question_semantic", | |
body={ | |
"condQuery": query, | |
"input.query(q)": f'embed(multilingual-e5-small, "query: {query_parsed}")', | |
}, | |
) | |
if not res.is_successful(): | |
logger.warning("Vespa query failed", json=res.json, status=res.status_code) | |
return {"questions": questions} | |
hits = [QaModel.model_validate(hit, strict=True) for hit in res.hits] | |
questions = [ | |
Question( | |
id=hit.fields.doc_id, | |
docId=hit.fields.doc_id, | |
categoryMajor=hit.fields.category_major, | |
categoryMedium=hit.fields.category_medium, | |
categoryMinor=hit.fields.category_minor, | |
question=hit.fields.question, | |
) | |
for hit in hits | |
] | |
await cache_questions.set(query, questions) | |
return {"questions": questions} | |
subscription = SubscriptionType() | |
class GenerateSummaryParamsModel(BaseModel): | |
query: str = Field(strict=True, max_length=1024) | |
docIds: Annotated[list[str], Len(max_length=10)] | |
def check_max_length(cls, v): | |
if len(v) > 1024: | |
raise ValueError("string length exceeds maximum of 1024") | |
return v | |
async def generate_generate_summary( | |
_, info: GraphQLResolveInfo, **params: Unpack[GenerateSummaryParams] | |
) -> AsyncGenerator[str, str]: | |
assert info is not None, "Prevent type check error" | |
params_parsed = GenerateSummaryParamsModel.model_validate(params, strict=True) | |
if not params_parsed.query: | |
logger.warning("No query found", params=params_parsed) | |
return | |
doc_ids = params_parsed.docIds or [] | |
if not doc_ids: | |
logger.warning("No docIds found", params=params_parsed) | |
return | |
key = params_parsed.query + "|" + "|".join(sorted(doc_ids)) | |
cached_summary = await cache_generate_summary.get(key) | |
if isinstance(cached_summary, str): | |
for char in cached_summary: | |
yield char | |
await asyncio.sleep(0.05) | |
return | |
query_in = ", ".join( | |
['"' + x.replace("\\", "\\\\").replace('"', '\\"') + '"' for x in doc_ids] | |
) | |
answers = [] | |
async with clientVespa.asyncio() as sess: | |
res: VespaQueryResponse = await sess.query( | |
yql=f"select * from qa where doc_id in ({query_in})", | |
lang="ja", | |
hits=5, | |
) | |
if not res.is_successful(): | |
logger.warning("Vespa query failed", json=res.json, status=res.status_code) | |
return | |
hits = [QaModel.model_validate(hit, strict=True) for hit in res.hits] | |
answers = [ | |
{ | |
"docId": hit.fields.doc_id, | |
"answer": hit.fields.answer, | |
"score": hit.relevance, | |
} | |
for hit in hits | |
] | |
if not answers: | |
logger.warning("No answers found", params=params_parsed) | |
return | |
system = """ใใชใใซใฏ่ณชๅ(question)ใจๅ่่ณๆ(references)ใไธใใใใพใใ | |
ใใชใใฎไปไบใฏไปฅไธใฎ2ใคใงใใ | |
- ไธใใใใๅ่่ณๆใซใใใใฆใใๆ ๅ ฑใฎใฟใไฝฟใฃใฆ่ณชๅใซๅ็ญใใใ | |
- ๅ่่ณๆใฎ่ฆ็ดใใใใใใใใพใจใใใ | |
ไปฅไธใฎใซใผใซใซๅพใฃใฆใใ ใใ: | |
- ๅ็ญใซใฏๅ่่ณๆใซๆธใใใฆใใๆญฃ็ขบใชๆ ๅ ฑใฎใฟใๅๆ ใใฆใใ ใใใ | |
- ๅ็ญใ่ฆ็ดใซใฏๅค้จใฎๆ ๅ ฑใๆ้ปใฎ็ฅ่ญใฏๅๆ ใใชใใงใใ ใใใ | |
ไปฅไธใฎใใฉใผใใใใงๅบๅใใฆใใ ใใ: | |
``` | |
### ๅ็ญ | |
ใใใซๅ็ญใๆธใใฆใใ ใใใ | |
### ่ฆ็ด | |
ใใใซ่ฆ็ดใๆธใใฆใใ ใใใ | |
``` | |
""" | |
user = f""" | |
## ่ณชๅ | |
{params_parsed.query} | |
## ๅ่่ณๆ | |
{json.dumps(answers, ensure_ascii=False, indent=2)} | |
""" | |
messages: Iterable[ChatCompletionMessageParam] = [ | |
{"role": "system", "name": "instruction", "content": system}, | |
{"role": "user", "name": "info", "content": shot_user}, | |
{"role": "assistant", "name": "summary", "content": shot_assistant}, | |
{"role": "user", "name": "info", "content": user}, | |
] | |
print("OpenAI chat completions", f"messages={messages}") | |
stream = await clientOpenAI.chat.completions.create( | |
messages=messages, | |
model="gpt-4-turbo-2024-04-09", | |
stream=True, | |
) | |
summary = "" | |
async for chunk in stream: | |
content = chunk.choices[0].delta.content or "" | |
summary += content | |
# FIXME: sanitize to return only elements that are not dangerous as markdown | |
yield content | |
await cache_generate_summary.set(key, summary) | |
return | |
def resolve_generate_summary( | |
summary: str, info: GraphQLResolveInfo, **params: Unpack[GenerateSummaryParams] | |
) -> GenerateSummarySubscriptionResult: | |
assert info and params, "Prevent type check error" | |
return {"summary": summary} | |