try fix similaruuty and add seed for master excel icd
This commit is contained in:
@@ -64,7 +64,7 @@ export class HealthController {
|
||||
status: 200,
|
||||
description: 'Application is ready',
|
||||
})
|
||||
async getReady() {
|
||||
getReady() {
|
||||
return { status: 'ready' };
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ export class HealthController {
|
||||
status: 200,
|
||||
description: 'Application is alive',
|
||||
})
|
||||
async getLive() {
|
||||
getLive() {
|
||||
return { status: 'alive' };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Controller, Get, Post, Query, Logger } from '@nestjs/common';
|
||||
import { Controller, Get, Query, Logger } from '@nestjs/common';
|
||||
import {
|
||||
ApiTags,
|
||||
ApiOperation,
|
||||
@@ -8,10 +8,8 @@ import {
|
||||
ApiInternalServerErrorResponse,
|
||||
} from '@nestjs/swagger';
|
||||
import { IcdService } from './icd.service';
|
||||
import { SearchIcdDto } from './dto/search-icd.dto';
|
||||
import {
|
||||
IcdSearchResponseDto,
|
||||
IcdImportResponseDto,
|
||||
IcdStatisticsResponseDto,
|
||||
ErrorResponseDto,
|
||||
} from './dto/icd-response.dto';
|
||||
@@ -23,40 +21,6 @@ export class IcdController {
|
||||
|
||||
constructor(private readonly icdService: IcdService) {}
|
||||
|
||||
@Post('import')
|
||||
@ApiOperation({
|
||||
summary: 'Import ICD data from Excel files',
|
||||
description:
|
||||
'Import ICD-9 and ICD-10 codes from Excel files located in the test directory. This operation will process both ICD files and insert/update the database with the latest codes.',
|
||||
})
|
||||
@ApiResponse({
|
||||
status: 200,
|
||||
description: 'ICD data imported successfully',
|
||||
type: IcdImportResponseDto,
|
||||
})
|
||||
@ApiBadRequestResponse({
|
||||
description: 'Bad request - Invalid file format or missing files',
|
||||
type: ErrorResponseDto,
|
||||
})
|
||||
@ApiInternalServerErrorResponse({
|
||||
description: 'Internal server error during import process',
|
||||
type: ErrorResponseDto,
|
||||
})
|
||||
async importData(): Promise<IcdImportResponseDto> {
|
||||
try {
|
||||
this.logger.log('Starting ICD data import...');
|
||||
const result = await this.icdService.importIcdData();
|
||||
return {
|
||||
success: true,
|
||||
message: 'ICD data imported successfully',
|
||||
data: result,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error('Error importing ICD data:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
@Get('search')
|
||||
@ApiOperation({
|
||||
summary: 'Search ICD codes with filters and pagination',
|
||||
|
||||
@@ -1,158 +1,11 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { PrismaClient } from '../../generated/prisma';
|
||||
import * as XLSX from 'xlsx';
|
||||
import * as path from 'path';
|
||||
import * as fs from 'fs';
|
||||
|
||||
interface IcdData {
|
||||
code: string;
|
||||
display: string;
|
||||
version: string;
|
||||
}
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class IcdService {
|
||||
private readonly logger = new Logger(IcdService.name);
|
||||
private readonly prisma = new PrismaClient();
|
||||
|
||||
async importIcdData(): Promise<{
|
||||
icd9Count: number;
|
||||
icd10Count: number;
|
||||
total: number;
|
||||
}> {
|
||||
try {
|
||||
this.logger.log('Starting ICD data import...');
|
||||
|
||||
// Import ICD-9 data
|
||||
const icd9Data = await this.readExcelFile(
|
||||
'test/[PUBLIC] ICD-9CM e-klaim.xlsx',
|
||||
'ICD9',
|
||||
);
|
||||
|
||||
// Import ICD-10 data
|
||||
const icd10Data = await this.readExcelFile(
|
||||
'test/[PUBLIC] ICD-10 e-klaim.xlsx',
|
||||
'ICD10',
|
||||
);
|
||||
|
||||
// Clear existing data
|
||||
await this.prisma.icdCode.deleteMany({});
|
||||
this.logger.log('Cleared existing ICD data');
|
||||
|
||||
// Insert ICD-9 data
|
||||
const icd9Count = await this.bulkInsertData(icd9Data, 'ICD9');
|
||||
this.logger.log(`Imported ${icd9Count} ICD-9 codes`);
|
||||
|
||||
// Insert ICD-10 data
|
||||
const icd10Count = await this.bulkInsertData(icd10Data, 'ICD10');
|
||||
this.logger.log(`Imported ${icd10Count} ICD-10 codes`);
|
||||
|
||||
const total = icd9Count + icd10Count;
|
||||
this.logger.log(`Total imported: ${total} ICD codes`);
|
||||
|
||||
return {
|
||||
icd9Count,
|
||||
icd10Count,
|
||||
total,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error('Error importing ICD data:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private async readExcelFile(
|
||||
filePath: string,
|
||||
category: string,
|
||||
): Promise<IcdData[]> {
|
||||
try {
|
||||
const fullPath = path.join(process.cwd(), filePath);
|
||||
|
||||
if (!fs.existsSync(fullPath)) {
|
||||
throw new Error(`File not found: ${fullPath}`);
|
||||
}
|
||||
|
||||
this.logger.log(`Reading ${category} file: ${filePath}`);
|
||||
|
||||
const workbook = XLSX.readFile(fullPath);
|
||||
const sheetName = workbook.SheetNames[0];
|
||||
const worksheet = workbook.Sheets[sheetName];
|
||||
|
||||
// Convert sheet to JSON
|
||||
const jsonData = XLSX.utils.sheet_to_json(worksheet, { header: 1 });
|
||||
|
||||
// Skip header row and process data
|
||||
const icdData: IcdData[] = [];
|
||||
|
||||
for (let i = 1; i < jsonData.length; i++) {
|
||||
const row = jsonData[i] as any[];
|
||||
|
||||
if (row && row.length >= 3) {
|
||||
const code = this.cleanString(row[0]);
|
||||
const display = this.cleanString(row[1]);
|
||||
const version = this.cleanString(row[2]);
|
||||
|
||||
if (code && display && version) {
|
||||
icdData.push({
|
||||
code,
|
||||
display,
|
||||
version,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.logger.log(`Found ${icdData.length} valid ${category} records`);
|
||||
return icdData;
|
||||
} catch (error) {
|
||||
this.logger.error(`Error reading ${category} file:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private async bulkInsertData(
|
||||
data: IcdData[],
|
||||
category: string,
|
||||
): Promise<number> {
|
||||
try {
|
||||
const batchSize = 1000;
|
||||
let totalInserted = 0;
|
||||
|
||||
for (let i = 0; i < data.length; i += batchSize) {
|
||||
const batch = data.slice(i, i + batchSize);
|
||||
|
||||
const insertData = batch.map((item) => ({
|
||||
code: item.code,
|
||||
display: item.display,
|
||||
version: item.version,
|
||||
category,
|
||||
}));
|
||||
|
||||
await this.prisma.icdCode.createMany({
|
||||
data: insertData,
|
||||
skipDuplicates: true,
|
||||
});
|
||||
|
||||
totalInserted += batch.length;
|
||||
this.logger.log(
|
||||
`Inserted batch ${Math.floor(i / batchSize) + 1} for ${category}: ${batch.length} records`,
|
||||
);
|
||||
}
|
||||
|
||||
return totalInserted;
|
||||
} catch (error) {
|
||||
this.logger.error(`Error inserting ${category} data:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private cleanString(value: any): string {
|
||||
if (value === null || value === undefined) {
|
||||
return '';
|
||||
}
|
||||
return String(value).trim();
|
||||
}
|
||||
|
||||
async findIcdCodes(
|
||||
category?: string,
|
||||
search?: string,
|
||||
|
||||
@@ -8,6 +8,15 @@ import {
|
||||
ValidationPipe,
|
||||
UsePipes,
|
||||
} from '@nestjs/common';
|
||||
import {
|
||||
IsString,
|
||||
IsOptional,
|
||||
IsNumber,
|
||||
IsEnum,
|
||||
Min,
|
||||
Max,
|
||||
IsNotEmpty,
|
||||
} from 'class-validator';
|
||||
import {
|
||||
ApiTags,
|
||||
ApiOperation,
|
||||
@@ -27,6 +36,8 @@ export class VectorSearchDto {
|
||||
minLength: 1,
|
||||
maxLength: 500,
|
||||
})
|
||||
@IsString()
|
||||
@IsNotEmpty()
|
||||
query: string;
|
||||
|
||||
@ApiProperty({
|
||||
@@ -37,6 +48,10 @@ export class VectorSearchDto {
|
||||
maximum: 100,
|
||||
default: 10,
|
||||
})
|
||||
@IsOptional()
|
||||
@IsNumber()
|
||||
@Min(1)
|
||||
@Max(100)
|
||||
limit?: number;
|
||||
|
||||
@ApiProperty({
|
||||
@@ -46,16 +61,22 @@ export class VectorSearchDto {
|
||||
enum: ['ICD9', 'ICD10'],
|
||||
default: undefined,
|
||||
})
|
||||
@IsOptional()
|
||||
@IsEnum(['ICD9', 'ICD10'])
|
||||
category?: string;
|
||||
|
||||
@ApiProperty({
|
||||
description: 'Similarity threshold (0.0 - 1.0) for filtering results',
|
||||
example: 0.7,
|
||||
example: 0.85,
|
||||
required: false,
|
||||
minimum: 0.0,
|
||||
maximum: 1.0,
|
||||
default: 0.7,
|
||||
default: 0.85,
|
||||
})
|
||||
@IsOptional()
|
||||
@IsNumber()
|
||||
@Min(0.0)
|
||||
@Max(1.0)
|
||||
threshold?: number;
|
||||
}
|
||||
|
||||
@@ -66,6 +87,8 @@ export class EmbeddingRequestDto {
|
||||
minLength: 1,
|
||||
maxLength: 1000,
|
||||
})
|
||||
@IsString()
|
||||
@IsNotEmpty()
|
||||
text: string;
|
||||
|
||||
@ApiProperty({
|
||||
@@ -74,9 +97,24 @@ export class EmbeddingRequestDto {
|
||||
required: false,
|
||||
default: 'text-embedding-ada-002',
|
||||
})
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
model?: string;
|
||||
}
|
||||
|
||||
export class ThresholdConfigDto {
|
||||
@ApiProperty({
|
||||
description: 'Similarity threshold value (0.0 - 1.0)',
|
||||
example: 0.85,
|
||||
minimum: 0.0,
|
||||
maximum: 1.0,
|
||||
})
|
||||
@IsNumber()
|
||||
@Min(0.0)
|
||||
@Max(1.0)
|
||||
threshold: number;
|
||||
}
|
||||
|
||||
export class VectorSearchResponseDto {
|
||||
@ApiProperty({
|
||||
description: 'Array of search results with similarity scores',
|
||||
@@ -486,6 +524,61 @@ export class PgVectorController {
|
||||
};
|
||||
}
|
||||
|
||||
@Post('advanced-search')
|
||||
@ApiOperation({
|
||||
summary: 'Advanced vector similarity search',
|
||||
description:
|
||||
'Advanced vector search using multiple similarity metrics (cosine + euclidean) for more accurate results with higher threshold.',
|
||||
tags: ['PgVector Operations'],
|
||||
})
|
||||
@ApiConsumes('application/json')
|
||||
@ApiProduces('application/json')
|
||||
@ApiBody({
|
||||
type: VectorSearchDto,
|
||||
description: 'Search parameters for advanced vector search',
|
||||
examples: {
|
||||
highPrecision: {
|
||||
summary: 'High precision search',
|
||||
value: {
|
||||
query: 'diabetes mellitus type 2',
|
||||
limit: 10,
|
||||
category: 'ICD10',
|
||||
threshold: 0.9,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.OK,
|
||||
description:
|
||||
'Advanced vector search results with enhanced similarity scores',
|
||||
type: VectorSearchResponseDto,
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.BAD_REQUEST,
|
||||
description: 'Invalid search parameters',
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
description: 'Error during advanced vector search',
|
||||
})
|
||||
async advancedVectorSearch(
|
||||
@Body() searchDto: VectorSearchDto,
|
||||
): Promise<VectorSearchResponseDto> {
|
||||
const results = await this.pgVectorService.advancedVectorSearch(
|
||||
searchDto.query,
|
||||
searchDto.limit || 10,
|
||||
searchDto.category,
|
||||
searchDto.threshold,
|
||||
);
|
||||
|
||||
return {
|
||||
data: results,
|
||||
total: results.length,
|
||||
query: searchDto.query,
|
||||
};
|
||||
}
|
||||
|
||||
@Post('generate-embedding')
|
||||
@ApiOperation({
|
||||
summary: 'Generate text embedding',
|
||||
@@ -570,6 +663,50 @@ export class PgVectorController {
|
||||
};
|
||||
}
|
||||
|
||||
@Post('regenerate-embeddings-enhanced')
|
||||
@ApiOperation({
|
||||
summary: 'Regenerate embeddings with enhanced text representation',
|
||||
description:
|
||||
'Regenerate existing embeddings using enhanced text representation for better similarity scores. This improves search quality.',
|
||||
tags: ['PgVector Operations'],
|
||||
})
|
||||
@ApiConsumes('application/json')
|
||||
@ApiProduces('application/json')
|
||||
@ApiResponse({
|
||||
status: HttpStatus.OK,
|
||||
description: 'Enhanced embedding regeneration results summary',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
processed: { type: 'number', example: 100 },
|
||||
errors: { type: 'number', example: 0 },
|
||||
totalSample: { type: 'number', example: 100 },
|
||||
message: {
|
||||
type: 'string',
|
||||
example: 'Enhanced embeddings regenerated successfully',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
description: 'Error during enhanced embedding regeneration',
|
||||
})
|
||||
async regenerateEmbeddingsEnhanced(): Promise<{
|
||||
processed: number;
|
||||
errors: number;
|
||||
totalSample: number;
|
||||
message: string;
|
||||
}> {
|
||||
const result =
|
||||
await this.pgVectorService.regenerateEmbeddingsWithEnhancedText();
|
||||
|
||||
return {
|
||||
...result,
|
||||
message: `Enhanced embeddings regenerated successfully. Processed: ${result.processed}, Errors: ${result.errors}`,
|
||||
};
|
||||
}
|
||||
|
||||
@Get('stats')
|
||||
@ApiOperation({
|
||||
summary: 'Get embedding statistics',
|
||||
@@ -640,6 +777,234 @@ export class PgVectorController {
|
||||
};
|
||||
}
|
||||
|
||||
@Get('threshold')
|
||||
@ApiOperation({
|
||||
summary: 'Get current similarity threshold',
|
||||
description:
|
||||
'Get the current similarity threshold configuration used for vector search filtering.',
|
||||
tags: ['PgVector Operations'],
|
||||
})
|
||||
@ApiProduces('application/json')
|
||||
@ApiResponse({
|
||||
status: HttpStatus.OK,
|
||||
description: 'Current similarity threshold configuration',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
threshold: {
|
||||
type: 'number',
|
||||
description: 'Current similarity threshold value',
|
||||
example: 0.85,
|
||||
},
|
||||
description: {
|
||||
type: 'string',
|
||||
description: 'Description of the threshold setting',
|
||||
example: 'Minimum similarity score required for search results',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
description: 'Error retrieving threshold configuration',
|
||||
})
|
||||
async getSimilarityThreshold(): Promise<{
|
||||
threshold: number;
|
||||
description: string;
|
||||
}> {
|
||||
const threshold = this.pgVectorService.getSimilarityThreshold();
|
||||
return {
|
||||
threshold,
|
||||
description:
|
||||
'Minimum similarity score required for search results (0.0 - 1.0)',
|
||||
};
|
||||
}
|
||||
|
||||
@Get('model')
|
||||
@ApiOperation({
|
||||
summary: 'Get current embedding model',
|
||||
description:
|
||||
'Get the current OpenAI embedding model configuration used for vector generation.',
|
||||
tags: ['PgVector Operations'],
|
||||
})
|
||||
@ApiProduces('application/json')
|
||||
@ApiResponse({
|
||||
status: HttpStatus.OK,
|
||||
description: 'Current embedding model configuration',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
model: {
|
||||
type: 'string',
|
||||
description: 'Current embedding model name',
|
||||
example: 'text-embedding-ada-002',
|
||||
},
|
||||
description: {
|
||||
type: 'string',
|
||||
description: 'Description of the model configuration',
|
||||
example: 'OpenAI embedding model for vector generation',
|
||||
},
|
||||
source: {
|
||||
type: 'string',
|
||||
description: 'Source of the model configuration',
|
||||
example: 'Environment Variable',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
description: 'Error retrieving model configuration',
|
||||
})
|
||||
async getEmbeddingModel(): Promise<{
|
||||
model: string;
|
||||
description: string;
|
||||
source: string;
|
||||
}> {
|
||||
const model = this.pgVectorService.getEmbeddingModel();
|
||||
const source = process.env.OPENAI_API_MODEL
|
||||
? 'Environment Variable'
|
||||
: 'Default';
|
||||
|
||||
return {
|
||||
model,
|
||||
description: 'OpenAI embedding model for vector generation',
|
||||
source,
|
||||
};
|
||||
}
|
||||
|
||||
@Post('threshold')
|
||||
@ApiOperation({
|
||||
summary: 'Set similarity threshold',
|
||||
description:
|
||||
'Set the similarity threshold for vector search filtering. Higher values result in more strict matching.',
|
||||
tags: ['PgVector Operations'],
|
||||
})
|
||||
@ApiConsumes('application/json')
|
||||
@ApiProduces('application/json')
|
||||
@ApiBody({
|
||||
type: ThresholdConfigDto,
|
||||
description: 'Threshold configuration parameters',
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.OK,
|
||||
description: 'Similarity threshold updated successfully',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
message: {
|
||||
type: 'string',
|
||||
description: 'Success message',
|
||||
example: 'Similarity threshold updated successfully',
|
||||
},
|
||||
threshold: {
|
||||
type: 'number',
|
||||
description: 'Updated threshold value',
|
||||
example: 0.9,
|
||||
},
|
||||
previousThreshold: {
|
||||
type: 'number',
|
||||
description: 'Previous threshold value',
|
||||
example: 0.85,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.BAD_REQUEST,
|
||||
description: 'Invalid threshold value (must be between 0.0 and 1.0)',
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
description: 'Error updating threshold configuration',
|
||||
})
|
||||
async setSimilarityThreshold(
|
||||
@Body() thresholdConfig: ThresholdConfigDto,
|
||||
): Promise<{
|
||||
message: string;
|
||||
threshold: number;
|
||||
previousThreshold: number;
|
||||
}> {
|
||||
const previousThreshold = this.pgVectorService.getSimilarityThreshold();
|
||||
this.pgVectorService.setSimilarityThreshold(thresholdConfig.threshold);
|
||||
|
||||
return {
|
||||
message: 'Similarity threshold updated successfully',
|
||||
threshold: thresholdConfig.threshold,
|
||||
previousThreshold,
|
||||
};
|
||||
}
|
||||
|
||||
@Post('model')
|
||||
@ApiOperation({
|
||||
summary: 'Set embedding model',
|
||||
description:
|
||||
'Set the OpenAI embedding model for vector generation. This will reinitialize the embeddings service.',
|
||||
tags: ['PgVector Operations'],
|
||||
})
|
||||
@ApiConsumes('application/json')
|
||||
@ApiProduces('application/json')
|
||||
@ApiBody({
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
model: {
|
||||
type: 'string',
|
||||
description: 'OpenAI embedding model name',
|
||||
example: 'text-embedding-ada-002',
|
||||
},
|
||||
},
|
||||
required: ['model'],
|
||||
},
|
||||
description: 'Model configuration parameters',
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.OK,
|
||||
description: 'Embedding model updated successfully',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
message: {
|
||||
type: 'string',
|
||||
description: 'Success message',
|
||||
example: 'Embedding model updated successfully',
|
||||
},
|
||||
model: {
|
||||
type: 'string',
|
||||
description: 'Updated model name',
|
||||
example: 'text-embedding-ada-002',
|
||||
},
|
||||
previousModel: {
|
||||
type: 'string',
|
||||
description: 'Previous model name',
|
||||
example: 'text-embedding-ada-002',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.BAD_REQUEST,
|
||||
description: 'Invalid model name',
|
||||
})
|
||||
@ApiResponse({
|
||||
status: HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
description: 'Error updating model configuration',
|
||||
})
|
||||
async setEmbeddingModel(@Body() body: { model: string }): Promise<{
|
||||
message: string;
|
||||
model: string;
|
||||
previousModel: string;
|
||||
}> {
|
||||
const previousModel = this.pgVectorService.getEmbeddingModel();
|
||||
await this.pgVectorService.setEmbeddingModel(body.model);
|
||||
|
||||
return {
|
||||
message: 'Embedding model updated successfully',
|
||||
model: body.model,
|
||||
previousModel,
|
||||
};
|
||||
}
|
||||
|
||||
@Post('refresh')
|
||||
@ApiOperation({
|
||||
summary: 'Refresh pgvector store',
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { PrismaClient } from '../../generated/prisma';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { OpenAIEmbeddings } from '@langchain/openai';
|
||||
import { PGVectorStore } from '@langchain/community/vectorstores/pgvector';
|
||||
import { Document } from 'langchain/document';
|
||||
import { Pool } from 'pg';
|
||||
|
||||
export interface VectorSearchResult {
|
||||
@@ -72,6 +71,41 @@ export class PgVectorService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reinitialize OpenAI embeddings with new model
|
||||
*/
|
||||
private async reinitializeEmbeddings(modelName: string): Promise<void> {
|
||||
try {
|
||||
const apiKey = process.env.OPENAI_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error('OPENAI_API_KEY not found');
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Reinitializing OpenAI embeddings with model: ${modelName}`,
|
||||
);
|
||||
|
||||
// Create new embeddings instance with new model
|
||||
this.embeddings = new OpenAIEmbeddings({
|
||||
openAIApiKey: apiKey,
|
||||
modelName: modelName,
|
||||
maxConcurrency: 5,
|
||||
});
|
||||
|
||||
// Update environment variable to reflect current model
|
||||
process.env.OPENAI_API_MODEL = modelName;
|
||||
|
||||
this.logger.log(
|
||||
`OpenAI embeddings reinitialized successfully with model: ${modelName}`,
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.error('Failed to reinitialize OpenAI embeddings:', error);
|
||||
throw new Error(
|
||||
`Failed to reinitialize OpenAI embeddings: ${error.message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize pgvector store dengan LangChain
|
||||
*/
|
||||
@@ -115,25 +149,35 @@ export class PgVectorService {
|
||||
/**
|
||||
* Generate embedding untuk text menggunakan OpenAI
|
||||
*/
|
||||
async generateEmbedding(
|
||||
text: string,
|
||||
model: string = 'text-embedding-ada-002',
|
||||
): Promise<number[]> {
|
||||
async generateEmbedding(text: string, model?: string): Promise<number[]> {
|
||||
try {
|
||||
// Get model from parameter, environment variable, or use default
|
||||
const embeddingModel =
|
||||
model || process.env.OPENAI_API_MODEL || 'text-embedding-ada-002';
|
||||
|
||||
this.logger.log(
|
||||
`Generating embedding for text: ${text.substring(0, 100)}...`,
|
||||
`Generating embedding for text: ${text.substring(0, 100)}... using model: ${embeddingModel}`,
|
||||
);
|
||||
|
||||
// Check if we need to reinitialize embeddings with new model
|
||||
const currentModel = this.getEmbeddingModel();
|
||||
if (model && model !== currentModel) {
|
||||
this.logger.log(
|
||||
`Switching embedding model from ${currentModel} to ${model}`,
|
||||
);
|
||||
await this.reinitializeEmbeddings(model);
|
||||
}
|
||||
|
||||
if (!this.embeddings) {
|
||||
throw new Error(
|
||||
'OpenAI embeddings not initialized. Please check your API configuration.',
|
||||
);
|
||||
}
|
||||
|
||||
// Use OpenAI embeddings
|
||||
// Use OpenAI embeddings with current model
|
||||
const embedding = await this.embeddings.embedQuery(text);
|
||||
this.logger.log(
|
||||
`Generated OpenAI embedding with ${embedding.length} dimensions`,
|
||||
`Generated OpenAI embedding with ${embedding.length} dimensions using model: ${this.getEmbeddingModel()}`,
|
||||
);
|
||||
return embedding;
|
||||
} catch (error) {
|
||||
@@ -191,7 +235,8 @@ export class PgVectorService {
|
||||
`UPDATE icd_codes
|
||||
SET embedding = $1::vector,
|
||||
metadata = $2::jsonb,
|
||||
content = $3
|
||||
content = $3,
|
||||
"updatedAt" = NOW()
|
||||
WHERE id = $4`,
|
||||
[
|
||||
vectorString,
|
||||
@@ -289,7 +334,8 @@ export class PgVectorService {
|
||||
`UPDATE icd_codes
|
||||
SET embedding = $1::vector,
|
||||
metadata = $2::jsonb,
|
||||
content = $3
|
||||
content = $3,
|
||||
"updatedAt" = NOW()
|
||||
WHERE id = $4`,
|
||||
[
|
||||
vectorString,
|
||||
@@ -337,16 +383,23 @@ export class PgVectorService {
|
||||
}
|
||||
|
||||
/**
|
||||
* Vector similarity search menggunakan pgvector
|
||||
* Vector similarity search menggunakan pgvector dengan threshold yang dapat dikonfigurasi
|
||||
*/
|
||||
async vectorSearch(
|
||||
query: string,
|
||||
limit: number = 10,
|
||||
category?: string,
|
||||
threshold: number = 0.7,
|
||||
threshold?: number,
|
||||
): Promise<VectorSearchResult[]> {
|
||||
// Get threshold from environment variable or use default
|
||||
const defaultThreshold = parseFloat(
|
||||
process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85',
|
||||
);
|
||||
const similarityThreshold = threshold || defaultThreshold;
|
||||
try {
|
||||
this.logger.log(`Performing pgvector search for: ${query}`);
|
||||
this.logger.log(
|
||||
`Performing pgvector search for: ${query} with threshold: ${similarityThreshold}`,
|
||||
);
|
||||
|
||||
if (!this.embeddings) {
|
||||
throw new Error('OpenAI embeddings not initialized');
|
||||
@@ -358,17 +411,19 @@ export class PgVectorService {
|
||||
// Convert embedding array to proper vector format for pgvector
|
||||
const vectorString = `[${queryEmbedding.join(',')}]`;
|
||||
|
||||
// Build SQL query for vector similarity search
|
||||
// Build SQL query for vector similarity search with higher precision
|
||||
// Using cosine distance and converting to similarity score
|
||||
let sql = `
|
||||
SELECT
|
||||
id, code, display, version, category,
|
||||
1 - (embedding <=> $1::vector) as similarity
|
||||
(1 - (embedding <=> $1::vector)) as similarity
|
||||
FROM icd_codes
|
||||
WHERE embedding IS NOT NULL
|
||||
AND (1 - (embedding <=> $1::vector)) >= $2
|
||||
`;
|
||||
|
||||
const params: any[] = [vectorString];
|
||||
let paramIndex = 2;
|
||||
const params: any[] = [vectorString, similarityThreshold];
|
||||
let paramIndex = 3;
|
||||
|
||||
if (category) {
|
||||
sql += ` AND category = $${paramIndex}`;
|
||||
@@ -376,23 +431,24 @@ export class PgVectorService {
|
||||
paramIndex++;
|
||||
}
|
||||
|
||||
sql += ` ORDER BY embedding <=> $1::vector ASC LIMIT $${paramIndex}`;
|
||||
// Order by similarity descending and limit results
|
||||
sql += ` ORDER BY similarity DESC LIMIT $${paramIndex}`;
|
||||
params.push(limit);
|
||||
|
||||
// Execute raw SQL query
|
||||
const result = await this.pool.query(sql, params);
|
||||
|
||||
// Transform and filter results
|
||||
const filteredResults: VectorSearchResult[] = result.rows
|
||||
.filter((row: any) => row.similarity >= threshold)
|
||||
.map((row: any) => ({
|
||||
// Transform results (no need to filter again since SQL already filters)
|
||||
const filteredResults: VectorSearchResult[] = result.rows.map(
|
||||
(row: any) => ({
|
||||
id: row.id,
|
||||
code: row.code,
|
||||
display: row.display,
|
||||
version: row.version,
|
||||
category: row.category,
|
||||
similarity: parseFloat(row.similarity),
|
||||
}));
|
||||
}),
|
||||
);
|
||||
|
||||
this.logger.log(
|
||||
`Pgvector search returned ${filteredResults.length} results for query: "${query}"`,
|
||||
@@ -417,12 +473,12 @@ export class PgVectorService {
|
||||
try {
|
||||
this.logger.log(`Performing hybrid search for: ${query}`);
|
||||
|
||||
// Get vector search results
|
||||
// Get vector search results with higher threshold
|
||||
const vectorResults = await this.vectorSearch(
|
||||
query,
|
||||
limit * 2,
|
||||
category,
|
||||
0.5,
|
||||
parseFloat(process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85'),
|
||||
);
|
||||
|
||||
// Get text search results
|
||||
@@ -475,7 +531,7 @@ export class PgVectorService {
|
||||
try {
|
||||
let sql = 'SELECT id, code, display, version, category FROM icd_codes';
|
||||
const params: any[] = [];
|
||||
let whereConditions: string[] = [];
|
||||
const whereConditions: string[] = [];
|
||||
let paramIndex = 1;
|
||||
|
||||
if (category) {
|
||||
@@ -515,6 +571,51 @@ export class PgVectorService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current similarity threshold configuration
|
||||
*/
|
||||
getSimilarityThreshold(): number {
|
||||
return parseFloat(process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current embedding model configuration
|
||||
*/
|
||||
getEmbeddingModel(): string {
|
||||
return process.env.OPENAI_API_MODEL || 'text-embedding-ada-002';
|
||||
}
|
||||
|
||||
/**
|
||||
* Set similarity threshold (for runtime configuration)
|
||||
*/
|
||||
setSimilarityThreshold(threshold: number): void {
|
||||
if (threshold < 0 || threshold > 1) {
|
||||
throw new Error('Similarity threshold must be between 0 and 1');
|
||||
}
|
||||
process.env.VECTOR_SIMILARITY_THRESHOLD = threshold.toString();
|
||||
this.logger.log(`Similarity threshold updated to: ${threshold}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set embedding model (for runtime configuration)
|
||||
*/
|
||||
async setEmbeddingModel(modelName: string): Promise<void> {
|
||||
if (!modelName || typeof modelName !== 'string') {
|
||||
throw new Error('Model name must be a valid string');
|
||||
}
|
||||
|
||||
const currentModel = this.getEmbeddingModel();
|
||||
if (modelName === currentModel) {
|
||||
this.logger.log(`Model ${modelName} is already active`);
|
||||
return;
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Switching embedding model from ${currentModel} to ${modelName}`,
|
||||
);
|
||||
await this.reinitializeEmbeddings(modelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding statistics
|
||||
*/
|
||||
@@ -524,6 +625,7 @@ export class PgVectorService {
|
||||
withoutEmbeddings: number;
|
||||
percentage: number;
|
||||
vectorStoreStatus: string;
|
||||
currentThreshold: number;
|
||||
}> {
|
||||
try {
|
||||
// Use raw SQL to get embedding statistics
|
||||
@@ -548,6 +650,7 @@ export class PgVectorService {
|
||||
withoutEmbeddings,
|
||||
percentage: Math.round(percentage * 100) / 100,
|
||||
vectorStoreStatus,
|
||||
currentThreshold: this.getSimilarityThreshold(),
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error('Error getting embedding stats:', error);
|
||||
@@ -569,6 +672,95 @@ export class PgVectorService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Advanced vector search dengan multiple similarity metrics untuk mendapatkan hasil yang lebih akurat
|
||||
*/
|
||||
async advancedVectorSearch(
|
||||
query: string,
|
||||
limit: number = 10,
|
||||
category?: string,
|
||||
threshold?: number,
|
||||
): Promise<VectorSearchResult[]> {
|
||||
try {
|
||||
// Get threshold from environment variable or use default
|
||||
const defaultThreshold = parseFloat(
|
||||
process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85',
|
||||
);
|
||||
const similarityThreshold = threshold || defaultThreshold;
|
||||
|
||||
this.logger.log(
|
||||
`Performing advanced vector search for: ${query} with threshold: ${similarityThreshold}`,
|
||||
);
|
||||
|
||||
if (!this.embeddings) {
|
||||
throw new Error('OpenAI embeddings not initialized');
|
||||
}
|
||||
|
||||
// Generate embedding for query
|
||||
const queryEmbedding = await this.generateEmbedding(query);
|
||||
const vectorString = `[${queryEmbedding.join(',')}]`;
|
||||
|
||||
// Advanced SQL query using multiple similarity metrics
|
||||
let sql = `
|
||||
SELECT
|
||||
id, code, display, version, category,
|
||||
(1 - (embedding <=> $1::vector)) as cosine_similarity,
|
||||
(1 - (embedding <-> $1::vector)) as euclidean_similarity,
|
||||
(embedding <#> $1::vector) as negative_inner_product
|
||||
FROM icd_codes
|
||||
WHERE embedding IS NOT NULL
|
||||
`;
|
||||
|
||||
const params: any[] = [vectorString];
|
||||
let paramIndex = 2;
|
||||
|
||||
if (category) {
|
||||
sql += ` AND category = $${paramIndex}`;
|
||||
params.push(category);
|
||||
paramIndex++;
|
||||
}
|
||||
|
||||
// Filter by cosine similarity threshold
|
||||
sql += ` AND (1 - (embedding <=> $1::vector)) >= $${paramIndex}`;
|
||||
params.push(similarityThreshold);
|
||||
paramIndex++;
|
||||
|
||||
// Order by combined similarity score and limit
|
||||
sql += ` ORDER BY cosine_similarity DESC, euclidean_similarity DESC LIMIT $${paramIndex}`;
|
||||
params.push(limit);
|
||||
|
||||
const result = await this.pool.query(sql, params);
|
||||
|
||||
// Transform results with enhanced similarity scoring
|
||||
const filteredResults: VectorSearchResult[] = result.rows.map(
|
||||
(row: any) => {
|
||||
const cosineSim = parseFloat(row.cosine_similarity);
|
||||
const euclideanSim = parseFloat(row.euclidean_similarity);
|
||||
|
||||
// Calculate combined similarity score (weighted average)
|
||||
const combinedSimilarity = cosineSim * 0.7 + euclideanSim * 0.3;
|
||||
|
||||
return {
|
||||
id: row.id,
|
||||
code: row.code,
|
||||
display: row.display,
|
||||
version: row.version,
|
||||
category: row.category,
|
||||
similarity: Math.round(combinedSimilarity * 1000) / 1000, // Round to 3 decimal places
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
this.logger.log(
|
||||
`Advanced vector search returned ${filteredResults.length} results for query: "${query}" with threshold: ${similarityThreshold}`,
|
||||
);
|
||||
return filteredResults;
|
||||
} catch (error) {
|
||||
this.logger.error('Error in advanced vector search:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get vector store status
|
||||
*/
|
||||
@@ -589,9 +781,10 @@ export class PgVectorService {
|
||||
initialized: !!this.vectorStore,
|
||||
documentCount,
|
||||
embeddingModel: this.embeddings
|
||||
? `OpenAI ${process.env.OPENAI_API_MODEL || 'text-embedding-ada-002'}`
|
||||
? `OpenAI ${this.getEmbeddingModel()}`
|
||||
: 'Not Available',
|
||||
lastUpdated: new Date(),
|
||||
currentThreshold: this.getSimilarityThreshold(),
|
||||
};
|
||||
|
||||
return status;
|
||||
@@ -601,6 +794,164 @@ export class PgVectorService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create enhanced text representation for better embedding quality
|
||||
*/
|
||||
private createEnhancedTextRepresentation(code: any): string {
|
||||
// Base text with code and display
|
||||
let text = `${code.code} ${code.display}`;
|
||||
|
||||
// Add category context
|
||||
if (code.category === 'ICD9') {
|
||||
text += ` ICD-9 CM procedure diagnosis`;
|
||||
} else if (code.category === 'ICD10') {
|
||||
text += ` ICD-10 diagnosis condition`;
|
||||
}
|
||||
|
||||
// Add version context
|
||||
if (code.version) {
|
||||
text += ` ${code.version}`;
|
||||
}
|
||||
|
||||
// Add medical context based on display content
|
||||
const display = code.display.toLowerCase();
|
||||
|
||||
// Add procedure context
|
||||
if (
|
||||
display.includes('procedure') ||
|
||||
display.includes('surgery') ||
|
||||
display.includes('operation')
|
||||
) {
|
||||
text += ' medical procedure surgical intervention';
|
||||
}
|
||||
|
||||
// Add diagnosis context
|
||||
if (
|
||||
display.includes('diagnosis') ||
|
||||
display.includes('condition') ||
|
||||
display.includes('disease')
|
||||
) {
|
||||
text += ' medical diagnosis clinical condition';
|
||||
}
|
||||
|
||||
// Add anatomical context
|
||||
if (
|
||||
display.includes('cranial') ||
|
||||
display.includes('brain') ||
|
||||
display.includes('head')
|
||||
) {
|
||||
text += ' neurological cranial brain head';
|
||||
}
|
||||
|
||||
if (
|
||||
display.includes('cardiac') ||
|
||||
display.includes('heart') ||
|
||||
display.includes('cardiovascular')
|
||||
) {
|
||||
text += ' cardiac heart cardiovascular';
|
||||
}
|
||||
|
||||
if (
|
||||
display.includes('pulmonary') ||
|
||||
display.includes('lung') ||
|
||||
display.includes('respiratory')
|
||||
) {
|
||||
text += ' pulmonary lung respiratory';
|
||||
}
|
||||
|
||||
// Add common medical terms
|
||||
text += ' medical healthcare clinical';
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate embeddings with enhanced text representation for better similarity
|
||||
*/
|
||||
async regenerateEmbeddingsWithEnhancedText(limit: number = 100): Promise<{
|
||||
processed: number;
|
||||
errors: number;
|
||||
totalSample: number;
|
||||
}> {
|
||||
try {
|
||||
this.logger.log(
|
||||
`Starting enhanced embedding regeneration for ${limit} ICD codes...`,
|
||||
);
|
||||
|
||||
// Get ICD codes with existing embeddings to regenerate
|
||||
const codesWithEmbeddings = await this.pool.query(
|
||||
'SELECT id, code, display, version, category FROM icd_codes WHERE embedding IS NOT NULL LIMIT $1',
|
||||
[limit],
|
||||
);
|
||||
|
||||
if (codesWithEmbeddings.rows.length === 0) {
|
||||
this.logger.log('No ICD codes found with embeddings to regenerate');
|
||||
return { processed: 0, errors: 0, totalSample: 0 };
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Found ${codesWithEmbeddings.rows.length} codes to regenerate with enhanced text`,
|
||||
);
|
||||
|
||||
let processed = 0;
|
||||
let errors = 0;
|
||||
|
||||
// Process each code
|
||||
for (let i = 0; i < codesWithEmbeddings.rows.length; i++) {
|
||||
const code = codesWithEmbeddings.rows[i];
|
||||
try {
|
||||
// Create enhanced text representation for better embedding quality
|
||||
const text = this.createEnhancedTextRepresentation(code);
|
||||
|
||||
// Generate new embedding
|
||||
const embedding = await this.generateEmbedding(text);
|
||||
|
||||
// Convert embedding array to proper vector format for pgvector
|
||||
const vectorString = `[${embedding.join(',')}]`;
|
||||
|
||||
// Update database with new embedding and enhanced content
|
||||
await this.pool.query(
|
||||
`UPDATE icd_codes
|
||||
SET embedding = $1::vector,
|
||||
content = $2,
|
||||
"updatedAt" = NOW()
|
||||
WHERE id = $3`,
|
||||
[vectorString, text, code.id],
|
||||
);
|
||||
|
||||
processed++;
|
||||
|
||||
if (processed % 10 === 0) {
|
||||
this.logger.log(
|
||||
`Regenerated ${processed}/${codesWithEmbeddings.rows.length} enhanced embeddings`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Error regenerating embedding for code ${code.code}:`,
|
||||
error,
|
||||
);
|
||||
errors++;
|
||||
}
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`Enhanced embedding regeneration completed. Processed: ${processed}, Errors: ${errors}, Total: ${codesWithEmbeddings.rows.length}`,
|
||||
);
|
||||
return {
|
||||
processed,
|
||||
errors,
|
||||
totalSample: codesWithEmbeddings.rows.length,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
'Error in regenerateEmbeddingsWithEnhancedText:',
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup resources
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user