Procházet zdrojové kódy

更新 API 端点以支持腾讯云配置,添加 .gitignore 文件,调整依赖项

孔祥赫 před 6 dny
rodič
revize
c37e793679

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+.env*
+**/*.pyc
+**/*.code-workspace

+ 6 - 6
enpoints/get_upload_credential.py

@@ -17,18 +17,18 @@ actions = [
 
 @app.post('/get_upload_credential')
 def get_upload_credential():
-	p = f'{settings.upload_prefix}/{uuid.uuid7()}'
+	p = f'{settings.tencent_cloud.upload_prefix}/{uuid.uuid7()}'
 	config = {
 		'duration_seconds': 600,
-		'secret_id': settings.tencent.secret_id,
-		'secret_key': settings.tencent.secret_key,
-		'region': settings.tencent.region,
+		'secret_id': settings.tencent_cloud.secret_id,
+		'secret_key': settings.tencent_cloud.secret_key,
+		'region': settings.tencent_cloud.region,
 		'policy': Sts.get_policy(
 			[
 				Scope(
 					action,
-					settings.tencent.bucket,
-					settings.tencent.region,
+					settings.tencent_cloud.bucket,
+					settings.tencent_cloud.region,
 					p,
 				)
 				for action in actions

+ 3 - 3
enpoints/notify_upload_complete.py

@@ -14,7 +14,7 @@ class FileNotify(BaseModel):
 
 lock = Lock()
 
-async def upload_to_vdb(local_path: str):
+def upload_to_vdb(local_path: str):
 	collection_view.load_and_split_text(
 		local_file_path=local_path,
 		metadata={},
@@ -30,11 +30,11 @@ async def upload_to_vdb(local_path: str):
 
 async def process_uploaded_file(notify: FileNotify):
 	async with lock:
-		if not await to_thread(cos.object_exists, settings.bucket, notify.key):
+		if not await to_thread(cos.object_exists, settings.tencent_cloud.bucket, notify.key):
 			return
 
 		local_path = f'/tmp/{notify.key}'
-		await to_thread(cos.download_file, settings.bucket, notify.key, local_path)
+		await to_thread(cos.download_file, settings.tencent_cloud.bucket, notify.key, local_path)
 		await to_thread(upload_to_vdb, local_path)
 
 		Path(local_path).unlink()

+ 5 - 5
enpoints/retrieval.py

@@ -1,6 +1,5 @@
 from typing import Dict
 from pydantic import BaseModel
-from 
 
 from main import app
 from services import collection_view
@@ -24,14 +23,15 @@ def retrieval(req: Retrieval):
 		expand_chunk=[1, 1],
 		limit=req.retrieval_setting.top_k,
 	)
+	chunks = [vars(i) for i in chunks]
 	return {
 		'records': [
 			{
-				'content': chunk.data.text,
-				'score': chunk.score,
-				'title': chunk.data.documentSet.documentSetName,
+				'content': chunk['data']['text'],
+				'score': chunk['score'],
+				'title': chunk['documentSet']['documentSetName'],
 				'metadata': {
-					'document_id': str(chunk.data.documentSet.documentSetId),
+					'document_id': str(chunk['documentSet']['documentSetId']),
 				},
 			} for chunk in chunks
 		]

+ 2 - 2
requirements.txt

@@ -1,5 +1,5 @@
-fastapi==0.135.1
-pydantic-settings==2.13.1
+fastapi[standard]==0.135.1
+pydantic-settings[yaml]==2.13.1
 qcloud-python-sts==3.1.6
 tencentcloud-sdk-python-common==3.1.53
 cos-python-sdk-v5==1.9.41

+ 24 - 14
services.py

@@ -4,39 +4,48 @@ import tcvectordb
 from tcvectordb.model.ai_database import AIDatabase
 from tcvectordb.model.collection_view import CollectionView, Embedding, SplitterProcess, ParsingProcess
 from tcvectordb.model.index import Index
-from tcvectordb.exceptions import DescribeCollectionException
+from tcvectordb.exceptions import ServerInternalError
+from rich.console import Console
 
 from settings import settings
 
+console = Console()
+
 cos = CosS3Client(
 	CosConfig(
-		Region=settings.region,
-		SecretId=settings.secret_id,
-		SecretKey=settings.secret_key,
+		Region=settings.tencent_cloud.region,
+		SecretId=settings.tencent_cloud.secret_id,
+		SecretKey=settings.tencent_cloud.secret_key,
 	),
 )
 
 vdb = tcvectordb.RPCVectorDBClient(
-	url=settings.VDB_config.url,
-	username=settings.VDB_config.username,
-	key=settings.VDB_config.key,
+	url=str(settings.VDB.url),
+	username=settings.VDB.username,
+	key=settings.VDB.key,
 )
 
 
 def create_ai_database_if_not_exists(database_name: str) -> AIDatabase:
 	if vdb.exists_db(database_name):
+		console.log(f'[green]Database "{database_name}" already exists.[/green]')
 		return vdb.database(database_name)
 	else:
+		console.log(f'[green]Creating database "{database_name}"...[/green]')
 		return vdb.create_ai_database(database_name)
 
 def create_collection_view_if_not_exists(database: AIDatabase, collection_name: str) -> CollectionView:
 	try:
-		return database.collection_view(collection_name)
-	except DescribeCollectionException:
+		console.log(f'[green]Checking if collection view "{collection_name}" exists...[/green]')
+		return database.describe_collection_view(collection_name)
+	except ServerInternalError as e:
+		if not e.message.startswith(f'CollectionView not exist: {collection_name}'):
+			raise e
+		console.log(f'[green]Creating collection view "{collection_name}"...[/green]')
 		return database.create_collection_view(
 			collection_name,
 			embedding=Embedding(
-				language='MULTI',
+				language='multi',
 				enable_words_embedding=True,	
 			),
 			splitter_process=SplitterProcess(
@@ -50,9 +59,10 @@ def create_collection_view_if_not_exists(database: AIDatabase, collection_name:
 
 def create_db_collection_view_if_not_exists(database: str, collection: str) -> CollectionView:
 	db = create_ai_database_if_not_exists(database)
-	return create_collection_view_if_not_exists(db, collection)
+	return db, create_collection_view_if_not_exists(db, collection)
 
-collection_view = create_db_collection_view_if_not_exists(
-	settings.VDB_config.database,
-	settings.VDB_config.collection,
+database, collection_view = create_db_collection_view_if_not_exists(
+	settings.VDB.database,
+	settings.VDB.collection,
 )
+console.log(f'[green]Collection view "{settings.VDB.collection}" is ready.[/green]')

+ 24 - 12
settings.py

@@ -1,20 +1,32 @@
-from pydantic_settings import BaseSettings
+from pydantic_settings import BaseSettings, SettingsConfigDict, YamlConfigSettingsSource
 from pydantic import Field, BaseModel, HttpUrl
 
 class VDBSettings(BaseModel):
-	url: HttpUrl = Field(validation_alias='VDB_URL')
-	username: str = Field(validation_alias='VDB_USERNAME')
-	key: str = Field(validation_alias='VDB_KEY')
-	database: str = Field(validation_alias='VDB_DATABASE')
-	collection: str = Field(validation_alias='VDB_COLLECTION')
+	url: HttpUrl
+	username: str
+	key: str
+	database: str
+	collection: str
+
+class TencentCloudSettings(BaseModel):
+	secret_id: str
+	secret_key: str
+	region: str
+	bucket: str
+	upload_prefix: str
 
 class Settings(BaseSettings):
-	secret_id: str = Field(validation_alias='TENCENT_SECRET_ID')
-	secret_key: str = Field(validation_alias='TENCENT_SECRET_KEY')
-	region: str = Field(validation_alias='TENCENT_REGION')
-	bucket: str = Field(validation_alias='TENCENT_BUCKET')
-	upload_prefix: str = Field(validation_alias='UPLOAD_PREFIX')
+	tencent_cloud: TencentCloudSettings
+	VDB: VDBSettings
+
+	model_config = SettingsConfigDict(
+		yaml_file=('.env.yaml'),
+	)
 
-	VDB_config: VDBSettings
+	@classmethod
+	def settings_customise_sources(cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings):
+		return (
+			YamlConfigSettingsSource(settings_cls),
+		)
 
 settings = Settings()